Compare commits

..

1 Commits

Author SHA1 Message Date
Teknium 61c1fd08cc fix(api-server): share one Docker container across all API conversations
The API server's _run_agent() was not passing task_id to
run_conversation(), causing a fresh random UUID per request. This meant
every Open WebUI message spun up a new Docker container and tore it down
afterward — making persistent filesystem state impossible.

Two fixes:

1. Pass task_id="default" so all API server conversations share the same
   Docker container (matching the design intent: one configured Docker
   environment, always the same container).

2. Derive a stable session_id from the system prompt + first user message
   hash instead of uuid4(). This stops hermes sessions list from being
   polluted with single-message throwaway sessions.

Fixes #3438.
2026-04-10 03:56:33 -07:00
210 changed files with 3781 additions and 10539 deletions
-3
View File
@@ -36,7 +36,6 @@ from acp.schema import (
SessionCapabilities,
SessionForkCapabilities,
SessionListCapabilities,
SessionResumeCapabilities,
SessionInfo,
TextContentBlock,
UnstructuredCommandInput,
@@ -246,11 +245,9 @@ class HermesACPAgent(acp.Agent):
protocol_version=acp.PROTOCOL_VERSION,
agent_info=Implementation(name="hermes-agent", version=HERMES_VERSION),
agent_capabilities=AgentCapabilities(
load_session=True,
session_capabilities=SessionCapabilities(
fork=SessionForkCapabilities(),
list=SessionListCapabilities(),
resume=SessionResumeCapabilities(),
),
),
auth_methods=auth_methods,
+77
View File
@@ -511,6 +511,35 @@ def _prefer_refreshable_claude_code_token(env_token: str, creds: Optional[Dict[s
return None
def get_anthropic_token_source(token: Optional[str] = None) -> str:
"""Best-effort source classification for an Anthropic credential token."""
token = (token or "").strip()
if not token:
return "none"
env_token = os.getenv("ANTHROPIC_TOKEN", "").strip()
if env_token and env_token == token:
return "anthropic_token_env"
cc_env_token = os.getenv("CLAUDE_CODE_OAUTH_TOKEN", "").strip()
if cc_env_token and cc_env_token == token:
return "claude_code_oauth_token_env"
creds = read_claude_code_credentials()
if creds and creds.get("accessToken") == token:
return str(creds.get("source") or "claude_code_credentials")
managed_key = read_claude_managed_key()
if managed_key and managed_key == token:
return "claude_json_primary_api_key"
api_key = os.getenv("ANTHROPIC_API_KEY", "").strip()
if api_key and api_key == token:
return "anthropic_api_key_env"
return "unknown"
def resolve_anthropic_token() -> Optional[str]:
"""Resolve an Anthropic token from all available sources.
@@ -717,6 +746,21 @@ def run_hermes_oauth_login_pure() -> Optional[Dict[str, Any]]:
}
def _save_hermes_oauth_credentials(access_token: str, refresh_token: str, expires_at_ms: int) -> None:
"""Save OAuth credentials to ~/.hermes/.anthropic_oauth.json."""
data = {
"accessToken": access_token,
"refreshToken": refresh_token,
"expiresAt": expires_at_ms,
}
try:
_HERMES_OAUTH_FILE.parent.mkdir(parents=True, exist_ok=True)
_HERMES_OAUTH_FILE.write_text(json.dumps(data, indent=2), encoding="utf-8")
_HERMES_OAUTH_FILE.chmod(0o600)
except (OSError, IOError) as e:
logger.debug("Failed to save Hermes OAuth credentials: %s", e)
def read_hermes_oauth_credentials() -> Optional[Dict[str, Any]]:
"""Read Hermes-managed OAuth credentials from ~/.hermes/.anthropic_oauth.json."""
if _HERMES_OAUTH_FILE.exists():
@@ -765,6 +809,39 @@ def _sanitize_tool_id(tool_id: str) -> str:
return sanitized or "tool_0"
def _convert_openai_image_part_to_anthropic(part: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""Convert an OpenAI-style image block to Anthropic's image source format."""
image_data = part.get("image_url", {})
url = image_data.get("url", "") if isinstance(image_data, dict) else str(image_data)
if not isinstance(url, str) or not url.strip():
return None
url = url.strip()
if url.startswith("data:"):
header, sep, data = url.partition(",")
if sep and ";base64" in header:
media_type = header[5:].split(";", 1)[0] or "image/png"
return {
"type": "image",
"source": {
"type": "base64",
"media_type": media_type,
"data": data,
},
}
if url.startswith(("http://", "https://")):
return {
"type": "image",
"source": {
"type": "url",
"url": url,
},
}
return None
def convert_tools_to_anthropic(tools: List[Dict]) -> List[Dict]:
"""Convert OpenAI tool definitions to Anthropic format."""
if not tools:
+79 -68
View File
@@ -687,15 +687,6 @@ def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
if pconfig.auth_type != "api_key":
continue
if provider_id == "anthropic":
# Only try anthropic when the user has explicitly configured it.
# Without this gate, Claude Code credentials get silently used
# as auxiliary fallback when the user's primary provider fails.
try:
from hermes_cli.auth import is_provider_explicitly_configured
if not is_provider_explicitly_configured("anthropic"):
continue
except ImportError:
pass
return _try_anthropic()
pool_present, entry = _select_pool_entry(provider_id)
@@ -857,7 +848,7 @@ def _read_main_provider() -> str:
return ""
def _resolve_custom_runtime() -> Tuple[Optional[str], Optional[str], Optional[str]]:
def _resolve_custom_runtime() -> Tuple[Optional[str], Optional[str]]:
"""Resolve the active custom/main endpoint the same way the main CLI does.
This covers both env-driven OPENAI_BASE_URL setups and config-saved custom
@@ -870,29 +861,18 @@ def _resolve_custom_runtime() -> Tuple[Optional[str], Optional[str], Optional[st
runtime = resolve_runtime_provider(requested="custom")
except Exception as exc:
logger.debug("Auxiliary client: custom runtime resolution failed: %s", exc)
runtime = None
if not isinstance(runtime, dict):
openai_base = os.getenv("OPENAI_BASE_URL", "").strip().rstrip("/")
openai_key = os.getenv("OPENAI_API_KEY", "").strip()
if not openai_base:
return None, None, None
runtime = {
"base_url": openai_base,
"api_key": openai_key,
}
return None, None
custom_base = runtime.get("base_url")
custom_key = runtime.get("api_key")
custom_mode = runtime.get("api_mode")
if not isinstance(custom_base, str) or not custom_base.strip():
return None, None, None
return None, None
custom_base = custom_base.strip().rstrip("/")
if "openrouter.ai" in custom_base.lower():
# requested='custom' falls back to OpenRouter when no custom endpoint is
# configured. Treat that as "no custom endpoint" for auxiliary routing.
return None, None, None
return None, None
# Local servers (Ollama, llama.cpp, vLLM, LM Studio) don't require auth.
# Use a placeholder key — the OpenAI SDK requires a non-empty string but
@@ -901,33 +881,20 @@ def _resolve_custom_runtime() -> Tuple[Optional[str], Optional[str], Optional[st
if not isinstance(custom_key, str) or not custom_key.strip():
custom_key = "no-key-required"
if not isinstance(custom_mode, str) or not custom_mode.strip():
custom_mode = None
return custom_base, custom_key.strip(), custom_mode
return custom_base, custom_key.strip()
def _current_custom_base_url() -> str:
custom_base, _, _ = _resolve_custom_runtime()
custom_base, _ = _resolve_custom_runtime()
return custom_base or ""
def _try_custom_endpoint() -> Tuple[Optional[OpenAI], Optional[str]]:
runtime = _resolve_custom_runtime()
if len(runtime) == 2:
custom_base, custom_key = runtime
custom_mode = None
else:
custom_base, custom_key, custom_mode = runtime
custom_base, custom_key = _resolve_custom_runtime()
if not custom_base or not custom_key:
return None, None
if custom_base.lower().startswith(_CODEX_AUX_BASE_URL.lower()):
return None, None
model = _read_main_model() or "gpt-4o-mini"
logger.debug("Auxiliary client: custom endpoint (%s, api_mode=%s)", model, custom_mode or "chat_completions")
if custom_mode == "codex_responses":
real_client = OpenAI(api_key=custom_key, base_url=custom_base)
return CodexAuxiliaryClient(real_client, model), model
logger.debug("Auxiliary client: custom endpoint (%s)", model)
return OpenAI(api_key=custom_key, base_url=custom_base), model
@@ -1000,6 +967,40 @@ def _try_anthropic() -> Tuple[Optional[Any], Optional[str]]:
return AnthropicAuxiliaryClient(real_client, model, token, base_url, is_oauth=is_oauth), model
def _resolve_forced_provider(forced: str) -> Tuple[Optional[OpenAI], Optional[str]]:
"""Resolve a specific forced provider. Returns (None, None) if creds missing."""
if forced == "openrouter":
client, model = _try_openrouter()
if client is None:
logger.warning("auxiliary.provider=openrouter but OPENROUTER_API_KEY not set")
return client, model
if forced == "nous":
client, model = _try_nous()
if client is None:
logger.warning("auxiliary.provider=nous but Nous Portal not configured (run: hermes auth)")
return client, model
if forced == "codex":
client, model = _try_codex()
if client is None:
logger.warning("auxiliary.provider=codex but no Codex OAuth token found (run: hermes model)")
return client, model
if forced == "main":
# "main" = skip OpenRouter/Nous, use the main chat model's credentials.
for try_fn in (_try_custom_endpoint, _try_codex, _resolve_api_key_provider):
client, model = try_fn()
if client is not None:
return client, model
logger.warning("auxiliary.provider=main but no main endpoint credentials found")
return None, None
# Unknown provider name — fall through to auto
logger.warning("Unknown auxiliary.provider=%r, falling back to auto", forced)
return None, None
_AUTO_PROVIDER_LABELS = {
"_try_openrouter": "openrouter",
"_try_nous": "nous",
@@ -1198,18 +1199,6 @@ def _to_async_client(sync_client, model: str):
return AsyncOpenAI(**async_kwargs), model
def _normalize_resolved_model(model_name: Optional[str], provider: str) -> Optional[str]:
"""Normalize a resolved model for the provider that will receive it."""
if not model_name:
return model_name
try:
from hermes_cli.model_normalize import normalize_model_for_provider
return normalize_model_for_provider(model_name, provider)
except Exception:
return model_name
def resolve_provider_client(
provider: str,
model: str = None,
@@ -1272,7 +1261,7 @@ def resolve_provider_client(
logger.warning("resolve_provider_client: openrouter requested "
"but OPENROUTER_API_KEY not set")
return None, None
final_model = _normalize_resolved_model(model or default, provider)
final_model = model or default
return (_to_async_client(client, final_model) if async_mode
else (client, final_model))
@@ -1283,7 +1272,7 @@ def resolve_provider_client(
logger.warning("resolve_provider_client: nous requested "
"but Nous Portal not configured (run: hermes auth)")
return None, None
final_model = _normalize_resolved_model(model or default, provider)
final_model = model or default
return (_to_async_client(client, final_model) if async_mode
else (client, final_model))
@@ -1297,7 +1286,7 @@ def resolve_provider_client(
logger.warning("resolve_provider_client: openai-codex requested "
"but no Codex OAuth token found (run: hermes model)")
return None, None
final_model = _normalize_resolved_model(model or _CODEX_AUX_MODEL, provider)
final_model = model or _CODEX_AUX_MODEL
raw_client = OpenAI(api_key=codex_token, base_url=_CODEX_AUX_BASE_URL)
return (raw_client, final_model)
# Standard path: wrap in CodexAuxiliaryClient adapter
@@ -1306,7 +1295,7 @@ def resolve_provider_client(
logger.warning("resolve_provider_client: openai-codex requested "
"but no Codex OAuth token found (run: hermes model)")
return None, None
final_model = _normalize_resolved_model(model or default, provider)
final_model = model or default
return (_to_async_client(client, final_model) if async_mode
else (client, final_model))
@@ -1325,10 +1314,7 @@ def resolve_provider_client(
"but base_url is empty"
)
return None, None
final_model = _normalize_resolved_model(
model or _read_main_model() or "gpt-4o-mini",
provider,
)
final_model = model or _read_main_model() or "gpt-4o-mini"
extra = {}
if "api.kimi.com" in custom_base.lower():
extra["default_headers"] = {"User-Agent": "KimiCLI/1.30.0"}
@@ -1343,7 +1329,7 @@ def resolve_provider_client(
_resolve_api_key_provider):
client, default = try_fn()
if client is not None:
final_model = _normalize_resolved_model(model or default, provider)
final_model = model or default
return (_to_async_client(client, final_model) if async_mode
else (client, final_model))
logger.warning("resolve_provider_client: custom/main requested "
@@ -1358,10 +1344,7 @@ def resolve_provider_client(
custom_base = custom_entry.get("base_url", "").strip()
custom_key = custom_entry.get("api_key", "").strip() or "no-key-required"
if custom_base:
final_model = _normalize_resolved_model(
model or _read_main_model() or "gpt-4o-mini",
provider,
)
final_model = model or _read_main_model() or "gpt-4o-mini"
client = OpenAI(api_key=custom_key, base_url=custom_base)
logger.debug(
"resolve_provider_client: named custom provider %r (%s)",
@@ -1393,7 +1376,7 @@ def resolve_provider_client(
if client is None:
logger.warning("resolve_provider_client: anthropic requested but no Anthropic credentials found")
return None, None
final_model = _normalize_resolved_model(model or default_model, provider)
final_model = model or default_model
return (_to_async_client(client, final_model) if async_mode else (client, final_model))
creds = resolve_api_key_provider_credentials(provider)
@@ -1412,7 +1395,7 @@ def resolve_provider_client(
)
default_model = _API_KEY_PROVIDER_AUX_MODELS.get(provider, "")
final_model = _normalize_resolved_model(model or default_model, provider)
final_model = model or default_model
# Provider-specific headers
headers = {}
@@ -1512,6 +1495,22 @@ def _strict_vision_backend_available(provider: str) -> bool:
return _resolve_strict_vision_backend(provider)[0] is not None
def _preferred_main_vision_provider() -> Optional[str]:
"""Return the selected main provider when it is also a supported vision backend."""
try:
from hermes_cli.config import load_config
config = load_config()
model_cfg = config.get("model", {})
if isinstance(model_cfg, dict):
provider = _normalize_vision_provider(model_cfg.get("provider", ""))
if provider in _VISION_AUTO_PROVIDER_ORDER:
return provider
except Exception:
pass
return None
def get_available_vision_backends() -> List[str]:
"""Return the currently available vision backends in auto-selection order.
@@ -1625,6 +1624,18 @@ def resolve_vision_provider_client(
return requested, client, final_model
def get_vision_auxiliary_client() -> Tuple[Optional[OpenAI], Optional[str]]:
"""Return (client, default_model_slug) for vision/multimodal auxiliary tasks."""
_, client, final_model = resolve_vision_provider_client(async_mode=False)
return client, final_model
def get_async_vision_auxiliary_client():
"""Return (async_client, model_slug) for async vision consumers."""
_, client, final_model = resolve_vision_provider_client(async_mode=True)
return client, final_model
def get_auxiliary_extra_body() -> dict:
"""Return extra_body kwargs for auxiliary API calls.
+114
View File
@@ -0,0 +1,114 @@
"""BuiltinMemoryProvider — wraps MEMORY.md / USER.md as a MemoryProvider.
Always registered as the first provider. Cannot be disabled or removed.
This is the existing Hermes memory system exposed through the provider
interface for compatibility with the MemoryManager.
The actual storage logic lives in tools/memory_tool.py (MemoryStore).
This provider is a thin adapter that delegates to MemoryStore and
exposes the memory tool schema.
"""
from __future__ import annotations
import json
import logging
from typing import Any, Dict, List
from agent.memory_provider import MemoryProvider
from tools.registry import tool_error
logger = logging.getLogger(__name__)
class BuiltinMemoryProvider(MemoryProvider):
"""Built-in file-backed memory (MEMORY.md + USER.md).
Always active, never disabled by other providers. The `memory` tool
is handled by run_agent.py's agent-level tool interception (not through
the normal registry), so get_tool_schemas() returns an empty list —
the memory tool is already wired separately.
"""
def __init__(
self,
memory_store=None,
memory_enabled: bool = False,
user_profile_enabled: bool = False,
):
self._store = memory_store
self._memory_enabled = memory_enabled
self._user_profile_enabled = user_profile_enabled
@property
def name(self) -> str:
return "builtin"
def is_available(self) -> bool:
"""Built-in memory is always available."""
return True
def initialize(self, session_id: str, **kwargs) -> None:
"""Load memory from disk if not already loaded."""
if self._store is not None:
self._store.load_from_disk()
def system_prompt_block(self) -> str:
"""Return MEMORY.md and USER.md content for the system prompt.
Uses the frozen snapshot captured at load time. This ensures the
system prompt stays stable throughout a session (preserving the
prompt cache), even though the live entries may change via tool calls.
"""
if not self._store:
return ""
parts = []
if self._memory_enabled:
mem_block = self._store.format_for_system_prompt("memory")
if mem_block:
parts.append(mem_block)
if self._user_profile_enabled:
user_block = self._store.format_for_system_prompt("user")
if user_block:
parts.append(user_block)
return "\n\n".join(parts)
def prefetch(self, query: str, *, session_id: str = "") -> str:
"""Built-in memory doesn't do query-based recall — it's injected via system_prompt_block."""
return ""
def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None:
"""Built-in memory doesn't auto-sync turns — writes happen via the memory tool."""
def get_tool_schemas(self) -> List[Dict[str, Any]]:
"""Return empty list.
The `memory` tool is an agent-level intercepted tool, handled
specially in run_agent.py before normal tool dispatch. It's not
part of the standard tool registry. We don't duplicate it here.
"""
return []
def handle_tool_call(self, tool_name: str, args: Dict[str, Any], **kwargs) -> str:
"""Not used — the memory tool is intercepted in run_agent.py."""
return tool_error("Built-in memory tool is handled by the agent loop")
def shutdown(self) -> None:
"""No cleanup needed — files are saved on every write."""
# -- Property access for backward compatibility --------------------------
@property
def store(self):
"""Access the underlying MemoryStore for legacy code paths."""
return self._store
@property
def memory_enabled(self) -> bool:
return self._memory_enabled
@property
def user_profile_enabled(self) -> bool:
return self._user_profile_enabled
+17
View File
@@ -114,6 +114,7 @@ class ContextCompressor:
self.last_prompt_tokens = 0
self.last_completion_tokens = 0
self.last_total_tokens = 0
self.summary_model = summary_model_override or ""
@@ -125,12 +126,28 @@ class ContextCompressor:
"""Update tracked token usage from API response."""
self.last_prompt_tokens = usage.get("prompt_tokens", 0)
self.last_completion_tokens = usage.get("completion_tokens", 0)
self.last_total_tokens = usage.get("total_tokens", 0)
def should_compress(self, prompt_tokens: int = None) -> bool:
"""Check if context exceeds the compression threshold."""
tokens = prompt_tokens if prompt_tokens is not None else self.last_prompt_tokens
return tokens >= self.threshold_tokens
def should_compress_preflight(self, messages: List[Dict[str, Any]]) -> bool:
"""Quick pre-flight check using rough estimate (before API call)."""
rough_estimate = estimate_messages_tokens_rough(messages)
return rough_estimate >= self.threshold_tokens
def get_status(self) -> Dict[str, Any]:
"""Get current compression status for display/logging."""
return {
"last_prompt_tokens": self.last_prompt_tokens,
"threshold_tokens": self.threshold_tokens,
"context_length": self.context_length,
"usage_percent": min(100, (self.last_prompt_tokens / self.context_length * 100)) if self.context_length else 0,
"compression_count": self.compression_count,
}
# ------------------------------------------------------------------
# Tool output pruning (cheap pre-pass, no LLM call)
# ------------------------------------------------------------------
+7 -36
View File
@@ -13,9 +13,8 @@ from typing import Awaitable, Callable
from agent.model_metadata import estimate_tokens_rough
_QUOTED_REFERENCE_VALUE = r'(?:`[^`\n]+`|"[^"\n]+"|\'[^\'\n]+\')'
REFERENCE_PATTERN = re.compile(
rf"(?<![\w/])@(?:(?P<simple>diff|staged)\b|(?P<kind>file|folder|git|url):(?P<value>{_QUOTED_REFERENCE_VALUE}(?::\d+(?:-\d+)?)?|\S+))"
r"(?<![\w/])@(?:(?P<simple>diff|staged)\b|(?P<kind>file|folder|git|url):(?P<value>\S+))"
)
TRAILING_PUNCTUATION = ",.;!?"
_SENSITIVE_HOME_DIRS = (".ssh", ".aws", ".gnupg", ".kube", ".docker", ".azure", ".config/gh")
@@ -82,10 +81,14 @@ def parse_context_references(message: str) -> list[ContextReference]:
value = _strip_trailing_punctuation(match.group("value") or "")
line_start = None
line_end = None
target = _strip_reference_wrappers(value)
target = value
if kind == "file":
target, line_start, line_end = _parse_file_reference_value(value)
range_match = re.match(r"^(?P<path>.+?):(?P<start>\d+)(?:-(?P<end>\d+))?$", value)
if range_match:
target = range_match.group("path")
line_start = int(range_match.group("start"))
line_end = int(range_match.group("end") or range_match.group("start"))
refs.append(
ContextReference(
@@ -372,38 +375,6 @@ def _strip_trailing_punctuation(value: str) -> str:
return stripped
def _strip_reference_wrappers(value: str) -> str:
if len(value) >= 2 and value[0] == value[-1] and value[0] in "`\"'":
return value[1:-1]
return value
def _parse_file_reference_value(value: str) -> tuple[str, int | None, int | None]:
quoted_match = re.match(
r'^(?P<quote>`|"|\')(?P<path>.+?)(?P=quote)(?::(?P<start>\d+)(?:-(?P<end>\d+))?)?$',
value,
)
if quoted_match:
line_start = quoted_match.group("start")
line_end = quoted_match.group("end")
return (
quoted_match.group("path"),
int(line_start) if line_start is not None else None,
int(line_end or line_start) if line_start is not None else None,
)
range_match = re.match(r"^(?P<path>.+?):(?P<start>\d+)(?:-(?P<end>\d+))?$", value)
if range_match:
line_start = int(range_match.group("start"))
return (
range_match.group("path"),
line_start,
int(range_match.group("end") or range_match.group("start")),
)
return _strip_reference_wrappers(value), None, None
def _remove_reference_tokens(message: str, refs: list[ContextReference]) -> str:
pieces: list[str] = []
cursor = 0
+16 -18
View File
@@ -739,6 +739,17 @@ class CredentialPool:
return False
return False
def mark_used(self, entry_id: Optional[str] = None) -> None:
"""Increment request_count for tracking. Used by least_used strategy."""
target_id = entry_id or self._current_id
if not target_id:
return
with self._lock:
for idx, entry in enumerate(self._entries):
if entry.id == target_id:
self._entries[idx] = replace(entry, request_count=entry.request_count + 1)
return
def select(self) -> Optional[PooledCredential]:
with self._lock:
return self._select_unlocked()
@@ -900,6 +911,11 @@ class CredentialPool:
else:
self._active_leases[credential_id] = count - 1
def active_lease_count(self, credential_id: str) -> int:
"""Return the number of active leases for a credential."""
with self._lock:
return self._active_leases.get(credential_id, 0)
def try_refresh_current(self) -> Optional[PooledCredential]:
with self._lock:
return self._try_refresh_current_unlocked()
@@ -1059,17 +1075,6 @@ def _seed_from_singletons(provider: str, entries: List[PooledCredential]) -> Tup
auth_store = _load_auth_store()
if provider == "anthropic":
# Only auto-discover external credentials (Claude Code, Hermes PKCE)
# when the user has explicitly configured anthropic as their provider.
# Without this gate, auxiliary client fallback chains silently read
# ~/.claude/.credentials.json without user consent. See PR #4210.
try:
from hermes_cli.auth import is_provider_explicitly_configured
if not is_provider_explicitly_configured("anthropic"):
return changed, active_sources
except ImportError:
pass
from agent.anthropic_adapter import read_claude_code_credentials, read_hermes_oauth_credentials
for source_name, creds in (
@@ -1077,13 +1082,6 @@ def _seed_from_singletons(provider: str, entries: List[PooledCredential]) -> Tup
("claude_code", read_claude_code_credentials()),
):
if creds and creds.get("accessToken"):
# Check if user explicitly removed this source
try:
from hermes_cli.auth import is_source_suppressed
if is_source_suppressed(provider, source_name):
continue
except ImportError:
pass
active_sources.add(source_name)
changed |= _upsert_entry(
entries,
+76
View File
@@ -67,6 +67,26 @@ def _get_skin():
return None
def get_skin_faces(key: str, default: list) -> list:
"""Get spinner face list from active skin, falling back to default."""
skin = _get_skin()
if skin:
faces = skin.get_spinner_list(key)
if faces:
return faces
return default
def get_skin_verbs() -> list:
"""Get thinking verbs from active skin."""
skin = _get_skin()
if skin:
verbs = skin.get_spinner_list("thinking_verbs")
if verbs:
return verbs
return KawaiiSpinner.THINKING_VERBS
def get_skin_tool_prefix() -> str:
"""Get tool output prefix character from active skin."""
skin = _get_skin()
@@ -703,6 +723,46 @@ class KawaiiSpinner:
return False
# =========================================================================
# Kawaii face arrays (used by AIAgent._execute_tool_calls for spinner text)
# =========================================================================
KAWAII_SEARCH = [
"♪(´ε` )", "(。◕‿◕。)", "ヾ(^∇^)", "(◕ᴗ◕✿)", "( ˘▽˘)っ",
"٩(◕‿◕。)۶", "(✿◠‿◠)", "♪~(´ε` )", "(ノ´ヮ`)*:・゚✧", "(◎o◎)",
]
KAWAII_READ = [
"φ(゜▽゜*)♪", "( ˘▽˘)っ", "(⌐■_■)", "٩(。•́‿•̀。)۶", "(◕‿◕✿)",
"ヾ(@⌒ー⌒@)", "(✧ω✧)", "♪(๑ᴖ◡ᴖ๑)♪", "(≧◡≦)", "( ´ ▽ ` )",
]
KAWAII_TERMINAL = [
"ヽ(>∀<☆)", "(ノ°∀°)", "٩(^ᴗ^)۶", "ヾ(⌐■_■)ノ♪", "(•̀ᴗ•́)و",
"┗(0)┓", "(`・ω・´)", "( ̄▽ ̄)", "(ง •̀_•́)ง", "ヽ(´▽`)/",
]
KAWAII_BROWSER = [
"(ノ°∀°)", "(☞゚ヮ゚)☞", "( ͡° ͜ʖ ͡°)", "┌( ಠ_ಠ)┘", "(⊙_⊙)",
"ヾ(•ω•`)o", "( ̄ω ̄)", "( ˇωˇ )", "(ᵔᴥᵔ)", "(◎o◎)",
]
KAWAII_CREATE = [
"✧*。٩(ˊᗜˋ*)و✧", "(ノ◕ヮ◕)ノ*:・゚✧", "ヽ(>∀<☆)", "٩(♡ε♡)۶", "(◕‿◕)♡",
"✿◕ ‿ ◕✿", "(*≧▽≦)", "ヾ(-)", "(☆▽☆)", "°˖✧◝(⁰▿⁰)◜✧˖°",
]
KAWAII_SKILL = [
"ヾ(@⌒ー⌒@)", "(๑˃ᴗ˂)ﻭ", "٩(◕‿◕。)۶", "(✿╹◡╹)", "ヽ(・∀・)",
"(ノ´ヮ`)*:・゚✧", "♪(๑ᴖ◡ᴖ๑)♪", "(◠‿◠)", "٩(ˊᗜˋ*)و", "(^▽^)",
"ヾ(^∇^)", "(★ω★)/", "٩(。•́‿•̀。)۶", "(◕ᴗ◕✿)", "(◎o◎)",
"(✧ω✧)", "ヽ(>∀<☆)", "( ˘▽˘)っ", "(≧◡≦) ♡", "ヾ( ̄▽ ̄)",
]
KAWAII_THINK = [
"(っ°Д°;)っ", "(;′⌒`)", "(・_・ヾ", "( ´_ゝ`)", "( ̄ヘ ̄)",
"(。-`ω´-)", "( ˘︹˘ )", "(¬_¬)", "ヽ(ー_ー )", "(;一_一)",
]
KAWAII_GENERIC = [
"♪(´ε` )", "(◕‿◕✿)", "ヾ(^∇^)", "٩(◕‿◕。)۶", "(✿◠‿◠)",
"(ノ´ヮ`)*:・゚✧", "ヽ(>∀<☆)", "(☆▽☆)", "( ˘▽˘)っ", "(≧◡≦)",
]
# =========================================================================
# Cute tool message (completion line that replaces the spinner)
# =========================================================================
@@ -910,6 +970,22 @@ _SKY_BLUE = "\033[38;5;117m"
_ANSI_RESET = "\033[0m"
def honcho_session_url(workspace: str, session_name: str) -> str:
"""Build a Honcho app URL for a session."""
from urllib.parse import quote
return (
f"https://app.honcho.dev/explore"
f"?workspace={quote(workspace, safe='')}"
f"&view=sessions"
f"&session={quote(session_name, safe='')}"
)
def _osc8_link(url: str, text: str) -> str:
"""OSC 8 terminal hyperlink (clickable in iTerm2, Ghostty, WezTerm, etc.)."""
return f"\033]8;;{url}\033\\{text}\033]8;;\033\\"
# =========================================================================
# Context pressure display (CLI user-facing warnings)
# =========================================================================
+10 -1
View File
@@ -82,6 +82,16 @@ class ClassifiedError:
def is_auth(self) -> bool:
return self.reason in (FailoverReason.auth, FailoverReason.auth_permanent)
@property
def is_transient(self) -> bool:
"""Error is expected to resolve on retry (with or without backoff)."""
return self.reason in (
FailoverReason.rate_limit,
FailoverReason.overloaded,
FailoverReason.server_error,
FailoverReason.timeout,
FailoverReason.unknown,
)
# ── Provider-specific patterns ──────────────────────────────────────────
@@ -112,7 +122,6 @@ _RATE_LIMIT_PATTERNS = [
"try again in",
"please retry after",
"resource_exhausted",
"rate increased too quickly", # Alibaba/DashScope throttling
]
# Usage-limit patterns that need disambiguation (could be billing OR rate_limit)
+9
View File
@@ -39,6 +39,15 @@ def _has_known_pricing(model_name: str, provider: str = None, base_url: str = No
return has_known_pricing(model_name, provider=provider, base_url=base_url)
def _get_pricing(model_name: str) -> Dict[str, float]:
"""Look up pricing for a model. Uses fuzzy matching on model name.
Returns _DEFAULT_PRICING (zero cost) for unknown/custom models
we can't assume costs for self-hosted endpoints, local inference, etc.
"""
return get_pricing(model_name)
def _estimate_cost(
session_or_model: Dict[str, Any] | str,
input_tokens: int = 0,
+5
View File
@@ -134,6 +134,11 @@ class MemoryManager:
"""All registered providers in order."""
return list(self._providers)
@property
def provider_names(self) -> List[str]:
"""Names of all registered providers."""
return [p.name for p in self._providers]
def get_provider(self, name: str) -> Optional[MemoryProvider]:
"""Get a provider by name, or None if not registered."""
for p in self._providers:
-1
View File
@@ -213,7 +213,6 @@ _URL_TO_PROVIDER: Dict[str, str] = {
"models.github.ai": "copilot",
"api.fireworks.ai": "fireworks",
"opencode.ai": "opencode-go",
"api.x.ai": "xai",
}
+111
View File
@@ -135,6 +135,9 @@ class ProviderInfo:
doc: str = "" # documentation URL
model_count: int = 0
def has_api_url(self) -> bool:
return bool(self.api)
# ---------------------------------------------------------------------------
# Provider ID mapping: Hermes ↔ models.dev
@@ -631,6 +634,43 @@ def get_provider_info(provider_id: str) -> Optional[ProviderInfo]:
return _parse_provider_info(mdev_id, raw)
def list_all_providers() -> Dict[str, ProviderInfo]:
"""Return all providers from models.dev as {provider_id: ProviderInfo}.
Returns the full catalog 109+ providers. For providers that have
a Hermes alias, both the models.dev ID and the Hermes ID are included.
"""
data = fetch_models_dev()
result: Dict[str, ProviderInfo] = {}
for pid, pdata in data.items():
if isinstance(pdata, dict):
info = _parse_provider_info(pid, pdata)
result[pid] = info
return result
def get_providers_for_env_var(env_var: str) -> List[str]:
"""Reverse lookup: find all providers that use a given env var.
Useful for auto-detection: "user has ANTHROPIC_API_KEY set, which
providers does that enable?"
Returns list of models.dev provider IDs.
"""
data = fetch_models_dev()
matches: List[str] = []
for pid, pdata in data.items():
if isinstance(pdata, dict):
env = pdata.get("env", [])
if isinstance(env, list) and env_var in env:
matches.append(pid)
return matches
# ---------------------------------------------------------------------------
# Model-level queries (rich ModelInfo)
# ---------------------------------------------------------------------------
@@ -668,3 +708,74 @@ def get_model_info(
return None
def get_model_info_any_provider(model_id: str) -> Optional[ModelInfo]:
"""Search all providers for a model by ID.
Useful when you have a full slug like "anthropic/claude-sonnet-4.6" or
a bare name and want to find it anywhere. Checks Hermes-mapped providers
first, then falls back to all models.dev providers.
"""
data = fetch_models_dev()
# Try Hermes-mapped providers first (more likely what the user wants)
for hermes_id, mdev_id in PROVIDER_TO_MODELS_DEV.items():
pdata = data.get(mdev_id)
if not isinstance(pdata, dict):
continue
models = pdata.get("models", {})
if not isinstance(models, dict):
continue
raw = models.get(model_id)
if isinstance(raw, dict):
return _parse_model_info(model_id, raw, mdev_id)
# Case-insensitive
model_lower = model_id.lower()
for mid, mdata in models.items():
if mid.lower() == model_lower and isinstance(mdata, dict):
return _parse_model_info(mid, mdata, mdev_id)
# Fall back to ALL providers
for pid, pdata in data.items():
if pid in _get_reverse_mapping():
continue # already checked
if not isinstance(pdata, dict):
continue
models = pdata.get("models", {})
if not isinstance(models, dict):
continue
raw = models.get(model_id)
if isinstance(raw, dict):
return _parse_model_info(model_id, raw, pid)
return None
def list_provider_model_infos(provider_id: str) -> List[ModelInfo]:
"""Return all models for a provider as ModelInfo objects.
Filters out deprecated models by default.
"""
mdev_id = PROVIDER_TO_MODELS_DEV.get(provider_id, provider_id)
data = fetch_models_dev()
pdata = data.get(mdev_id)
if not isinstance(pdata, dict):
return []
models = pdata.get("models", {})
if not isinstance(models, dict):
return []
result: List[ModelInfo] = []
for mid, mdata in models.items():
if not isinstance(mdata, dict):
continue
status = mdata.get("status", "")
if status == "deprecated":
continue
result.append(_parse_model_info(mid, mdata, mdev_id))
return result
+14 -12
View File
@@ -356,14 +356,6 @@ PLATFORM_HINTS = {
"MEDIA:/absolute/path/to/file in your response. Images (.jpg, .png, "
".heic) appear as photos and other files arrive as attachments."
),
"weixin": (
"You are on Weixin/WeChat. Markdown formatting is supported, so you may use it when "
"it improves readability, but keep the message compact and chat-friendly. You can send media files natively: "
"include MEDIA:/absolute/path/to/file in your response. Images are sent as native "
"photos, videos play inline when supported, and other files arrive as downloadable "
"documents. You can also include image URLs in markdown format ![alt](url) and they "
"will be downloaded and sent as native media when possible."
),
}
CONTEXT_FILE_MAX_CHARS = 20_000
@@ -487,7 +479,7 @@ def _parse_skill_file(skill_file: Path) -> tuple[bool, dict, str]:
(True, {}, "") to err on the side of showing the skill.
"""
try:
raw = skill_file.read_text(encoding="utf-8")
raw = skill_file.read_text(encoding="utf-8")[:2000]
frontmatter, _ = parse_frontmatter(raw)
if not skill_matches_platform(frontmatter):
@@ -495,10 +487,21 @@ def _parse_skill_file(skill_file: Path) -> tuple[bool, dict, str]:
return True, frontmatter, extract_skill_description(frontmatter)
except Exception as e:
logger.warning("Failed to parse skill file %s: %s", skill_file, e)
logger.debug("Failed to parse skill file %s: %s", skill_file, e)
return True, {}, ""
def _read_skill_conditions(skill_file: Path) -> dict:
"""Extract conditional activation fields from SKILL.md frontmatter."""
try:
raw = skill_file.read_text(encoding="utf-8")[:2000]
frontmatter, _ = parse_frontmatter(raw)
return extract_skill_conditions(frontmatter)
except Exception as e:
logger.debug("Failed to read skill conditions from %s: %s", skill_file, e)
return {}
def _skill_should_show(
conditions: dict,
available_tools: "set[str] | None",
@@ -558,10 +561,9 @@ def build_skills_system_prompt(
# ── Layer 1: in-process LRU cache ─────────────────────────────────
# Include the resolved platform so per-platform disabled-skill lists
# produce distinct cache entries (gateway serves multiple platforms).
from gateway.session_context import get_session_env
_platform_hint = (
os.environ.get("HERMES_PLATFORM")
or get_session_env("HERMES_SESSION_PLATFORM")
or os.environ.get("HERMES_SESSION_PLATFORM")
or ""
)
cache_key = (
+4 -8
View File
@@ -97,12 +97,8 @@ def parse_rate_limit_headers(
Returns None if no rate limit headers are present.
"""
# Normalize to lowercase so lookups work regardless of how the server
# capitalises headers (HTTP header names are case-insensitive per RFC 7230).
lowered = {k.lower(): v for k, v in headers.items()}
# Quick check: at least one rate limit header must exist
has_any = any(k.startswith("x-ratelimit-") for k in lowered)
has_any = any(k.lower().startswith("x-ratelimit-") for k in headers)
if not has_any:
return None
@@ -113,9 +109,9 @@ def parse_rate_limit_headers(
# resource="tokens", suffix="-1h" -> per-hour
tag = f"{resource}{suffix}"
return RateLimitBucket(
limit=_safe_int(lowered.get(f"x-ratelimit-limit-{tag}")),
remaining=_safe_int(lowered.get(f"x-ratelimit-remaining-{tag}")),
reset_seconds=_safe_float(lowered.get(f"x-ratelimit-reset-{tag}")),
limit=_safe_int(headers.get(f"x-ratelimit-limit-{tag}")),
remaining=_safe_int(headers.get(f"x-ratelimit-remaining-{tag}")),
reset_seconds=_safe_float(headers.get(f"x-ratelimit-reset-{tag}")),
captured_at=now,
)
+1 -2
View File
@@ -145,11 +145,10 @@ def get_disabled_skill_names(platform: str | None = None) -> Set[str]:
if not isinstance(skills_cfg, dict):
return set()
from gateway.session_context import get_session_env
resolved_platform = (
platform
or os.getenv("HERMES_PLATFORM")
or get_session_env("HERMES_SESSION_PLATFORM")
or os.getenv("HERMES_SESSION_PLATFORM")
)
if resolved_platform:
platform_disabled = (skills_cfg.get("platform_disabled") or {}).get(
-1
View File
@@ -181,7 +181,6 @@ def resolve_turn_route(user_message: str, routing_config: Optional[Dict[str, Any
"api_mode": runtime.get("api_mode"),
"command": runtime.get("command"),
"args": list(runtime.get("args") or []),
"credential_pool": runtime.get("credential_pool"),
},
"label": f"smart route → {route.get('model')} ({runtime.get('provider')})",
"signature": (
+24
View File
@@ -595,6 +595,30 @@ def get_pricing(
}
def estimate_cost_usd(
model: str,
input_tokens: int,
output_tokens: int,
*,
provider: Optional[str] = None,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
) -> float:
"""Backward-compatible helper for legacy callers.
This uses non-cached input/output only. New code should call
`estimate_usage_cost()` with canonical usage buckets.
"""
result = estimate_usage_cost(
model,
CanonicalUsage(input_tokens=input_tokens, output_tokens=output_tokens),
provider=provider,
base_url=base_url,
api_key=api_key,
)
return float(result.amount_usd or _ZERO)
def format_duration_compact(seconds: float) -> str:
if seconds < 60:
return f"{seconds:.0f}s"
+23 -108
View File
@@ -319,7 +319,7 @@ def load_cli_config() -> Dict[str, Any]:
# Load from file if exists
if config_path.exists():
try:
with open(config_path, "r", encoding="utf-8") as f:
with open(config_path, "r") as f:
file_config = yaml.safe_load(f) or {}
_file_has_terminal_config = "terminal" in file_config
@@ -1048,7 +1048,7 @@ def _termux_example_image_path(filename: str = "cat.png") -> str:
def _split_path_input(raw: str) -> tuple[str, str]:
r"""Split a leading file path token from trailing free-form text.
"""Split a leading file path token from trailing free-form text.
Supports quoted paths and backslash-escaped spaces so callers can accept
inputs like:
@@ -1292,6 +1292,14 @@ HERMES_CADUCEUS = """[#CD7F32]⠀⢀⣀⡀⠀⣀⣀
[#B8860B]⠀⠈⠳⠈⣡⠞⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[/]
[#B8860B]⠀⠈⠀[/]"""
# Compact banner for smaller terminals (fallback)
# Note: built dynamically by _build_compact_banner() to fit terminal width
COMPACT_BANNER = """
[bold #FFD700]╔══════════════════════════════════════════════════════════════╗[/]
[bold #FFD700]║[/] [#FFBF00]⚕ NOUS HERMES[/] [dim #B8860B]- AI Agent Framework[/] [bold #FFD700]║[/]
[bold #FFD700]║[/] [#CD7F32]Messenger of the Digital Gods[/] [dim #B8860B]Nous Research[/] [bold #FFD700]║[/]
[bold #FFD700]╚══════════════════════════════════════════════════════════════╝[/]
"""
def _build_compact_banner() -> str:
@@ -1537,6 +1545,7 @@ class HermesCLI:
self._stream_buf = "" # Partial line buffer for line-buffered rendering
self._stream_started = False # True once first delta arrives
self._stream_box_opened = False # True once the response box header is printed
self._reasoning_stream_started = False # True once live reasoning starts streaming
self._reasoning_preview_buf = "" # Coalesce tiny reasoning chunks for [thinking] output
self._pending_edit_snapshots = {}
@@ -1594,6 +1603,8 @@ class HermesCLI:
self.api_key = api_key or os.getenv("OPENROUTER_API_KEY") or os.getenv("OPENAI_API_KEY")
else:
self.api_key = api_key or os.getenv("OPENAI_API_KEY") or os.getenv("OPENROUTER_API_KEY")
self._nous_key_expires_at: Optional[str] = None
self._nous_key_source: Optional[str] = None
# Max turns priority: CLI arg > config file > env var > default
if max_turns is not None: # CLI arg was explicitly set
self.max_turns = max_turns
@@ -1719,7 +1730,6 @@ class HermesCLI:
self._secret_state = None
self._secret_deadline = 0
self._spinner_text: str = "" # thinking spinner text for TUI
self._tool_start_time: float = 0.0 # monotonic timestamp when current tool started (for live elapsed)
self._command_running = False
self._command_status = ""
self._attached_images: list[Path] = []
@@ -2028,25 +2038,6 @@ class HermesCLI:
current_model = (self.model or "").strip()
changed = False
try:
from hermes_cli.model_normalize import (
_AGGREGATOR_PROVIDERS,
normalize_model_for_provider,
)
if resolved_provider not in _AGGREGATOR_PROVIDERS:
normalized_model = normalize_model_for_provider(current_model, resolved_provider)
if normalized_model and normalized_model != current_model:
if not self._model_is_default:
self.console.print(
f"[yellow]⚠️ Normalized model '{current_model}' to '{normalized_model}' for {resolved_provider}.[/]"
)
self.model = normalized_model
current_model = normalized_model
changed = True
except Exception:
pass
if resolved_provider == "copilot":
try:
from hermes_cli.models import copilot_model_api_mode, normalize_copilot_model_id
@@ -2092,7 +2083,7 @@ class HermesCLI:
return changed
if resolved_provider != "openai-codex":
return changed
return False
# 1. Strip provider prefix ("openai/gpt-5.4" → "gpt-5.4")
if "/" in current_model:
@@ -2131,7 +2122,6 @@ class HermesCLI:
if not text:
self._flush_reasoning_preview(force=True)
self._spinner_text = text or ""
self._tool_start_time = 0.0 # clear tool timer when switching to thinking
self._invalidate()
# ── Streaming display ────────────────────────────────────────────────
@@ -2244,6 +2234,7 @@ class HermesCLI:
"""
if not text:
return
self._reasoning_stream_started = True
self._reasoning_shown_this_turn = True
if getattr(self, "_stream_box_opened", False):
return
@@ -2504,6 +2495,7 @@ class HermesCLI:
self._stream_buf = ""
self._stream_started = False
self._stream_box_opened = False
self._reasoning_stream_started = False
self._stream_text_ansi = ""
self._stream_prefilt = ""
self._in_reasoning_block = False
@@ -3381,22 +3373,22 @@ class HermesCLI:
pass # Don't crash on import errors
def _show_status(self):
"""Show compact startup status line."""
"""Show current status bar."""
# Get tool count
tools = get_tool_definitions(enabled_toolsets=self.enabled_toolsets, quiet_mode=True)
tool_count = len(tools) if tools else 0
# Format model name (shorten if needed)
model_short = self.model.split("/")[-1] if "/" in self.model else self.model
if len(model_short) > 30:
model_short = model_short[:27] + "..."
# Get API status indicator
if self.api_key:
api_indicator = "[green bold]●[/]"
else:
api_indicator = "[red bold]●[/]"
# Build status line with proper markup
toolsets_info = ""
if self.enabled_toolsets and "all" not in self.enabled_toolsets:
@@ -3411,59 +3403,6 @@ class HermesCLI:
f"[dim #B8860B]·[/] [bold cyan]{tool_count} tools[/]"
f"{toolsets_info}{provider_info}"
)
def _show_session_status(self):
"""Show gateway-style status for the current CLI session."""
session_meta = {}
if self._session_db:
try:
session_meta = self._session_db.get_session(self.session_id) or {}
except Exception:
session_meta = {}
title = (session_meta.get("title") or "").strip()
created_at = self.session_start
started_at = session_meta.get("started_at")
if started_at:
try:
created_at = datetime.fromtimestamp(float(started_at))
except Exception:
created_at = self.session_start
updated_at = created_at
for field in ("updated_at", "last_updated_at", "last_activity_at"):
value = session_meta.get(field)
if not value:
continue
try:
updated_at = datetime.fromtimestamp(float(value))
break
except Exception:
pass
agent = getattr(self, "agent", None)
total_tokens = getattr(agent, "session_total_tokens", 0) or 0
provider = getattr(self, "provider", None) or "unknown"
model = getattr(self, "model", None) or "(unknown)"
is_running = bool(getattr(self, "_agent_running", False))
lines = [
"Hermes CLI Status",
"",
f"Session ID: {self.session_id}",
f"Path: {display_hermes_home()}",
]
if title:
lines.append(f"Title: {title}")
lines.extend([
f"Model: {model} ({provider})",
f"Created: {created_at.strftime('%Y-%m-%d %H:%M')}",
f"Last Activity: {updated_at.strftime('%Y-%m-%d %H:%M')}",
f"Tokens: {total_tokens:,}",
f"Agent Running: {'Yes' if is_running else 'No'}",
])
self.console.print("\n".join(lines), highlight=False, markup=False)
def _fast_command_available(self) -> bool:
try:
@@ -4947,8 +4886,6 @@ class HermesCLI:
self._handle_skills_command(cmd_original)
elif canonical == "platforms":
self._show_gateway_status()
elif canonical == "status":
self._show_session_status()
elif canonical == "statusbar":
self._status_bar_visible = not self._status_bar_visible
state = "visible" if self._status_bar_visible else "hidden"
@@ -5838,7 +5775,7 @@ class HermesCLI:
approx_tokens = estimate_messages_tokens_rough(self.conversation_history)
print(f"🗜️ Compressing {original_count} messages (~{approx_tokens:,} tokens)...")
compressed, _new_system = self.agent._compress_context(
compressed, new_system = self.agent._compress_context(
self.conversation_history,
self.agent._cached_system_prompt or "",
approx_tokens=approx_tokens,
@@ -6147,20 +6084,11 @@ class HermesCLI:
Updates the TUI spinner widget so the user can see what the agent
is doing during tool execution (fills the gap between thinking
spinner and next response). Also plays audio cue in voice mode.
On tool.started, records a monotonic timestamp so get_spinner_text()
can show a live elapsed timer (the TUI poll loop already invalidates
every ~0.15s, so the counter updates automatically).
"""
if event_type == "tool.completed":
import time as _time
self._tool_start_time = 0.0
self._invalidate()
return
# Only act on tool.started; ignore tool.completed, reasoning.available, etc.
if event_type != "tool.started":
return
if function_name and not function_name.startswith("_"):
import time as _time
from agent.display import get_tool_emoji
emoji = get_tool_emoji(function_name)
label = preview or function_name
@@ -6169,7 +6097,6 @@ class HermesCLI:
if _pl > 0 and len(label) > _pl:
label = label[:_pl - 3] + "..."
self._spinner_text = f"{emoji} {label}"
self._tool_start_time = _time.monotonic()
self._invalidate()
if not self._voice_mode:
@@ -8011,7 +7938,7 @@ class HermesCLI:
agent_name = get_active_skin().get_branding("agent_name", "Hermes Agent")
msg = f"\n{agent_name} has been suspended. Run `fg` to bring {agent_name} back."
def _suspend():
os.write(1, msg.encode("utf-8", errors="replace"))
os.write(1, msg.encode())
os.kill(0, _sig.SIGTSTP)
run_in_terminal(_suspend)
@@ -8371,17 +8298,6 @@ class HermesCLI:
txt = cli_ref._spinner_text
if not txt:
return []
# Append live elapsed timer when a tool is running
t0 = cli_ref._tool_start_time
if t0 > 0:
import time as _time
elapsed = _time.monotonic() - t0
if elapsed >= 60:
_m, _s = int(elapsed // 60), int(elapsed % 60)
elapsed_str = f"{_m}m {_s}s"
else:
elapsed_str = f"{elapsed:.1f}s"
return [('class:hint', f' {txt} ({elapsed_str})')]
return [('class:hint', f' {txt}')]
def get_spinner_height():
@@ -8916,7 +8832,6 @@ class HermesCLI:
finally:
self._agent_running = False
self._spinner_text = ""
self._tool_start_time = 0.0
app.invalidate() # Refresh status line
+7 -10
View File
@@ -31,7 +31,7 @@ except ImportError:
# Configuration
# =============================================================================
HERMES_DIR = get_hermes_home().resolve()
HERMES_DIR = get_hermes_home()
CRON_DIR = HERMES_DIR / "cron"
JOBS_FILE = CRON_DIR / "jobs.json"
OUTPUT_DIR = CRON_DIR / "output"
@@ -338,12 +338,10 @@ def load_jobs() -> List[Dict[str, Any]]:
save_jobs(jobs)
logger.warning("Auto-repaired jobs.json (had invalid control characters)")
return jobs
except Exception as e:
logger.error("Failed to auto-repair jobs.json: %s", e)
raise RuntimeError(f"Cron database corrupted and unrepairable: {e}") from e
except IOError as e:
logger.error("IOError reading jobs.json: %s", e)
raise RuntimeError(f"Failed to read cron database: {e}") from e
except Exception:
return []
except IOError:
return []
def save_jobs(jobs: List[Dict[str, Any]]):
@@ -454,7 +452,6 @@ def create_job(
"last_run_at": None,
"last_status": None,
"last_error": None,
"last_delivery_error": None,
# Delivery configuration
"deliver": deliver,
"origin": origin, # Tracks where job was created for "origin" delivery
@@ -623,8 +620,8 @@ def mark_job_run(job_id: str, success: bool, error: Optional[str] = None,
save_jobs(jobs)
return
logger.warning("mark_job_run: job_id %s not found, skipping save", job_id)
save_jobs(jobs)
def advance_next_run(job_id: str) -> bool:
+2 -3
View File
@@ -44,7 +44,7 @@ logger = logging.getLogger(__name__)
_KNOWN_DELIVERY_PLATFORMS = frozenset({
"telegram", "discord", "slack", "whatsapp", "signal",
"matrix", "mattermost", "homeassistant", "dingtalk", "feishu",
"wecom", "weixin", "sms", "email", "webhook", "bluebubbles",
"wecom", "sms", "email", "webhook", "bluebubbles",
})
from cron.jobs import get_due_jobs, mark_job_run, save_job_output, advance_next_run
@@ -234,7 +234,6 @@ def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> Option
"dingtalk": Platform.DINGTALK,
"feishu": Platform.FEISHU,
"wecom": Platform.WECOM,
"weixin": Platform.WEIXIN,
"email": Platform.EMAIL,
"sms": Platform.SMS,
"bluebubbles": Platform.BLUEBUBBLES,
@@ -769,7 +768,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
_cron_pool.shutdown(wait=False, cancel_futures=True)
raise
finally:
_cron_pool.shutdown(wait=False, cancel_futures=True)
_cron_pool.shutdown(wait=False)
if _inactivity_timeout:
# Build diagnostic summary from the agent's activity tracker.
+1 -4
View File
@@ -9,10 +9,7 @@ INSTALL_DIR="/opt/hermes"
# (cache/images, cache/audio, platforms/whatsapp, etc.) are created on
# demand by the application — don't pre-create them here so new installs
# get the consolidated layout from get_hermes_dir().
# The "home/" subdirectory is a per-profile HOME for subprocesses (git,
# ssh, gh, npm …). Without it those tools write to /root which is
# ephemeral and shared across profiles. See issue #4426.
mkdir -p "$HERMES_HOME"/{cron,sessions,logs,hooks,memories,skills,skins,plans,workspace,home}
mkdir -p "$HERMES_HOME"/{cron,sessions,logs,hooks,memories,skills}
# .env
if [ ! -f "$HERMES_HOME/.env" ]; then
+4 -9
View File
@@ -76,15 +76,10 @@ def build_channel_directory(adapters: Dict[Any, Any]) -> Dict[str, Any]:
except Exception as e:
logger.warning("Channel directory: failed to build %s: %s", platform.value, e)
# Platforms that don't support direct channel enumeration get session-based
# discovery automatically. Skip infrastructure entries that aren't messaging
# platforms — everything else falls through to _build_from_sessions().
_SKIP_SESSION_DISCOVERY = frozenset({"local", "api_server", "webhook"})
for plat in Platform:
plat_name = plat.value
if plat_name in _SKIP_SESSION_DISCOVERY or plat_name in platforms:
continue
platforms[plat_name] = _build_from_sessions(plat_name)
# Telegram, WhatsApp & Signal can't enumerate chats -- pull from session history
for plat_name in ("telegram", "whatsapp", "signal", "email", "sms", "bluebubbles"):
if plat_name not in platforms:
platforms[plat_name] = _build_from_sessions(plat_name)
directory = {
"updated_at": datetime.now().isoformat(),
-49
View File
@@ -63,7 +63,6 @@ class Platform(Enum):
WEBHOOK = "webhook"
FEISHU = "feishu"
WECOM = "wecom"
WEIXIN = "weixin"
BLUEBUBBLES = "bluebubbles"
@@ -262,11 +261,6 @@ class GatewayConfig:
for platform, config in self.platforms.items():
if not config.enabled:
continue
# Weixin requires both a token and an account_id
if platform == Platform.WEIXIN:
if config.extra.get("account_id") and (config.token or config.extra.get("token")):
connected.append(platform)
continue
# Platforms that use token/api_key auth
if config.token or config.api_key:
connected.append(platform)
@@ -542,8 +536,6 @@ def load_gateway_config() -> GatewayConfig:
bridged["free_response_channels"] = platform_cfg["free_response_channels"]
if "mention_patterns" in platform_cfg:
bridged["mention_patterns"] = platform_cfg["mention_patterns"]
if plat == Platform.DISCORD and "channel_skill_bindings" in platform_cfg:
bridged["channel_skill_bindings"] = platform_cfg["channel_skill_bindings"]
if not bridged:
continue
plat_data = platforms_data.setdefault(plat.value, {})
@@ -642,8 +634,6 @@ def load_gateway_config() -> GatewayConfig:
os.environ["MATRIX_FREE_RESPONSE_ROOMS"] = str(frc)
if "auto_thread" in matrix_cfg and not os.getenv("MATRIX_AUTO_THREAD"):
os.environ["MATRIX_AUTO_THREAD"] = str(matrix_cfg["auto_thread"]).lower()
if "dm_mention_threads" in matrix_cfg and not os.getenv("MATRIX_DM_MENTION_THREADS"):
os.environ["MATRIX_DM_MENTION_THREADS"] = str(matrix_cfg["dm_mention_threads"]).lower()
except Exception as e:
logger.warning(
@@ -682,7 +672,6 @@ def load_gateway_config() -> GatewayConfig:
Platform.SLACK: "SLACK_BOT_TOKEN",
Platform.MATTERMOST: "MATTERMOST_TOKEN",
Platform.MATRIX: "MATRIX_ACCESS_TOKEN",
Platform.WEIXIN: "WEIXIN_TOKEN",
}
for platform, pconfig in config.platforms.items():
if not pconfig.enabled:
@@ -987,44 +976,6 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
name=os.getenv("WECOM_HOME_CHANNEL_NAME", "Home"),
)
# Weixin (personal WeChat via iLink Bot API)
weixin_token = os.getenv("WEIXIN_TOKEN")
weixin_account_id = os.getenv("WEIXIN_ACCOUNT_ID")
if weixin_token or weixin_account_id:
if Platform.WEIXIN not in config.platforms:
config.platforms[Platform.WEIXIN] = PlatformConfig()
config.platforms[Platform.WEIXIN].enabled = True
if weixin_token:
config.platforms[Platform.WEIXIN].token = weixin_token
extra = config.platforms[Platform.WEIXIN].extra
if weixin_account_id:
extra["account_id"] = weixin_account_id
weixin_base_url = os.getenv("WEIXIN_BASE_URL", "").strip()
if weixin_base_url:
extra["base_url"] = weixin_base_url.rstrip("/")
weixin_cdn_base_url = os.getenv("WEIXIN_CDN_BASE_URL", "").strip()
if weixin_cdn_base_url:
extra["cdn_base_url"] = weixin_cdn_base_url.rstrip("/")
weixin_dm_policy = os.getenv("WEIXIN_DM_POLICY", "").strip().lower()
if weixin_dm_policy:
extra["dm_policy"] = weixin_dm_policy
weixin_group_policy = os.getenv("WEIXIN_GROUP_POLICY", "").strip().lower()
if weixin_group_policy:
extra["group_policy"] = weixin_group_policy
weixin_allowed_users = os.getenv("WEIXIN_ALLOWED_USERS", "").strip()
if weixin_allowed_users:
extra["allow_from"] = weixin_allowed_users
weixin_group_allowed_users = os.getenv("WEIXIN_GROUP_ALLOWED_USERS", "").strip()
if weixin_group_allowed_users:
extra["group_allow_from"] = weixin_group_allowed_users
weixin_home = os.getenv("WEIXIN_HOME_CHANNEL", "").strip()
if weixin_home:
config.platforms[Platform.WEIXIN].home_channel = HomeChannel(
platform=Platform.WEIXIN,
chat_id=weixin_home,
name=os.getenv("WEIXIN_HOME_CHANNEL_NAME", "Home"),
)
# BlueBubbles (iMessage)
bluebubbles_server_url = os.getenv("BLUEBUBBLES_SERVER_URL")
bluebubbles_password = os.getenv("BLUEBUBBLES_PASSWORD")
+61
View File
@@ -124,6 +124,53 @@ class DeliveryRouter:
self.adapters = adapters or {}
self.output_dir = get_hermes_home() / "cron" / "output"
def resolve_targets(
self,
deliver: Union[str, List[str]],
origin: Optional[SessionSource] = None
) -> List[DeliveryTarget]:
"""
Resolve delivery specification to concrete targets.
Args:
deliver: Delivery spec - "origin", "telegram", ["local", "discord"], etc.
origin: The source where the request originated (for "origin" target)
Returns:
List of resolved delivery targets
"""
if isinstance(deliver, str):
deliver = [deliver]
targets = []
seen_platforms = set()
for target_str in deliver:
target = DeliveryTarget.parse(target_str, origin)
# Resolve home channel if needed
if target.chat_id is None and target.platform != Platform.LOCAL:
home = self.config.get_home_channel(target.platform)
if home:
target.chat_id = home.chat_id
else:
# No home channel configured, skip this platform
continue
# Deduplicate
key = (target.platform, target.chat_id, target.thread_id)
if key not in seen_platforms:
seen_platforms.add(key)
targets.append(target)
# Always include local if configured
if self.config.always_log_local:
local_key = (Platform.LOCAL, None, None)
if local_key not in seen_platforms:
targets.append(DeliveryTarget(platform=Platform.LOCAL))
return targets
async def deliver(
self,
content: str,
@@ -252,5 +299,19 @@ class DeliveryRouter:
return await adapter.send(target.chat_id, content, metadata=send_metadata or None)
def parse_deliver_spec(
deliver: Optional[Union[str, List[str]]],
origin: Optional[SessionSource] = None,
default: str = "origin"
) -> Union[str, List[str]]:
"""
Normalize a delivery specification.
If None or empty, returns the default.
"""
if not deliver:
return default
return deliver
+2 -13
View File
@@ -25,7 +25,6 @@ import hmac
import json
import logging
import os
import socket as _socket
import re
import sqlite3
import time
@@ -43,7 +42,6 @@ from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import (
BasePlatformAdapter,
SendResult,
is_network_accessible,
)
logger = logging.getLogger(__name__)
@@ -408,8 +406,7 @@ class APIServerAdapter(BasePlatformAdapter):
Validate Bearer token from Authorization header.
Returns None if auth is OK, or a 401 web.Response on failure.
If no API key is configured, all requests are allowed (only when API
server is local).
If no API key is configured, all requests are allowed.
"""
if not self._api_key:
return None # No key configured — allow all (local-only use)
@@ -1716,16 +1713,8 @@ class APIServerAdapter(BasePlatformAdapter):
if hasattr(sweep_task, "add_done_callback"):
sweep_task.add_done_callback(self._background_tasks.discard)
# Refuse to start network-accessible without authentication
if is_network_accessible(self._host) and not self._api_key:
logger.error(
"[%s] Refusing to start: binding to %s requires API_SERVER_KEY. "
"Set API_SERVER_KEY or use the default 127.0.0.1.",
self.name, self._host,
)
return False
# Port conflict detection — fail fast if port is already in use
import socket as _socket
try:
with _socket.socket(_socket.AF_INET, _socket.SOCK_STREAM) as _s:
_s.settimeout(1)
+15 -128
View File
@@ -6,12 +6,10 @@ and implement the required methods.
"""
import asyncio
import ipaddress
import logging
import os
import random
import re
import socket as _socket
import subprocess
import sys
import uuid
@@ -21,41 +19,6 @@ from urllib.parse import urlsplit
logger = logging.getLogger(__name__)
def is_network_accessible(host: str) -> bool:
"""Return True if *host* would expose the server beyond loopback.
Loopback addresses (127.0.0.1, ::1, IPv4-mapped ::ffff:127.0.0.1)
are local-only. Unspecified addresses (0.0.0.0, ::) bind all
interfaces. Hostnames are resolved; DNS failure fails closed.
"""
try:
addr = ipaddress.ip_address(host)
if addr.is_loopback:
return False
# ::ffff:127.0.0.1 — Python reports is_loopback=False for mapped
# addresses, so check the underlying IPv4 explicitly.
if getattr(addr, "ipv4_mapped", None) and addr.ipv4_mapped.is_loopback:
return False
return True
except ValueError:
# when host variable is a hostname, we should try to resolve below
pass
try:
resolved = _socket.getaddrinfo(
host, None, _socket.AF_UNSPEC, _socket.SOCK_STREAM,
)
# if the hostname resolves into at least one non-loopback address,
# then we consider it to be network accessible
for _family, _type, _proto, _canonname, sockaddr in resolved:
addr = ipaddress.ip_address(sockaddr[0])
if not addr.is_loopback:
return True
return False
except (_socket.gaierror, OSError):
return True
def _detect_macos_system_proxy() -> str | None:
"""Read the macOS system HTTP(S) proxy via ``scutil --proxy``.
@@ -197,7 +160,7 @@ GATEWAY_SECRET_CAPTURE_UNSUPPORTED_MESSAGE = (
)
def safe_url_for_log(url: str, max_len: int = 80) -> str:
def _safe_url_for_log(url: str, max_len: int = 80) -> str:
"""Return a URL string safe for logs (no query/fragment/userinfo)."""
if max_len <= 0:
return ""
@@ -234,23 +197,6 @@ def safe_url_for_log(url: str, max_len: int = 80) -> str:
return f"{safe[:max_len - 3]}..."
async def _ssrf_redirect_guard(response):
"""Re-validate each redirect target to prevent redirect-based SSRF.
Without this, an attacker can host a public URL that 302-redirects to
http://169.254.169.254/ and bypass the pre-flight is_safe_url() check.
Must be async because httpx.AsyncClient awaits response event hooks.
"""
if response.is_redirect and response.next_request:
redirect_url = str(response.next_request.url)
from tools.url_safety import is_safe_url
if not is_safe_url(redirect_url):
raise ValueError(
f"Blocked redirect to private/internal address: {safe_url_for_log(redirect_url)}"
)
# ---------------------------------------------------------------------------
# Image cache utilities
#
@@ -270,23 +216,6 @@ def get_image_cache_dir() -> Path:
return IMAGE_CACHE_DIR
def _looks_like_image(data: bytes) -> bool:
"""Return True if *data* starts with a known image magic-byte sequence."""
if len(data) < 4:
return False
if data[:8] == b"\x89PNG\r\n\x1a\n":
return True
if data[:3] == b"\xff\xd8\xff":
return True
if data[:6] in (b"GIF87a", b"GIF89a"):
return True
if data[:2] == b"BM":
return True
if data[:4] == b"RIFF" and len(data) >= 12 and data[8:12] == b"WEBP":
return True
return False
def cache_image_from_bytes(data: bytes, ext: str = ".jpg") -> str:
"""
Save raw image bytes to the cache and return the absolute file path.
@@ -297,17 +226,7 @@ def cache_image_from_bytes(data: bytes, ext: str = ".jpg") -> str:
Returns:
Absolute path to the cached image file as a string.
Raises:
ValueError: If *data* does not look like a valid image (e.g. an HTML
error page returned by the upstream server).
"""
if not _looks_like_image(data):
snippet = data[:80].decode("utf-8", errors="replace")
raise ValueError(
f"Refusing to cache non-image data as {ext} "
f"(starts with: {snippet!r})"
)
cache_dir = get_image_cache_dir()
filename = f"img_{uuid.uuid4().hex[:12]}{ext}"
filepath = cache_dir / filename
@@ -335,7 +254,7 @@ async def cache_image_from_url(url: str, ext: str = ".jpg", retries: int = 2) ->
"""
from tools.url_safety import is_safe_url
if not is_safe_url(url):
raise ValueError(f"Blocked unsafe URL (SSRF protection): {safe_url_for_log(url)}")
raise ValueError(f"Blocked unsafe URL (SSRF protection): {_safe_url_for_log(url)}")
import asyncio
import httpx
@@ -343,11 +262,7 @@ async def cache_image_from_url(url: str, ext: str = ".jpg", retries: int = 2) ->
_log = _logging.getLogger(__name__)
last_exc = None
async with httpx.AsyncClient(
timeout=30.0,
follow_redirects=True,
event_hooks={"response": [_ssrf_redirect_guard]},
) as client:
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
for attempt in range(retries + 1):
try:
response = await client.get(
@@ -369,7 +284,7 @@ async def cache_image_from_url(url: str, ext: str = ".jpg", retries: int = 2) ->
"Media cache retry %d/%d for %s (%.1fs): %s",
attempt + 1,
retries,
safe_url_for_log(url),
_safe_url_for_log(url),
wait,
exc,
)
@@ -454,7 +369,7 @@ async def cache_audio_from_url(url: str, ext: str = ".ogg", retries: int = 2) ->
"""
from tools.url_safety import is_safe_url
if not is_safe_url(url):
raise ValueError(f"Blocked unsafe URL (SSRF protection): {safe_url_for_log(url)}")
raise ValueError(f"Blocked unsafe URL (SSRF protection): {_safe_url_for_log(url)}")
import asyncio
import httpx
@@ -462,11 +377,7 @@ async def cache_audio_from_url(url: str, ext: str = ".ogg", retries: int = 2) ->
_log = _logging.getLogger(__name__)
last_exc = None
async with httpx.AsyncClient(
timeout=30.0,
follow_redirects=True,
event_hooks={"response": [_ssrf_redirect_guard]},
) as client:
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
for attempt in range(retries + 1):
try:
response = await client.get(
@@ -488,7 +399,7 @@ async def cache_audio_from_url(url: str, ext: str = ".ogg", retries: int = 2) ->
"Audio cache retry %d/%d for %s (%.1fs): %s",
attempt + 1,
retries,
safe_url_for_log(url),
_safe_url_for_log(url),
wait,
exc,
)
@@ -591,14 +502,6 @@ class MessageType(Enum):
COMMAND = "command" # /command style
class ProcessingOutcome(Enum):
"""Result classification for message-processing lifecycle hooks."""
SUCCESS = "success"
FAILURE = "failure"
CANCELLED = "cancelled"
@dataclass
class MessageEvent:
"""
@@ -626,9 +529,8 @@ class MessageEvent:
reply_to_message_id: Optional[str] = None
reply_to_text: Optional[str] = None # Text of the replied-to message (for context injection)
# Auto-loaded skill(s) for topic/channel bindings (e.g., Telegram DM Topics,
# Discord channel_skill_bindings). A single name or ordered list.
auto_skill: Optional[str | list[str]] = None
# Auto-loaded skill for topic/channel bindings (e.g., Telegram DM Topics)
auto_skill: Optional[str] = None
# Internal flag — set for synthetic events (e.g. background process
# completion notifications) that must bypass user authorization checks.
@@ -650,9 +552,6 @@ class MessageEvent:
raw = parts[0][1:].lower() if parts else None
if raw and "@" in raw:
raw = raw.split("@", 1)[0]
# Reject file paths: valid command names never contain /
if raw and "/" in raw:
return None
return raw
def get_command_args(self) -> str:
@@ -726,7 +625,6 @@ class BasePlatformAdapter(ABC):
# Gateway shutdown cancels these so an old gateway instance doesn't keep
# working on a task after --replace or manual restarts.
self._background_tasks: set[asyncio.Task] = set()
self._expected_cancelled_tasks: set[asyncio.Task] = set()
# Chats where auto-TTS on voice input is disabled (set by /voice off)
self._auto_tts_disabled_chats: set = set()
# Chats where typing indicator is paused (e.g. during approval waits).
@@ -1235,7 +1133,7 @@ class BasePlatformAdapter(ABC):
async def on_processing_start(self, event: MessageEvent) -> None:
"""Hook called when background processing begins."""
async def on_processing_complete(self, event: MessageEvent, outcome: ProcessingOutcome) -> None:
async def on_processing_complete(self, event: MessageEvent, success: bool) -> None:
"""Hook called when background processing completes."""
async def _run_processing_hook(self, hook_name: str, *args: Any, **kwargs: Any) -> None:
@@ -1396,7 +1294,7 @@ class BasePlatformAdapter(ABC):
# session lifecycle and its cleanup races with the running task
# (see PR #4926).
cmd = event.get_command()
if cmd in ("approve", "deny", "status", "stop", "new", "reset", "background"):
if cmd in ("approve", "deny", "status", "stop", "new", "reset"):
logger.debug(
"[%s] Command '/%s' bypassing active-session guard for %s",
self.name, cmd, session_key,
@@ -1454,7 +1352,6 @@ class BasePlatformAdapter(ABC):
return
if hasattr(task, "add_done_callback"):
task.add_done_callback(self._background_tasks.discard)
task.add_done_callback(self._expected_cancelled_tasks.discard)
@staticmethod
def _get_human_delay() -> float:
@@ -1591,7 +1488,7 @@ class BasePlatformAdapter(ABC):
logger.info(
"[%s] Sending image: %s (alt=%s)",
self.name,
safe_url_for_log(image_url),
_safe_url_for_log(image_url),
alt_text[:30] if alt_text else "",
)
# Route animated GIFs through send_animation for proper playback
@@ -1683,11 +1580,7 @@ class BasePlatformAdapter(ABC):
# Determine overall success for the processing hook
processing_ok = delivery_succeeded if delivery_attempted else not bool(response)
await self._run_processing_hook(
"on_processing_complete",
event,
ProcessingOutcome.SUCCESS if processing_ok else ProcessingOutcome.FAILURE,
)
await self._run_processing_hook("on_processing_complete", event, processing_ok)
# Check if there's a pending message that was queued during our processing
if session_key in self._pending_messages:
@@ -1706,14 +1599,10 @@ class BasePlatformAdapter(ABC):
return # Already cleaned up
except asyncio.CancelledError:
current_task = asyncio.current_task()
outcome = ProcessingOutcome.CANCELLED
if current_task is None or current_task not in self._expected_cancelled_tasks:
outcome = ProcessingOutcome.FAILURE
await self._run_processing_hook("on_processing_complete", event, outcome)
await self._run_processing_hook("on_processing_complete", event, False)
raise
except Exception as e:
await self._run_processing_hook("on_processing_complete", event, ProcessingOutcome.FAILURE)
await self._run_processing_hook("on_processing_complete", event, False)
logger.error("[%s] Error handling message: %s", self.name, e, exc_info=True)
# Send the error to the user so they aren't left with radio silence
try:
@@ -1757,12 +1646,10 @@ class BasePlatformAdapter(ABC):
"""
tasks = [task for task in self._background_tasks if not task.done()]
for task in tasks:
self._expected_cancelled_tasks.add(task)
task.cancel()
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
self._background_tasks.clear()
self._expected_cancelled_tasks.clear()
self._pending_messages.clear()
self._active_sessions.clear()
+15 -65
View File
@@ -49,7 +49,6 @@ from gateway.platforms.base import (
BasePlatformAdapter,
MessageEvent,
MessageType,
ProcessingOutcome,
SendResult,
cache_image_from_url,
cache_audio_from_url,
@@ -606,35 +605,22 @@ class DiscordAdapter(BasePlatformAdapter):
if not self._client.user or self._client.user not in message.mentions:
return
# "all" falls through to handle_message
# Multi-agent filtering: if the message mentions specific bots
# but NOT this bot, the sender is talking to another agent —
# stay silent. Messages with no bot mentions (general chat)
# still fall through to _handle_message for the existing
# DISCORD_REQUIRE_MENTION check.
#
# This replaces the older DISCORD_IGNORE_NO_MENTION logic
# with bot-aware filtering that works correctly when multiple
# agents share a channel.
if not isinstance(message.channel, discord.DMChannel) and message.mentions:
_self_mentioned = (
# If the message @mentions other users but NOT the bot, the
# sender is talking to someone else — stay silent. Only
# applies in server channels; in DMs the user is always
# talking to the bot (mentions are just references).
# Controlled by DISCORD_IGNORE_NO_MENTION (default: true).
_ignore_no_mention = os.getenv(
"DISCORD_IGNORE_NO_MENTION", "true"
).lower() in ("true", "1", "yes")
if _ignore_no_mention and message.mentions and not isinstance(message.channel, discord.DMChannel):
_bot_mentioned = (
self._client.user is not None
and self._client.user in message.mentions
)
_other_bots_mentioned = any(
m.bot and m != self._client.user
for m in message.mentions
)
# If other bots are mentioned but we're not → not for us
if _other_bots_mentioned and not _self_mentioned:
return
# If humans are mentioned but we're not → not for us
# (preserves old DISCORD_IGNORE_NO_MENTION=true behavior)
_ignore_no_mention = os.getenv(
"DISCORD_IGNORE_NO_MENTION", "true"
).lower() in ("true", "1", "yes")
if _ignore_no_mention and not _self_mentioned and not _other_bots_mentioned:
return
if not _bot_mentioned:
return # Talking to someone else, don't interrupt
await self._handle_message(message)
@@ -768,17 +754,14 @@ class DiscordAdapter(BasePlatformAdapter):
if hasattr(message, "add_reaction"):
await self._add_reaction(message, "👀")
async def on_processing_complete(self, event: MessageEvent, outcome: ProcessingOutcome) -> None:
async def on_processing_complete(self, event: MessageEvent, success: bool) -> None:
"""Swap the in-progress reaction for a final success/failure reaction."""
if not self._reactions_enabled():
return
message = event.raw_message
if hasattr(message, "add_reaction"):
await self._remove_reaction(message, "👀")
if outcome == ProcessingOutcome.SUCCESS:
await self._add_reaction(message, "")
elif outcome == ProcessingOutcome.FAILURE:
await self._add_reaction(message, "")
await self._add_reaction(message, "" if success else "")
async def send(
self,
@@ -1905,42 +1888,14 @@ class DiscordAdapter(BasePlatformAdapter):
chat_topic=chat_topic,
)
_parent_id = str(getattr(getattr(interaction, "channel", None), "parent_id", "") or "")
_skills = self._resolve_channel_skills(thread_id, _parent_id or None)
event = MessageEvent(
text=text,
message_type=MessageType.TEXT,
source=source,
raw_message=interaction,
auto_skill=_skills,
)
await self.handle_message(event)
def _resolve_channel_skills(self, channel_id: str, parent_id: str | None = None) -> list[str] | None:
"""Look up auto-skill bindings for a Discord channel/forum thread.
Config format (in platform extra):
channel_skill_bindings:
- id: "123456"
skills: ["skill-a", "skill-b"]
Also checks parent_id so forum threads inherit the forum's bindings.
"""
bindings = self.config.extra.get("channel_skill_bindings", [])
if not bindings:
return None
ids_to_check = {channel_id}
if parent_id:
ids_to_check.add(parent_id)
for entry in bindings:
entry_id = str(entry.get("id", ""))
if entry_id in ids_to_check:
skills = entry.get("skills") or entry.get("skill")
if isinstance(skills, str):
return [skills]
if isinstance(skills, list) and skills:
return list(dict.fromkeys(skills)) # dedup, preserve order
return None
def _thread_parent_channel(self, channel: Any) -> Any:
"""Return the parent text channel when invoked from a thread."""
return getattr(channel, "parent", None) or channel
@@ -2525,10 +2480,6 @@ class DiscordAdapter(BasePlatformAdapter):
if not event_text or not event_text.strip():
event_text = "(The user sent a message with no text content)"
_chan = message.channel
_parent_id = str(getattr(_chan, "parent_id", "") or "")
_chan_id = str(getattr(_chan, "id", ""))
_skills = self._resolve_channel_skills(_chan_id, _parent_id or None)
event = MessageEvent(
text=event_text,
message_type=msg_type,
@@ -2539,7 +2490,6 @@ class DiscordAdapter(BasePlatformAdapter):
media_types=media_types,
reply_to_message_id=str(message.reference.message_id) if message.reference else None,
timestamp=message.created_at,
auto_skill=_skills,
)
# Track thread participation so the bot won't require @mention for
+1 -5
View File
@@ -195,11 +195,7 @@ def _extract_attachments(
ext = Path(filename).suffix.lower()
if ext in _IMAGE_EXTS:
try:
cached_path = cache_image_from_bytes(payload, ext)
except ValueError:
logger.debug("Skipping non-image attachment %s (invalid magic bytes)", filename)
continue
cached_path = cache_image_from_bytes(payload, ext)
attachments.append({
"path": cached_path,
"filename": filename,
+8 -16
View File
@@ -973,8 +973,7 @@ def _run_official_feishu_ws_client(ws_client: Any, adapter: Any) -> None:
return await original_connect(*args, **kwargs)
def _configure_with_overrides(conf: Any) -> Any:
if original_configure is None:
raise RuntimeError("Feishu _configure_with_overrides called but original_configure is None")
assert original_configure is not None
result = original_configure(conf)
_apply_runtime_ws_overrides()
return result
@@ -1190,8 +1189,6 @@ class FeishuAdapter(BasePlatformAdapter):
lambda data: self._on_reaction_event("im.message.reaction.deleted_v1", data)
)
.register_p2_card_action_trigger(self._on_card_action_trigger)
.register_p2_im_chat_member_bot_added_v1(self._on_bot_added_to_chat)
.register_p2_im_chat_member_bot_deleted_v1(self._on_bot_removed_from_chat)
.build()
)
@@ -1582,18 +1579,13 @@ class FeishuAdapter(BasePlatformAdapter):
return SendResult(success=False, error=f"Image file not found: {image_path}")
try:
import io as _io
with open(image_path, "rb") as f:
image_bytes = f.read()
# Wrap in BytesIO so lark SDK's MultipartEncoder can read .name and .tell()
image_file = _io.BytesIO(image_bytes)
image_file.name = os.path.basename(image_path)
body = self._build_image_upload_body(
image_type=_FEISHU_IMAGE_UPLOAD_TYPE,
image=image_file,
)
request = self._build_image_upload_request(body)
upload_response = await asyncio.to_thread(self._client.im.v1.image.create, request)
with open(image_path, "rb") as image_file:
body = self._build_image_upload_body(
image_type=_FEISHU_IMAGE_UPLOAD_TYPE,
image=image_file,
)
request = self._build_image_upload_request(body)
upload_response = await asyncio.to_thread(self._client.im.v1.image.create, request)
image_key = self._extract_response_field(upload_response, "image_key")
if not image_key:
return self._response_error_result(
+12 -42
View File
@@ -18,7 +18,6 @@ Environment variables:
MATRIX_REQUIRE_MENTION Require @mention in rooms (default: true)
MATRIX_FREE_RESPONSE_ROOMS Comma-separated room IDs exempt from mention requirement
MATRIX_AUTO_THREAD Auto-create threads for room messages (default: true)
MATRIX_DM_MENTION_THREADS Create a thread when bot is @mentioned in a DM (default: false)
"""
from __future__ import annotations
@@ -41,7 +40,6 @@ from gateway.platforms.base import (
BasePlatformAdapter,
MessageEvent,
MessageType,
ProcessingOutcome,
SendResult,
)
@@ -178,9 +176,6 @@ class MatrixAdapter(BasePlatformAdapter):
self._reactions_enabled: bool = os.getenv(
"MATRIX_REACTIONS", "true"
).lower() not in ("false", "0", "no")
# Tracks the reaction event_id for in-progress (eyes) reactions.
# Key: (room_id, message_event_id) → reaction_event_id (for the eyes reaction).
self._pending_reactions: dict[tuple[str, str], str] = {}
# Text batching: merge rapid successive messages (Telegram-style).
# Matrix clients split long messages around 4000 chars.
@@ -1044,13 +1039,6 @@ class MatrixAdapter(BasePlatformAdapter):
if not self._is_bot_mentioned(body, formatted_body):
return
# DM mention-thread: when enabled, @mentioning bot in a DM creates a thread.
if is_dm and not thread_id:
dm_mention_threads = os.getenv("MATRIX_DM_MENTION_THREADS", "false").lower() in ("true", "1", "yes")
if dm_mention_threads and self._is_bot_mentioned(body, source_content.get("formatted_body")):
thread_id = event.event_id
self._track_thread(thread_id)
# Strip mention from body when present (including in DMs).
if self._is_bot_mentioned(body, source_content.get("formatted_body")):
body = self._strip_mention(body)
@@ -1368,13 +1356,6 @@ class MatrixAdapter(BasePlatformAdapter):
if not self._is_bot_mentioned(body, formatted_body):
return
# DM mention-thread: when enabled, @mentioning bot in a DM creates a thread.
if is_dm and not thread_id:
dm_mention_threads = os.getenv("MATRIX_DM_MENTION_THREADS", "false").lower() in ("true", "1", "yes")
if dm_mention_threads and self._is_bot_mentioned(body, source_content.get("formatted_body")):
thread_id = event.event_id
self._track_thread(thread_id)
# Strip mention from body when present (including in DMs).
if self._is_bot_mentioned(body, source_content.get("formatted_body")):
body = self._strip_mention(body)
@@ -1455,14 +1436,12 @@ class MatrixAdapter(BasePlatformAdapter):
async def _send_reaction(
self, room_id: str, event_id: str, emoji: str,
) -> Optional[str]:
"""Send an emoji reaction to a message in a room.
Returns the reaction event_id on success, None on failure.
"""
) -> bool:
"""Send an emoji reaction to a message in a room."""
import nio
if not self._client:
return None
return False
content = {
"m.relates_to": {
"rel_type": "m.annotation",
@@ -1477,12 +1456,12 @@ class MatrixAdapter(BasePlatformAdapter):
)
if isinstance(resp, nio.RoomSendResponse):
logger.debug("Matrix: sent reaction %s to %s", emoji, event_id)
return resp.event_id
return True
logger.debug("Matrix: reaction send failed: %s", resp)
return None
return False
except Exception as exc:
logger.debug("Matrix: reaction send error: %s", exc)
return None
return False
async def _redact_reaction(
self, room_id: str, reaction_event_id: str, reason: str = "",
@@ -1497,12 +1476,10 @@ class MatrixAdapter(BasePlatformAdapter):
msg_id = event.message_id
room_id = event.source.chat_id
if msg_id and room_id:
reaction_event_id = await self._send_reaction(room_id, msg_id, "\U0001f440")
if reaction_event_id:
self._pending_reactions[(room_id, msg_id)] = reaction_event_id
await self._send_reaction(room_id, msg_id, "\U0001f440")
async def on_processing_complete(
self, event: MessageEvent, outcome: ProcessingOutcome,
self, event: MessageEvent, success: bool,
) -> None:
"""Replace eyes with checkmark (success) or cross (failure)."""
if not self._reactions_enabled:
@@ -1511,18 +1488,11 @@ class MatrixAdapter(BasePlatformAdapter):
room_id = event.source.chat_id
if not msg_id or not room_id:
return
if outcome == ProcessingOutcome.CANCELLED:
return
# Remove the eyes reaction first, if we tracked its event_id.
reaction_key = (room_id, msg_id)
if reaction_key in self._pending_reactions:
eyes_event_id = self._pending_reactions.pop(reaction_key)
if not await self._redact_reaction(room_id, eyes_event_id):
logger.debug("Matrix: failed to redact eyes reaction %s", eyes_event_id)
# Note: Matrix doesn't support removing a specific reaction easily
# without tracking the reaction event_id. We send the new reaction;
# the eyes stays (acceptable UX — both are visible).
await self._send_reaction(
room_id,
msg_id,
"\u2705" if outcome == ProcessingOutcome.SUCCESS else "\u274c",
room_id, msg_id, "\u2705" if success else "\u274c",
)
async def _on_reaction(self, room: Any, event: Any) -> None:
+2 -26
View File
@@ -39,7 +39,6 @@ from gateway.platforms.base import (
MessageType,
SendResult,
SUPPORTED_DOCUMENT_TYPES,
safe_url_for_log,
cache_document_from_bytes,
)
@@ -657,19 +656,8 @@ class SlackAdapter(BasePlatformAdapter):
try:
import httpx
async def _ssrf_redirect_guard(response):
"""Re-check redirect targets so public URLs cannot bounce into private IPs."""
if response.is_redirect and response.next_request:
redirect_url = str(response.next_request.url)
if not is_safe_url(redirect_url):
raise ValueError("Blocked redirect to private/internal address")
# Download the image first
async with httpx.AsyncClient(
timeout=30.0,
follow_redirects=True,
event_hooks={"response": [_ssrf_redirect_guard]},
) as client:
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
response = await client.get(image_url)
response.raise_for_status()
@@ -686,7 +674,7 @@ class SlackAdapter(BasePlatformAdapter):
except Exception as e: # pragma: no cover - defensive logging
logger.warning(
"[Slack] Failed to upload image from URL %s, falling back to text: %s",
safe_url_for_log(image_url),
image_url,
e,
exc_info=True,
)
@@ -1608,18 +1596,6 @@ class SlackAdapter(BasePlatformAdapter):
)
response.raise_for_status()
# Slack may return an HTML sign-in/redirect page
# instead of actual media bytes (e.g. expired token,
# restricted file access). Detect this early so we
# don't cache bogus data and confuse downstream tools.
ct = response.headers.get("content-type", "")
if "text/html" in ct:
raise ValueError(
"Slack returned HTML instead of media "
f"(content-type: {ct}); "
"check bot token scopes and file permissions"
)
if audio:
from gateway.platforms.base import cache_audio_from_bytes
return cache_audio_from_bytes(response.content, ext)
+8 -68
View File
@@ -60,7 +60,6 @@ from gateway.platforms.base import (
BasePlatformAdapter,
MessageEvent,
MessageType,
ProcessingOutcome,
SendResult,
cache_image_from_bytes,
cache_audio_from_bytes,
@@ -518,45 +517,6 @@ class TelegramAdapter(BasePlatformAdapter):
# Build the application
builder = Application.builder().token(self.config.token)
custom_base_url = self.config.extra.get("base_url")
if custom_base_url:
builder = builder.base_url(custom_base_url)
builder = builder.base_file_url(
self.config.extra.get("base_file_url", custom_base_url)
)
logger.info(
"[%s] Using custom Telegram base_url: %s",
self.name, custom_base_url,
)
# PTB defaults (pool_timeout=1s) are too aggressive on flaky networks and
# can trigger "Pool timeout: All connections in the connection pool are occupied"
# during reconnect/bootstrap. Use safer defaults and allow env overrides.
def _env_int(name: str, default: int) -> int:
try:
return int(os.getenv(name, str(default)))
except (TypeError, ValueError):
return default
def _env_float(name: str, default: float) -> float:
try:
return float(os.getenv(name, str(default)))
except (TypeError, ValueError):
return default
request_kwargs = {
"connection_pool_size": _env_int("HERMES_TELEGRAM_HTTP_POOL_SIZE", 512),
"pool_timeout": _env_float("HERMES_TELEGRAM_HTTP_POOL_TIMEOUT", 8.0),
"connect_timeout": _env_float("HERMES_TELEGRAM_HTTP_CONNECT_TIMEOUT", 10.0),
"read_timeout": _env_float("HERMES_TELEGRAM_HTTP_READ_TIMEOUT", 20.0),
"write_timeout": _env_float("HERMES_TELEGRAM_HTTP_WRITE_TIMEOUT", 20.0),
}
proxy_configured = any(
(os.getenv(k) or "").strip()
for k in ("HTTPS_PROXY", "HTTP_PROXY", "ALL_PROXY", "https_proxy", "http_proxy", "all_proxy")
)
disable_fallback = (os.getenv("HERMES_TELEGRAM_DISABLE_FALLBACK_IPS", "").strip().lower() in ("1", "true", "yes", "on"))
fallback_ips = self._fallback_ips()
if not fallback_ips:
fallback_ips = await discover_fallback_ips()
@@ -565,32 +525,16 @@ class TelegramAdapter(BasePlatformAdapter):
self.name,
", ".join(fallback_ips),
)
if fallback_ips and not proxy_configured and not disable_fallback:
if fallback_ips:
logger.info(
"[%s] Telegram fallback IPs active: %s",
self.name,
", ".join(fallback_ips),
)
# Keep request/update pools separate to reduce contention during
# polling reconnect + bot API bootstrap/delete_webhook calls.
request = HTTPXRequest(
**request_kwargs,
httpx_kwargs={"transport": TelegramFallbackTransport(fallback_ips)},
)
get_updates_request = HTTPXRequest(
**request_kwargs,
httpx_kwargs={"transport": TelegramFallbackTransport(fallback_ips)},
)
else:
if proxy_configured:
logger.info("[%s] Proxy configured; skipping Telegram fallback-IP transport", self.name)
elif disable_fallback:
logger.info("[%s] Telegram fallback-IP transport disabled via env", self.name)
request = HTTPXRequest(**request_kwargs)
get_updates_request = HTTPXRequest(**request_kwargs)
builder = builder.request(request).get_updates_request(get_updates_request)
transport = TelegramFallbackTransport(fallback_ips)
request = HTTPXRequest(httpx_kwargs={"transport": transport})
get_updates_request = HTTPXRequest(httpx_kwargs={"transport": transport})
builder = builder.request(request).get_updates_request(get_updates_request)
self._app = builder.build()
self._bot = self._app.bot
@@ -2788,7 +2732,7 @@ class TelegramAdapter(BasePlatformAdapter):
if chat_id and message_id:
await self._set_reaction(chat_id, message_id, "\U0001f440")
async def on_processing_complete(self, event: MessageEvent, outcome: ProcessingOutcome) -> None:
async def on_processing_complete(self, event: MessageEvent, success: bool) -> None:
"""Swap the in-progress reaction for a final success/failure reaction.
Unlike Discord (additive reactions), Telegram's set_message_reaction
@@ -2798,9 +2742,5 @@ class TelegramAdapter(BasePlatformAdapter):
return
chat_id = getattr(event.source, "chat_id", None)
message_id = getattr(event, "message_id", None)
if chat_id and message_id and outcome != ProcessingOutcome.CANCELLED:
await self._set_reaction(
chat_id,
message_id,
"\U0001f44d" if outcome == ProcessingOutcome.SUCCESS else "\U0001f44e",
)
if chat_id and message_id:
await self._set_reaction(chat_id, message_id, "\u2705" if success else "\u274c")
+1 -2
View File
@@ -110,8 +110,7 @@ class TelegramFallbackTransport(httpx.AsyncBaseTransport):
logger.warning("[Telegram] Fallback IP %s failed: %s", ip, exc)
continue
if last_error is None:
raise RuntimeError("All Telegram fallback IPs exhausted but no error was recorded")
assert last_error is not None
raise last_error
async def aclose(self) -> None:
-1
View File
@@ -201,7 +201,6 @@ class WebhookAdapter(BasePlatformAdapter):
"dingtalk",
"feishu",
"wecom",
"weixin",
"bluebubbles",
):
return await self._deliver_cross_platform(
+2 -10
View File
@@ -696,11 +696,7 @@ class WeComAdapter(BasePlatformAdapter):
if kind == "image":
ext = self._detect_image_ext(raw)
try:
return cache_image_from_bytes(raw, ext), self._mime_for_ext(ext, fallback="image/jpeg")
except ValueError as exc:
logger.warning("[%s] Rejected non-image bytes: %s", self.name, exc)
return None
return cache_image_from_bytes(raw, ext), self._mime_for_ext(ext, fallback="image/jpeg")
filename = str(media.get("filename") or media.get("name") or "wecom_file")
return cache_document_from_bytes(raw, filename), mimetypes.guess_type(filename)[0] or "application/octet-stream"
@@ -726,11 +722,7 @@ class WeComAdapter(BasePlatformAdapter):
content_type = str(headers.get("content-type") or "").split(";", 1)[0].strip() or "application/octet-stream"
if kind == "image":
ext = self._guess_extension(url, content_type, fallback=self._detect_image_ext(raw))
try:
return cache_image_from_bytes(raw, ext), content_type or self._mime_for_ext(ext, fallback="image/jpeg")
except ValueError as exc:
logger.warning("[%s] Rejected non-image bytes from %s: %s", self.name, url, exc)
return None
return cache_image_from_bytes(raw, ext), content_type or self._mime_for_ext(ext, fallback="image/jpeg")
filename = self._guess_filename(url, headers.get("content-disposition"), content_type)
return cache_document_from_bytes(raw, filename), content_type
File diff suppressed because it is too large Load Diff
+77 -270
View File
@@ -481,7 +481,6 @@ class GatewayRunner:
self._prefill_messages = self._load_prefill_messages()
self._ephemeral_system_prompt = self._load_ephemeral_system_prompt()
self._reasoning_config = self._load_reasoning_config()
self._service_tier = self._load_service_tier()
self._show_reasoning = self._load_show_reasoning()
self._provider_routing = self._load_provider_routing()
self._fallback_model = self._load_fallback_model()
@@ -515,6 +514,12 @@ class GatewayRunner:
self._agent_cache: Dict[str, tuple] = {}
self._agent_cache_lock = _threading.Lock()
# Track active fallback model/provider when primary is rate-limited.
# Set after an agent run where fallback was activated; cleared when
# the primary model succeeds again or the user switches via /model.
self._effective_model: Optional[str] = None
self._effective_provider: Optional[str] = None
# Per-session model overrides from /model command.
# Key: session_key, Value: dict with model/provider/api_key/base_url/api_mode
self._session_model_overrides: Dict[str, Dict[str, str]] = {}
@@ -777,7 +782,6 @@ class GatewayRunner:
def _resolve_turn_agent_config(self, user_message: str, model: str, runtime_kwargs: dict) -> dict:
from agent.smart_model_routing import resolve_turn_route
from hermes_cli.models import resolve_fast_mode_overrides
primary = {
"model": model,
@@ -789,19 +793,7 @@ class GatewayRunner:
"args": list(runtime_kwargs.get("args") or []),
"credential_pool": runtime_kwargs.get("credential_pool"),
}
route = resolve_turn_route(user_message, getattr(self, "_smart_model_routing", {}), primary)
service_tier = getattr(self, "_service_tier", None)
if not service_tier:
route["request_overrides"] = None
return route
try:
overrides = resolve_fast_mode_overrides(route.get("model"))
except Exception:
overrides = None
route["request_overrides"] = overrides
return route
return resolve_turn_route(user_message, getattr(self, "_smart_model_routing", {}), primary)
async def _handle_adapter_fatal_error(self, adapter: BasePlatformAdapter) -> None:
"""React to an adapter failure after startup.
@@ -953,33 +945,6 @@ class GatewayRunner:
logger.warning("Unknown reasoning_effort '%s', using default (medium)", effort)
return result
@staticmethod
def _load_service_tier() -> str | None:
"""Load Priority Processing setting from config.yaml.
Reads agent.service_tier from config.yaml. Accepted values mirror the CLI:
"fast"/"priority"/"on" => "priority", while "normal"/"off" disables it.
Returns None when unset or unsupported.
"""
raw = ""
try:
import yaml as _y
cfg_path = _hermes_home / "config.yaml"
if cfg_path.exists():
with open(cfg_path, encoding="utf-8") as _f:
cfg = _y.safe_load(_f) or {}
raw = str(cfg.get("agent", {}).get("service_tier", "") or "").strip()
except Exception:
pass
value = raw.lower()
if not value or value in {"normal", "default", "standard", "off", "none"}:
return None
if value in {"fast", "priority", "on"}:
return "priority"
logger.warning("Unknown service_tier '%s', ignoring", raw)
return None
@staticmethod
def _load_show_reasoning() -> bool:
"""Load show_reasoning toggle from config.yaml display section."""
@@ -1110,7 +1075,6 @@ class GatewayRunner:
"MATRIX_ALLOWED_USERS", "DINGTALK_ALLOWED_USERS",
"FEISHU_ALLOWED_USERS",
"WECOM_ALLOWED_USERS",
"WEIXIN_ALLOWED_USERS",
"BLUEBUBBLES_ALLOWED_USERS",
"GATEWAY_ALLOWED_USERS")
)
@@ -1123,7 +1087,6 @@ class GatewayRunner:
"MATRIX_ALLOW_ALL_USERS", "DINGTALK_ALLOW_ALL_USERS",
"FEISHU_ALLOW_ALL_USERS",
"WECOM_ALLOW_ALL_USERS",
"WEIXIN_ALLOW_ALL_USERS",
"BLUEBUBBLES_ALLOW_ALL_USERS")
)
if not _any_allowlist and not _allow_all:
@@ -1348,28 +1311,12 @@ class GatewayRunner:
for key, entry in _expired_entries:
try:
await self._async_flush_memories(entry.session_id)
# Shut down memory provider and close tool resources
# on the cached agent. Idle agents live in
# _agent_cache (not _running_agents), so look there.
_cached_agent = None
_cache_lock = getattr(self, "_agent_cache_lock", None)
if _cache_lock is not None:
with _cache_lock:
_cached = self._agent_cache.get(key)
_cached_agent = _cached[0] if isinstance(_cached, tuple) else _cached if _cached else None
# Fall back to _running_agents in case the agent is
# still mid-turn when the expiry fires.
if _cached_agent is None:
_cached_agent = self._running_agents.get(key)
if _cached_agent and _cached_agent is not _AGENT_PENDING_SENTINEL:
# Shut down memory provider on the cached agent
cached_agent = self._running_agents.get(key)
if cached_agent and cached_agent is not _AGENT_PENDING_SENTINEL:
try:
if hasattr(_cached_agent, 'shutdown_memory_provider'):
_cached_agent.shutdown_memory_provider()
except Exception:
pass
try:
if hasattr(_cached_agent, 'close'):
_cached_agent.close()
if hasattr(cached_agent, 'shutdown_memory_provider'):
cached_agent.shutdown_memory_provider()
except Exception:
pass
# Mark as flushed and persist to disk so the flag
@@ -1552,14 +1499,6 @@ class GatewayRunner:
agent.shutdown_memory_provider()
except Exception:
pass
# Close tool resources (terminal sandboxes, browser daemons,
# background processes, httpx clients) to prevent zombie
# process accumulation.
try:
if hasattr(agent, 'close'):
agent.close()
except Exception:
pass
for platform, adapter in list(self.adapters.items()):
try:
@@ -1582,25 +1521,7 @@ class GatewayRunner:
self._pending_messages.clear()
self._pending_approvals.clear()
self._shutdown_event.set()
# Global cleanup: kill any remaining tool subprocesses not tied
# to a specific agent (catch-all for zombie prevention).
try:
from tools.process_registry import process_registry
process_registry.kill_all()
except Exception:
pass
try:
from tools.terminal_tool import cleanup_all_environments
cleanup_all_environments()
except Exception:
pass
try:
from tools.browser_tool import cleanup_all_browsers
cleanup_all_browsers()
except Exception:
pass
from gateway.status import remove_pid_file, write_runtime_status
remove_pid_file()
try:
@@ -1707,13 +1628,6 @@ class GatewayRunner:
return None
return WeComAdapter(config)
elif platform == Platform.WEIXIN:
from gateway.platforms.weixin import WeixinAdapter, check_weixin_requirements
if not check_weixin_requirements():
logger.warning("Weixin: aiohttp/cryptography not installed")
return None
return WeixinAdapter(config)
elif platform == Platform.MATTERMOST:
from gateway.platforms.mattermost import MattermostAdapter, check_mattermost_requirements
if not check_mattermost_requirements():
@@ -1789,7 +1703,6 @@ class GatewayRunner:
Platform.DINGTALK: "DINGTALK_ALLOWED_USERS",
Platform.FEISHU: "FEISHU_ALLOWED_USERS",
Platform.WECOM: "WECOM_ALLOWED_USERS",
Platform.WEIXIN: "WEIXIN_ALLOWED_USERS",
Platform.BLUEBUBBLES: "BLUEBUBBLES_ALLOWED_USERS",
}
platform_allow_all_map = {
@@ -1805,7 +1718,6 @@ class GatewayRunner:
Platform.DINGTALK: "DINGTALK_ALLOW_ALL_USERS",
Platform.FEISHU: "FEISHU_ALLOW_ALL_USERS",
Platform.WECOM: "WECOM_ALLOW_ALL_USERS",
Platform.WEIXIN: "WEIXIN_ALLOW_ALL_USERS",
Platform.BLUEBUBBLES: "BLUEBUBBLES_ALLOW_ALL_USERS",
}
@@ -2085,11 +1997,6 @@ class GatewayRunner:
return await self._handle_approve_command(event)
return await self._handle_deny_command(event)
# /background must bypass the running-agent guard — it starts a
# parallel task and must never interrupt the active conversation.
if _cmd_def_inner and _cmd_def_inner.name == "background":
return await self._handle_background_command(event)
if event.message_type == MessageType.PHOTO:
logger.debug("PRIORITY photo follow-up for session %s — queueing without interrupt", _quick_key[:20])
adapter = self.adapters.get(source.platform)
@@ -2171,9 +2078,6 @@ class GatewayRunner:
if canonical == "reasoning":
return await self._handle_reasoning_command(event)
if canonical == "fast":
return await self._handle_fast_command(event)
if canonical == "verbose":
return await self._handle_verbose_command(event)
@@ -2442,8 +2346,8 @@ class GatewayRunner:
# Build session context
context = build_session_context(source, self.config, session_entry)
# Set session context variables for tools (task-local, concurrency-safe)
_session_env_tokens = self._set_session_env(context)
# Set environment variables for tools
self._set_session_env(context)
# Read privacy.redact_pii from config (re-read per message)
_redact_pii = False
@@ -2516,41 +2420,37 @@ class GatewayRunner:
session_entry.was_auto_reset = False
session_entry.auto_reset_reason = None
# Auto-load skill(s) for topic/channel bindings (Telegram DM Topics,
# Discord channel_skill_bindings). Supports a single name or ordered list.
# Only inject on NEW sessions — ongoing conversations already have the
# skill content in their conversation history from the first message.
_auto = getattr(event, "auto_skill", None)
if _is_new_session and _auto:
_skill_names = [_auto] if isinstance(_auto, str) else list(_auto)
# Auto-load skill for DM topic bindings (e.g., Telegram Private Chat Topics)
# Only inject on NEW sessions — for ongoing conversations the skill content
# is already in the conversation history from the first message.
if _is_new_session and getattr(event, "auto_skill", None):
try:
from agent.skill_commands import _load_skill_payload, _build_skill_message
_combined_parts: list[str] = []
_loaded_names: list[str] = []
for _sname in _skill_names:
_loaded = _load_skill_payload(_sname, task_id=_quick_key)
if _loaded:
_loaded_skill, _skill_dir, _display_name = _loaded
_note = (
f'[SYSTEM: The "{_display_name}" skill is auto-loaded. '
f"Follow its instructions for this session.]"
_skill_name = event.auto_skill
_loaded = _load_skill_payload(_skill_name, task_id=_quick_key)
if _loaded:
_loaded_skill, _skill_dir, _display_name = _loaded
_activation_note = (
f'[SYSTEM: This conversation is in a topic with the "{_display_name}" skill '
f"auto-loaded. Follow its instructions for the duration of this session.]"
)
_skill_msg = _build_skill_message(
_loaded_skill, _skill_dir, _activation_note,
user_instruction=event.text,
)
if _skill_msg:
event.text = _skill_msg
logger.info(
"[Gateway] Auto-loaded skill '%s' for DM topic session %s",
_skill_name, session_key,
)
_part = _build_skill_message(_loaded_skill, _skill_dir, _note)
if _part:
_combined_parts.append(_part)
_loaded_names.append(_sname)
else:
logger.warning("[Gateway] Auto-skill '%s' not found", _sname)
if _combined_parts:
# Append the user's original text after all skill payloads
_combined_parts.append(event.text)
event.text = "\n\n".join(_combined_parts)
logger.info(
"[Gateway] Auto-loaded skill(s) %s for session %s",
_loaded_names, session_key,
else:
logger.warning(
"[Gateway] DM topic skill '%s' not found in available skills",
_skill_name,
)
except Exception as e:
logger.warning("[Gateway] Failed to auto-load skill(s) %s: %s", _skill_names, e)
logger.warning("[Gateway] Failed to auto-load topic skill '%s': %s", event.auto_skill, e)
# Load conversation history from transcript
history = self.session_store.load_transcript(session_entry.session_id)
@@ -3276,8 +3176,8 @@ class GatewayRunner:
"Try again or use /reset to start a fresh session."
)
finally:
# Restore session context variables to their pre-handler state
self._clear_session_env(_session_env_tokens)
# Clear session env
self._clear_session_env()
def _format_session_info(self) -> str:
"""Resolve current model config and return a formatted info block.
@@ -3377,22 +3277,8 @@ class GatewayRunner:
_flush_task.add_done_callback(self._background_tasks.discard)
except Exception as e:
logger.debug("Gateway memory flush on reset failed: %s", e)
# Close tool resources on the old agent (terminal sandboxes, browser
# daemons, background processes) before evicting from cache.
# Guard with getattr because test fixtures may skip __init__.
_cache_lock = getattr(self, "_agent_cache_lock", None)
if _cache_lock is not None:
with _cache_lock:
_cached = self._agent_cache.get(session_key)
_old_agent = _cached[0] if isinstance(_cached, tuple) else _cached if _cached else None
if _old_agent is not None:
try:
if hasattr(_old_agent, "close"):
_old_agent.close()
except Exception:
pass
self._evict_cached_agent(session_key)
try:
from tools.env_passthrough import clear_env_passthrough
clear_env_passthrough()
@@ -3961,7 +3847,6 @@ class GatewayRunner:
# Resolve current provider from config
current_provider = "openrouter"
model_cfg = {}
config_path = _hermes_home / 'config.yaml'
try:
if config_path.exists():
@@ -4702,7 +4587,6 @@ class GatewayRunner:
max_iterations = int(os.getenv("HERMES_MAX_ITERATIONS", "90"))
reasoning_config = self._load_reasoning_config()
self._reasoning_config = reasoning_config
self._service_tier = self._load_service_tier()
turn_route = self._resolve_turn_agent_config(prompt, model, runtime_kwargs)
def run_sync():
@@ -4714,8 +4598,6 @@ class GatewayRunner:
verbose_logging=False,
enabled_toolsets=enabled_toolsets,
reasoning_config=reasoning_config,
service_tier=self._service_tier,
request_overrides=turn_route.get("request_overrides"),
providers_allowed=pr.get("only"),
providers_ignored=pr.get("ignore"),
providers_order=pr.get("order"),
@@ -4865,7 +4747,6 @@ class GatewayRunner:
model = _resolve_gateway_model(user_config)
platform_key = _platform_config_key(source.platform)
reasoning_config = self._load_reasoning_config()
self._service_tier = self._load_service_tier()
turn_route = self._resolve_turn_agent_config(question, model, runtime_kwargs)
pr = self._provider_routing
@@ -4892,8 +4773,6 @@ class GatewayRunner:
verbose_logging=False,
enabled_toolsets=[],
reasoning_config=reasoning_config,
service_tier=self._service_tier,
request_overrides=turn_route.get("request_overrides"),
providers_allowed=pr.get("only"),
providers_ignored=pr.get("ignore"),
providers_order=pr.get("order"),
@@ -5047,82 +4926,15 @@ class GatewayRunner:
else:
return f"🧠 ✓ Reasoning effort set to `{effort}` (this session only)"
async def _handle_fast_command(self, event: MessageEvent) -> str:
"""Handle /fast — mirror the CLI Priority Processing toggle in gateway chats."""
import yaml
from hermes_cli.models import model_supports_fast_mode
args = event.get_command_args().strip().lower()
config_path = _hermes_home / "config.yaml"
self._service_tier = self._load_service_tier()
user_config = _load_gateway_config()
model = _resolve_gateway_model(user_config)
if not model_supports_fast_mode(model):
return "⚡ /fast is only available for OpenAI models that support Priority Processing."
def _save_config_key(key_path: str, value):
"""Save a dot-separated key to config.yaml."""
try:
user_config = {}
if config_path.exists():
with open(config_path, encoding="utf-8") as f:
user_config = yaml.safe_load(f) or {}
keys = key_path.split(".")
current = user_config
for k in keys[:-1]:
if k not in current or not isinstance(current[k], dict):
current[k] = {}
current = current[k]
current[keys[-1]] = value
atomic_yaml_write(config_path, user_config)
return True
except Exception as e:
logger.error("Failed to save config key %s: %s", key_path, e)
return False
if not args or args == "status":
status = "fast" if self._service_tier == "priority" else "normal"
return (
"⚡ Priority Processing\n\n"
f"Current mode: `{status}`\n\n"
"_Usage:_ `/fast <normal|fast|status>`"
)
if args in {"fast", "on"}:
self._service_tier = "priority"
saved_value = "fast"
label = "FAST"
elif args in {"normal", "off"}:
self._service_tier = None
saved_value = "normal"
label = "NORMAL"
else:
return (
f"⚠️ Unknown argument: `{args}`\n\n"
"**Valid options:** normal, fast, status"
)
if _save_config_key("agent.service_tier", saved_value):
return f"⚡ ✓ Priority Processing: **{label}** (saved to config)\n_(takes effect on next message)_"
return f"⚡ ✓ Priority Processing: **{label}** (this session only)"
async def _handle_yolo_command(self, event: MessageEvent) -> str:
"""Handle /yolo — toggle dangerous command approval bypass for this session only."""
from tools.approval import (
disable_session_yolo,
enable_session_yolo,
is_session_yolo_enabled,
)
session_key = self._session_key_for_source(event.source)
current = is_session_yolo_enabled(session_key)
"""Handle /yolo — toggle dangerous command approval bypass."""
current = bool(os.environ.get("HERMES_YOLO_MODE"))
if current:
disable_session_yolo(session_key)
return "⚠️ YOLO mode **OFF** for this session — dangerous commands will require approval."
os.environ.pop("HERMES_YOLO_MODE", None)
return "⚠️ YOLO mode **OFF** — dangerous commands will require approval."
else:
enable_session_yolo(session_key)
return "⚡ YOLO mode **ON** for this session — all commands auto-approved. Use with caution."
os.environ["HERMES_YOLO_MODE"] = "1"
return "⚡ YOLO mode **ON** — all commands auto-approved. Use with caution."
async def _handle_verbose_command(self, event: MessageEvent) -> str:
"""Handle /verbose command — cycle tool progress display mode.
@@ -5788,7 +5600,7 @@ class GatewayRunner:
Platform.TELEGRAM, Platform.DISCORD, Platform.SLACK, Platform.WHATSAPP,
Platform.SIGNAL, Platform.MATTERMOST, Platform.MATRIX,
Platform.HOMEASSISTANT, Platform.EMAIL, Platform.SMS, Platform.DINGTALK,
Platform.FEISHU, Platform.WECOM, Platform.WEIXIN, Platform.BLUEBUBBLES, Platform.LOCAL,
Platform.FEISHU, Platform.WECOM, Platform.BLUEBUBBLES, Platform.LOCAL,
})
async def _handle_update_command(self, event: MessageEvent) -> str:
@@ -6176,27 +5988,20 @@ class GatewayRunner:
return True
def _set_session_env(self, context: SessionContext) -> list:
"""Set session context variables for the current async task.
Uses ``contextvars`` instead of ``os.environ`` so that concurrent
gateway messages cannot overwrite each other's session state.
Returns a list of reset tokens; pass them to ``_clear_session_env``
in a ``finally`` block.
"""
from gateway.session_context import set_session_vars
return set_session_vars(
platform=context.source.platform.value,
chat_id=context.source.chat_id,
chat_name=context.source.chat_name or "",
thread_id=str(context.source.thread_id) if context.source.thread_id else "",
)
def _clear_session_env(self, tokens: list) -> None:
"""Restore session context variables to their pre-handler values."""
from gateway.session_context import clear_session_vars
clear_session_vars(tokens)
def _set_session_env(self, context: SessionContext) -> None:
"""Set environment variables for the current session."""
os.environ["HERMES_SESSION_PLATFORM"] = context.source.platform.value
os.environ["HERMES_SESSION_CHAT_ID"] = context.source.chat_id
if context.source.chat_name:
os.environ["HERMES_SESSION_CHAT_NAME"] = context.source.chat_name
if context.source.thread_id:
os.environ["HERMES_SESSION_THREAD_ID"] = str(context.source.thread_id)
def _clear_session_env(self) -> None:
"""Clear session environment variables."""
for var in ["HERMES_SESSION_PLATFORM", "HERMES_SESSION_CHAT_ID", "HERMES_SESSION_CHAT_NAME", "HERMES_SESSION_THREAD_ID"]:
if var in os.environ:
del os.environ[var]
async def _enrich_message_with_vision(
self,
@@ -6944,7 +6749,6 @@ class GatewayRunner:
pr = self._provider_routing
reasoning_config = self._load_reasoning_config()
self._reasoning_config = reasoning_config
self._service_tier = self._load_service_tier()
# Set up streaming consumer if enabled
_stream_consumer = None
_stream_delta_cb = None
@@ -7007,8 +6811,6 @@ class GatewayRunner:
ephemeral_system_prompt=combined_ephemeral or None,
prefill_messages=self._prefill_messages or None,
reasoning_config=reasoning_config,
service_tier=self._service_tier,
request_overrides=turn_route.get("request_overrides"),
providers_allowed=pr.get("only"),
providers_ignored=pr.get("ignore"),
providers_order=pr.get("order"),
@@ -7033,8 +6835,6 @@ class GatewayRunner:
agent.stream_delta_callback = _stream_delta_cb
agent.status_callback = _status_callback_sync
agent.reasoning_config = reasoning_config
agent.service_tier = self._service_tier
agent.request_overrides = turn_route.get("request_overrides")
# Background review delivery — send "💾 Memory updated" etc. to user
def _bg_review_send(message: str) -> None:
@@ -7566,9 +7366,16 @@ class GatewayRunner:
if _agent is not None and hasattr(_agent, 'model'):
_cfg_model = _resolve_gateway_model()
if _agent.model != _cfg_model and not self._is_intentional_model_switch(session_key, _agent.model):
self._effective_model = _agent.model
self._effective_provider = getattr(_agent, 'provider', None)
# Fallback activated — evict cached agent so the next
# message starts fresh and retries the primary model.
self._evict_cached_agent(session_key)
else:
# Primary model worked (or intentional /model switch)
# — clear any stale fallback state.
self._effective_model = None
self._effective_provider = None
# Check if we were interrupted OR have a queued message (/queue).
result = result_holder[0]
@@ -7776,7 +7583,7 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool =
# setups (each profile using a distinct HERMES_HOME) will naturally
# allow concurrent instances without tripping this guard.
import time as _time
from gateway.status import get_running_pid, remove_pid_file, terminate_pid
from gateway.status import get_running_pid, remove_pid_file
existing_pid = get_running_pid()
if existing_pid is not None and existing_pid != os.getpid():
if replace:
@@ -7785,10 +7592,10 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool =
existing_pid,
)
try:
terminate_pid(existing_pid, force=False)
os.kill(existing_pid, signal.SIGTERM)
except ProcessLookupError:
pass # Already gone
except (PermissionError, OSError):
except PermissionError:
logger.error(
"Permission denied killing PID %d. Cannot replace.",
existing_pid,
@@ -7808,9 +7615,9 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool =
existing_pid,
)
try:
terminate_pid(existing_pid, force=True)
os.kill(existing_pid, signal.SIGKILL)
_time.sleep(0.5)
except (ProcessLookupError, PermissionError, OSError):
except (ProcessLookupError, PermissionError):
pass
remove_pid_file()
# Also release all scoped locks left by the old process.
+18 -1
View File
@@ -32,6 +32,9 @@ def _now() -> datetime:
# PII redaction helpers
# ---------------------------------------------------------------------------
_PHONE_RE = re.compile(r"^\+?\d[\d\-\s]{6,}$")
def _hash_id(value: str) -> str:
"""Deterministic 12-char hex hash of an identifier."""
return hashlib.sha256(value.encode("utf-8")).hexdigest()[:12]
@@ -55,6 +58,10 @@ def _hash_chat_id(value: str) -> str:
return _hash_id(value)
def _looks_like_phone(value: str) -> bool:
"""Return True if *value* looks like a phone number (E.164 or similar)."""
return bool(_PHONE_RE.match(value.strip()))
from .config import (
Platform,
GatewayConfig,
@@ -137,6 +144,15 @@ class SessionSource:
chat_id_alt=data.get("chat_id_alt"),
)
@classmethod
def local_cli(cls) -> "SessionSource":
"""Create a source representing the local CLI."""
return cls(
platform=Platform.LOCAL,
chat_id="cli",
chat_name="CLI terminal",
chat_type="dm",
)
@dataclass
@@ -494,7 +510,8 @@ class SessionStore:
"""
def __init__(self, sessions_dir: Path, config: GatewayConfig,
has_active_processes_fn=None):
has_active_processes_fn=None,
on_auto_reset=None):
self.sessions_dir = sessions_dir
self.config = config
self._entries: Dict[str, SessionEntry] = {}
-113
View File
@@ -1,113 +0,0 @@
"""
Session-scoped context variables for the Hermes gateway.
Replaces the previous ``os.environ``-based session state
(``HERMES_SESSION_PLATFORM``, ``HERMES_SESSION_CHAT_ID``, etc.) with
Python's ``contextvars.ContextVar``.
**Why this matters**
The gateway processes messages concurrently via ``asyncio``. When two
messages arrive at the same time the old code did:
os.environ["HERMES_SESSION_THREAD_ID"] = str(context.source.thread_id)
Because ``os.environ`` is *process-global*, Message A's value was
silently overwritten by Message B before Message A's agent finished
running. Background-task notifications and tool calls therefore routed
to the wrong thread.
``contextvars.ContextVar`` values are *task-local*: each ``asyncio``
task (and any ``run_in_executor`` thread it spawns) gets its own copy,
so concurrent messages never interfere.
**Backward compatibility**
The public helper ``get_session_env(name, default="")`` mirrors the old
``os.getenv("HERMES_SESSION_*", ...)`` calls. Existing tool code only
needs to replace the import + call site:
# before
import os
platform = os.getenv("HERMES_SESSION_PLATFORM", "")
# after
from gateway.session_context import get_session_env
platform = get_session_env("HERMES_SESSION_PLATFORM", "")
"""
from contextvars import ContextVar
# ---------------------------------------------------------------------------
# Per-task session variables
# ---------------------------------------------------------------------------
_SESSION_PLATFORM: ContextVar[str] = ContextVar("HERMES_SESSION_PLATFORM", default="")
_SESSION_CHAT_ID: ContextVar[str] = ContextVar("HERMES_SESSION_CHAT_ID", default="")
_SESSION_CHAT_NAME: ContextVar[str] = ContextVar("HERMES_SESSION_CHAT_NAME", default="")
_SESSION_THREAD_ID: ContextVar[str] = ContextVar("HERMES_SESSION_THREAD_ID", default="")
_VAR_MAP = {
"HERMES_SESSION_PLATFORM": _SESSION_PLATFORM,
"HERMES_SESSION_CHAT_ID": _SESSION_CHAT_ID,
"HERMES_SESSION_CHAT_NAME": _SESSION_CHAT_NAME,
"HERMES_SESSION_THREAD_ID": _SESSION_THREAD_ID,
}
def set_session_vars(
platform: str = "",
chat_id: str = "",
chat_name: str = "",
thread_id: str = "",
) -> list:
"""Set all session context variables and return reset tokens.
Call ``clear_session_vars(tokens)`` in a ``finally`` block to restore
the previous values when the handler exits.
Returns a list of ``Token`` objects (one per variable) that can be
passed to ``clear_session_vars``.
"""
tokens = [
_SESSION_PLATFORM.set(platform),
_SESSION_CHAT_ID.set(chat_id),
_SESSION_CHAT_NAME.set(chat_name),
_SESSION_THREAD_ID.set(thread_id),
]
return tokens
def clear_session_vars(tokens: list) -> None:
"""Restore session context variables to their pre-handler values."""
if not tokens:
return
vars_in_order = [
_SESSION_PLATFORM,
_SESSION_CHAT_ID,
_SESSION_CHAT_NAME,
_SESSION_THREAD_ID,
]
for var, token in zip(vars_in_order, tokens):
var.reset(token)
def get_session_env(name: str, default: str = "") -> str:
"""Read a session context variable by its legacy ``HERMES_SESSION_*`` name.
Drop-in replacement for ``os.getenv("HERMES_SESSION_*", default)``.
Resolution order:
1. Context variable (set by the gateway for concurrency-safe access)
2. ``os.environ`` (used by CLI, cron scheduler, and tests)
3. *default*
"""
import os
var = _VAR_MAP.get(name)
if var is not None:
value = var.get()
if value:
return value
# Fall back to os.environ for CLI, cron, and test compatibility
return os.getenv(name, default)
-30
View File
@@ -14,8 +14,6 @@ concurrently under distinct configurations).
import hashlib
import json
import os
import signal
import subprocess
import sys
from datetime import datetime, timezone
from pathlib import Path
@@ -25,7 +23,6 @@ from typing import Any, Optional
_GATEWAY_KIND = "hermes-gateway"
_RUNTIME_STATUS_FILE = "gateway_state.json"
_LOCKS_DIRNAME = "gateway-locks"
_IS_WINDOWS = sys.platform == "win32"
def _get_pid_path() -> Path:
@@ -52,33 +49,6 @@ def _utc_now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
def terminate_pid(pid: int, *, force: bool = False) -> None:
"""Terminate a PID with platform-appropriate force semantics.
POSIX uses SIGTERM/SIGKILL. Windows uses taskkill /T /F for true force-kill
because os.kill(..., SIGTERM) is not equivalent to a tree-killing hard stop.
"""
if force and _IS_WINDOWS:
try:
result = subprocess.run(
["taskkill", "/PID", str(pid), "/T", "/F"],
capture_output=True,
text=True,
timeout=10,
)
except FileNotFoundError:
os.kill(pid, signal.SIGTERM)
return
if result.returncode != 0:
details = (result.stderr or result.stdout or "").strip()
raise OSError(details or f"taskkill failed for PID {pid}")
return
sig = signal.SIGTERM if not force else getattr(signal, "SIGKILL", signal.SIGTERM)
os.kill(pid, sig)
def _scope_hash(identity: str) -> str:
return hashlib.sha256(identity.encode("utf-8")).hexdigest()[:16]
+5 -14
View File
@@ -205,20 +205,11 @@ class GatewayStreamConsumer:
await self._send_or_edit(self._accumulated)
return
# Tool boundary: reset message state so the next text chunk
# creates a fresh message below any tool-progress messages.
#
# Exception: when _message_id is "__no_edit__" the platform
# never returned a real message ID (e.g. Signal, webhook with
# github_comment delivery). Resetting to None would re-enter
# the "first send" path on every tool boundary and post one
# platform message per tool call — that is what caused 155
# comments under a single PR. Instead, keep all state so the
# full continuation is delivered once via _send_fallback_final.
# (When editing fails mid-stream due to flood control the id is
# a real string like "msg_1", not "__no_edit__", so that case
# still resets and creates a fresh segment as intended.)
if got_segment_break and self._message_id != "__no_edit__":
# Tool boundary: the should_edit block above already flushed
# accumulated text without a cursor. Reset state so the next
# text chunk creates a fresh message below any tool-progress
# messages the gateway sent in between.
if got_segment_break:
self._message_id = None
self._accumulated = ""
self._last_sent_text = ""
+30 -92
View File
@@ -70,6 +70,7 @@ DEFAULT_CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex"
DEFAULT_QWEN_BASE_URL = "https://portal.qwen.ai/v1"
DEFAULT_GITHUB_MODELS_BASE_URL = "https://api.githubcopilot.com"
DEFAULT_COPILOT_ACP_BASE_URL = "acp://copilot"
DEFAULT_GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai"
CODEX_OAUTH_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
CODEX_OAUTH_TOKEN_URL = "https://auth.openai.com/oauth/token"
CODEX_ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120
@@ -198,14 +199,6 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
api_key_env_vars=("DEEPSEEK_API_KEY",),
base_url_env_var="DEEPSEEK_BASE_URL",
),
"xai": ProviderConfig(
id="xai",
name="xAI",
auth_type="api_key",
inference_base_url="https://api.x.ai/v1",
api_key_env_vars=("XAI_API_KEY",),
base_url_env_var="XAI_BASE_URL",
),
"ai-gateway": ProviderConfig(
id="ai-gateway",
name="AI Gateway",
@@ -712,27 +705,6 @@ def write_credential_pool(provider_id: str, entries: List[Dict[str, Any]]) -> Pa
return _save_auth_store(auth_store)
def suppress_credential_source(provider_id: str, source: str) -> None:
"""Mark a credential source as suppressed so it won't be re-seeded."""
with _auth_store_lock():
auth_store = _load_auth_store()
suppressed = auth_store.setdefault("suppressed_sources", {})
provider_list = suppressed.setdefault(provider_id, [])
if source not in provider_list:
provider_list.append(source)
_save_auth_store(auth_store)
def is_source_suppressed(provider_id: str, source: str) -> bool:
"""Check if a credential source has been suppressed by the user."""
try:
auth_store = _load_auth_store()
suppressed = auth_store.get("suppressed_sources", {})
return source in suppressed.get(provider_id, [])
except Exception:
return False
def get_provider_auth_state(provider_id: str) -> Optional[Dict[str, Any]]:
"""Return persisted auth state for a provider, or None."""
auth_store = _load_auth_store()
@@ -745,57 +717,6 @@ def get_active_provider() -> Optional[str]:
return auth_store.get("active_provider")
def is_provider_explicitly_configured(provider_id: str) -> bool:
"""Return True only if the user has explicitly configured this provider.
Checks:
1. active_provider in auth.json matches
2. model.provider in config.yaml matches
3. Provider-specific env vars are set (e.g. ANTHROPIC_API_KEY)
This is used to gate auto-discovery of external credentials (e.g.
Claude Code's ~/.claude/.credentials.json) so they are never used
without the user's explicit choice. See PR #4210 for the same
pattern applied to the setup wizard gate.
"""
normalized = (provider_id or "").strip().lower()
# 1. Check auth.json active_provider
try:
auth_store = _load_auth_store()
active = (auth_store.get("active_provider") or "").strip().lower()
if active and active == normalized:
return True
except Exception:
pass
# 2. Check config.yaml model.provider
try:
from hermes_cli.config import load_config
cfg = load_config()
model_cfg = cfg.get("model")
if isinstance(model_cfg, dict):
cfg_provider = (model_cfg.get("provider") or "").strip().lower()
if cfg_provider == normalized:
return True
except Exception:
pass
# 3. Check provider-specific env vars
# Exclude CLAUDE_CODE_OAUTH_TOKEN — it's set by Claude Code itself,
# not by the user explicitly configuring anthropic in Hermes.
_IMPLICIT_ENV_VARS = {"CLAUDE_CODE_OAUTH_TOKEN"}
pconfig = PROVIDER_REGISTRY.get(normalized)
if pconfig and pconfig.auth_type == "api_key":
for env_var in pconfig.api_key_env_vars:
if env_var in _IMPLICIT_ENV_VARS:
continue
if has_usable_secret(os.getenv(env_var, "")):
return True
return False
def clear_provider_auth(provider_id: Optional[str] = None) -> bool:
"""
Clear auth state for a provider. Used by `hermes logout`.
@@ -898,7 +819,7 @@ def resolve_provider(
_PROVIDER_ALIASES = {
"glm": "zai", "z-ai": "zai", "z.ai": "zai", "zhipu": "zai",
"google": "gemini", "google-gemini": "gemini", "google-ai-studio": "gemini",
"kimi": "kimi-coding", "kimi-for-coding": "kimi-coding", "moonshot": "kimi-coding",
"kimi": "kimi-coding", "moonshot": "kimi-coding",
"minimax-china": "minimax-cn", "minimax_cn": "minimax-cn",
"claude": "anthropic", "claude-code": "anthropic",
"github": "copilot", "github-copilot": "copilot",
@@ -1521,15 +1442,7 @@ def _resolve_verify(
if effective_insecure:
return False
if effective_ca:
ca_path = str(effective_ca)
if not os.path.isfile(ca_path):
import logging
logging.getLogger("hermes.auth").warning(
"CA bundle path does not exist: %s — falling back to default certificates",
ca_path,
)
return True
return ca_path
return str(effective_ca)
return True
@@ -2429,6 +2342,33 @@ def resolve_external_process_provider_credentials(provider_id: str) -> Dict[str,
}
# =============================================================================
# External credential detection
# =============================================================================
def detect_external_credentials() -> List[Dict[str, Any]]:
"""Scan for credentials from other CLI tools that Hermes can reuse.
Returns a list of dicts, each with:
- provider: str -- Hermes provider id (e.g. "openai-codex")
- path: str -- filesystem path where creds were found
- label: str -- human-friendly description for the setup UI
"""
found: List[Dict[str, Any]] = []
# Codex CLI: ~/.codex/auth.json (importable, not shared)
cli_tokens = _import_codex_cli_tokens()
if cli_tokens:
codex_path = Path.home() / ".codex" / "auth.json"
found.append({
"provider": "openai-codex",
"path": str(codex_path),
"label": f"Codex CLI credentials found ({codex_path}) — run `hermes auth` to create a separate session",
})
return found
# =============================================================================
# CLI Commands — login / logout
# =============================================================================
@@ -2632,8 +2572,6 @@ def _prompt_model_selection(
title=effective_title,
)
idx = menu.show()
from hermes_cli.curses_ui import flush_stdin
flush_stdin()
if idx is None:
return None
print()
+2 -5
View File
@@ -347,11 +347,8 @@ def auth_remove_command(args) -> None:
print("Cleared Hermes Anthropic OAuth credentials")
elif removed.source == "claude_code" and provider == "anthropic":
from hermes_cli.auth import suppress_credential_source
suppress_credential_source(provider, "claude_code")
print("Suppressed claude_code credential — it will not be re-seeded.")
print("Note: Claude Code credentials still live in ~/.claude/.credentials.json")
print("Run `hermes auth add anthropic` to re-enable if needed.")
print("Note: Claude Code credentials live in ~/.claude/.credentials.json")
print(" Remove them manually if you want to deauthorize Claude Code.")
def auth_reset_command(args) -> None:
+6
View File
@@ -90,6 +90,12 @@ HERMES_CADUCEUS = """[#CD7F32]⠀⢀⣀⡀⠀⣀⣀
[#B8860B]⠀⠈⠳⠈⣡⠞⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[/]
[#B8860B]⠀⠈⠀[/]"""
COMPACT_BANNER = """
[bold #FFD700]╔══════════════════════════════════════════════════════════════╗[/]
[bold #FFD700]║[/] [#FFBF00]⚕ NOUS HERMES[/] [dim #B8860B]- AI Agent Framework[/] [bold #FFD700]║[/]
[bold #FFD700]║[/] [#CD7F32]Messenger of the Digital Gods[/] [dim #B8860B]Nous Research[/] [bold #FFD700]║[/]
[bold #FFD700]╚══════════════════════════════════════════════════════════════╝[/]
"""
# =========================================================================
+140
View File
@@ -0,0 +1,140 @@
"""Shared curses-based multi-select checklist for Hermes CLI.
Used by both ``hermes tools`` and ``hermes skills`` to present a
toggleable list of items. Falls back to a numbered text UI when
curses is unavailable (Windows without curses, piped stdin, etc.).
"""
import sys
from typing import List, Set
from hermes_cli.colors import Colors, color
def curses_checklist(
title: str,
items: List[str],
pre_selected: Set[int],
) -> Set[int]:
"""Multi-select checklist. Returns set of **selected** indices.
Args:
title: Header text shown at the top of the checklist.
items: Display labels for each row.
pre_selected: Indices that start checked.
Returns:
The indices the user confirmed as checked. On cancel (ESC/q),
returns ``pre_selected`` unchanged.
"""
# Safety: return defaults when stdin is not a terminal.
if not sys.stdin.isatty():
return set(pre_selected)
try:
import curses
selected = set(pre_selected)
result = [None]
def _ui(stdscr):
curses.curs_set(0)
if curses.has_colors():
curses.start_color()
curses.use_default_colors()
curses.init_pair(1, curses.COLOR_GREEN, -1)
curses.init_pair(2, curses.COLOR_YELLOW, -1)
curses.init_pair(3, 8, -1) # dim gray
cursor = 0
scroll_offset = 0
while True:
stdscr.clear()
max_y, max_x = stdscr.getmaxyx()
# Header
try:
hattr = curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0)
stdscr.addnstr(0, 0, title, max_x - 1, hattr)
stdscr.addnstr(
1, 0,
" ↑↓ navigate SPACE toggle ENTER confirm ESC cancel",
max_x - 1, curses.A_DIM,
)
except curses.error:
pass
# Scrollable item list
visible_rows = max_y - 3
if cursor < scroll_offset:
scroll_offset = cursor
elif cursor >= scroll_offset + visible_rows:
scroll_offset = cursor - visible_rows + 1
for draw_i, i in enumerate(
range(scroll_offset, min(len(items), scroll_offset + visible_rows))
):
y = draw_i + 3
if y >= max_y - 1:
break
check = "" if i in selected else " "
arrow = "" if i == cursor else " "
line = f" {arrow} [{check}] {items[i]}"
attr = curses.A_NORMAL
if i == cursor:
attr = curses.A_BOLD
if curses.has_colors():
attr |= curses.color_pair(1)
try:
stdscr.addnstr(y, 0, line, max_x - 1, attr)
except curses.error:
pass
stdscr.refresh()
key = stdscr.getch()
if key in (curses.KEY_UP, ord("k")):
cursor = (cursor - 1) % len(items)
elif key in (curses.KEY_DOWN, ord("j")):
cursor = (cursor + 1) % len(items)
elif key == ord(" "):
selected.symmetric_difference_update({cursor})
elif key in (curses.KEY_ENTER, 10, 13):
result[0] = set(selected)
return
elif key in (27, ord("q")):
result[0] = set(pre_selected)
return
curses.wrapper(_ui)
return result[0] if result[0] is not None else set(pre_selected)
except Exception:
pass # fall through to numbered fallback
# ── Numbered text fallback ────────────────────────────────────────────
selected = set(pre_selected)
print(color(f"\n {title}", Colors.YELLOW))
print(color(" Toggle by number, Enter to confirm.\n", Colors.DIM))
while True:
for i, label in enumerate(items):
check = "" if i in selected else " "
print(f" {i + 1:3}. [{check}] {label}")
print()
try:
raw = input(color(" Number to toggle, 's' to save, 'q' to cancel: ", Colors.DIM)).strip()
except (KeyboardInterrupt, EOFError):
return set(pre_selected)
if raw.lower() == "s" or raw == "":
return selected
if raw.lower() == "q":
return set(pre_selected)
try:
idx = int(raw) - 1
if 0 <= idx < len(items):
selected.symmetric_difference_update({idx})
except ValueError:
print(color(" Invalid input", Colors.DIM))
+11 -14
View File
@@ -16,18 +16,8 @@ from collections.abc import Callable, Mapping
from dataclasses import dataclass
from typing import Any
# prompt_toolkit is an optional CLI dependency — only needed for
# SlashCommandCompleter and SlashCommandAutoSuggest. Gateway and test
# environments that lack it must still be able to import this module
# for resolve_command, gateway_help_lines, and COMMAND_REGISTRY.
try:
from prompt_toolkit.auto_suggest import AutoSuggest, Suggestion
from prompt_toolkit.completion import Completer, Completion
except ImportError: # pragma: no cover
AutoSuggest = object # type: ignore[assignment,misc]
Completer = object # type: ignore[assignment,misc]
Suggestion = None # type: ignore[assignment]
Completion = None # type: ignore[assignment]
from prompt_toolkit.auto_suggest import AutoSuggest, Suggestion
from prompt_toolkit.completion import Completer, Completion
# ---------------------------------------------------------------------------
@@ -83,7 +73,8 @@ COMMAND_REGISTRY: list[CommandDef] = [
args_hint="<question>"),
CommandDef("queue", "Queue a prompt for the next turn (doesn't interrupt)", "Session",
aliases=("q",), args_hint="<prompt>"),
CommandDef("status", "Show session info", "Session"),
CommandDef("status", "Show session info", "Session",
gateway_only=True),
CommandDef("profile", "Show active profile name and home directory", "Info"),
CommandDef("sethome", "Set this chat as the home channel", "Session",
gateway_only=True, aliases=("set-home",)),
@@ -110,7 +101,7 @@ COMMAND_REGISTRY: list[CommandDef] = [
args_hint="[level|show|hide]",
subcommands=("none", "minimal", "low", "medium", "high", "xhigh", "show", "hide", "on", "off")),
CommandDef("fast", "Toggle fast mode — OpenAI Priority Processing / Anthropic Fast Mode (Normal/Fast)", "Configuration",
args_hint="[normal|fast|status]",
cli_only=True, args_hint="[normal|fast|status]",
subcommands=("normal", "fast", "status", "on", "off")),
CommandDef("skin", "Show or change the display skin/theme", "Configuration",
cli_only=True, args_hint="[name]"),
@@ -183,6 +174,12 @@ def resolve_command(name: str) -> CommandDef | None:
return _COMMAND_LOOKUP.get(name.lower().lstrip("/"))
def register_plugin_command(cmd: CommandDef) -> None:
"""Append a plugin-defined command to the registry and refresh lookups."""
COMMAND_REGISTRY.append(cmd)
rebuild_lookups()
def rebuild_lookups() -> None:
"""Rebuild all derived lookup dicts from the current COMMAND_REGISTRY.
+3 -68
View File
@@ -39,9 +39,6 @@ _EXTRA_ENV_KEYS = frozenset({
"DINGTALK_CLIENT_ID", "DINGTALK_CLIENT_SECRET",
"FEISHU_APP_ID", "FEISHU_APP_SECRET", "FEISHU_ENCRYPT_KEY", "FEISHU_VERIFICATION_TOKEN",
"WECOM_BOT_ID", "WECOM_SECRET",
"WEIXIN_ACCOUNT_ID", "WEIXIN_TOKEN", "WEIXIN_BASE_URL", "WEIXIN_CDN_BASE_URL",
"WEIXIN_HOME_CHANNEL", "WEIXIN_HOME_CHANNEL_NAME", "WEIXIN_DM_POLICY", "WEIXIN_GROUP_POLICY",
"WEIXIN_ALLOWED_USERS", "WEIXIN_GROUP_ALLOWED_USERS", "WEIXIN_ALLOW_ALL_USERS",
"BLUEBUBBLES_SERVER_URL", "BLUEBUBBLES_PASSWORD",
"TERMINAL_ENV", "TERMINAL_SSH_KEY", "TERMINAL_SSH_PORT",
"WHATSAPP_MODE", "WHATSAPP_ENABLED",
@@ -141,68 +138,6 @@ def managed_error(action: str = "modify configuration"):
print(format_managed_message(action), file=sys.stderr)
# =============================================================================
# Container-aware CLI (NixOS container mode)
# =============================================================================
def _is_inside_container() -> bool:
"""Detect if we're already running inside a Docker/Podman container."""
# Standard Docker/Podman indicators
if os.path.exists("/.dockerenv"):
return True
# Podman uses /run/.containerenv
if os.path.exists("/run/.containerenv"):
return True
# Check cgroup for container runtime evidence (works for both Docker & Podman)
try:
with open("/proc/1/cgroup", "r") as f:
cgroup = f.read()
if "docker" in cgroup or "podman" in cgroup or "/lxc/" in cgroup:
return True
except (OSError, IOError):
pass
return False
def get_container_exec_info() -> Optional[dict]:
"""Read container mode metadata from HERMES_HOME/.container-mode.
Returns a dict with keys: backend, container_name, hermes_bin
or None if container mode is not active or we're already inside the container.
The .container-mode file is written by the NixOS activation script when
container.enable = true. It tells the host CLI to exec into the container
instead of running locally.
"""
if _is_inside_container():
return None
container_mode_file = get_hermes_home() / ".container-mode"
if not container_mode_file.exists():
return None
try:
info = {}
with open(container_mode_file, "r") as f:
for line in f:
line = line.strip()
if "=" in line and not line.startswith("#"):
key, _, value = line.partition("=")
info[key.strip()] = value.strip()
backend = info.get("backend", "docker")
container_name = info.get("container_name", "hermes-agent")
hermes_bin = info.get("hermes_bin", "/data/current-package/bin/hermes")
return {
"backend": backend,
"container_name": container_name,
"hermes_bin": hermes_bin,
}
except (OSError, IOError):
return None
# =============================================================================
# Config paths
# =============================================================================
@@ -1271,8 +1206,8 @@ OPTIONAL_ENV_VARS = {
"advanced": True,
},
"API_SERVER_KEY": {
"description": "Bearer token for API server authentication. Required for non-loopback binding; server refuses to start without it. On loopback (127.0.0.1), all requests are allowed if empty.",
"prompt": "API server auth key (required for network access)",
"description": "Bearer token for API server authentication. If empty, all requests are allowed (local use only).",
"prompt": "API server auth key (optional)",
"url": None,
"password": True,
"category": "messaging",
@@ -1287,7 +1222,7 @@ OPTIONAL_ENV_VARS = {
"advanced": True,
},
"API_SERVER_HOST": {
"description": "Host/bind address for the API server (default: 127.0.0.1). Use 0.0.0.0 for network access — server refuses to start without API_SERVER_KEY.",
"description": "Host/bind address for the API server (default: 127.0.0.1). Use 0.0.0.0 for network access — requires API_SERVER_KEY for security.",
"prompt": "API server host",
"url": None,
"password": False,
+12
View File
@@ -31,6 +31,13 @@ logger = logging.getLogger(__name__)
# OAuth device code flow constants (same client ID as opencode/Copilot CLI)
COPILOT_OAUTH_CLIENT_ID = "Ov23li8tweQw6odWQebz"
COPILOT_DEVICE_CODE_URL = "https://github.com/login/device/code"
COPILOT_ACCESS_TOKEN_URL = "https://github.com/login/oauth/access_token"
# Copilot API constants
COPILOT_TOKEN_EXCHANGE_URL = "https://api.github.com/copilot_internal/v2/token"
COPILOT_API_BASE_URL = "https://api.githubcopilot.com"
# Token type prefixes
_CLASSIC_PAT_PREFIX = "ghp_"
_SUPPORTED_PREFIXES = ("gho_", "github_pat_", "ghu_")
@@ -43,6 +50,11 @@ _DEVICE_CODE_POLL_INTERVAL = 5 # seconds
_DEVICE_CODE_POLL_SAFETY_MARGIN = 3 # seconds
def is_classic_pat(token: str) -> bool:
"""Check if a token is a classic PAT (ghp_*), which Copilot doesn't support."""
return token.strip().startswith(_CLASSIC_PAT_PREFIX)
def validate_copilot_token(token: str) -> tuple[bool, str]:
"""Validate that a token is usable with the Copilot API.
-23
View File
@@ -10,28 +10,6 @@ from typing import Callable, List, Optional, Set
from hermes_cli.colors import Colors, color
def flush_stdin() -> None:
"""Flush any stray bytes from the stdin input buffer.
Must be called after ``curses.wrapper()`` (or any terminal-mode library
like simple_term_menu) returns, **before** the next ``input()`` /
``getpass.getpass()`` call. ``curses.endwin()`` restores the terminal
but does NOT drain the OS input buffer leftover escape-sequence bytes
(from arrow keys, terminal mode-switch responses, or rapid keypresses)
remain buffered and silently get consumed by the next ``input()`` call,
corrupting user data (e.g. writing ``^[^[`` into .env files).
On non-TTY stdin (piped, redirected) or Windows, this is a no-op.
"""
try:
if not sys.stdin.isatty():
return
import termios
termios.tcflush(sys.stdin, termios.TCIFLUSH)
except Exception:
pass
def curses_checklist(
title: str,
items: List[str],
@@ -153,7 +131,6 @@ def curses_checklist(
return
curses.wrapper(_draw)
flush_stdin()
return result_holder[0] if result_holder[0] is not None else cancel_returns
except Exception:
+5 -1
View File
@@ -32,6 +32,11 @@ def _get_git_commit(project_root: Path) -> str:
return "(unknown)"
def _key_present(name: str) -> str:
"""Return 'set' or 'not set' for an env var."""
return "set" if os.getenv(name) else "not set"
def _redact(value: str) -> str:
"""Redact all but first 4 and last 4 chars."""
if not value:
@@ -119,7 +124,6 @@ def _configured_platforms() -> list[str]:
"dingtalk": "DINGTALK_CLIENT_ID",
"feishu": "FEISHU_APP_ID",
"wecom": "WECOM_BOT_ID",
"weixin": "WEIXIN_ACCOUNT_ID",
}
return [name for name, env in checks.items() if os.getenv(env)]
+32 -200
View File
@@ -14,7 +14,6 @@ from pathlib import Path
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
from gateway.status import terminate_pid
from hermes_cli.config import get_env_value, get_hermes_home, save_env_value, is_managed, managed_error
# display_hermes_home is imported lazily at call sites to avoid ImportError
# when hermes_constants is cached from a pre-update version during `hermes update`.
@@ -163,7 +162,7 @@ def kill_gateway_processes(force: bool = False, exclude_pids: set | None = None)
"""Kill any running gateway processes. Returns count killed.
Args:
force: Use the platform's force-kill mechanism instead of graceful terminate.
force: Use SIGKILL instead of SIGTERM.
exclude_pids: PIDs to skip (e.g. service-managed PIDs that were just
restarted and should not be killed).
"""
@@ -172,7 +171,10 @@ def kill_gateway_processes(force: bool = False, exclude_pids: set | None = None)
for pid in pids:
try:
terminate_pid(pid, force=force)
if force and not is_windows():
os.kill(pid, signal.SIGKILL)
else:
os.kill(pid, signal.SIGTERM)
killed += 1
except ProcessLookupError:
# Process already gone
@@ -180,8 +182,6 @@ def kill_gateway_processes(force: bool = False, exclude_pids: set | None = None)
except PermissionError:
print(f"⚠ Permission denied to kill PID {pid}")
except OSError as exc:
print(f"Failed to kill PID {pid}: {exc}")
return killed
@@ -251,18 +251,18 @@ SERVICE_DESCRIPTION = "Hermes Agent Gateway - Messaging Platform Integration"
def _profile_suffix() -> str:
"""Derive a service-name suffix from the current HERMES_HOME.
Returns ``""`` for the default root, the profile name for
``<root>/profiles/<name>``, or a short hash for any other path.
Works correctly in Docker (HERMES_HOME=/opt/data) and standard deployments.
Returns ``""`` for the default ``~/.hermes``, the profile name for
``~/.hermes/profiles/<name>``, or a short hash for any other custom
HERMES_HOME path.
"""
import hashlib
import re
from hermes_constants import get_default_hermes_root
from pathlib import Path as _Path
home = get_hermes_home().resolve()
default = get_default_hermes_root().resolve()
default = (_Path.home() / ".hermes").resolve()
if home == default:
return ""
# Detect <root>/profiles/<name> pattern → use the profile name
# Detect ~/.hermes/profiles/<name> pattern → use the profile name
profiles_root = (default / "profiles").resolve()
try:
rel = home.relative_to(profiles_root)
@@ -287,9 +287,9 @@ def _profile_arg(hermes_home: str | None = None) -> str:
service definition for a different user (e.g. system service).
"""
import re
from hermes_constants import get_default_hermes_root
from pathlib import Path as _Path
home = Path(hermes_home or str(get_hermes_home())).resolve()
default = get_default_hermes_root().resolve()
default = (_Path.home() / ".hermes").resolve()
if home == default:
return ""
profiles_root = (default / "profiles").resolve()
@@ -316,6 +316,8 @@ def get_service_name() -> str:
return f"{_SERVICE_BASE}-{suffix}"
SERVICE_NAME = _SERVICE_BASE # backward-compat for external importers; prefer get_service_name()
def get_systemd_unit_path(system: bool = False) -> Path:
name = get_service_name()
@@ -589,6 +591,17 @@ def get_python_path() -> str:
return str(venv_python)
return sys.executable
def get_hermes_cli_path() -> str:
"""Get the path to the hermes CLI."""
# Check if installed via pip
import shutil
hermes_bin = shutil.which("hermes")
if hermes_bin:
return hermes_bin
# Fallback to direct module execution
return f"{get_python_path()} -m hermes_cli.main"
# =============================================================================
# Systemd (Linux)
@@ -605,24 +618,6 @@ def _build_user_local_paths(home: Path, path_entries: list[str]) -> list[str]:
return [p for p in candidates if p not in path_entries and Path(p).exists()]
def _remap_path_for_user(path: str, target_home_dir: str) -> str:
"""Remap *path* from the current user's home to *target_home_dir*.
If *path* lives under ``Path.home()`` the corresponding prefix is swapped
to *target_home_dir*; otherwise the path is returned unchanged.
/root/.hermes/hermes-agent -> /home/alice/.hermes/hermes-agent
/opt/hermes -> /opt/hermes (kept as-is)
"""
current_home = Path.home().resolve()
resolved = Path(path).resolve()
try:
relative = resolved.relative_to(current_home)
return str(Path(target_home_dir) / relative)
except ValueError:
return str(resolved)
def _hermes_home_for_target_user(target_home_dir: str) -> str:
"""Remap the current HERMES_HOME to the equivalent under a target user's home.
@@ -670,15 +665,6 @@ def generate_systemd_unit(system: bool = False, run_as_user: str | None = None)
username, group_name, home_dir = _system_service_identity(run_as_user)
hermes_home = _hermes_home_for_target_user(home_dir)
profile_arg = _profile_arg(hermes_home)
# Remap all paths that may resolve under the calling user's home
# (e.g. /root/) to the target user's home so the service can
# actually access them.
python_path = _remap_path_for_user(python_path, home_dir)
working_dir = _remap_path_for_user(working_dir, home_dir)
venv_dir = _remap_path_for_user(venv_dir, home_dir)
venv_bin = _remap_path_for_user(venv_bin, home_dir)
node_bin = _remap_path_for_user(node_bin, home_dir)
path_entries = [_remap_path_for_user(p, home_dir) for p in path_entries]
path_entries.extend(_build_user_local_paths(Path(home_dir), path_entries))
path_entries.extend(common_bin_paths)
sane_path = ":".join(path_entries)
@@ -1196,19 +1182,7 @@ def launchd_start():
def launchd_stop():
label = get_launchd_label()
target = f"{_launchd_domain()}/{label}"
# bootout unloads the service definition so KeepAlive doesn't respawn
# the process. A plain `kill SIGTERM` only signals the process — launchd
# immediately restarts it because KeepAlive.SuccessfulExit = false.
# `hermes gateway start` re-bootstraps when it detects the job is unloaded.
try:
subprocess.run(["launchctl", "bootout", target], check=True, timeout=90)
except subprocess.CalledProcessError as e:
if e.returncode in (3, 113):
pass # Already unloaded — nothing to stop.
else:
raise
_wait_for_gateway_exit(timeout=10.0, force_after=5.0)
subprocess.run(["launchctl", "kill", "SIGTERM", f"{_launchd_domain()}/{label}"], check=True, timeout=30)
print("✓ Service stopped")
def _wait_for_gateway_exit(timeout: float = 10.0, force_after: float = 5.0):
@@ -1220,7 +1194,7 @@ def _wait_for_gateway_exit(timeout: float = 10.0, force_after: float = 5.0):
Args:
timeout: Total seconds to wait before giving up.
force_after: Seconds of graceful waiting before escalating to force-kill.
force_after: Seconds of graceful waiting before sending SIGKILL.
"""
import time
from gateway.status import get_running_pid
@@ -1237,15 +1211,15 @@ def _wait_for_gateway_exit(timeout: float = 10.0, force_after: float = 5.0):
if not force_sent and time.monotonic() >= force_deadline:
# Grace period expired — force-kill the specific PID.
try:
terminate_pid(pid, force=True)
os.kill(pid, signal.SIGKILL)
print(f"⚠ Gateway PID {pid} did not exit gracefully; sent SIGKILL")
except (ProcessLookupError, PermissionError, OSError):
except (ProcessLookupError, PermissionError):
return # Already gone or we can't touch it.
force_sent = True
time.sleep(0.3)
# Timed out even after force-kill.
# Timed out even after SIGKILL.
remaining_pid = get_running_pid()
if remaining_pid is not None:
print(f"⚠ Gateway PID {remaining_pid} still running after {timeout}s — restart may fail")
@@ -1624,12 +1598,6 @@ _PLATFORMS = [
"help": "Chat ID for scheduled results and notifications."},
],
},
{
"key": "weixin",
"label": "Weixin / WeChat",
"emoji": "💬",
"token_var": "WEIXIN_ACCOUNT_ID",
},
{
"key": "bluebubbles",
"label": "BlueBubbles (iMessage)",
@@ -1702,13 +1670,6 @@ def _platform_status(platform: dict) -> str:
if val or password or homeserver:
return "partially configured"
return "not configured"
if platform.get("key") == "weixin":
token = get_env_value("WEIXIN_TOKEN")
if val and token:
return "configured"
if val or token:
return "partially configured"
return "not configured"
if val:
return "configured"
return "not configured"
@@ -1812,7 +1773,7 @@ def _setup_standard_platform(platform: dict):
print_warning(" Open access enabled — anyone can use your bot!")
elif access_idx == 1:
print_success(" DM pairing mode — users will receive a code to request access.")
print_info(" Approve with: hermes pairing approve <platform> <code>")
print_info(" Approve with: hermes pairing approve {platform} {code}")
else:
print_info(" Skipped — configure later with 'hermes gateway setup'")
continue
@@ -1899,133 +1860,6 @@ def _is_service_running() -> bool:
return len(find_gateway_pids()) > 0
def _setup_weixin():
"""Interactive setup for Weixin / WeChat personal accounts."""
print()
print(color(" ─── 💬 Weixin / WeChat Setup ───", Colors.CYAN))
print()
print_info(" 1. Hermes will open Tencent iLink QR login in this terminal.")
print_info(" 2. Use WeChat to scan and confirm the QR code.")
print_info(" 3. Hermes will store the returned account_id/token in ~/.hermes/.env.")
print_info(" 4. This adapter supports native text, image, video, and document delivery.")
existing_account = get_env_value("WEIXIN_ACCOUNT_ID")
existing_token = get_env_value("WEIXIN_TOKEN")
if existing_account and existing_token:
print()
print_success("Weixin is already configured.")
if not prompt_yes_no(" Reconfigure Weixin?", False):
return
try:
from gateway.platforms.weixin import check_weixin_requirements, qr_login
except Exception as exc:
print_error(f" Weixin adapter import failed: {exc}")
print_info(" Install gateway dependencies first, then retry.")
return
if not check_weixin_requirements():
print_error(" Missing dependencies: Weixin needs aiohttp and cryptography.")
print_info(" Install them, then rerun `hermes gateway setup`.")
return
print()
if not prompt_yes_no(" Start QR login now?", True):
print_info(" Cancelled.")
return
import asyncio
try:
credentials = asyncio.run(qr_login(str(get_hermes_home())))
except KeyboardInterrupt:
print()
print_warning(" Weixin setup cancelled.")
return
except Exception as exc:
print_error(f" QR login failed: {exc}")
return
if not credentials:
print_warning(" QR login did not complete.")
return
account_id = credentials.get("account_id", "")
token = credentials.get("token", "")
base_url = credentials.get("base_url", "")
user_id = credentials.get("user_id", "")
save_env_value("WEIXIN_ACCOUNT_ID", account_id)
save_env_value("WEIXIN_TOKEN", token)
if base_url:
save_env_value("WEIXIN_BASE_URL", base_url)
save_env_value("WEIXIN_CDN_BASE_URL", get_env_value("WEIXIN_CDN_BASE_URL") or "https://novac2c.cdn.weixin.qq.com/c2c")
print()
access_choices = [
"Use DM pairing approval (recommended)",
"Allow all direct messages",
"Only allow listed user IDs",
"Disable direct messages",
]
access_idx = prompt_choice(" How should direct messages be authorized?", access_choices, 0)
if access_idx == 0:
save_env_value("WEIXIN_DM_POLICY", "pairing")
save_env_value("WEIXIN_ALLOW_ALL_USERS", "false")
save_env_value("WEIXIN_ALLOWED_USERS", "")
print_success(" DM pairing enabled.")
print_info(" Unknown DM users can request access and you approve them with `hermes pairing approve`.")
elif access_idx == 1:
save_env_value("WEIXIN_DM_POLICY", "open")
save_env_value("WEIXIN_ALLOW_ALL_USERS", "true")
save_env_value("WEIXIN_ALLOWED_USERS", "")
print_warning(" Open DM access enabled for Weixin.")
elif access_idx == 2:
default_allow = user_id or ""
allowlist = prompt(" Allowed Weixin user IDs (comma-separated)", default_allow, password=False).replace(" ", "")
save_env_value("WEIXIN_DM_POLICY", "allowlist")
save_env_value("WEIXIN_ALLOW_ALL_USERS", "false")
save_env_value("WEIXIN_ALLOWED_USERS", allowlist)
print_success(" Weixin allowlist saved.")
else:
save_env_value("WEIXIN_DM_POLICY", "disabled")
save_env_value("WEIXIN_ALLOW_ALL_USERS", "false")
save_env_value("WEIXIN_ALLOWED_USERS", "")
print_warning(" Direct messages disabled.")
print()
group_choices = [
"Disable group chats (recommended)",
"Allow all group chats",
"Only allow listed group chat IDs",
]
group_idx = prompt_choice(" How should group chats be handled?", group_choices, 0)
if group_idx == 0:
save_env_value("WEIXIN_GROUP_POLICY", "disabled")
save_env_value("WEIXIN_GROUP_ALLOWED_USERS", "")
print_info(" Group chats disabled.")
elif group_idx == 1:
save_env_value("WEIXIN_GROUP_POLICY", "open")
save_env_value("WEIXIN_GROUP_ALLOWED_USERS", "")
print_warning(" All group chats enabled.")
else:
allow_groups = prompt(" Allowed group chat IDs (comma-separated)", "", password=False).replace(" ", "")
save_env_value("WEIXIN_GROUP_POLICY", "allowlist")
save_env_value("WEIXIN_GROUP_ALLOWED_USERS", allow_groups)
print_success(" Group allowlist saved.")
if user_id:
print()
if prompt_yes_no(f" Use your Weixin user ID ({user_id}) as the home channel?", True):
save_env_value("WEIXIN_HOME_CHANNEL", user_id)
print_success(f" Home channel set to {user_id}")
print()
print_success("Weixin configured!")
print_info(f" Account ID: {account_id}")
if user_id:
print_info(f" User ID: {user_id}")
def _setup_signal():
"""Interactive setup for Signal messenger."""
import shutil
@@ -2201,8 +2035,6 @@ def gateway_setup():
_setup_whatsapp()
elif platform["key"] == "signal":
_setup_signal()
elif platform["key"] == "weixin":
_setup_weixin()
else:
_setup_standard_platform(platform)
+33 -113
View File
@@ -97,11 +97,10 @@ def _apply_profile_override() -> None:
consume = 1
break
# 2. If no flag, check active_profile in the hermes root
# 2. If no flag, check ~/.hermes/active_profile
if profile_name is None:
try:
from hermes_constants import get_default_hermes_root
active_path = get_default_hermes_root() / "active_profile"
active_path = Path.home() / ".hermes" / "active_profile"
if active_path.exists():
name = active_path.read_text().strip()
if name and name != "default":
@@ -528,56 +527,6 @@ def _resolve_last_cli_session() -> Optional[str]:
return None
def _exec_in_container(container_info: dict, cli_args: list):
"""Replace the current process with a command inside the managed container.
Uses os.execvp to hand off to docker/podman exec, preserving the TTY
so the interactive CLI works seamlessly inside the container.
Args:
container_info: dict with backend, container_name, hermes_bin
cli_args: the original CLI arguments (everything after 'hermes')
"""
import shutil
import subprocess
backend = container_info["backend"]
container_name = container_info["container_name"]
hermes_bin = container_info["hermes_bin"]
# Find the container runtime on PATH
runtime = shutil.which(backend)
if not runtime:
print(f"Warning: {backend} not found on PATH, falling back to host CLI.",
file=sys.stderr)
return # Fall through to normal CLI
# Check if the container is actually running
try:
result = subprocess.run(
[runtime, "inspect", "--format", "{{.State.Running}}", container_name],
capture_output=True, text=True, timeout=5
)
if result.returncode != 0 or result.stdout.strip().lower() != "true":
print(f"Warning: container '{container_name}' is not running, falling back to host CLI.",
file=sys.stderr)
return
except (subprocess.TimeoutExpired, OSError):
return # Fall through on any error
# Filter out --host flag from forwarded args (it's not meaningful inside)
forwarded_args = [a for a in cli_args if a != "--host"]
# Build the exec command
exec_cmd = [runtime, "exec", "-it", container_name, hermes_bin] + forwarded_args
print(f"Routing to container '{container_name}' via {backend}...",
file=sys.stderr)
# Replace the current process — this never returns on success
os.execvp(runtime, exec_cmd)
def _resolve_session_by_name_or_id(name_or_id: str) -> Optional[str]:
"""Resolve a session name (title) or ID to a session ID.
@@ -606,21 +555,6 @@ def _resolve_session_by_name_or_id(name_or_id: str) -> Optional[str]:
def cmd_chat(args):
"""Run interactive chat CLI."""
# ── Container-aware routing ──────────────────────────────────────────
# When NixOS container mode is active and we're on the host, exec into
# the managed container instead of running locally. --host bypasses this.
if not getattr(args, "host", False):
try:
from hermes_cli.config import get_container_exec_info
container_info = get_container_exec_info()
if container_info:
_exec_in_container(container_info, sys.argv[1:])
# _exec_in_container calls os.execvp which replaces the process.
# If we get here, the exec failed.
sys.exit(1)
except Exception:
pass # Fall through to normal CLI on any detection error
# Resolve --continue into --resume with the latest CLI session or by name
continue_val = getattr(args, "continue_last", None)
if continue_val and not getattr(args, "resume", None):
@@ -1738,8 +1672,6 @@ def _remove_custom_provider(config):
title="Select provider to remove:",
)
idx = menu.show()
from hermes_cli.curses_ui import flush_stdin
flush_stdin()
print()
except (ImportError, NotImplementedError, OSError, subprocess.SubprocessError):
for i, c in enumerate(choices, 1):
@@ -1765,9 +1697,8 @@ def _remove_custom_provider(config):
def _model_flow_named_custom(config, provider_info):
"""Handle a named custom provider from config.yaml custom_providers list.
Always probes the endpoint's /models API to let the user pick a model.
If a model was previously saved, it is pre-selected in the menu.
Falls back to the saved model if probing fails.
If the entry has a saved model name, activates it immediately.
Otherwise probes the endpoint's /models API to let the user pick one.
"""
from hermes_cli.auth import _save_model_choice, deactivate_provider
from hermes_cli.config import load_config, save_config
@@ -1778,37 +1709,46 @@ def _model_flow_named_custom(config, provider_info):
api_key = provider_info.get("api_key", "")
saved_model = provider_info.get("model", "")
# If a model is saved, just activate immediately — no probing needed
if saved_model:
_save_model_choice(saved_model)
cfg = load_config()
model = cfg.get("model")
if not isinstance(model, dict):
model = {"default": model} if model else {}
cfg["model"] = model
model["provider"] = "custom"
model["base_url"] = base_url
if api_key:
model["api_key"] = api_key
save_config(cfg)
deactivate_provider()
print(f"✅ Switched to: {saved_model}")
print(f" Provider: {name} ({base_url})")
return
# No saved model — probe endpoint and let user pick
print(f" Provider: {name}")
print(f" URL: {base_url}")
if saved_model:
print(f" Current: {saved_model}")
print()
print("Fetching available models...")
print("No model saved for this provider. Fetching available models...")
models = fetch_api_models(api_key, base_url, timeout=8.0)
if models:
default_idx = 0
if saved_model and saved_model in models:
default_idx = models.index(saved_model)
print(f"Found {len(models)} model(s):\n")
try:
from simple_term_menu import TerminalMenu
menu_items = [
f" {m} (current)" if m == saved_model else f" {m}"
for m in models
] + [" Cancel"]
menu_items = [f" {m}" for m in models] + [" Cancel"]
menu = TerminalMenu(
menu_items, cursor_index=default_idx,
menu_items, cursor_index=0,
menu_cursor="-> ", menu_cursor_style=("fg_green", "bold"),
menu_highlight_style=("fg_green",),
cycle_cursor=True, clear_screen=False,
title=f"Select model from {name}:",
)
idx = menu.show()
from hermes_cli.curses_ui import flush_stdin
flush_stdin()
print()
if idx is None or idx >= len(models):
print("Cancelled.")
@@ -1816,8 +1756,7 @@ def _model_flow_named_custom(config, provider_info):
model_name = models[idx]
except (ImportError, NotImplementedError, OSError, subprocess.SubprocessError):
for i, m in enumerate(models, 1):
suffix = " (current)" if m == saved_model else ""
print(f" {i}. {m}{suffix}")
print(f" {i}. {m}")
print(f" {len(models) + 1}. Cancel")
print()
try:
@@ -1833,13 +1772,6 @@ def _model_flow_named_custom(config, provider_info):
except (ValueError, KeyboardInterrupt, EOFError):
print("\nCancelled.")
return
elif saved_model:
print("Could not fetch models from endpoint.")
try:
model_name = input(f"Model name [{saved_model}]: ").strip() or saved_model
except (KeyboardInterrupt, EOFError):
print("\nCancelled.")
return
else:
print("Could not fetch models from endpoint. Enter model name manually.")
try:
@@ -1935,8 +1867,6 @@ def _prompt_reasoning_effort_selection(efforts, current_effort=""):
title="Select reasoning effort:",
)
idx = menu.show()
from hermes_cli.curses_ui import flush_stdin
flush_stdin()
if idx is None:
return None
print()
@@ -3379,11 +3309,10 @@ def _invalidate_update_cache():
``hermes update``, every profile is now current.
"""
homes = []
# Default profile home (Docker-aware — uses /opt/data in Docker)
from hermes_constants import get_default_hermes_root
default_home = get_default_hermes_root()
# Default profile home
default_home = Path.home() / ".hermes"
homes.append(default_home)
# Named profiles under <root>/profiles/
# Named profiles under ~/.hermes/profiles/
profiles_root = default_home / "profiles"
if profiles_root.is_dir():
for entry in profiles_root.iterdir():
@@ -4120,10 +4049,7 @@ def cmd_profile(args):
print(f" {name} chat Start chatting")
print(f" {name} gateway start Start the messaging gateway")
if clone or clone_all:
try:
profile_dir_display = "~/" + str(profile_dir.relative_to(Path.home()))
except ValueError:
profile_dir_display = str(profile_dir)
profile_dir_display = f"~/.hermes/profiles/{name}"
print(f"\n Edit {profile_dir_display}/.env for different API keys")
print(f" Edit {profile_dir_display}/SOUL.md for different personality")
print()
@@ -4451,12 +4377,6 @@ For more help on a command:
default=None,
help="Session source tag for filtering (default: cli). Use 'tool' for third-party integrations that should not appear in user session lists."
)
chat_parser.add_argument(
"--host",
action="store_true",
default=False,
help="Run on the host even when NixOS container mode is active (bypass container exec)"
)
chat_parser.set_defaults(func=cmd_chat)
# =========================================================================
+38 -62
View File
@@ -76,22 +76,17 @@ _STRIP_VENDOR_ONLY_PROVIDERS: frozenset[str] = frozenset({
"copilot-acp",
})
# Providers whose native naming is authoritative -- pass through unchanged.
_AUTHORITATIVE_NATIVE_PROVIDERS: frozenset[str] = frozenset({
# Providers whose own naming is authoritative -- pass through unchanged.
_PASSTHROUGH_PROVIDERS: frozenset[str] = frozenset({
"gemini",
"huggingface",
"openai-codex",
})
# Direct providers that accept bare native names but should repair a matching
# provider/ prefix when users copy the aggregator form into config.yaml.
_MATCHING_PREFIX_STRIP_PROVIDERS: frozenset[str] = frozenset({
"zai",
"kimi-coding",
"minimax",
"minimax-cn",
"alibaba",
"qwen-oauth",
"huggingface",
"openai-codex",
"custom",
})
@@ -173,40 +168,6 @@ def _dots_to_hyphens(model_name: str) -> str:
return model_name.replace(".", "-")
def _normalize_provider_alias(provider_name: str) -> str:
"""Resolve provider aliases to Hermes' canonical ids."""
raw = (provider_name or "").strip().lower()
if not raw:
return raw
try:
from hermes_cli.models import normalize_provider
return normalize_provider(raw)
except Exception:
return raw
def _strip_matching_provider_prefix(model_name: str, target_provider: str) -> str:
"""Strip ``provider/`` only when the prefix matches the target provider.
This prevents arbitrary slash-bearing model IDs from being mangled on
native providers while still repairing manual config values like
``zai/glm-5.1`` for the ``zai`` provider.
"""
if "/" not in model_name:
return model_name
prefix, remainder = model_name.split("/", 1)
if not prefix.strip() or not remainder.strip():
return model_name
normalized_prefix = _normalize_provider_alias(prefix)
normalized_target = _normalize_provider_alias(target_provider)
if normalized_prefix and normalized_prefix == normalized_target:
return remainder.strip()
return model_name
def detect_vendor(model_name: str) -> Optional[str]:
"""Detect the vendor slug from a bare model name.
@@ -344,37 +305,24 @@ def normalize_model_for_provider(model_input: str, target_provider: str) -> str:
if not name:
return name
provider = _normalize_provider_alias(target_provider)
provider = (target_provider or "").strip().lower()
# --- Aggregators: need vendor/model format ---
if provider in _AGGREGATOR_PROVIDERS:
return _prepend_vendor(name)
# --- Anthropic / OpenCode: strip matching provider prefix, dots -> hyphens ---
# --- Anthropic / OpenCode: strip vendor, dots -> hyphens ---
if provider in _DOT_TO_HYPHEN_PROVIDERS:
bare = _strip_matching_provider_prefix(name, provider)
if "/" in bare:
return bare
bare = _strip_vendor_prefix(name)
return _dots_to_hyphens(bare)
# --- Copilot: strip matching provider prefix, keep dots ---
# --- Copilot: strip vendor, keep dots ---
if provider in _STRIP_VENDOR_ONLY_PROVIDERS:
return _strip_matching_provider_prefix(name, provider)
return _strip_vendor_prefix(name)
# --- DeepSeek: map to one of two canonical names ---
if provider == "deepseek":
bare = _strip_matching_provider_prefix(name, provider)
if "/" in bare:
return bare
return _normalize_for_deepseek(bare)
# --- Direct providers: repair matching provider prefixes only ---
if provider in _MATCHING_PREFIX_STRIP_PROVIDERS:
return _strip_matching_provider_prefix(name, provider)
# --- Authoritative native providers: preserve user-facing slugs as-is ---
if provider in _AUTHORITATIVE_NATIVE_PROVIDERS:
return name
return _normalize_for_deepseek(name)
# --- Custom & all others: pass through as-is ---
return name
@@ -384,3 +332,31 @@ def normalize_model_for_provider(model_input: str, target_provider: str) -> str:
# Batch / convenience helpers
# ---------------------------------------------------------------------------
def model_display_name(model_id: str) -> str:
"""Return a short, human-readable display name for a model id.
Strips the vendor prefix (if any) for a cleaner display in menus
and status bars, while preserving dots for readability.
Examples::
>>> model_display_name("anthropic/claude-sonnet-4.6")
'claude-sonnet-4.6'
>>> model_display_name("claude-sonnet-4-6")
'claude-sonnet-4-6'
"""
return _strip_vendor_prefix((model_id or "").strip())
def is_aggregator_provider(provider: str) -> bool:
"""Check if a provider is an aggregator that needs vendor/model format."""
return (provider or "").strip().lower() in _AGGREGATOR_PROVIDERS
def vendor_for_model(model_name: str) -> str:
"""Return the vendor slug for a model, or ``""`` if unknown.
Convenience wrapper around :func:`detect_vendor` that never returns
``None``.
"""
return detect_vendor(model_name) or ""
+79 -35
View File
@@ -809,69 +809,42 @@ def list_authenticated_providers(
})
seen_slugs.add(slug)
# --- 2. Check Hermes-only providers (nous, openai-codex, copilot, opencode-go) ---
# --- 2. Check Hermes-only providers (nous, openai-codex, copilot) ---
from hermes_cli.providers import HERMES_OVERLAYS
from hermes_cli.auth import PROVIDER_REGISTRY as _auth_registry
# Build reverse mapping: models.dev ID → Hermes provider ID.
# HERMES_OVERLAYS keys may be models.dev IDs (e.g. "github-copilot")
# while _PROVIDER_MODELS and config.yaml use Hermes IDs ("copilot").
_mdev_to_hermes = {v: k for k, v in PROVIDER_TO_MODELS_DEV.items()}
for pid, overlay in HERMES_OVERLAYS.items():
if pid in seen_slugs:
continue
# Resolve Hermes slug — e.g. "github-copilot" → "copilot"
hermes_slug = _mdev_to_hermes.get(pid, pid)
if hermes_slug in seen_slugs:
continue
# Check if credentials exist
has_creds = False
if overlay.extra_env_vars:
has_creds = any(os.environ.get(ev) for ev in overlay.extra_env_vars)
# Also check api_key_env_vars from PROVIDER_REGISTRY for api_key auth_type
if not has_creds and overlay.auth_type == "api_key":
for _key in (pid, hermes_slug):
pcfg = _auth_registry.get(_key)
if pcfg and pcfg.api_key_env_vars:
if any(os.environ.get(ev) for ev in pcfg.api_key_env_vars):
has_creds = True
break
if not has_creds and overlay.auth_type in ("oauth_device_code", "oauth_external", "external_process"):
if overlay.auth_type in ("oauth_device_code", "oauth_external", "external_process"):
# These use auth stores, not env vars — check for auth.json entries
try:
from hermes_cli.auth import _load_auth_store
store = _load_auth_store()
providers_store = store.get("providers", {})
pool_store = store.get("credential_pool", {})
if store and (
pid in providers_store or hermes_slug in providers_store
or pid in pool_store or hermes_slug in pool_store
):
if store and (pid in store.get("providers", {}) or pid in store.get("credential_pool", {})):
has_creds = True
except Exception as exc:
logger.debug("Auth store check failed for %s: %s", pid, exc)
if not has_creds:
continue
# Use curated list — look up by Hermes slug, fall back to overlay key
model_ids = curated.get(hermes_slug, []) or curated.get(pid, [])
# Use curated list
model_ids = curated.get(pid, [])
total = len(model_ids)
top = model_ids[:max_models]
results.append({
"slug": hermes_slug,
"name": get_label(hermes_slug),
"is_current": hermes_slug == current_provider or pid == current_provider,
"slug": pid,
"name": get_label(pid),
"is_current": pid == current_provider,
"is_user_defined": False,
"models": top,
"total_models": total,
"source": "hermes",
})
seen_slugs.add(pid)
seen_slugs.add(hermes_slug)
# --- 3. User-defined endpoints from config ---
if user_providers and isinstance(user_providers, dict):
@@ -942,3 +915,74 @@ def list_authenticated_providers(
return results
# ---------------------------------------------------------------------------
# Fuzzy suggestions
# ---------------------------------------------------------------------------
def suggest_models(raw_input: str, limit: int = 3) -> List[str]:
"""Return fuzzy model suggestions for a (possibly misspelled) input."""
query = raw_input.strip()
if not query:
return []
results = search_models_dev(query, limit=limit)
suggestions: list[str] = []
for r in results:
mid = r.get("model_id", "")
if mid:
suggestions.append(mid)
return suggestions[:limit]
# ---------------------------------------------------------------------------
# Custom provider switch
# ---------------------------------------------------------------------------
def switch_to_custom_provider() -> CustomAutoResult:
"""Handle bare '/model --provider custom' — resolve endpoint and auto-detect model."""
from hermes_cli.runtime_provider import (
resolve_runtime_provider,
_auto_detect_local_model,
)
try:
runtime = resolve_runtime_provider(requested="custom")
except Exception as e:
return CustomAutoResult(
success=False,
error_message=f"Could not resolve custom endpoint: {e}",
)
cust_base = runtime.get("base_url", "")
cust_key = runtime.get("api_key", "")
if not cust_base or "openrouter.ai" in cust_base:
return CustomAutoResult(
success=False,
error_message=(
"No custom endpoint configured. "
"Set model.base_url in config.yaml, or set OPENAI_BASE_URL "
"in .env, or run: hermes setup -> Custom OpenAI-compatible endpoint"
),
)
detected_model = _auto_detect_local_model(cust_base)
if not detected_model:
return CustomAutoResult(
success=False,
base_url=cust_base,
api_key=cust_key,
error_message=(
f"Custom endpoint at {cust_base} is reachable but no single "
f"model was auto-detected. Specify the model explicitly: "
f"/model <model-name> --provider custom"
),
)
return CustomAutoResult(
success=True,
model=detected_model,
base_url=cust_base,
api_key=cust_key,
)
+34 -14
View File
@@ -20,6 +20,9 @@ COPILOT_EDITOR_VERSION = "vscode/1.104.1"
COPILOT_REASONING_EFFORTS_GPT5 = ["minimal", "low", "medium", "high"]
COPILOT_REASONING_EFFORTS_O_SERIES = ["low", "medium", "high"]
# Backward-compatible aliases for the earlier GitHub Models-backed Copilot work.
GITHUB_MODELS_BASE_URL = COPILOT_BASE_URL
GITHUB_MODELS_CATALOG_URL = COPILOT_MODELS_URL
# Fallback OpenRouter snapshot used when the live catalog is unavailable.
# (model_id, display description shown in menus)
@@ -129,19 +132,6 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
"glm-4.5",
"glm-4.5-flash",
],
"xai": [
"grok-4.20-0309-reasoning",
"grok-4.20-0309-non-reasoning",
"grok-4.20-multi-agent-0309",
"grok-4-1-fast-reasoning",
"grok-4-1-fast-non-reasoning",
"grok-4-fast-reasoning",
"grok-4-fast-non-reasoning",
"grok-4-0709",
"grok-code-fast-1",
"grok-3",
"grok-3-mini",
],
"kimi-coding": [
"kimi-for-coding",
"kimi-k2.5",
@@ -429,6 +419,12 @@ _FREE_TIER_CACHE_TTL: int = 180 # seconds (3 minutes)
_free_tier_cache: tuple[bool, float] | None = None # (result, timestamp)
def clear_nous_free_tier_cache() -> None:
"""Invalidate the cached free-tier result (e.g. after login/logout)."""
global _free_tier_cache
_free_tier_cache = None
def check_nous_free_tier() -> bool:
"""Check if the current Nous Portal user is on a free (unpaid) tier.
@@ -614,7 +610,6 @@ def menu_labels(*, force_refresh: bool = False) -> list[str]:
return labels
# ---------------------------------------------------------------------------
# Pricing helpers — fetch live pricing from OpenRouter-compatible /v1/models
# ---------------------------------------------------------------------------
@@ -647,6 +642,31 @@ def _format_price_per_mtok(per_token_str: str) -> str:
return f"${per_m:.2f}"
def format_pricing_label(pricing: dict[str, str] | None) -> str:
"""Build a compact pricing label like 'in $3 · out $15 · cache $0.30/Mtok'.
Returns empty string when pricing is unavailable.
"""
if not pricing:
return ""
prompt_price = pricing.get("prompt", "")
completion_price = pricing.get("completion", "")
if not prompt_price and not completion_price:
return ""
inp = _format_price_per_mtok(prompt_price)
out = _format_price_per_mtok(completion_price)
if inp == "free" and out == "free":
return "free"
cache_read = pricing.get("input_cache_read", "")
cache_str = _format_price_per_mtok(cache_read) if cache_read else ""
if inp == out and not cache_str:
return f"{inp}/Mtok"
parts = [f"in {inp}", f"out {out}"]
if cache_str and cache_str != "?" and cache_str != inp:
parts.append(f"cache {cache_str}")
return " · ".join(parts) + "/Mtok"
def format_model_pricing_table(
models: list[tuple[str, str]],
pricing_map: dict[str, dict[str, str]],
+6 -21
View File
@@ -42,11 +42,6 @@ _PROFILE_DIRS = [
"plans",
"workspace",
"cron",
# Per-profile HOME for subprocesses: isolates system tool configs (git,
# ssh, gh, npm …) so credentials don't bleed between profiles. In Docker
# this also ensures tool configs land inside the persistent volume.
# See hermes_constants.get_subprocess_home() and issue #4426.
"home",
]
# Files copied during --clone (if they exist in the source)
@@ -120,26 +115,16 @@ _HERMES_SUBCOMMANDS = frozenset({
def _get_profiles_root() -> Path:
"""Return the directory where named profiles are stored.
Anchored to the hermes root, NOT to the current HERMES_HOME
(which may itself be a profile). This ensures ``coder profile list``
can see all profiles.
In Docker/custom deployments where HERMES_HOME points outside
``~/.hermes``, profiles live under ``HERMES_HOME/profiles/`` so
they persist on the mounted volume.
Always ``~/.hermes/profiles/`` anchored to the user's home,
NOT to the current HERMES_HOME (which may itself be a profile).
This ensures ``coder profile list`` can see all profiles.
"""
return _get_default_hermes_home() / "profiles"
return Path.home() / ".hermes" / "profiles"
def _get_default_hermes_home() -> Path:
"""Return the default (pre-profile) HERMES_HOME path.
In standard deployments this is ``~/.hermes``.
In Docker/custom deployments where HERMES_HOME is outside ``~/.hermes``
(e.g. ``/opt/data``), returns HERMES_HOME directly.
"""
from hermes_constants import get_default_hermes_root
return get_default_hermes_root()
"""Return the default (pre-profile) HERMES_HOME path."""
return Path.home() / ".hermes"
def _get_active_profile_path() -> Path:
+40 -9
View File
@@ -127,11 +127,6 @@ HERMES_OVERLAYS: Dict[str, HermesOverlay] = {
is_aggregator=True,
base_url_env_var="HF_BASE_URL",
),
"xai": HermesOverlay(
transport="openai_chat",
base_url_override="https://api.x.ai/v1",
base_url_env_var="XAI_BASE_URL",
),
}
@@ -153,6 +148,10 @@ class ProviderDef:
doc: str = ""
source: str = "" # "models.dev", "hermes", "user-config"
@property
def is_user_defined(self) -> bool:
return self.source == "user-config"
# -- Aliases ------------------------------------------------------------------
# Maps human-friendly / legacy names to canonical provider IDs.
@@ -168,10 +167,6 @@ ALIASES: Dict[str, str] = {
"z.ai": "zai",
"zhipu": "zai",
# xai
"x-ai": "xai",
"x.ai": "xai",
# kimi-for-coding (models.dev ID)
"kimi": "kimi-for-coding",
"kimi-coding": "kimi-for-coding",
@@ -267,6 +262,12 @@ def normalize_provider(name: str) -> str:
return ALIASES.get(key, key)
def get_overlay(provider_id: str) -> Optional[HermesOverlay]:
"""Get Hermes overlay for a provider, if one exists."""
canonical = normalize_provider(provider_id)
return HERMES_OVERLAYS.get(canonical)
def get_provider(name: str) -> Optional[ProviderDef]:
"""Look up a provider by id or alias, merging all data sources.
@@ -349,6 +350,36 @@ def get_label(provider_id: str) -> str:
return canonical
# For direct import compat, expose as module-level dict
# Built on demand by get_label() calls
LABELS: Dict[str, str] = {
# Static entries for backward compat — get_label() is the proper API
"openrouter": "OpenRouter",
"nous": "Nous Portal",
"openai-codex": "OpenAI Codex",
"copilot-acp": "GitHub Copilot ACP",
"github-copilot": "GitHub Copilot",
"anthropic": "Anthropic",
"zai": "Z.AI / GLM",
"kimi-for-coding": "Kimi / Moonshot",
"minimax": "MiniMax",
"minimax-cn": "MiniMax (China)",
"deepseek": "DeepSeek",
"alibaba": "Alibaba Cloud (DashScope)",
"vercel": "Vercel AI Gateway",
"opencode": "OpenCode Zen",
"opencode-go": "OpenCode Go",
"kilo": "Kilo Gateway",
"huggingface": "Hugging Face",
"local": "Local endpoint",
"custom": "Custom endpoint",
# Legacy Hermes IDs (point to same providers)
"ai-gateway": "Vercel AI Gateway",
"kilocode": "Kilo Gateway",
"copilot": "GitHub Copilot",
"kimi-coding": "Kimi / Moonshot",
"opencode-zen": "OpenCode Zen",
}
def is_aggregator(provider: str) -> bool:
+141 -9
View File
@@ -173,6 +173,147 @@ def _setup_copilot_reasoning_selection(
_set_reasoning_effort(config, "none")
def _setup_provider_model_selection(config, provider_id, current_model, prompt_choice, prompt_fn):
"""Model selection for API-key providers with live /models detection.
Tries the provider's /models endpoint first. Falls back to a
hardcoded default list with a warning if the endpoint is unreachable.
Always offers a 'Custom model' escape hatch.
"""
from hermes_cli.auth import PROVIDER_REGISTRY, resolve_api_key_provider_credentials
from hermes_cli.config import get_env_value
from hermes_cli.models import (
copilot_model_api_mode,
fetch_api_models,
fetch_github_model_catalog,
normalize_copilot_model_id,
normalize_opencode_model_id,
opencode_model_api_mode,
)
pconfig = PROVIDER_REGISTRY[provider_id]
is_copilot_catalog_provider = provider_id in {"copilot", "copilot-acp"}
# Resolve API key and base URL for the probe
if is_copilot_catalog_provider:
api_key = ""
if provider_id == "copilot":
creds = resolve_api_key_provider_credentials(provider_id)
api_key = creds.get("api_key", "")
base_url = creds.get("base_url", "") or pconfig.inference_base_url
else:
try:
creds = resolve_api_key_provider_credentials("copilot")
api_key = creds.get("api_key", "")
except Exception:
pass
base_url = pconfig.inference_base_url
catalog = fetch_github_model_catalog(api_key)
current_model = normalize_copilot_model_id(
current_model,
catalog=catalog,
api_key=api_key,
) or current_model
else:
api_key = ""
for ev in pconfig.api_key_env_vars:
api_key = get_env_value(ev) or os.getenv(ev, "")
if api_key:
break
base_url_env = pconfig.base_url_env_var or ""
base_url = (get_env_value(base_url_env) if base_url_env else "") or pconfig.inference_base_url
catalog = None
# Try live /models endpoint
if is_copilot_catalog_provider and catalog:
live_models = [item.get("id", "") for item in catalog if item.get("id")]
else:
live_models = fetch_api_models(api_key, base_url)
if live_models:
provider_models = live_models
print_info(f"Found {len(live_models)} model(s) from {pconfig.name} API")
else:
fallback_provider_id = "copilot" if provider_id == "copilot-acp" else provider_id
provider_models = _DEFAULT_PROVIDER_MODELS.get(fallback_provider_id, [])
if provider_models:
print_warning(
f"Could not auto-detect models from {pconfig.name} API — showing defaults.\n"
f" Use \"Custom model\" if the model you expect isn't listed."
)
if provider_id in {"opencode-zen", "opencode-go"}:
provider_models = [normalize_opencode_model_id(provider_id, mid) for mid in provider_models]
current_model = normalize_opencode_model_id(provider_id, current_model)
provider_models = list(dict.fromkeys(mid for mid in provider_models if mid))
model_choices = list(provider_models)
model_choices.append("Custom model")
model_choices.append(f"Keep current ({current_model})")
keep_idx = len(model_choices) - 1
model_idx = prompt_choice("Select default model:", model_choices, keep_idx)
selected_model = current_model
if model_idx < len(provider_models):
selected_model = provider_models[model_idx]
if is_copilot_catalog_provider:
selected_model = normalize_copilot_model_id(
selected_model,
catalog=catalog,
api_key=api_key,
) or selected_model
elif provider_id in {"opencode-zen", "opencode-go"}:
selected_model = normalize_opencode_model_id(provider_id, selected_model)
_set_default_model(config, selected_model)
elif model_idx == len(provider_models):
custom = prompt_fn("Enter model name")
if custom:
if is_copilot_catalog_provider:
selected_model = normalize_copilot_model_id(
custom,
catalog=catalog,
api_key=api_key,
) or custom
elif provider_id in {"opencode-zen", "opencode-go"}:
selected_model = normalize_opencode_model_id(provider_id, custom)
else:
selected_model = custom
_set_default_model(config, selected_model)
else:
# "Keep current" selected — validate it's compatible with the new
# provider. OpenRouter-formatted names (containing "/") won't work
# on direct-API providers and would silently break the gateway.
if "/" in (current_model or "") and provider_models:
print_warning(
f"Current model \"{current_model}\" looks like an OpenRouter model "
f"and won't work with {pconfig.name}. "
f"Switching to {provider_models[0]}."
)
selected_model = provider_models[0]
_set_default_model(config, provider_models[0])
if provider_id == "copilot" and selected_model:
model_cfg = _model_config_dict(config)
model_cfg["api_mode"] = copilot_model_api_mode(
selected_model,
catalog=catalog,
api_key=api_key,
)
config["model"] = model_cfg
_setup_copilot_reasoning_selection(
config,
selected_model,
prompt_choice,
catalog=catalog,
api_key=api_key,
)
elif provider_id in {"opencode-zen", "opencode-go"} and selected_model:
model_cfg = _model_config_dict(config)
model_cfg["api_mode"] = opencode_model_api_mode(provider_id, selected_model)
config["model"] = model_cfg
# Import config helpers
from hermes_cli.config import (
@@ -338,8 +479,6 @@ def _curses_prompt_choice(question: str, choices: list, default: int = 0) -> int
return
curses.wrapper(_curses_menu)
from hermes_cli.curses_ui import flush_stdin
flush_stdin()
return result_holder[0]
except Exception:
return -1
@@ -2030,12 +2169,6 @@ def _setup_whatsapp():
print_info("or personal self-chat) and pair via QR code.")
def _setup_weixin():
"""Configure Weixin (personal WeChat) via iLink Bot API QR login."""
from hermes_cli.gateway import _setup_weixin as _gateway_setup_weixin
_gateway_setup_weixin()
def _setup_bluebubbles():
"""Configure BlueBubbles iMessage gateway."""
print_header("BlueBubbles (iMessage)")
@@ -2155,7 +2288,6 @@ _GATEWAY_PLATFORMS = [
("Matrix", "MATRIX_ACCESS_TOKEN", _setup_matrix),
("Mattermost", "MATTERMOST_TOKEN", _setup_mattermost),
("WhatsApp", "WHATSAPP_ENABLED", _setup_whatsapp),
("Weixin (WeChat)", "WEIXIN_ACCOUNT_ID", _setup_weixin),
("BlueBubbles (iMessage)", "BLUEBUBBLES_SERVER_URL", _setup_bluebubbles),
("Webhooks (GitHub, GitLab, etc.)", "WEBHOOK_ENABLED", _setup_webhooks),
]
-1
View File
@@ -31,7 +31,6 @@ PLATFORMS = {
"dingtalk": "💬 DingTalk",
"feishu": "🪽 Feishu",
"wecom": "💬 WeCom",
"weixin": "💬 Weixin",
"webhook": "🔗 Webhook",
}
+24 -27
View File
@@ -151,8 +151,7 @@ def do_search(query: str, source: str = "all", limit: int = 10,
auth = GitHubAuth()
sources = create_source_router(auth)
with c.status("[bold]Searching registries..."):
results = unified_search(query, sources, source_filter=source, limit=limit)
results = unified_search(query, sources, source_filter=source, limit=limit)
if not results:
c.print("[dim]No skills found matching your query.[/]\n")
@@ -188,7 +187,7 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
Official skills are always shown first, regardless of source filter.
"""
from tools.skills_hub import (
GitHubAuth, create_source_router, parallel_search_sources,
GitHubAuth, create_source_router,
)
# Clamp page_size to safe range
@@ -199,23 +198,27 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
auth = GitHubAuth()
sources = create_source_router(auth)
# Collect results from all (or filtered) sources in parallel.
# Per-source limits are generous — parallelism + 30s timeout cap prevents hangs.
# Collect results from all (or filtered) sources
# Use empty query to get everything; per-source limits prevent overload
_TRUST_RANK = {"builtin": 3, "trusted": 2, "community": 1}
_PER_SOURCE_LIMIT = {
"official": 200, "skills-sh": 200, "well-known": 50,
"github": 200, "clawhub": 500, "claude-marketplace": 100,
"lobehub": 500,
}
_PER_SOURCE_LIMIT = {"official": 100, "skills-sh": 100, "well-known": 25, "github": 100, "clawhub": 50,
"claude-marketplace": 50, "lobehub": 50}
with c.status("[bold]Fetching skills from registries..."):
all_results, source_counts, timed_out = parallel_search_sources(
sources,
query="",
per_source_limits=_PER_SOURCE_LIMIT,
source_filter=source,
overall_timeout=30,
)
all_results: list = []
source_counts: dict = {}
for src in sources:
sid = src.source_id()
if source != "all" and sid != source and sid != "official":
# Always include official source for the "first" placement
continue
try:
limit = _PER_SOURCE_LIMIT.get(sid, 50)
results = src.search("", limit=limit)
source_counts[sid] = len(results)
all_results.extend(results)
except Exception:
continue
if not all_results:
c.print("[dim]No skills found in the Skills Hub.[/]\n")
@@ -249,11 +252,8 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
# Build header
source_label = f"{source}" if source != "all" else "— all sources"
loaded_label = f"{total} skills loaded"
if timed_out:
loaded_label += f", {len(timed_out)} source(s) still loading"
c.print(f"\n[bold]Skills Hub — Browse {source_label}[/]"
f" [dim]({loaded_label}, page {page}/{total_pages})[/]")
f" [dim]({total} skills, page {page}/{total_pages})[/]")
if official_count > 0 and page == 1:
c.print(f"[bright_cyan]★ {official_count} official optional skill(s) from Nous Research[/]")
c.print()
@@ -300,11 +300,8 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
parts = [f"{sid}: {ct}" for sid, ct in sorted(source_counts.items())]
c.print(f" [dim]Sources: {', '.join(parts)}[/]")
if timed_out:
c.print(f" [yellow]⚡ Slow sources skipped: {', '.join(timed_out)} "
f"— run again for cached results[/]")
c.print("[dim]Tip: 'hermes skills search <query>' searches deeper across all registries[/]\n")
c.print("[dim]Use: hermes skills inspect <identifier> to preview, "
"hermes skills install <identifier> to install[/]\n")
def do_install(identifier: str, category: str = "", force: bool = False,
-1
View File
@@ -305,7 +305,6 @@ def show_status(args):
"DingTalk": ("DINGTALK_CLIENT_ID", None),
"Feishu": ("FEISHU_APP_ID", "FEISHU_HOME_CHANNEL"),
"WeCom": ("WECOM_BOT_ID", "WECOM_HOME_CHANNEL"),
"Weixin": ("WEIXIN_ACCOUNT_ID", "WEIXIN_HOME_CHANNEL"),
"BlueBubbles": ("BLUEBUBBLES_SERVER_URL", "BLUEBUBBLES_HOME_CHANNEL"),
}
-3
View File
@@ -133,7 +133,6 @@ PLATFORMS = {
"dingtalk": {"label": "💬 DingTalk", "default_toolset": "hermes-dingtalk"},
"feishu": {"label": "🪽 Feishu", "default_toolset": "hermes-feishu"},
"wecom": {"label": "💬 WeCom", "default_toolset": "hermes-wecom"},
"weixin": {"label": "💬 Weixin", "default_toolset": "hermes-weixin"},
"api_server": {"label": "🌐 API Server", "default_toolset": "hermes-api-server"},
"mattermost": {"label": "💬 Mattermost", "default_toolset": "hermes-mattermost"},
"webhook": {"label": "🔗 Webhook", "default_toolset": "hermes-webhook"},
@@ -721,8 +720,6 @@ def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
return
curses.wrapper(_curses_menu)
from hermes_cli.curses_ui import flush_stdin
flush_stdin()
return result_holder[0]
except Exception:
+4 -65
View File
@@ -17,45 +17,6 @@ def get_hermes_home() -> Path:
return Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
def get_default_hermes_root() -> Path:
"""Return the root Hermes directory for profile-level operations.
In standard deployments this is ``~/.hermes``.
In Docker or custom deployments where ``HERMES_HOME`` points outside
``~/.hermes`` (e.g. ``/opt/data``), returns ``HERMES_HOME`` directly
that IS the root.
In profile mode where ``HERMES_HOME`` is ``<root>/profiles/<name>``,
returns ``<root>`` so that ``profile list`` can see all profiles.
Works both for standard (``~/.hermes/profiles/coder``) and Docker
(``/opt/data/profiles/coder``) layouts.
Import-safe no dependencies beyond stdlib.
"""
native_home = Path.home() / ".hermes"
env_home = os.environ.get("HERMES_HOME", "")
if not env_home:
return native_home
env_path = Path(env_home)
try:
env_path.resolve().relative_to(native_home.resolve())
# HERMES_HOME is under ~/.hermes (normal or profile mode)
return native_home
except ValueError:
pass
# Docker / custom deployment.
# Check if this is a profile path: <root>/profiles/<name>
# If the immediate parent dir is named "profiles", the root is
# the grandparent — this covers Docker profiles correctly.
if env_path.parent.name == "profiles":
return env_path.parent.parent
# Not a profile path — HERMES_HOME itself is the root
return env_path
def get_optional_skills_dir(default: Path | None = None) -> Path:
"""Return the optional-skills directory, honoring package-manager wrappers.
@@ -111,32 +72,6 @@ def display_hermes_home() -> str:
return str(home)
def get_subprocess_home() -> str | None:
"""Return a per-profile HOME directory for subprocesses, or None.
When ``{HERMES_HOME}/home/`` exists on disk, subprocesses should use it
as ``HOME`` so system tools (git, ssh, gh, npm ) write their configs
inside the Hermes data directory instead of the OS-level ``/root`` or
``~/``. This provides:
* **Docker persistence** tool configs land inside the persistent volume.
* **Profile isolation** each profile gets its own git identity, SSH
keys, gh tokens, etc.
The Python process's own ``os.environ["HOME"]`` and ``Path.home()`` are
**never** modified only subprocess environments should inject this value.
Activation is directory-based: if the ``home/`` subdirectory doesn't
exist, returns ``None`` and behavior is unchanged.
"""
hermes_home = os.getenv("HERMES_HOME")
if not hermes_home:
return None
profile_home = os.path.join(hermes_home, "home")
if os.path.isdir(profile_home):
return profile_home
return None
VALID_REASONING_EFFORTS = ("minimal", "low", "medium", "high", "xhigh")
@@ -170,7 +105,11 @@ def is_termux() -> bool:
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
OPENROUTER_MODELS_URL = f"{OPENROUTER_BASE_URL}/models"
OPENROUTER_CHAT_URL = f"{OPENROUTER_BASE_URL}/chat/completions"
AI_GATEWAY_BASE_URL = "https://ai-gateway.vercel.sh/v1"
AI_GATEWAY_MODELS_URL = f"{AI_GATEWAY_BASE_URL}/models"
AI_GATEWAY_CHAT_URL = f"{AI_GATEWAY_BASE_URL}/chat/completions"
NOUS_API_BASE_URL = "https://inference-api.nousresearch.com/v1"
NOUS_API_CHAT_URL = f"{NOUS_API_BASE_URL}/chat/completions"
+66
View File
@@ -520,6 +520,72 @@ class SessionDB:
)
self._execute_write(_do)
def set_token_counts(
self,
session_id: str,
input_tokens: int = 0,
output_tokens: int = 0,
model: str = None,
cache_read_tokens: int = 0,
cache_write_tokens: int = 0,
reasoning_tokens: int = 0,
estimated_cost_usd: Optional[float] = None,
actual_cost_usd: Optional[float] = None,
cost_status: Optional[str] = None,
cost_source: Optional[str] = None,
pricing_version: Optional[str] = None,
billing_provider: Optional[str] = None,
billing_base_url: Optional[str] = None,
billing_mode: Optional[str] = None,
) -> None:
"""Set token counters to absolute values (not increment).
Use this when the caller provides cumulative totals from a completed
conversation run (e.g. the gateway, where the cached agent's
session_prompt_tokens already reflects the running total).
"""
def _do(conn):
conn.execute(
"""UPDATE sessions SET
input_tokens = ?,
output_tokens = ?,
cache_read_tokens = ?,
cache_write_tokens = ?,
reasoning_tokens = ?,
estimated_cost_usd = ?,
actual_cost_usd = CASE
WHEN ? IS NULL THEN actual_cost_usd
ELSE ?
END,
cost_status = COALESCE(?, cost_status),
cost_source = COALESCE(?, cost_source),
pricing_version = COALESCE(?, pricing_version),
billing_provider = COALESCE(billing_provider, ?),
billing_base_url = COALESCE(billing_base_url, ?),
billing_mode = COALESCE(billing_mode, ?),
model = COALESCE(model, ?)
WHERE id = ?""",
(
input_tokens,
output_tokens,
cache_read_tokens,
cache_write_tokens,
reasoning_tokens,
estimated_cost_usd,
actual_cost_usd,
actual_cost_usd,
cost_status,
cost_source,
pricing_version,
billing_provider,
billing_base_url,
billing_mode,
model,
session_id,
),
)
self._execute_write(_do)
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
"""Get a session by ID."""
with self._lock:
+13
View File
@@ -89,6 +89,13 @@ def get_timezone() -> Optional[ZoneInfo]:
return _cached_tz
def get_timezone_name() -> str:
"""Return the IANA name of the configured timezone, or empty string."""
if not _cache_resolved:
get_timezone() # populates cache
return _cached_tz_name or ""
def now() -> datetime:
"""
Return the current time as a timezone-aware datetime.
@@ -103,3 +110,9 @@ def now() -> datetime:
return datetime.now().astimezone()
def reset_cache() -> None:
"""Clear the cached timezone. Used by tests and after config changes."""
global _cached_tz, _cached_tz_name, _cache_resolved
_cached_tz = None
_cached_tz_name = None
_cache_resolved = False
-16
View File
@@ -611,22 +611,6 @@
chown ${cfg.user}:${cfg.group} ${cfg.stateDir}/.hermes/.managed
chmod 0644 ${cfg.stateDir}/.hermes/.managed
# Container mode metadata — tells the host CLI to exec into the
# container instead of running locally. Removed when container mode
# is disabled so the host CLI falls back to native execution.
${if cfg.container.enable then ''
cat > ${cfg.stateDir}/.hermes/.container-mode <<'HERMES_CONTAINER_MODE_EOF'
# Written by NixOS activation script. Do not edit manually.
backend=${cfg.container.backend}
container_name=${containerName}
hermes_bin=${containerDataDir}/current-package/bin/hermes
HERMES_CONTAINER_MODE_EOF
chown ${cfg.user}:${cfg.group} ${cfg.stateDir}/.hermes/.container-mode
chmod 0644 ${cfg.stateDir}/.hermes/.container-mode
'' else ''
rm -f ${cfg.stateDir}/.hermes/.container-mode
''}
# Seed auth file if provided
${lib.optionalString (cfg.authFile != null) ''
${if cfg.authFileForceOverwrite then ''
+5 -5
View File
@@ -16,7 +16,7 @@ dependencies = [
"anthropic>=0.39.0,<1",
"python-dotenv>=1.2.1,<2",
"fire>=0.7.1,<1",
"httpx[socks]>=0.28.1,<1",
"httpx>=0.28.1,<1",
"rich>=14.3.3,<15",
"tenacity>=9.1.4,<10",
"pyyaml>=6.0.2,<7",
@@ -88,10 +88,10 @@ all = [
"hermes-agent[modal]",
"hermes-agent[daytona]",
"hermes-agent[messaging]",
# matrix: python-olm (required by matrix-nio[e2e]) is upstream-broken on
# modern macOS (archived libolm, C++ errors with Clang 21+). On Linux the
# [matrix] extra's own marker pulls in the [e2e] variant automatically.
"hermes-agent[matrix]; sys_platform == 'linux'",
# matrix excluded: python-olm (required by matrix-nio[e2e]) is upstream-broken
# on modern macOS (archived libolm, C++ errors with Clang 21+). Including it
# here causes the entire [all] install to fail, dropping all other extras.
# Users who need Matrix can install manually: pip install 'hermes-agent[matrix]'
"hermes-agent[cron]",
"hermes-agent[cli]",
"hermes-agent[dev]",
+50 -228
View File
@@ -359,9 +359,8 @@ def _sanitize_surrogates(text: str) -> str:
def _sanitize_messages_surrogates(messages: list) -> bool:
"""Sanitize surrogate characters from all string content in a messages list.
Walks message dicts in-place. Returns True if any surrogates were found
and replaced, False otherwise. Covers content/text, name, and tool call
metadata/arguments so retries don't fail on a non-content field.
Walks message dicts in-place. Returns True if any surrogates were found
and replaced, False otherwise.
"""
found = False
for msg in messages:
@@ -378,88 +377,6 @@ def _sanitize_messages_surrogates(messages: list) -> bool:
if isinstance(text, str) and _SURROGATE_RE.search(text):
part["text"] = _SURROGATE_RE.sub('\ufffd', text)
found = True
name = msg.get("name")
if isinstance(name, str) and _SURROGATE_RE.search(name):
msg["name"] = _SURROGATE_RE.sub('\ufffd', name)
found = True
tool_calls = msg.get("tool_calls")
if isinstance(tool_calls, list):
for tc in tool_calls:
if not isinstance(tc, dict):
continue
tc_id = tc.get("id")
if isinstance(tc_id, str) and _SURROGATE_RE.search(tc_id):
tc["id"] = _SURROGATE_RE.sub('\ufffd', tc_id)
found = True
fn = tc.get("function")
if isinstance(fn, dict):
fn_name = fn.get("name")
if isinstance(fn_name, str) and _SURROGATE_RE.search(fn_name):
fn["name"] = _SURROGATE_RE.sub('\ufffd', fn_name)
found = True
fn_args = fn.get("arguments")
if isinstance(fn_args, str) and _SURROGATE_RE.search(fn_args):
fn["arguments"] = _SURROGATE_RE.sub('\ufffd', fn_args)
found = True
return found
def _strip_non_ascii(text: str) -> str:
"""Remove non-ASCII characters, replacing with closest ASCII equivalent or removing.
Used as a last resort when the system encoding is ASCII and can't handle
any non-ASCII characters (e.g. LANG=C on Chromebooks).
"""
return text.encode('ascii', errors='ignore').decode('ascii')
def _sanitize_messages_non_ascii(messages: list) -> bool:
"""Strip non-ASCII characters from all string content in a messages list.
This is a last-resort recovery for systems with ASCII-only encoding
(LANG=C, Chromebooks, minimal containers). Returns True if any
non-ASCII content was found and sanitized.
"""
found = False
for msg in messages:
if not isinstance(msg, dict):
continue
# Sanitize content (string)
content = msg.get("content")
if isinstance(content, str):
sanitized = _strip_non_ascii(content)
if sanitized != content:
msg["content"] = sanitized
found = True
elif isinstance(content, list):
for part in content:
if isinstance(part, dict):
text = part.get("text")
if isinstance(text, str):
sanitized = _strip_non_ascii(text)
if sanitized != text:
part["text"] = sanitized
found = True
# Sanitize name field (can contain non-ASCII in tool results)
name = msg.get("name")
if isinstance(name, str):
sanitized = _strip_non_ascii(name)
if sanitized != name:
msg["name"] = sanitized
found = True
# Sanitize tool_calls
tool_calls = msg.get("tool_calls")
if isinstance(tool_calls, list):
for tc in tool_calls:
if isinstance(tc, dict):
fn = tc.get("function", {})
if isinstance(fn, dict):
fn_args = fn.get("arguments")
if isinstance(fn_args, str):
sanitized = _strip_non_ascii(fn_args)
if sanitized != fn_args:
fn["arguments"] = sanitized
found = True
return found
@@ -689,17 +606,6 @@ class AIAgent:
else:
self.api_mode = "chat_completions"
try:
from hermes_cli.model_normalize import (
_AGGREGATOR_PROVIDERS,
normalize_model_for_provider,
)
if self.provider not in _AGGREGATOR_PROVIDERS:
self.model = normalize_model_for_provider(self.model, self.provider)
except Exception:
pass
# Direct OpenAI sessions use the Responses API path. GPT-5.x tool
# calls with reasoning are rejected on /v1/chat/completions, and
# Hermes is a tool-using client by default.
@@ -721,6 +627,7 @@ class AIAgent:
self.suppress_status_output = False
self.thinking_callback = thinking_callback
self.reasoning_callback = reasoning_callback
self._reasoning_deltas_fired = False # Set by _fire_reasoning_delta, reset per API call
self.clarify_callback = clarify_callback
self.step_callback = step_callback
self.stream_delta_callback = stream_delta_callback
@@ -947,7 +854,6 @@ class AIAgent:
client_kwargs["default_headers"] = headers
self.api_key = client_kwargs.get("api_key", "")
self.base_url = client_kwargs.get("base_url", self.base_url)
try:
self.client = self._create_openai_client(client_kwargs, reason="agent_init", shared=True)
if not self.quiet_mode:
@@ -1244,9 +1150,6 @@ class AIAgent:
except (TypeError, ValueError):
_config_context_length = None
# Store for reuse in switch_model (so config override persists across model switches)
self._config_context_length = _config_context_length
# Check custom_providers per-model context_length
if _config_context_length is None:
_custom_providers = _agent_cfg.get("custom_providers")
@@ -1401,6 +1304,7 @@ class AIAgent:
if hasattr(self, "context_compressor") and self.context_compressor:
self.context_compressor.last_prompt_tokens = 0
self.context_compressor.last_completion_tokens = 0
self.context_compressor.last_total_tokens = 0
self.context_compressor.compression_count = 0
self.context_compressor._context_probed = False
self.context_compressor._context_probe_persistable = False
@@ -1484,7 +1388,6 @@ class AIAgent:
base_url=self.base_url,
api_key=self.api_key,
provider=self.provider,
config_context_length=getattr(self, "_config_context_length", None),
)
self.context_compressor.model = self.model
self.context_compressor.base_url = self.base_url
@@ -1977,14 +1880,19 @@ class AIAgent:
except Exception as e:
logger.debug("Background memory/skill review failed: %s", e)
finally:
# Close all resources (httpx client, subprocesses, etc.) so
# GC doesn't try to clean them up on a dead asyncio event
# loop (which produces "Event loop is closed" errors).
# Explicitly close the OpenAI/httpx client so GC doesn't
# try to clean it up on a dead asyncio event loop (which
# produces "Event loop is closed" errors in the terminal).
if review_agent is not None:
try:
review_agent.close()
except Exception:
pass
client = getattr(review_agent, "client", None)
if client is not None:
try:
review_agent._close_openai_client(
client, reason="bg_review_done", shared=True
)
review_agent.client = None
except Exception:
pass
t = threading.Thread(target=_run_review, daemon=True, name="bg-review")
t.start()
@@ -2724,64 +2632,6 @@ class AIAgent:
except Exception:
pass
def close(self) -> None:
"""Release all resources held by this agent instance.
Cleans up subprocess resources that would otherwise become orphans:
- Background processes tracked in ProcessRegistry
- Terminal sandbox environments
- Browser daemon sessions
- Active child agents (subagent delegation)
- OpenAI/httpx client connections
Safe to call multiple times (idempotent). Each cleanup step is
independently guarded so a failure in one does not prevent the rest.
"""
task_id = getattr(self, "session_id", None) or ""
# 1. Kill background processes for this task
try:
from tools.process_registry import process_registry
process_registry.kill_all(task_id=task_id)
except Exception:
pass
# 2. Clean terminal sandbox environments
try:
from tools.terminal_tool import cleanup_vm
cleanup_vm(task_id)
except Exception:
pass
# 3. Clean browser daemon sessions
try:
from tools.browser_tool import cleanup_browser
cleanup_browser(task_id)
except Exception:
pass
# 4. Close active child agents
try:
with self._active_children_lock:
children = list(self._active_children)
self._active_children.clear()
for child in children:
try:
child.close()
except Exception:
pass
except Exception:
pass
# 5. Close the OpenAI/httpx client
try:
client = getattr(self, "client", None)
if client is not None:
self._close_openai_client(client, reason="agent_close", shared=True)
self.client = None
except Exception:
pass
def _hydrate_todo_store(self, history: List[Dict[str, Any]]) -> None:
"""
Recover todo state from conversation history.
@@ -3074,7 +2924,7 @@ class AIAgent:
@staticmethod
def _cap_delegate_task_calls(tool_calls: list) -> list:
"""Truncate excess delegate_task calls to max_concurrent_children.
"""Truncate excess delegate_task calls to MAX_CONCURRENT_CHILDREN.
The delegate_tool caps the task list inside a single call, but the
model can emit multiple separate delegate_task tool_calls in one
@@ -3082,24 +2932,23 @@ class AIAgent:
Returns the original list if no truncation was needed.
"""
from tools.delegate_tool import _get_max_concurrent_children
max_children = _get_max_concurrent_children()
from tools.delegate_tool import MAX_CONCURRENT_CHILDREN
delegate_count = sum(1 for tc in tool_calls if tc.function.name == "delegate_task")
if delegate_count <= max_children:
if delegate_count <= MAX_CONCURRENT_CHILDREN:
return tool_calls
kept_delegates = 0
truncated = []
for tc in tool_calls:
if tc.function.name == "delegate_task":
if kept_delegates < max_children:
if kept_delegates < MAX_CONCURRENT_CHILDREN:
truncated.append(tc)
kept_delegates += 1
else:
truncated.append(tc)
logger.warning(
"Truncated %d excess delegate_task call(s) to enforce "
"max_concurrent_children=%d limit",
delegate_count - max_children, max_children,
"MAX_CONCURRENT_CHILDREN=%d limit",
delegate_count - MAX_CONCURRENT_CHILDREN, MAX_CONCURRENT_CHILDREN,
)
return truncated
@@ -4026,6 +3875,7 @@ class AIAgent:
max_stream_retries = 1
has_tool_calls = False
first_delta_fired = False
self._reasoning_deltas_fired = False
# Accumulate streamed text so we can recover if get_final_response()
# returns empty output (e.g. chatgpt.com backend-api sends
# response.incomplete instead of response.completed).
@@ -4534,6 +4384,7 @@ class AIAgent:
def _fire_reasoning_delta(self, text: str) -> None:
"""Fire reasoning callback if registered."""
self._reasoning_deltas_fired = True
cb = self.reasoning_callback
if cb is not None:
try:
@@ -4663,6 +4514,10 @@ class AIAgent:
role = "assistant"
reasoning_parts: list = []
usage_obj = None
# Reset per-call reasoning tracking so _build_assistant_message
# knows whether reasoning was already displayed during streaming.
self._reasoning_deltas_fired = False
_first_chunk_seen = False
for chunk in stream:
last_chunk_time["t"] = time.time()
@@ -4830,20 +4685,13 @@ class AIAgent:
works unchanged.
"""
has_tool_use = False
self._reasoning_deltas_fired = False
# Reset stale-stream timer for this attempt
last_chunk_time["t"] = time.time()
# Use the Anthropic SDK's streaming context manager
with self._anthropic_client.messages.stream(**api_kwargs) as stream:
for event in stream:
# Update stale-stream timer on every event so the
# outer poll loop knows data is flowing. Without
# this, the detector kills healthy long-running
# Opus streams after 180 s even when events are
# actively arriving (the chat_completions path
# already does this at the top of its chunk loop).
last_chunk_time["t"] = time.time()
if self._interrupt_requested:
break
@@ -4867,7 +4715,6 @@ class AIAgent:
if text and not has_tool_use:
_fire_first_delta()
self._fire_stream_delta(text)
deltas_were_sent["yes"] = True
elif delta_type == "thinking_delta":
thinking_text = getattr(delta, "thinking", "")
if thinking_text:
@@ -5158,7 +5005,7 @@ class AIAgent:
# when no explicit key is in the fallback config.
if fb_base_url_hint and "ollama.com" in fb_base_url_hint.lower() and not fb_api_key_hint:
fb_api_key_hint = os.getenv("OLLAMA_API_KEY") or None
fb_client, _resolved_fb_model = resolve_provider_client(
fb_client, _ = resolve_provider_client(
fb_provider, model=fb_model, raw_codex=True,
explicit_base_url=fb_base_url_hint,
explicit_api_key=fb_api_key_hint)
@@ -5167,12 +5014,6 @@ class AIAgent:
"Fallback to %s failed: provider not configured",
fb_provider)
return self._try_activate_fallback() # try next in chain
try:
from hermes_cli.model_normalize import normalize_model_for_provider
fb_model = normalize_model_for_provider(fb_model, fb_provider)
except Exception:
pass
# Determine api_mode from provider / base URL
fb_api_mode = "chat_completions"
@@ -5657,7 +5498,7 @@ class AIAgent:
preserve_dots=self._anthropic_preserve_dots(),
context_length=ctx_len,
base_url=getattr(self, "_anthropic_base_url", None),
fast_mode=(self.request_overrides or {}).get("speed") == "fast",
fast_mode=self.request_overrides.get("speed") == "fast",
)
if self.api_mode == "codex_responses":
@@ -7321,7 +7162,7 @@ class AIAgent:
self._thinking_prefill_retries = 0
self._last_content_with_tools = None
self._mute_post_response = False
self._unicode_sanitization_passes = 0
self._surrogate_sanitized = False
# Pre-turn connection health check: detect and clean up dead TCP
# connections left over from provider outages or dropped streams.
@@ -7761,7 +7602,6 @@ class AIAgent:
finish_reason = "stop"
response = None # Guard against UnboundLocalError if all retries fail
api_kwargs = None # Guard against UnboundLocalError in except handler
while retry_count < max_retries:
try:
@@ -8307,40 +8147,22 @@ class AIAgent:
self.thinking_callback("")
# -----------------------------------------------------------
# UnicodeEncodeError recovery. Two common causes:
# 1. Lone surrogates (U+D800..U+DFFF) from clipboard paste
# (Google Docs, rich-text editors) — sanitize and retry.
# 2. ASCII codec on systems with LANG=C or non-UTF-8 locale
# (e.g. Chromebooks) — any non-ASCII character fails.
# Detect via the error message mentioning 'ascii' codec.
# We sanitize messages in-place and may retry twice:
# first to strip surrogates, then once more for pure
# ASCII-only locale sanitization if needed.
# Surrogate character recovery. UnicodeEncodeError happens
# when the messages contain lone surrogates (U+D800..U+DFFF)
# that are invalid UTF-8. Common source: clipboard paste
# from Google Docs or similar rich-text editors. We sanitize
# the entire messages list in-place and retry once.
# -----------------------------------------------------------
if isinstance(api_error, UnicodeEncodeError) and getattr(self, '_unicode_sanitization_passes', 0) < 2:
_err_str = str(api_error).lower()
_is_ascii_codec = "'ascii'" in _err_str or "ascii" in _err_str
_surrogates_found = _sanitize_messages_surrogates(messages)
if _surrogates_found:
self._unicode_sanitization_passes += 1
if isinstance(api_error, UnicodeEncodeError) and not getattr(self, '_surrogate_sanitized', False):
self._surrogate_sanitized = True
if _sanitize_messages_surrogates(messages):
self._vprint(
f"{self.log_prefix}⚠️ Stripped invalid surrogate characters from messages. Retrying...",
force=True,
)
continue
if _is_ascii_codec:
# ASCII codec: the system encoding can't handle
# non-ASCII characters at all. Sanitize all
# non-ASCII content from messages and retry.
if _sanitize_messages_non_ascii(messages):
self._unicode_sanitization_passes += 1
self._vprint(
f"{self.log_prefix}⚠️ System encoding is ASCII — stripped non-ASCII characters from messages. Retrying...",
force=True,
)
continue
# Nothing to sanitize in messages — might be in system
# prompt or prefill. Fall through to normal error path.
# Surrogates weren't in messages — might be in system
# prompt or prefill. Fall through to normal error path.
status_code = getattr(api_error, "status_code", None)
error_context = self._extract_api_error_context(api_error)
@@ -8796,10 +8618,9 @@ class AIAgent:
if self._try_activate_fallback():
retry_count = 0
continue
if api_kwargs is not None:
self._dump_api_request_debug(
api_kwargs, reason="non_retryable_client_error", error=api_error,
)
self._dump_api_request_debug(
api_kwargs, reason="non_retryable_client_error", error=api_error,
)
self._emit_status(
f"❌ Non-retryable error (HTTP {status_code}): "
f"{self._summarize_api_error(api_error)}"
@@ -8902,10 +8723,9 @@ class AIAgent:
self.log_prefix, max_retries, _final_summary,
_provider, _model, len(api_messages), f"{approx_tokens:,}",
)
if api_kwargs is not None:
self._dump_api_request_debug(
api_kwargs, reason="max_retries_exhausted", error=api_error,
)
self._dump_api_request_debug(
api_kwargs, reason="max_retries_exhausted", error=api_error,
)
self._persist_session(messages, conversation_history)
_final_response = f"API call failed after {max_retries} retries: {_final_summary}"
if _is_stream_drop:
@@ -9543,6 +9363,7 @@ class AIAgent:
# Reset retry counter/signature on successful content
if hasattr(self, '_empty_content_retries'):
self._empty_content_retries = 0
self._last_empty_content_signature = None
self._thinking_prefill_retries = 0
if (
@@ -9614,6 +9435,7 @@ class AIAgent:
# If an assistant message with tool_calls was already appended,
# the API expects a role="tool" result for every tool_call_id.
# Fill in error results for any that weren't answered yet.
pending_handled = False
for idx in range(len(messages) - 1, -1, -1):
msg = messages[idx]
if not isinstance(msg, dict):
+7 -36
View File
@@ -1082,19 +1082,10 @@ install_node_deps() {
log_success "Node.js dependencies installed"
# Install Playwright browser + system dependencies.
# Playwright's --with-deps only supports apt-based systems natively.
# Playwright's install-deps only supports apt/dnf/zypper natively.
# For Arch/Manjaro we install the system libs via pacman first.
# Other systems must install Chromium dependencies manually.
log_info "Installing browser engine (Playwright Chromium)..."
case "$DISTRO" in
ubuntu|debian|raspbian|pop|linuxmint|elementary|zorin|kali|parrot)
log_info "Playwright may request sudo to install browser system dependencies (shared libraries)."
log_info "This is standard Playwright setup — Hermes itself does not require root access."
cd "$INSTALL_DIR" && npx playwright install --with-deps chromium 2>/dev/null || {
log_warn "Playwright browser installation failed — browser tools will not work."
log_warn "Try running manually: cd $INSTALL_DIR && npx playwright install --with-deps chromium"
}
;;
arch|manjaro)
if command -v pacman &> /dev/null; then
log_info "Arch/Manjaro detected — installing Chromium system dependencies via pacman..."
@@ -1109,35 +1100,15 @@ install_node_deps() {
log_warn " sudo pacman -S nss atk at-spi2-core cups libdrm libxkbcommon mesa pango cairo alsa-lib"
fi
fi
cd "$INSTALL_DIR" && npx playwright install chromium 2>/dev/null || {
log_warn "Playwright browser installation failed — browser tools will not work."
}
;;
fedora|rhel|centos|rocky|alma)
log_warn "Playwright does not support automatic dependency installation on RPM-based systems."
log_info "Install Chromium system dependencies manually before using browser tools:"
log_info " sudo dnf install nss atk at-spi2-core cups-libs libdrm libxkbcommon mesa-libgbm pango cairo alsa-lib"
cd "$INSTALL_DIR" && npx playwright install chromium 2>/dev/null || {
log_warn "Playwright browser installation failed — install dependencies above and retry."
}
;;
opensuse*|sles)
log_warn "Playwright does not support automatic dependency installation on zypper-based systems."
log_info "Install Chromium system dependencies manually before using browser tools:"
log_info " sudo zypper install mozilla-nss libatk-1_0-0 at-spi2-core cups-libs libdrm2 libxkbcommon0 Mesa-libgbm1 pango cairo libasound2"
cd "$INSTALL_DIR" && npx playwright install chromium 2>/dev/null || {
log_warn "Playwright browser installation failed — install dependencies above and retry."
}
;;
*)
log_warn "Playwright does not support automatic dependency installation on $DISTRO."
log_info "Install Chromium/browser system dependencies for your distribution, then run:"
log_info " cd $INSTALL_DIR && npx playwright install chromium"
log_info "Browser tools will not work until dependencies are installed."
cd "$INSTALL_DIR" && npx playwright install chromium 2>/dev/null || true
;;
*)
log_info "Playwright may request sudo to install browser system dependencies (shared libraries)."
log_info "This is standard Playwright setup — Hermes itself does not require root access."
cd "$INSTALL_DIR" && npx playwright install --with-deps chromium 2>/dev/null || true
;;
esac
log_success "Browser engine setup complete"
log_success "Browser engine installed"
fi
# Install WhatsApp bridge dependencies
-13
View File
@@ -68,22 +68,9 @@ class TestInitialize:
resp = await agent.initialize(protocol_version=1)
caps = resp.agent_capabilities
assert isinstance(caps, AgentCapabilities)
assert caps.load_session is True
assert caps.session_capabilities is not None
assert caps.session_capabilities.fork is not None
assert caps.session_capabilities.list is not None
assert caps.session_capabilities.resume is not None
@pytest.mark.asyncio
async def test_initialize_capabilities_wire_format(self, agent):
"""Verify the JSON wire format uses correct aliases so ACP clients see the right keys."""
resp = await agent.initialize(protocol_version=1)
payload = resp.agent_capabilities.model_dump(by_alias=True, exclude_none=True)
assert payload["loadSession"] is True
session_caps = payload["sessionCapabilities"]
assert "fork" in session_caps
assert "list" in session_caps
assert "resume" in session_caps
# ---------------------------------------------------------------------------
+10
View File
@@ -17,6 +17,7 @@ from agent.anthropic_adapter import (
build_anthropic_kwargs,
convert_messages_to_anthropic,
convert_tools_to_anthropic,
get_anthropic_token_source,
is_claude_code_token_valid,
normalize_anthropic_response,
normalize_model_name,
@@ -180,6 +181,15 @@ class TestResolveAnthropicToken:
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
assert resolve_anthropic_token() == "sk-ant-oat01-mytoken"
def test_reports_claude_json_primary_key_source(self, monkeypatch, tmp_path):
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
(tmp_path / ".claude.json").write_text(json.dumps({"primaryApiKey": "sk-ant-api03-primary"}))
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
assert get_anthropic_token_source("sk-ant-api03-primary") == "claude_json_primary_api_key"
def test_does_not_resolve_primary_api_key_as_native_anthropic_token(self, monkeypatch, tmp_path):
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
+226 -55
View File
@@ -9,6 +9,7 @@ import pytest
from agent.auxiliary_client import (
get_text_auxiliary_client,
get_vision_auxiliary_client,
get_available_vision_backends,
resolve_vision_provider_client,
resolve_provider_client,
@@ -19,6 +20,7 @@ from agent.auxiliary_client import (
_get_provider_chain,
_is_payment_error,
_try_payment_fallback,
_resolve_forced_provider,
_resolve_auto,
)
@@ -658,23 +660,19 @@ class TestGetTextAuxiliaryClient:
assert client is None
assert model is None
def test_custom_endpoint_uses_codex_wrapper_when_runtime_requests_responses_api(self):
with patch("agent.auxiliary_client._resolve_custom_runtime",
return_value=("https://api.openai.com/v1", "sk-test", "codex_responses")), \
patch("agent.auxiliary_client._read_main_model", return_value="gpt-5.3-codex"), \
patch("agent.auxiliary_client.OpenAI") as mock_openai:
client, model = get_text_auxiliary_client()
from agent.auxiliary_client import CodexAuxiliaryClient
assert isinstance(client, CodexAuxiliaryClient)
assert model == "gpt-5.3-codex"
assert mock_openai.call_args.kwargs["base_url"] == "https://api.openai.com/v1"
assert mock_openai.call_args.kwargs["api_key"] == "sk-test"
class TestVisionClientFallback:
"""Vision client auto mode resolves known-good multimodal backends."""
def test_vision_returns_none_without_any_credentials(self):
with (
patch("agent.auxiliary_client._read_nous_auth", return_value=None),
patch("agent.auxiliary_client._try_anthropic", return_value=(None, None)),
):
client, model = get_vision_auxiliary_client()
assert client is None
assert model is None
def test_vision_auto_includes_active_provider_when_configured(self, monkeypatch):
"""Active provider appears in available backends when credentials exist."""
monkeypatch.setenv("ANTHROPIC_API_KEY", "***")
@@ -756,6 +754,21 @@ class TestAuxiliaryPoolAwareness:
assert call_kwargs["base_url"] == "https://api.githubcopilot.com"
assert call_kwargs["default_headers"]["Editor-Version"]
def test_vision_auto_uses_active_provider_as_fallback(self, monkeypatch):
"""When no OpenRouter/Nous available, vision auto falls back to active provider."""
monkeypatch.setenv("ANTHROPIC_API_KEY", "***")
with (
patch("agent.auxiliary_client._read_nous_auth", return_value=None),
patch("agent.auxiliary_client._read_main_provider", return_value="anthropic"),
patch("agent.auxiliary_client._read_main_model", return_value="claude-sonnet-4"),
patch("agent.anthropic_adapter.build_anthropic_client", return_value=MagicMock()),
patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="***"),
):
client, model = get_vision_auxiliary_client()
assert client is not None
assert client.__class__.__name__ == "AnthropicAuxiliaryClient"
def test_vision_auto_prefers_active_provider_over_openrouter(self, monkeypatch):
"""Active provider is tried before OpenRouter in vision auto."""
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
@@ -787,6 +800,43 @@ class TestAuxiliaryPoolAwareness:
assert client is not None
assert provider == "custom:local"
def test_vision_direct_endpoint_override(self, monkeypatch):
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
monkeypatch.setenv("AUXILIARY_VISION_BASE_URL", "http://localhost:4567/v1")
monkeypatch.setenv("AUXILIARY_VISION_API_KEY", "vision-key")
monkeypatch.setenv("AUXILIARY_VISION_MODEL", "vision-model")
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
client, model = get_vision_auxiliary_client()
assert model == "vision-model"
assert mock_openai.call_args.kwargs["base_url"] == "http://localhost:4567/v1"
assert mock_openai.call_args.kwargs["api_key"] == "vision-key"
def test_vision_direct_endpoint_without_key_uses_placeholder(self, monkeypatch):
"""Vision endpoint without API key should use 'no-key-required' placeholder."""
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
monkeypatch.setenv("AUXILIARY_VISION_BASE_URL", "http://localhost:4567/v1")
monkeypatch.setenv("AUXILIARY_VISION_MODEL", "vision-model")
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
client, model = get_vision_auxiliary_client()
assert client is not None
assert model == "vision-model"
assert mock_openai.call_args.kwargs["api_key"] == "no-key-required"
def test_vision_uses_openrouter_when_available(self, monkeypatch):
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
client, model = get_vision_auxiliary_client()
assert model == "google/gemini-3-flash-preview"
assert client is not None
def test_vision_uses_nous_when_available(self, monkeypatch):
with patch("agent.auxiliary_client._read_nous_auth") as mock_nous, \
patch("agent.auxiliary_client.OpenAI"):
mock_nous.return_value = {"access_token": "nous-tok"}
client, model = get_vision_auxiliary_client()
assert model == "google/gemini-3-flash-preview"
assert client is not None
def test_vision_config_google_provider_uses_gemini_credentials(self, monkeypatch):
config = {
"auxiliary": {
@@ -812,6 +862,53 @@ class TestAuxiliaryPoolAwareness:
assert mock_openai.call_args.kwargs["api_key"] == "gemini-key"
assert mock_openai.call_args.kwargs["base_url"] == "https://generativelanguage.googleapis.com/v1beta/openai"
def test_vision_forced_main_uses_custom_endpoint(self, monkeypatch):
"""When explicitly forced to 'main', vision CAN use custom endpoint."""
config = {
"model": {
"provider": "custom",
"base_url": "http://localhost:1234/v1",
"default": "my-local-model",
}
}
monkeypatch.setenv("AUXILIARY_VISION_PROVIDER", "main")
monkeypatch.setenv("OPENAI_API_KEY", "local-key")
monkeypatch.setattr("hermes_cli.config.load_config", lambda: config)
monkeypatch.setattr("hermes_cli.runtime_provider.load_config", lambda: config)
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
patch("agent.auxiliary_client.OpenAI") as mock_openai:
client, model = get_vision_auxiliary_client()
assert client is not None
assert model == "my-local-model"
def test_vision_forced_main_returns_none_without_creds(self, monkeypatch):
"""Forced main with no credentials still returns None."""
monkeypatch.setenv("AUXILIARY_VISION_PROVIDER", "main")
monkeypatch.delenv("OPENAI_BASE_URL", raising=False)
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
# Clear client cache to avoid stale entries from previous tests
from agent.auxiliary_client import _client_cache
_client_cache.clear()
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
patch("agent.auxiliary_client._read_main_provider", return_value=""), \
patch("agent.auxiliary_client._read_main_model", return_value=""), \
patch("agent.auxiliary_client._select_pool_entry", return_value=(False, None)), \
patch("agent.auxiliary_client._resolve_custom_runtime", return_value=(None, None)), \
patch("agent.auxiliary_client._read_codex_access_token", return_value=None), \
patch("agent.auxiliary_client._resolve_api_key_provider", return_value=(None, None)):
client, model = get_vision_auxiliary_client()
assert client is None
assert model is None
def test_vision_forced_codex(self, monkeypatch, codex_auth_dir):
"""When forced to 'codex', vision uses Codex OAuth."""
monkeypatch.setenv("AUXILIARY_VISION_PROVIDER", "codex")
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
patch("agent.auxiliary_client.OpenAI"):
client, model = get_vision_auxiliary_client()
from agent.auxiliary_client import CodexAuxiliaryClient
assert isinstance(client, CodexAuxiliaryClient)
assert model == "gpt-5.2-codex"
class TestGetAuxiliaryProvider:
@@ -851,6 +948,122 @@ class TestGetAuxiliaryProvider:
assert _get_auxiliary_provider("web_extract") == "main"
class TestResolveForcedProvider:
"""Tests for _resolve_forced_provider with explicit provider selection."""
def test_forced_openrouter(self, monkeypatch):
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
client, model = _resolve_forced_provider("openrouter")
assert model == "google/gemini-3-flash-preview"
assert client is not None
def test_forced_openrouter_no_key(self, monkeypatch):
with patch("agent.auxiliary_client._read_nous_auth", return_value=None):
client, model = _resolve_forced_provider("openrouter")
assert client is None
assert model is None
def test_forced_nous(self, monkeypatch):
with patch("agent.auxiliary_client._read_nous_auth") as mock_nous, \
patch("agent.auxiliary_client.OpenAI"):
mock_nous.return_value = {"access_token": "nous-tok"}
client, model = _resolve_forced_provider("nous")
assert model == "google/gemini-3-flash-preview"
assert client is not None
def test_forced_nous_not_configured(self, monkeypatch):
with patch("agent.auxiliary_client._read_nous_auth", return_value=None):
client, model = _resolve_forced_provider("nous")
assert client is None
assert model is None
def test_forced_main_uses_custom(self, monkeypatch):
config = {
"model": {
"provider": "custom",
"base_url": "http://local:8080/v1",
"default": "my-local-model",
}
}
monkeypatch.setenv("OPENAI_API_KEY", "local-key")
monkeypatch.setattr("hermes_cli.config.load_config", lambda: config)
monkeypatch.setattr("hermes_cli.runtime_provider.load_config", lambda: config)
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
patch("agent.auxiliary_client.OpenAI") as mock_openai:
client, model = _resolve_forced_provider("main")
assert model == "my-local-model"
def test_forced_main_uses_config_saved_custom_endpoint(self, monkeypatch):
config = {
"model": {
"provider": "custom",
"base_url": "http://local:8080/v1",
"default": "my-local-model",
}
}
monkeypatch.setenv("OPENAI_API_KEY", "local-key")
monkeypatch.setattr("hermes_cli.config.load_config", lambda: config)
monkeypatch.setattr("hermes_cli.runtime_provider.load_config", lambda: config)
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
patch("agent.auxiliary_client._read_codex_access_token", return_value=None), \
patch("agent.auxiliary_client._resolve_api_key_provider", return_value=(None, None)), \
patch("agent.auxiliary_client.OpenAI") as mock_openai:
client, model = _resolve_forced_provider("main")
assert client is not None
assert model == "my-local-model"
call_kwargs = mock_openai.call_args
assert call_kwargs.kwargs["base_url"] == "http://local:8080/v1"
def test_forced_main_skips_openrouter_nous(self, monkeypatch):
"""Even if OpenRouter key is set, 'main' skips it."""
config = {
"model": {
"provider": "custom",
"base_url": "http://local:8080/v1",
"default": "my-local-model",
}
}
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
monkeypatch.setenv("OPENAI_API_KEY", "local-key")
monkeypatch.setattr("hermes_cli.config.load_config", lambda: config)
monkeypatch.setattr("hermes_cli.runtime_provider.load_config", lambda: config)
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
patch("agent.auxiliary_client.OpenAI") as mock_openai:
client, model = _resolve_forced_provider("main")
# Should use custom endpoint, not OpenRouter
assert model == "my-local-model"
def test_forced_main_falls_to_codex(self, codex_auth_dir, monkeypatch):
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
patch("agent.auxiliary_client.OpenAI"):
client, model = _resolve_forced_provider("main")
from agent.auxiliary_client import CodexAuxiliaryClient
assert isinstance(client, CodexAuxiliaryClient)
assert model == "gpt-5.2-codex"
def test_forced_codex(self, codex_auth_dir, monkeypatch):
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
patch("agent.auxiliary_client.OpenAI"):
client, model = _resolve_forced_provider("codex")
from agent.auxiliary_client import CodexAuxiliaryClient
assert isinstance(client, CodexAuxiliaryClient)
assert model == "gpt-5.2-codex"
def test_forced_codex_no_token(self, monkeypatch):
with patch("agent.auxiliary_client._read_codex_access_token", return_value=None):
client, model = _resolve_forced_provider("codex")
assert client is None
assert model is None
def test_forced_unknown_returns_none(self, monkeypatch):
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
patch("agent.auxiliary_client._read_codex_access_token", return_value=None):
client, model = _resolve_forced_provider("invalid-provider")
assert client is None
assert model is None
class TestTaskSpecificOverrides:
"""Integration tests for per-task provider routing via get_text_auxiliary_client(task=...)."""
@@ -1124,45 +1337,3 @@ class TestCallLlmPaymentFallback:
task="compression",
messages=[{"role": "user", "content": "hello"}],
)
# ---------------------------------------------------------------------------
# Gate: _resolve_api_key_provider must skip anthropic when not configured
# ---------------------------------------------------------------------------
def test_resolve_api_key_provider_skips_unconfigured_anthropic(monkeypatch):
"""_resolve_api_key_provider must not try anthropic when user never configured it."""
from collections import OrderedDict
from hermes_cli.auth import ProviderConfig
# Build a minimal registry with only "anthropic" so the loop is guaranteed
# to reach it without being short-circuited by earlier providers.
fake_registry = OrderedDict({
"anthropic": ProviderConfig(
id="anthropic",
name="Anthropic",
auth_type="api_key",
inference_base_url="https://api.anthropic.com",
api_key_env_vars=("ANTHROPIC_API_KEY",),
),
})
called = []
def mock_try_anthropic():
called.append("anthropic")
return None, None
monkeypatch.setattr("agent.auxiliary_client._try_anthropic", mock_try_anthropic)
monkeypatch.setattr("hermes_cli.auth.PROVIDER_REGISTRY", fake_registry)
monkeypatch.setattr(
"hermes_cli.auth.is_provider_explicitly_configured",
lambda pid: False,
)
from agent.auxiliary_client import _resolve_api_key_provider
_resolve_api_key_provider()
assert "anthropic" not in called, \
"_try_anthropic() should not be called when anthropic is not explicitly configured"
@@ -12,17 +12,6 @@ def _isolate(tmp_path, monkeypatch):
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
for env_var in (
"AUXILIARY_VISION_PROVIDER",
"AUXILIARY_VISION_MODEL",
"AUXILIARY_VISION_BASE_URL",
"AUXILIARY_VISION_API_KEY",
"CONTEXT_VISION_PROVIDER",
"CONTEXT_VISION_MODEL",
"CONTEXT_VISION_BASE_URL",
"CONTEXT_VISION_API_KEY",
):
monkeypatch.delenv(env_var, raising=False)
# Write a minimal config so load_config doesn't fail
(hermes_home / "config.yaml").write_text("model:\n default: test-model\n")
@@ -160,83 +149,3 @@ class TestResolveProviderClientNamedCustom:
# "coffee" doesn't exist in custom_providers
client, model = resolve_provider_client("coffee", "test")
assert client is None
class TestResolveProviderClientModelNormalization:
"""Direct-provider auxiliary routing should normalize models like main runtime."""
def test_matching_native_prefix_is_stripped_for_main_provider(self, tmp_path):
_write_config(tmp_path, {
"model": {"default": "zai/glm-5.1", "provider": "zai"},
})
with (
patch("hermes_cli.auth.resolve_api_key_provider_credentials", return_value={
"api_key": "glm-key",
"base_url": "https://api.z.ai/api/paas/v4",
}),
patch("agent.auxiliary_client.OpenAI") as mock_openai,
):
mock_openai.return_value = MagicMock()
from agent.auxiliary_client import resolve_provider_client
client, model = resolve_provider_client("main", "zai/glm-5.1")
assert client is not None
assert model == "glm-5.1"
def test_non_matching_prefix_is_preserved_for_direct_provider(self, tmp_path):
_write_config(tmp_path, {
"model": {"default": "zai/glm-5.1", "provider": "zai"},
})
with (
patch("hermes_cli.auth.resolve_api_key_provider_credentials", return_value={
"api_key": "glm-key",
"base_url": "https://api.z.ai/api/paas/v4",
}),
patch("agent.auxiliary_client.OpenAI") as mock_openai,
):
mock_openai.return_value = MagicMock()
from agent.auxiliary_client import resolve_provider_client
client, model = resolve_provider_client("zai", "google/gemini-2.5-pro")
assert client is not None
assert model == "google/gemini-2.5-pro"
def test_aggregator_vendor_slug_is_preserved(self, monkeypatch):
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
mock_openai.return_value = MagicMock()
from agent.auxiliary_client import resolve_provider_client
client, model = resolve_provider_client(
"openrouter", "anthropic/claude-sonnet-4.6"
)
assert client is not None
assert model == "anthropic/claude-sonnet-4.6"
class TestResolveVisionProviderClientModelNormalization:
"""Vision auto-routing should reuse the same provider-specific normalization."""
def test_vision_auto_strips_matching_main_provider_prefix(self, tmp_path):
_write_config(tmp_path, {
"model": {"default": "zai/glm-5.1", "provider": "zai"},
})
with (
patch("agent.auxiliary_client._read_nous_auth", return_value=None),
patch("hermes_cli.auth.resolve_api_key_provider_credentials", return_value={
"api_key": "glm-key",
"base_url": "https://api.z.ai/api/paas/v4",
}),
patch("agent.auxiliary_client.OpenAI") as mock_openai,
):
mock_openai.return_value = MagicMock()
from agent.auxiliary_client import resolve_vision_provider_client
provider, client, model = resolve_vision_provider_client()
assert provider == "zai"
assert client is not None
assert model == "glm-5.1"
+25
View File
@@ -38,6 +38,16 @@ class TestShouldCompress:
assert compressor.should_compress(prompt_tokens=50000) is False
class TestShouldCompressPreflight:
def test_short_messages(self, compressor):
msgs = [{"role": "user", "content": "short"}]
assert compressor.should_compress_preflight(msgs) is False
def test_long_messages(self, compressor):
# Each message ~100k chars / 4 = 25k tokens, need >85k threshold
msgs = [{"role": "user", "content": "x" * 400000}]
assert compressor.should_compress_preflight(msgs) is True
class TestUpdateFromResponse:
def test_updates_fields(self, compressor):
@@ -48,12 +58,27 @@ class TestUpdateFromResponse:
})
assert compressor.last_prompt_tokens == 5000
assert compressor.last_completion_tokens == 1000
assert compressor.last_total_tokens == 6000
def test_missing_fields_default_zero(self, compressor):
compressor.update_from_response({})
assert compressor.last_prompt_tokens == 0
class TestGetStatus:
def test_returns_expected_keys(self, compressor):
status = compressor.get_status()
assert "last_prompt_tokens" in status
assert "threshold_tokens" in status
assert "context_length" in status
assert "usage_percent" in status
assert "compression_count" in status
def test_usage_percent_calculation(self, compressor):
compressor.last_prompt_tokens = 50000
status = compressor.get_status()
assert status["usage_percent"] == 50.0
class TestCompress:
def _make_messages(self, n):
-42
View File
@@ -83,24 +83,6 @@ def test_parse_references_strips_trailing_punctuation():
assert refs[1].target == "https://example.com/docs"
def test_parse_quoted_references_with_spaces_and_preserve_unquoted_ranges():
from agent.context_references import parse_context_references
refs = parse_context_references(
'review @file:"C:\\Users\\Simba\\My Project\\main.py":7-9 '
'and @folder:"docs and specs" plus @file:src/main.py:1-2'
)
assert [ref.kind for ref in refs] == ["file", "folder", "file"]
assert refs[0].target == r"C:\Users\Simba\My Project\main.py"
assert refs[0].line_start == 7
assert refs[0].line_end == 9
assert refs[1].target == "docs and specs"
assert refs[2].target == "src/main.py"
assert refs[2].line_start == 1
assert refs[2].line_end == 2
def test_expand_file_range_and_folder_listing(sample_repo: Path):
from agent.context_references import preprocess_context_references
@@ -124,30 +106,6 @@ def test_expand_file_range_and_folder_listing(sample_repo: Path):
assert not result.warnings
def test_expand_quoted_file_reference_with_spaces(tmp_path: Path):
from agent.context_references import preprocess_context_references
workspace = tmp_path / "repo"
folder = workspace / "docs and specs"
folder.mkdir(parents=True)
file_path = folder / "release notes.txt"
file_path.write_text("line 1\nline 2\nline 3\n", encoding="utf-8")
result = preprocess_context_references(
'Review @file:"docs and specs/release notes.txt":2-3',
cwd=workspace,
context_length=100_000,
)
assert result.expanded
assert result.message.startswith("Review")
assert "line 1" not in result.message
assert "line 2" in result.message
assert "line 3" in result.message
assert "release notes.txt" in result.message
assert not result.warnings
def test_expand_git_diff_staged_and_log(sample_repo: Path):
from agent.context_references import preprocess_context_references
+52 -32
View File
@@ -567,7 +567,6 @@ def test_singleton_seed_does_not_clobber_manual_oauth_entry(tmp_path, monkeypatc
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
monkeypatch.setattr("hermes_cli.auth.is_provider_explicitly_configured", lambda pid: True)
_write_auth_store(
tmp_path,
{
@@ -703,6 +702,53 @@ def test_least_used_strategy_selects_lowest_count(tmp_path, monkeypatch):
assert entry.access_token == "sk-or-light"
def test_mark_used_increments_request_count(tmp_path, monkeypatch):
"""mark_used should increment the request_count of the current entry."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
monkeypatch.setattr(
"agent.credential_pool.get_pool_strategy",
lambda _provider: "fill_first",
)
monkeypatch.setattr(
"agent.credential_pool._seed_from_singletons",
lambda provider, entries: (False, set()),
)
monkeypatch.setattr(
"agent.credential_pool._seed_from_env",
lambda provider, entries: (False, set()),
)
_write_auth_store(
tmp_path,
{
"version": 1,
"credential_pool": {
"openrouter": [
{
"id": "key-a",
"label": "test",
"auth_type": "api_key",
"priority": 0,
"source": "manual",
"access_token": "sk-or-test",
"request_count": 5,
},
]
},
},
)
from agent.credential_pool import load_pool
pool = load_pool("openrouter")
entry = pool.select()
assert entry is not None
assert entry.request_count == 5
pool.mark_used()
updated = pool.current()
assert updated is not None
assert updated.request_count == 6
def test_thread_safety_concurrent_select(tmp_path, monkeypatch):
"""Concurrent select() calls should not corrupt pool state."""
import threading as _threading
@@ -752,6 +798,7 @@ def test_thread_safety_concurrent_select(tmp_path, monkeypatch):
entry = pool.select()
if entry:
results.append(entry.id)
pool.mark_used(entry.id)
except Exception as exc:
errors.append(exc)
@@ -1009,8 +1056,8 @@ def test_acquire_lease_prefers_unleased_entry(tmp_path, monkeypatch):
assert first == "cred-1"
assert second == "cred-2"
assert pool._active_leases.get("cred-1", 0) == 1
assert pool._active_leases.get("cred-2", 0) == 1
assert pool.active_lease_count("cred-1") == 1
assert pool.active_lease_count("cred-2") == 1
@@ -1040,34 +1087,7 @@ def test_release_lease_decrements_counter(tmp_path, monkeypatch):
pool = load_pool("openrouter")
leased = pool.acquire_lease()
assert leased == "cred-1"
assert pool._active_leases.get("cred-1", 0) == 1
assert pool.active_lease_count("cred-1") == 1
pool.release_lease("cred-1")
assert pool._active_leases.get("cred-1", 0) == 0
def test_load_pool_does_not_seed_claude_code_when_anthropic_not_configured(tmp_path, monkeypatch):
"""Claude Code credentials must not be auto-seeded when the user never selected anthropic."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
_write_auth_store(tmp_path, {"version": 1, "credential_pool": {}})
# Claude Code credentials exist on disk
monkeypatch.setattr(
"agent.anthropic_adapter.read_claude_code_credentials",
lambda: {"accessToken": "sk-ant...oken", "refreshToken": "rt", "expiresAt": 9999999999999},
)
monkeypatch.setattr(
"agent.anthropic_adapter.read_hermes_oauth_credentials",
lambda: None,
)
# User configured kimi-coding, NOT anthropic
monkeypatch.setattr(
"hermes_cli.auth.is_provider_explicitly_configured",
lambda pid: pid == "kimi-coding",
)
from agent.credential_pool import load_pool
pool = load_pool("anthropic")
# Should NOT have seeded the claude_code entry
assert pool.entries() == []
assert pool.active_lease_count("cred-1") == 0
+22 -16
View File
@@ -75,6 +75,28 @@ class TestClassifiedError:
e3 = ClassifiedError(reason=FailoverReason.billing)
assert e3.is_auth is False
def test_is_transient_property(self):
transient_reasons = [
FailoverReason.rate_limit,
FailoverReason.overloaded,
FailoverReason.server_error,
FailoverReason.timeout,
FailoverReason.unknown,
]
for reason in transient_reasons:
e = ClassifiedError(reason=reason)
assert e.is_transient is True, f"{reason} should be transient"
non_transient = [
FailoverReason.auth,
FailoverReason.billing,
FailoverReason.model_not_found,
FailoverReason.format_error,
]
for reason in non_transient:
e = ClassifiedError(reason=reason)
assert e.is_transient is False, f"{reason} should NOT be transient"
def test_defaults(self):
e = ClassifiedError(reason=FailoverReason.unknown)
assert e.retryable is True
@@ -249,22 +271,6 @@ class TestClassifyApiError:
assert result.reason == FailoverReason.rate_limit
assert result.should_fallback is True
def test_alibaba_rate_increased_too_quickly(self):
"""Alibaba/DashScope returns a unique throttling message.
Port from anomalyco/opencode#21355.
"""
msg = (
"Upstream error from Alibaba: Request rate increased too quickly. "
"To ensure system stability, please adjust your client logic to "
"scale requests more smoothly over time."
)
e = MockAPIError(msg, status_code=400)
result = classify_api_error(e)
assert result.reason == FailoverReason.rate_limit
assert result.retryable is True
assert result.should_rotate_credential is True
# ── Server errors ──
def test_500_server_error(self):
+40
View File
@@ -7,6 +7,7 @@ from pathlib import Path
from hermes_state import SessionDB
from agent.insights import (
InsightsEngine,
_get_pricing,
_estimate_cost,
_format_duration,
_bar_chart,
@@ -117,6 +118,45 @@ def populated_db(db):
return db
# =========================================================================
# Pricing helpers
# =========================================================================
class TestPricing:
def test_provider_prefix_stripped(self):
pricing = _get_pricing("anthropic/claude-sonnet-4-20250514")
assert pricing["input"] == 3.00
assert pricing["output"] == 15.00
def test_unknown_models_do_not_use_heuristics(self):
pricing = _get_pricing("some-new-opus-model")
assert pricing == _DEFAULT_PRICING
pricing = _get_pricing("anthropic/claude-haiku-future")
assert pricing == _DEFAULT_PRICING
def test_unknown_model_returns_zero_cost(self):
"""Unknown/custom models should NOT have fabricated costs."""
pricing = _get_pricing("totally-unknown-model-xyz")
assert pricing == _DEFAULT_PRICING
assert pricing["input"] == 0.0
assert pricing["output"] == 0.0
def test_custom_endpoint_model_zero_cost(self):
"""Self-hosted models should return zero cost."""
for model in ["FP16_Hermes_4.5", "Hermes_4.5_1T_epoch2", "my-local-llama"]:
pricing = _get_pricing(model)
assert pricing["input"] == 0.0, f"{model} should have zero cost"
assert pricing["output"] == 0.0, f"{model} should have zero cost"
def test_none_model(self):
pricing = _get_pricing(None)
assert pricing == _DEFAULT_PRICING
def test_empty_model(self):
pricing = _get_pricing("")
assert pricing == _DEFAULT_PRICING
class TestHasKnownPricing:
def test_known_commercial_model(self):
assert _has_known_pricing("gpt-4o", provider="openai") is True
+299
View File
@@ -0,0 +1,299 @@
"""End-to-end test: a SQLite-backed memory plugin exercising the full interface.
This proves a real plugin can register as a MemoryProvider and get wired
into the agent loop via MemoryManager. Uses SQLite + FTS5 (stdlib, no
external deps, no API keys).
"""
import json
import os
import sqlite3
import tempfile
import pytest
from unittest.mock import patch, MagicMock
from agent.memory_provider import MemoryProvider
from agent.memory_manager import MemoryManager
from agent.builtin_memory_provider import BuiltinMemoryProvider
# ---------------------------------------------------------------------------
# SQLite FTS5 memory provider — a real, minimal plugin implementation
# ---------------------------------------------------------------------------
class SQLiteMemoryProvider(MemoryProvider):
"""Minimal SQLite + FTS5 memory provider for testing.
Demonstrates the full MemoryProvider interface with a real backend.
No external dependencies just stdlib sqlite3.
"""
def __init__(self, db_path: str = ":memory:"):
self._db_path = db_path
self._conn = None
@property
def name(self) -> str:
return "sqlite_memory"
def is_available(self) -> bool:
return True # SQLite is always available
def initialize(self, session_id: str, **kwargs) -> None:
self._conn = sqlite3.connect(self._db_path)
self._conn.execute("PRAGMA journal_mode=WAL")
self._conn.execute("""
CREATE VIRTUAL TABLE IF NOT EXISTS memories
USING fts5(content, context, session_id)
""")
self._session_id = session_id
def system_prompt_block(self) -> str:
if not self._conn:
return ""
count = self._conn.execute("SELECT COUNT(*) FROM memories").fetchone()[0]
if count == 0:
return ""
return (
f"# SQLite Memory Plugin\n"
f"Active. {count} memories stored.\n"
f"Use sqlite_recall to search, sqlite_retain to store."
)
def prefetch(self, query: str, *, session_id: str = "") -> str:
if not self._conn or not query:
return ""
# FTS5 search
try:
rows = self._conn.execute(
"SELECT content FROM memories WHERE memories MATCH ? LIMIT 5",
(query,)
).fetchall()
if not rows:
return ""
results = [row[0] for row in rows]
return "## SQLite Memory\n" + "\n".join(f"- {r}" for r in results)
except sqlite3.OperationalError:
return ""
def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None:
if not self._conn:
return
combined = f"User: {user_content}\nAssistant: {assistant_content}"
self._conn.execute(
"INSERT INTO memories (content, context, session_id) VALUES (?, ?, ?)",
(combined, "conversation", self._session_id),
)
self._conn.commit()
def get_tool_schemas(self):
return [
{
"name": "sqlite_retain",
"description": "Store a fact to SQLite memory.",
"parameters": {
"type": "object",
"properties": {
"content": {"type": "string", "description": "What to remember"},
"context": {"type": "string", "description": "Category/context"},
},
"required": ["content"],
},
},
{
"name": "sqlite_recall",
"description": "Search SQLite memory.",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "Search query"},
},
"required": ["query"],
},
},
]
def handle_tool_call(self, tool_name: str, args: dict, **kwargs) -> str:
if tool_name == "sqlite_retain":
content = args.get("content", "")
context = args.get("context", "explicit")
if not content:
return json.dumps({"error": "content is required"})
self._conn.execute(
"INSERT INTO memories (content, context, session_id) VALUES (?, ?, ?)",
(content, context, self._session_id),
)
self._conn.commit()
return json.dumps({"result": "Stored."})
elif tool_name == "sqlite_recall":
query = args.get("query", "")
if not query:
return json.dumps({"error": "query is required"})
try:
rows = self._conn.execute(
"SELECT content, context FROM memories WHERE memories MATCH ? LIMIT 10",
(query,)
).fetchall()
results = [{"content": r[0], "context": r[1]} for r in rows]
return json.dumps({"results": results})
except sqlite3.OperationalError:
return json.dumps({"results": []})
return json.dumps({"error": f"Unknown tool: {tool_name}"})
def on_memory_write(self, action, target, content):
"""Mirror built-in memory writes to SQLite."""
if action == "add" and self._conn:
self._conn.execute(
"INSERT INTO memories (content, context, session_id) VALUES (?, ?, ?)",
(content, f"builtin_{target}", self._session_id),
)
self._conn.commit()
def shutdown(self):
if self._conn:
self._conn.close()
self._conn = None
# ---------------------------------------------------------------------------
# End-to-end tests
# ---------------------------------------------------------------------------
class TestSQLiteMemoryPlugin:
"""Full lifecycle test with the SQLite provider."""
def test_full_lifecycle(self):
"""Exercise init → store → recall → sync → prefetch → shutdown."""
mgr = MemoryManager()
builtin = BuiltinMemoryProvider()
sqlite_mem = SQLiteMemoryProvider()
mgr.add_provider(builtin)
mgr.add_provider(sqlite_mem)
# Initialize
mgr.initialize_all(session_id="test-session-1", platform="cli")
assert sqlite_mem._conn is not None
# System prompt — empty at first
prompt = mgr.build_system_prompt()
assert "SQLite Memory Plugin" not in prompt
# Store via tool call
result = json.loads(mgr.handle_tool_call(
"sqlite_retain", {"content": "User prefers dark mode", "context": "preference"}
))
assert result["result"] == "Stored."
# System prompt now shows count
prompt = mgr.build_system_prompt()
assert "1 memories stored" in prompt
# Recall via tool call
result = json.loads(mgr.handle_tool_call(
"sqlite_recall", {"query": "dark mode"}
))
assert len(result["results"]) == 1
assert "dark mode" in result["results"][0]["content"]
# Sync a turn (auto-stores conversation)
mgr.sync_all("What's my theme?", "You prefer dark mode.")
count = sqlite_mem._conn.execute("SELECT COUNT(*) FROM memories").fetchone()[0]
assert count == 2 # 1 explicit + 1 synced
# Prefetch for next turn
prefetched = mgr.prefetch_all("dark mode")
assert "dark mode" in prefetched
# Memory bridge — mirroring builtin writes
mgr.on_memory_write("add", "user", "Timezone: US Pacific")
count = sqlite_mem._conn.execute("SELECT COUNT(*) FROM memories").fetchone()[0]
assert count == 3
# Shutdown
mgr.shutdown_all()
assert sqlite_mem._conn is None
def test_tool_routing_with_builtin(self):
"""Verify builtin + plugin tools coexist without conflict."""
mgr = MemoryManager()
builtin = BuiltinMemoryProvider()
sqlite_mem = SQLiteMemoryProvider()
mgr.add_provider(builtin)
mgr.add_provider(sqlite_mem)
mgr.initialize_all(session_id="test-2")
# Builtin has no tools
assert len(builtin.get_tool_schemas()) == 0
# SQLite has 2 tools
schemas = mgr.get_all_tool_schemas()
names = {s["name"] for s in schemas}
assert names == {"sqlite_retain", "sqlite_recall"}
# Routing works
assert mgr.has_tool("sqlite_retain")
assert mgr.has_tool("sqlite_recall")
assert not mgr.has_tool("memory") # builtin doesn't register this
def test_second_external_plugin_rejected(self):
"""Only one external memory provider is allowed at a time."""
mgr = MemoryManager()
p1 = SQLiteMemoryProvider()
p2 = SQLiteMemoryProvider()
# Hack name for p2
p2._name_override = "sqlite_memory_2"
original_name = p2.__class__.name
type(p2).name = property(lambda self: getattr(self, '_name_override', 'sqlite_memory'))
mgr.add_provider(p1)
mgr.add_provider(p2) # should be rejected
# Only p1 was accepted
assert len(mgr.providers) == 1
assert mgr.provider_names == ["sqlite_memory"]
# Restore class
type(p2).name = original_name
mgr.shutdown_all()
def test_provider_failure_isolation(self):
"""Failing external provider doesn't break builtin."""
from agent.builtin_memory_provider import BuiltinMemoryProvider
mgr = MemoryManager()
builtin = BuiltinMemoryProvider() # name="builtin", always accepted
ext = SQLiteMemoryProvider()
mgr.add_provider(builtin)
mgr.add_provider(ext)
mgr.initialize_all(session_id="test-4")
# Break external provider's connection
ext._conn.close()
ext._conn = None
# Sync — external fails silently, builtin (no-op sync) succeeds
mgr.sync_all("user", "assistant") # should not raise
mgr.shutdown_all()
def test_plugin_registration_flow(self):
"""Simulate the full plugin load → agent init path."""
# Simulate what AIAgent.__init__ does via plugins/memory/ discovery
provider = SQLiteMemoryProvider()
mem_mgr = MemoryManager()
mem_mgr.add_provider(BuiltinMemoryProvider())
if provider.is_available():
mem_mgr.add_provider(provider)
mem_mgr.initialize_all(session_id="agent-session")
assert len(mem_mgr.providers) == 2
assert mem_mgr.provider_names == ["builtin", "sqlite_memory"]
assert provider._conn is not None # initialized = connection established
mem_mgr.shutdown_all()
+157 -4
View File
@@ -6,6 +6,8 @@ from unittest.mock import MagicMock, patch
from agent.memory_provider import MemoryProvider
from agent.memory_manager import MemoryManager
from agent.builtin_memory_provider import BuiltinMemoryProvider
# ---------------------------------------------------------------------------
# Concrete test provider
@@ -116,7 +118,7 @@ class TestMemoryManager:
def test_empty_manager(self):
mgr = MemoryManager()
assert mgr.providers == []
assert [p.name for p in mgr.providers] == []
assert mgr.provider_names == []
assert mgr.get_all_tool_schemas() == []
assert mgr.build_system_prompt() == ""
assert mgr.prefetch_all("test") == ""
@@ -126,7 +128,7 @@ class TestMemoryManager:
p = FakeMemoryProvider("test1")
mgr.add_provider(p)
assert len(mgr.providers) == 1
assert [p.name for p in mgr.providers] == ["test1"]
assert mgr.provider_names == ["test1"]
def test_get_provider_by_name(self):
mgr = MemoryManager()
@@ -141,7 +143,7 @@ class TestMemoryManager:
p2 = FakeMemoryProvider("external")
mgr.add_provider(p1)
mgr.add_provider(p2)
assert [p.name for p in mgr.providers] == ["builtin", "external"]
assert mgr.provider_names == ["builtin", "external"]
def test_second_external_rejected(self):
"""Only one non-builtin provider is allowed."""
@@ -152,7 +154,7 @@ class TestMemoryManager:
mgr.add_provider(builtin)
mgr.add_provider(ext1)
mgr.add_provider(ext2) # should be rejected
assert [p.name for p in mgr.providers] == ["builtin", "mem0"]
assert mgr.provider_names == ["builtin", "mem0"]
assert len(mgr.providers) == 2
def test_system_prompt_merges_blocks(self):
@@ -319,6 +321,17 @@ class TestMemoryManager:
mgr.on_pre_compress([{"role": "user", "content": "old"}])
assert p.pre_compress_called
def test_on_memory_write_skips_builtin(self):
"""on_memory_write should skip the builtin provider."""
mgr = MemoryManager()
builtin = BuiltinMemoryProvider()
external = FakeMemoryProvider("external")
mgr.add_provider(builtin)
mgr.add_provider(external)
mgr.on_memory_write("add", "memory", "test fact")
assert external.memory_writes == [("add", "memory", "test fact")]
def test_shutdown_all_reverse_order(self):
mgr = MemoryManager()
order = []
@@ -372,6 +385,146 @@ class TestMemoryManager:
assert result == "works fine"
# ---------------------------------------------------------------------------
# BuiltinMemoryProvider tests
# ---------------------------------------------------------------------------
class TestBuiltinMemoryProvider:
def test_name(self):
p = BuiltinMemoryProvider()
assert p.name == "builtin"
def test_always_available(self):
p = BuiltinMemoryProvider()
assert p.is_available()
def test_no_tools(self):
"""Builtin provider exposes no tools (memory tool is agent-level)."""
p = BuiltinMemoryProvider()
assert p.get_tool_schemas() == []
def test_system_prompt_with_store(self):
store = MagicMock()
store.format_for_system_prompt.side_effect = lambda t: f"BLOCK_{t}" if t == "memory" else f"BLOCK_{t}"
p = BuiltinMemoryProvider(
memory_store=store,
memory_enabled=True,
user_profile_enabled=True,
)
block = p.system_prompt_block()
assert "BLOCK_memory" in block
assert "BLOCK_user" in block
def test_system_prompt_memory_disabled(self):
store = MagicMock()
store.format_for_system_prompt.return_value = "content"
p = BuiltinMemoryProvider(
memory_store=store,
memory_enabled=False,
user_profile_enabled=False,
)
assert p.system_prompt_block() == ""
def test_system_prompt_no_store(self):
p = BuiltinMemoryProvider(memory_store=None, memory_enabled=True)
assert p.system_prompt_block() == ""
def test_prefetch_returns_empty(self):
p = BuiltinMemoryProvider()
assert p.prefetch("anything") == ""
def test_store_property(self):
store = MagicMock()
p = BuiltinMemoryProvider(memory_store=store)
assert p.store is store
def test_initialize_loads_from_disk(self):
store = MagicMock()
p = BuiltinMemoryProvider(memory_store=store)
p.initialize(session_id="test")
store.load_from_disk.assert_called_once()
# ---------------------------------------------------------------------------
# Plugin registration tests
# ---------------------------------------------------------------------------
class TestSingleProviderGating:
"""Only the configured provider should activate."""
def test_no_provider_configured_means_builtin_only(self):
"""When memory.provider is empty, no plugin providers activate."""
mgr = MemoryManager()
builtin = BuiltinMemoryProvider()
mgr.add_provider(builtin)
# Simulate what run_agent.py does when provider=""
configured = ""
available_plugins = [
FakeMemoryProvider("holographic"),
FakeMemoryProvider("mem0"),
]
# With empty config, no plugins should be added
if configured:
for p in available_plugins:
if p.name == configured and p.is_available():
mgr.add_provider(p)
assert mgr.provider_names == ["builtin"]
def test_configured_provider_activates(self):
"""Only the named provider should be added."""
mgr = MemoryManager()
builtin = BuiltinMemoryProvider()
mgr.add_provider(builtin)
configured = "holographic"
p1 = FakeMemoryProvider("holographic")
p2 = FakeMemoryProvider("mem0")
p3 = FakeMemoryProvider("hindsight")
for p in [p1, p2, p3]:
if p.name == configured and p.is_available():
mgr.add_provider(p)
assert mgr.provider_names == ["builtin", "holographic"]
assert p1.initialized is False # not initialized by the gating logic itself
def test_unavailable_provider_skipped(self):
"""If the configured provider is unavailable, it should be skipped."""
mgr = MemoryManager()
builtin = BuiltinMemoryProvider()
mgr.add_provider(builtin)
configured = "holographic"
p1 = FakeMemoryProvider("holographic", available=False)
for p in [p1]:
if p.name == configured and p.is_available():
mgr.add_provider(p)
assert mgr.provider_names == ["builtin"]
def test_nonexistent_provider_results_in_builtin_only(self):
"""If the configured name doesn't match any plugin, only builtin remains."""
mgr = MemoryManager()
builtin = BuiltinMemoryProvider()
mgr.add_provider(builtin)
configured = "nonexistent"
plugins = [FakeMemoryProvider("holographic"), FakeMemoryProvider("mem0")]
for p in plugins:
if p.name == configured and p.is_available():
mgr.add_provider(p)
assert mgr.provider_names == ["builtin"]
class TestPluginMemoryDiscovery:
"""Memory providers are discovered from plugins/memory/ directory."""
+56
View File
@@ -11,6 +11,7 @@ from agent.prompt_builder import (
_scan_context_content,
_truncate_content,
_parse_skill_file,
_read_skill_conditions,
_skill_should_show,
_find_hermes_md,
_find_git_root,
@@ -774,6 +775,61 @@ class TestPromptBuilderConstants:
# Conditional skill activation
# =========================================================================
class TestReadSkillConditions:
def test_no_conditions_returns_empty_lists(self, tmp_path):
skill_file = tmp_path / "SKILL.md"
skill_file.write_text("---\nname: test\ndescription: A skill\n---\n")
conditions = _read_skill_conditions(skill_file)
assert conditions["fallback_for_toolsets"] == []
assert conditions["requires_toolsets"] == []
assert conditions["fallback_for_tools"] == []
assert conditions["requires_tools"] == []
def test_reads_fallback_for_toolsets(self, tmp_path):
skill_file = tmp_path / "SKILL.md"
skill_file.write_text(
"---\nname: ddg\ndescription: DuckDuckGo\nmetadata:\n hermes:\n fallback_for_toolsets: [web]\n---\n"
)
conditions = _read_skill_conditions(skill_file)
assert conditions["fallback_for_toolsets"] == ["web"]
def test_reads_requires_toolsets(self, tmp_path):
skill_file = tmp_path / "SKILL.md"
skill_file.write_text(
"---\nname: openhue\ndescription: Hue lights\nmetadata:\n hermes:\n requires_toolsets: [terminal]\n---\n"
)
conditions = _read_skill_conditions(skill_file)
assert conditions["requires_toolsets"] == ["terminal"]
def test_reads_multiple_conditions(self, tmp_path):
skill_file = tmp_path / "SKILL.md"
skill_file.write_text(
"---\nname: test\ndescription: Test\nmetadata:\n hermes:\n fallback_for_toolsets: [browser]\n requires_tools: [terminal]\n---\n"
)
conditions = _read_skill_conditions(skill_file)
assert conditions["fallback_for_toolsets"] == ["browser"]
assert conditions["requires_tools"] == ["terminal"]
def test_missing_file_returns_empty(self, tmp_path):
conditions = _read_skill_conditions(tmp_path / "missing.md")
assert conditions == {}
def test_logs_condition_read_failures_and_returns_empty(self, tmp_path, monkeypatch, caplog):
skill_file = tmp_path / "SKILL.md"
skill_file.write_text("---\nname: broken\n---\n")
def boom(*args, **kwargs):
raise OSError("read exploded")
monkeypatch.setattr(type(skill_file), "read_text", boom)
with caplog.at_level(logging.DEBUG, logger="agent.prompt_builder"):
conditions = _read_skill_conditions(skill_file)
assert conditions == {}
assert "Failed to read skill conditions" in caplog.text
assert str(skill_file) in caplog.text
class TestSkillShouldShow:
def test_no_filter_info_always_shows(self):
assert _skill_should_show({}, None, None) is True
-85
View File
@@ -1,85 +0,0 @@
"""Tests for CLI /status command behavior."""
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from cli import HermesCLI
from hermes_cli.commands import resolve_command
def _make_cli():
cli_obj = HermesCLI.__new__(HermesCLI)
cli_obj.config = {}
cli_obj.console = MagicMock()
cli_obj.agent = None
cli_obj.conversation_history = []
cli_obj.session_id = "session-123"
cli_obj._pending_input = MagicMock()
cli_obj._status_bar_visible = True
cli_obj.model = "openai/gpt-5.4"
cli_obj.provider = "openai"
cli_obj.session_start = datetime(2026, 4, 9, 19, 24)
cli_obj._agent_running = False
cli_obj._session_db = MagicMock()
cli_obj._session_db.get_session.return_value = None
return cli_obj
def test_status_command_is_available_in_cli_registry():
cmd = resolve_command("status")
assert cmd is not None
assert cmd.gateway_only is False
def test_process_command_status_dispatches_without_toggling_status_bar():
cli_obj = _make_cli()
with patch.object(cli_obj, "_show_session_status", create=True) as mock_status:
assert cli_obj.process_command("/status") is True
mock_status.assert_called_once_with()
assert cli_obj._status_bar_visible is True
def test_statusbar_still_toggles_visibility():
cli_obj = _make_cli()
assert cli_obj.process_command("/statusbar") is True
assert cli_obj._status_bar_visible is False
def test_status_prefix_prefers_status_command_over_statusbar_toggle():
cli_obj = _make_cli()
with patch.object(cli_obj, "_show_session_status") as mock_status:
assert cli_obj.process_command("/sta") is True
mock_status.assert_called_once_with()
assert cli_obj._status_bar_visible is True
def test_show_session_status_prints_gateway_style_summary():
cli_obj = _make_cli()
cli_obj.agent = SimpleNamespace(
session_total_tokens=321,
session_api_calls=4,
)
cli_obj._session_db.get_session.return_value = {
"title": "My titled session",
"started_at": 1775791440,
}
with patch("cli.display_hermes_home", return_value="~/.hermes"):
cli_obj._show_session_status()
printed = "\n".join(str(call.args[0]) for call in cli_obj.console.print.call_args_list)
assert "Hermes CLI Status" in printed
assert "Session ID: session-123" in printed
assert "Path: ~/.hermes" in printed
assert "Title: My titled session" in printed
assert "Model: openai/gpt-5.4 (openai)" in printed
assert "Tokens: 321" in printed
assert "Agent Running: No" in printed
_, kwargs = cli_obj.console.print.call_args
assert kwargs.get("highlight") is False
assert kwargs.get("markup") is False
+10 -4
View File
@@ -619,14 +619,17 @@ class TestReasoningDeltasFiredFlag(unittest.TestCase):
agent = AIAgent.__new__(AIAgent)
agent.reasoning_callback = None
agent.stream_delta_callback = None
agent._reasoning_deltas_fired = False
agent.verbose_logging = False
return agent
def test_fire_reasoning_delta_calls_callback(self):
def test_fire_reasoning_delta_sets_flag(self):
agent = self._make_agent()
captured = []
agent.reasoning_callback = lambda t: captured.append(t)
self.assertFalse(agent._reasoning_deltas_fired)
agent._fire_reasoning_delta("thinking...")
self.assertTrue(agent._reasoning_deltas_fired)
self.assertEqual(captured, ["thinking..."])
def test_build_assistant_message_skips_callback_when_already_streamed(self):
@@ -637,7 +640,8 @@ class TestReasoningDeltasFiredFlag(unittest.TestCase):
agent.reasoning_callback = lambda t: captured.append(t)
agent.stream_delta_callback = lambda t: None # streaming is active
# Simulate streaming having already fired reasoning
# Simulate streaming having fired reasoning
agent._reasoning_deltas_fired = True
msg = SimpleNamespace(
content="I'll merge that.",
@@ -661,8 +665,9 @@ class TestReasoningDeltasFiredFlag(unittest.TestCase):
agent.reasoning_callback = lambda t: captured.append(t)
agent.stream_delta_callback = lambda t: None # streaming active
# Reasoning came through content tags, not reasoning_content deltas.
# Callback should not fire since streaming is active.
# Even though _reasoning_deltas_fired is False (reasoning came through
# content tags, not reasoning_content deltas), callback should not fire
agent._reasoning_deltas_fired = False
msg = SimpleNamespace(
content="I'll merge that.",
@@ -684,6 +689,7 @@ class TestReasoningDeltasFiredFlag(unittest.TestCase):
agent.reasoning_callback = lambda t: captured.append(t)
# No streaming
agent.stream_delta_callback = None
agent._reasoning_deltas_fired = False
msg = SimpleNamespace(
content="I'll merge that.",
+58 -150
View File
@@ -1,4 +1,4 @@
"""Shared fixtures for gateway e2e tests (Telegram, Discord).
"""Shared fixtures for Telegram gateway e2e tests.
These tests exercise the full async message flow:
adapter.handle_message(event)
@@ -14,22 +14,19 @@ import sys
import uuid
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from unittest.mock import AsyncMock, MagicMock
from gateway.config import GatewayConfig, Platform, PlatformConfig
from gateway.platforms.base import MessageEvent, SendResult
from gateway.session import SessionEntry, SessionSource, build_session_key
# Platform library mocks
#Ensure telegram module is available (mock it if not installed)
# Ensure telegram module is available (mock it if not installed)
def _ensure_telegram_mock():
"""Install mock telegram modules so TelegramAdapter can be imported."""
if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"):
return # Real library installed
return # Real library installed
telegram_mod = MagicMock()
telegram_mod.Update = MagicMock()
@@ -54,118 +51,24 @@ def _ensure_telegram_mock():
sys.modules.setdefault(name, telegram_mod)
# Ensure discord module is available (mock it if not installed)
def _ensure_discord_mock():
"""Install mock discord modules so DiscordAdapter can be imported."""
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
return # Real library installed
discord_mod = MagicMock()
discord_mod.Intents.default.return_value = MagicMock()
discord_mod.DMChannel = type("DMChannel", (), {})
discord_mod.Thread = type("Thread", (), {})
discord_mod.ForumChannel = type("ForumChannel", (), {})
discord_mod.Interaction = object
discord_mod.app_commands = SimpleNamespace(
describe=lambda **kwargs: (lambda fn: fn),
choices=lambda **kwargs: (lambda fn: fn),
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
)
discord_mod.opus.is_loaded.return_value = True
ext_mod = MagicMock()
commands_mod = MagicMock()
commands_mod.Bot = MagicMock
ext_mod.commands = commands_mod
sys.modules.setdefault("discord", discord_mod)
sys.modules.setdefault("discord.ext", ext_mod)
sys.modules.setdefault("discord.ext.commands", commands_mod)
sys.modules.setdefault("discord.opus", discord_mod.opus)
def _ensure_slack_mock():
"""Install mock slack modules so SlackAdapter can be imported."""
if "slack_bolt" in sys.modules and hasattr(sys.modules["slack_bolt"], "__file__"):
return # Real library installed
slack_bolt = MagicMock()
slack_bolt.async_app.AsyncApp = MagicMock
slack_bolt.adapter.socket_mode.async_handler.AsyncSocketModeHandler = MagicMock
slack_sdk = MagicMock()
slack_sdk.web.async_client.AsyncWebClient = MagicMock
for name, mod in [
("slack_bolt", slack_bolt),
("slack_bolt.async_app", slack_bolt.async_app),
("slack_bolt.adapter", slack_bolt.adapter),
("slack_bolt.adapter.socket_mode", slack_bolt.adapter.socket_mode),
("slack_bolt.adapter.socket_mode.async_handler", slack_bolt.adapter.socket_mode.async_handler),
("slack_sdk", slack_sdk),
("slack_sdk.web", slack_sdk.web),
("slack_sdk.web.async_client", slack_sdk.web.async_client),
]:
sys.modules.setdefault(name, mod)
_ensure_telegram_mock()
_ensure_discord_mock()
_ensure_slack_mock()
from gateway.platforms.discord import DiscordAdapter # noqa: E402
from gateway.platforms.telegram import TelegramAdapter # noqa: E402
import gateway.platforms.slack as _slack_mod # noqa: E402
_slack_mod.SLACK_AVAILABLE = True
from gateway.platforms.slack import SlackAdapter # noqa: E402
#GatewayRunner factory (based on tests/gateway/test_status_command.py)
# Platform-generic factories
def make_source(platform: Platform, chat_id: str = "e2e-chat-1", user_id: str = "e2e-user-1") -> SessionSource:
return SessionSource(
platform=platform,
chat_id=chat_id,
user_id=user_id,
user_name="e2e_tester",
chat_type="dm",
)
def make_session_entry(platform: Platform, source: SessionSource = None) -> SessionEntry:
source = source or make_source(platform)
return SessionEntry(
session_key=build_session_key(source),
session_id=f"sess-{uuid.uuid4().hex[:8]}",
created_at=datetime.now(),
updated_at=datetime.now(),
platform=platform,
chat_type="dm",
)
def make_event(platform: Platform, text: str = "/help", chat_id: str = "e2e-chat-1", user_id: str = "e2e-user-1") -> MessageEvent:
return MessageEvent(
text=text,
source=make_source(platform, chat_id, user_id),
message_id=f"msg-{uuid.uuid4().hex[:8]}",
)
def make_runner(platform: Platform, session_entry: SessionEntry = None) -> "GatewayRunner":
def make_runner(session_entry: SessionEntry) -> "GatewayRunner":
"""Create a GatewayRunner with mocked internals for e2e testing.
Skips __init__ to avoid filesystem/network side effects.
All command-dispatch dependencies are wired manually.
"""
from gateway.run import GatewayRunner
if session_entry is None:
session_entry = make_session_entry(platform)
runner = object.__new__(GatewayRunner)
runner.config = GatewayConfig(
platforms={platform: PlatformConfig(enabled=True, token="e2e-test-token")}
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="e2e-test-token")}
)
runner.adapters = {}
runner._voice_mode = {}
@@ -196,6 +99,7 @@ def make_runner(platform: Platform, session_entry: SessionEntry = None) -> "Gate
runner._capture_gateway_honcho_if_configured = lambda *a, **kw: None
runner._emit_gateway_run_progress = AsyncMock()
# Pairing store (used by authorization rejection path)
runner.pairing_store = MagicMock()
runner.pairing_store._is_rate_limited = MagicMock(return_value=False)
runner.pairing_store.generate_code = MagicMock(return_value="ABC123")
@@ -203,63 +107,67 @@ def make_runner(platform: Platform, session_entry: SessionEntry = None) -> "Gate
return runner
def make_adapter(platform: Platform, runner=None):
"""Create a platform adapter wired to *runner*, with send methods mocked."""
if runner is None:
runner = make_runner(platform)
#TelegramAdapter factory
def make_adapter(runner) -> TelegramAdapter:
"""Create a TelegramAdapter wired to *runner*, with send methods mocked.
connect() is NOT called no polling, no token lock, no real HTTP.
"""
config = PlatformConfig(enabled=True, token="e2e-test-token")
adapter = TelegramAdapter(config)
if platform == Platform.DISCORD:
with patch.object(DiscordAdapter, "_load_participated_threads", return_value=set()):
adapter = DiscordAdapter(config)
platform_key = Platform.DISCORD
elif platform == Platform.SLACK:
adapter = SlackAdapter(config)
platform_key = Platform.SLACK
else:
adapter = TelegramAdapter(config)
platform_key = Platform.TELEGRAM
# Mock outbound methods so tests can capture what was sent
adapter.send = AsyncMock(return_value=SendResult(success=True, message_id="e2e-resp-1"))
adapter.send_typing = AsyncMock()
# Wire adapter ↔ runner
adapter.set_message_handler(runner._handle_message)
runner.adapters[platform_key] = adapter
runner.adapters[Platform.TELEGRAM] = adapter
return adapter
async def send_and_capture(adapter, text: str, platform: Platform, **event_kwargs) -> AsyncMock:
"""Send a message through the full e2e flow and return the send mock."""
event = make_event(platform, text, **event_kwargs)
#Helpers
def make_source(chat_id: str = "e2e-chat-1", user_id: str = "e2e-user-1") -> SessionSource:
return SessionSource(
platform=Platform.TELEGRAM,
chat_id=chat_id,
user_id=user_id,
user_name="e2e_tester",
chat_type="dm",
)
def make_event(text: str, chat_id: str = "e2e-chat-1", user_id: str = "e2e-user-1") -> MessageEvent:
return MessageEvent(
text=text,
source=make_source(chat_id, user_id),
message_id=f"msg-{uuid.uuid4().hex[:8]}",
)
def make_session_entry(source: SessionSource = None) -> SessionEntry:
source = source or make_source()
return SessionEntry(
session_key=build_session_key(source),
session_id=f"sess-{uuid.uuid4().hex[:8]}",
created_at=datetime.now(),
updated_at=datetime.now(),
platform=Platform.TELEGRAM,
chat_type="dm",
)
async def send_and_capture(adapter: TelegramAdapter, text: str, **event_kwargs) -> AsyncMock:
"""Send a message through the full e2e flow and return the send mock.
Drives: adapter.handle_message background task runner dispatch adapter.send.
"""
event = make_event(text, **event_kwargs)
adapter.send.reset_mock()
await adapter.handle_message(event)
# Let the background task complete
await asyncio.sleep(0.3)
return adapter.send
# Parametrized fixtures for platform-generic tests
@pytest.fixture(params=[Platform.TELEGRAM, Platform.DISCORD, Platform.SLACK], ids=["telegram", "discord", "slack"])
def platform(request):
return request.param
@pytest.fixture()
def source(platform):
return make_source(platform)
@pytest.fixture()
def session_entry(platform, source):
return make_session_entry(platform, source)
@pytest.fixture()
def runner(platform, session_entry):
return make_runner(platform, session_entry)
@pytest.fixture()
def adapter(platform, runner):
return make_adapter(platform, runner)
@@ -1,4 +1,4 @@
"""E2E tests for gateway slash commands (Telegram, Discord).
"""E2E tests for Telegram gateway slash commands.
Each test drives a message through the full async pipeline:
adapter.handle_message(event)
@@ -7,7 +7,6 @@ Each test drives a message through the full async pipeline:
adapter.send() (captured for assertions)
No LLM involved only gateway-level commands are tested.
Tests are parametrized over platforms via the ``platform`` fixture in conftest.
"""
import asyncio
@@ -16,15 +15,46 @@ from unittest.mock import AsyncMock
import pytest
from gateway.platforms.base import SendResult
from tests.e2e.conftest import make_event, send_and_capture
from tests.e2e.conftest import (
make_adapter,
make_event,
make_runner,
make_session_entry,
make_source,
send_and_capture,
)
class TestSlashCommands:
#Fixtures
@pytest.fixture()
def source():
return make_source()
@pytest.fixture()
def session_entry(source):
return make_session_entry(source)
@pytest.fixture()
def runner(session_entry):
return make_runner(session_entry)
@pytest.fixture()
def adapter(runner):
return make_adapter(runner)
#Tests
class TestTelegramSlashCommands:
"""Gateway slash commands dispatched through the full adapter pipeline."""
@pytest.mark.asyncio
async def test_help_returns_command_list(self, adapter, platform):
send = await send_and_capture(adapter, "/help", platform)
async def test_help_returns_command_list(self, adapter):
send = await send_and_capture(adapter, "/help")
send.assert_called_once()
response_text = send.call_args[1].get("content") or send.call_args[0][1]
@@ -32,23 +62,24 @@ class TestSlashCommands:
assert "/status" in response_text
@pytest.mark.asyncio
async def test_status_shows_session_info(self, adapter, platform):
send = await send_and_capture(adapter, "/status", platform)
async def test_status_shows_session_info(self, adapter):
send = await send_and_capture(adapter, "/status")
send.assert_called_once()
response_text = send.call_args[1].get("content") or send.call_args[0][1]
# Status output includes session metadata
assert "session" in response_text.lower() or "Session" in response_text
@pytest.mark.asyncio
async def test_new_resets_session(self, adapter, runner, platform):
send = await send_and_capture(adapter, "/new", platform)
async def test_new_resets_session(self, adapter, runner):
send = await send_and_capture(adapter, "/new")
send.assert_called_once()
runner.session_store.reset_session.assert_called_once()
@pytest.mark.asyncio
async def test_stop_when_no_agent_running(self, adapter, platform):
send = await send_and_capture(adapter, "/stop", platform)
async def test_stop_when_no_agent_running(self, adapter):
send = await send_and_capture(adapter, "/stop")
send.assert_called_once()
response_text = send.call_args[1].get("content") or send.call_args[0][1]
@@ -56,8 +87,8 @@ class TestSlashCommands:
assert "no" in response_lower or "stop" in response_lower or "not running" in response_lower
@pytest.mark.asyncio
async def test_commands_shows_listing(self, adapter, platform):
send = await send_and_capture(adapter, "/commands", platform)
async def test_commands_shows_listing(self, adapter):
send = await send_and_capture(adapter, "/commands")
send.assert_called_once()
response_text = send.call_args[1].get("content") or send.call_args[0][1]
@@ -65,25 +96,29 @@ class TestSlashCommands:
assert "/" in response_text
@pytest.mark.asyncio
async def test_sequential_commands_share_session(self, adapter, platform):
async def test_sequential_commands_share_session(self, adapter):
"""Two commands from the same chat_id should both succeed."""
send_help = await send_and_capture(adapter, "/help", platform)
send_help = await send_and_capture(adapter, "/help")
send_help.assert_called_once()
send_status = await send_and_capture(adapter, "/status", platform)
send_status = await send_and_capture(adapter, "/status")
send_status.assert_called_once()
@pytest.mark.asyncio
async def test_provider_shows_current_provider(self, adapter, platform):
send = await send_and_capture(adapter, "/provider", platform)
@pytest.mark.xfail(
reason="Bug: _handle_provider_command references unbound model_cfg when config.yaml is absent",
strict=False,
)
async def test_provider_shows_current_provider(self, adapter):
send = await send_and_capture(adapter, "/provider")
send.assert_called_once()
response_text = send.call_args[1].get("content") or send.call_args[0][1]
assert "provider" in response_text.lower()
@pytest.mark.asyncio
async def test_verbose_responds(self, adapter, platform):
send = await send_and_capture(adapter, "/verbose", platform)
async def test_verbose_responds(self, adapter):
send = await send_and_capture(adapter, "/verbose")
send.assert_called_once()
response_text = send.call_args[1].get("content") or send.call_args[0][1]
@@ -91,50 +126,42 @@ class TestSlashCommands:
assert "verbose" in response_text.lower() or "tool_progress" in response_text
@pytest.mark.asyncio
async def test_personality_lists_options(self, adapter, platform):
send = await send_and_capture(adapter, "/personality", platform)
async def test_personality_lists_options(self, adapter):
send = await send_and_capture(adapter, "/personality")
send.assert_called_once()
response_text = send.call_args[1].get("content") or send.call_args[0][1]
assert "personalit" in response_text.lower() # matches "personality" or "personalities"
@pytest.mark.asyncio
async def test_yolo_toggles_mode(self, adapter, platform):
send = await send_and_capture(adapter, "/yolo", platform)
async def test_yolo_toggles_mode(self, adapter):
send = await send_and_capture(adapter, "/yolo")
send.assert_called_once()
response_text = send.call_args[1].get("content") or send.call_args[0][1]
assert "yolo" in response_text.lower()
@pytest.mark.asyncio
async def test_compress_command(self, adapter, platform):
send = await send_and_capture(adapter, "/compress", platform)
send.assert_called_once()
response_text = send.call_args[1].get("content") or send.call_args[0][1]
assert "compress" in response_text.lower() or "context" in response_text.lower()
class TestSessionLifecycle:
"""Verify session state changes across command sequences."""
@pytest.mark.asyncio
async def test_new_then_status_reflects_reset(self, adapter, runner, session_entry, platform):
async def test_new_then_status_reflects_reset(self, adapter, runner, session_entry):
"""After /new, /status should report the fresh session."""
await send_and_capture(adapter, "/new", platform)
await send_and_capture(adapter, "/new")
runner.session_store.reset_session.assert_called_once()
send = await send_and_capture(adapter, "/status", platform)
send = await send_and_capture(adapter, "/status")
send.assert_called_once()
response_text = send.call_args[1].get("content") or send.call_args[0][1]
# Session ID from the entry should appear in the status output
assert session_entry.session_id[:8] in response_text
@pytest.mark.asyncio
async def test_new_is_idempotent(self, adapter, runner, platform):
async def test_new_is_idempotent(self, adapter, runner):
"""/new called twice should not crash."""
await send_and_capture(adapter, "/new", platform)
await send_and_capture(adapter, "/new", platform)
await send_and_capture(adapter, "/new")
await send_and_capture(adapter, "/new")
assert runner.session_store.reset_session.call_count == 2
@@ -142,11 +169,11 @@ class TestAuthorization:
"""Verify the pipeline handles unauthorized users."""
@pytest.mark.asyncio
async def test_unauthorized_user_gets_pairing_response(self, adapter, runner, platform):
async def test_unauthorized_user_gets_pairing_response(self, adapter, runner):
"""Unauthorized DM should trigger pairing code, not a command response."""
runner._is_user_authorized = lambda _source: False
event = make_event(platform, "/help")
event = make_event("/help")
adapter.send.reset_mock()
await adapter.handle_message(event)
await asyncio.sleep(0.3)
@@ -158,11 +185,11 @@ class TestAuthorization:
assert "recognize" in response_text.lower() or "pair" in response_text.lower() or "ABC123" in response_text
@pytest.mark.asyncio
async def test_unauthorized_user_does_not_get_help(self, adapter, runner, platform):
async def test_unauthorized_user_does_not_get_help(self, adapter, runner):
"""Unauthorized user should NOT see the help command output."""
runner._is_user_authorized = lambda _source: False
event = make_event(platform, "/help")
event = make_event("/help")
adapter.send.reset_mock()
await adapter.handle_message(event)
await asyncio.sleep(0.3)
@@ -177,12 +204,12 @@ class TestSendFailureResilience:
"""Verify the pipeline handles send failures gracefully."""
@pytest.mark.asyncio
async def test_send_failure_does_not_crash_pipeline(self, adapter, platform):
async def test_send_failure_does_not_crash_pipeline(self, adapter):
"""If send() returns failure, the pipeline should not raise."""
adapter.send = AsyncMock(return_value=SendResult(success=False, error="network timeout"))
adapter.set_message_handler(adapter._message_handler) # re-wire with same handler
adapter.set_message_handler(adapter._message_handler) # re-wire with same handler
event = make_event(platform, "/help")
event = make_event("/help")
# Should not raise — pipeline handles send failures internally
await adapter.handle_message(event)
await asyncio.sleep(0.3)
-132
View File
@@ -1,132 +0,0 @@
"""Tests for the API server bind-address startup guard.
Validates that is_network_accessible() correctly classifies addresses and
that connect() refuses to start on non-loopback without API_SERVER_KEY.
"""
import socket
from unittest.mock import AsyncMock, patch
import pytest
from gateway.config import PlatformConfig
from gateway.platforms.api_server import APIServerAdapter
from gateway.platforms.base import is_network_accessible
# ---------------------------------------------------------------------------
# Unit tests: is_network_accessible()
# ---------------------------------------------------------------------------
class TestIsNetworkAccessible:
"""Direct tests for the address classification helper."""
# -- Loopback (safe, should return False) --
def test_ipv4_loopback(self):
assert is_network_accessible("127.0.0.1") is False
def test_ipv6_loopback(self):
assert is_network_accessible("::1") is False
def test_ipv4_mapped_loopback(self):
# ::ffff:127.0.0.1 — Python's is_loopback returns False for mapped
# addresses; the helper must unwrap and check ipv4_mapped.
assert is_network_accessible("::ffff:127.0.0.1") is False
# -- Network-accessible (should return True) --
def test_ipv4_wildcard(self):
assert is_network_accessible("0.0.0.0") is True
def test_ipv6_wildcard(self):
# This is the bypass vector that the string-based check missed.
assert is_network_accessible("::") is True
def test_ipv4_mapped_unspecified(self):
assert is_network_accessible("::ffff:0.0.0.0") is True
def test_private_ipv4(self):
assert is_network_accessible("10.0.0.1") is True
def test_private_ipv4_class_c(self):
assert is_network_accessible("192.168.1.1") is True
def test_public_ipv4(self):
assert is_network_accessible("8.8.8.8") is True
# -- Hostname resolution --
def test_localhost_resolves_to_loopback(self):
loopback_result = [
(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("127.0.0.1", 0)),
]
with patch("gateway.platforms.base._socket.getaddrinfo", return_value=loopback_result):
assert is_network_accessible("localhost") is False
def test_hostname_resolving_to_non_loopback(self):
non_loopback_result = [
(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("10.0.0.1", 0)),
]
with patch("gateway.platforms.base._socket.getaddrinfo", return_value=non_loopback_result):
assert is_network_accessible("my-server.local") is True
def test_hostname_mixed_resolution(self):
"""If a hostname resolves to both loopback and non-loopback, it's
network-accessible (any non-loopback address is enough)."""
mixed_result = [
(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("127.0.0.1", 0)),
(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("10.0.0.1", 0)),
]
with patch("gateway.platforms.base._socket.getaddrinfo", return_value=mixed_result):
assert is_network_accessible("dual-host.local") is True
def test_dns_failure_fails_closed(self):
"""Unresolvable hostnames should require an API key (fail closed)."""
with patch(
"gateway.platforms.base._socket.getaddrinfo",
side_effect=socket.gaierror("Name resolution failed"),
):
assert is_network_accessible("nonexistent.invalid") is True
# ---------------------------------------------------------------------------
# Integration tests: connect() startup guard
# ---------------------------------------------------------------------------
class TestConnectBindGuard:
"""Verify that connect() refuses dangerous configurations."""
@pytest.mark.asyncio
async def test_refuses_ipv4_wildcard_without_key(self):
adapter = APIServerAdapter(PlatformConfig(enabled=True, extra={"host": "0.0.0.0"}))
result = await adapter.connect()
assert result is False
@pytest.mark.asyncio
async def test_refuses_ipv6_wildcard_without_key(self):
adapter = APIServerAdapter(PlatformConfig(enabled=True, extra={"host": "::"}))
result = await adapter.connect()
assert result is False
def test_allows_loopback_without_key(self):
"""Loopback with no key should pass the guard."""
adapter = APIServerAdapter(PlatformConfig(enabled=True, extra={"host": "127.0.0.1"}))
assert adapter._api_key == ""
# The guard condition: is_network_accessible(host) AND NOT api_key
# For loopback, is_network_accessible is False so the guard does not block.
assert is_network_accessible(adapter._host) is False
@pytest.mark.asyncio
async def test_allows_wildcard_with_key(self):
"""Non-loopback with a key should pass the guard."""
adapter = APIServerAdapter(
PlatformConfig(enabled=True, extra={"host": "0.0.0.0", "key": "sk-test"})
)
# The guard checks: is_network_accessible(host) AND NOT api_key
# With a key set, the guard should not block.
assert adapter._api_key == "sk-test"
assert is_network_accessible("0.0.0.0") is True
# Combined: the guard condition is False (key is set), so it passes
+33 -4
View File
@@ -141,7 +141,7 @@ class TestBlockingGatewayApproval:
def test_resolve_single_pops_oldest_fifo(self):
"""resolve_gateway_approval without resolve_all resolves oldest first."""
from tools.approval import (
resolve_gateway_approval,
resolve_gateway_approval, pending_approval_count,
_ApprovalEntry, _gateway_queues,
)
session_key = "test-fifo"
@@ -154,7 +154,7 @@ class TestBlockingGatewayApproval:
assert e1.event.is_set()
assert e1.result == "once"
assert not e2.event.is_set()
assert len(_gateway_queues[session_key]) == 1
assert pending_approval_count(session_key) == 1
def test_unregister_signals_all_entries(self):
"""unregister_gateway_notify signals all waiting entries to prevent hangs."""
@@ -173,6 +173,35 @@ class TestBlockingGatewayApproval:
assert e1.event.is_set()
assert e2.event.is_set()
def test_clear_session_signals_all_entries(self):
"""clear_session should unblock all waiting approval threads."""
from tools.approval import (
register_gateway_notify, clear_session,
_ApprovalEntry, _gateway_queues,
)
session_key = "test-clear"
register_gateway_notify(session_key, lambda d: None)
e1 = _ApprovalEntry({"command": "cmd1"})
e2 = _ApprovalEntry({"command": "cmd2"})
_gateway_queues[session_key] = [e1, e2]
clear_session(session_key)
assert e1.event.is_set()
assert e2.event.is_set()
def test_pending_approval_count(self):
from tools.approval import (
pending_approval_count, _ApprovalEntry, _gateway_queues,
)
session_key = "test-count"
assert pending_approval_count(session_key) == 0
_gateway_queues[session_key] = [
_ApprovalEntry({"command": "a"}),
_ApprovalEntry({"command": "b"}),
]
assert pending_approval_count(session_key) == 2
# ------------------------------------------------------------------
# /approve command
@@ -477,7 +506,7 @@ class TestBlockingApprovalE2E:
from tools.approval import (
register_gateway_notify, unregister_gateway_notify,
resolve_gateway_approval, check_all_command_guards,
_gateway_queues,
pending_approval_count,
)
session_key = "e2e-parallel"
@@ -516,7 +545,7 @@ class TestBlockingApprovalE2E:
time.sleep(0.05)
assert len(notified) == 3
assert len(_gateway_queues.get(session_key, [])) == 3
assert pending_approval_count(session_key) == 3
# Approve all at once
count = resolve_gateway_approval(session_key, "session", resolve_all=True)
-1
View File
@@ -308,7 +308,6 @@ class TestBackgroundInCLICommands:
def test_background_autocompletes(self):
"""The /background command appears in autocomplete results."""
pytest.importorskip("prompt_toolkit")
from hermes_cli.commands import SlashCommandCompleter
from prompt_toolkit.document import Document
+7 -33
View File
@@ -6,7 +6,7 @@ from types import SimpleNamespace
import pytest
from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, ProcessingOutcome, SendResult
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult
from gateway.session import SessionSource, build_session_key
@@ -44,8 +44,8 @@ class DummyTelegramAdapter(BasePlatformAdapter):
async def on_processing_start(self, event: MessageEvent) -> None:
self.processing_hooks.append(("start", event.message_id))
async def on_processing_complete(self, event: MessageEvent, outcome: ProcessingOutcome) -> None:
self.processing_hooks.append(("complete", event.message_id, outcome))
async def on_processing_complete(self, event: MessageEvent, success: bool) -> None:
self.processing_hooks.append(("complete", event.message_id, success))
def _make_event(chat_id: str, thread_id: str, message_id: str = "1") -> MessageEvent:
@@ -142,7 +142,7 @@ class TestBasePlatformTopicSessions:
]
assert adapter.processing_hooks == [
("start", "1"),
("complete", "1", ProcessingOutcome.SUCCESS),
("complete", "1", True),
]
@pytest.mark.asyncio
@@ -168,7 +168,7 @@ class TestBasePlatformTopicSessions:
assert adapter.processing_hooks == [
("start", "1"),
("complete", "1", ProcessingOutcome.FAILURE),
("complete", "1", False),
]
@pytest.mark.asyncio
@@ -190,7 +190,7 @@ class TestBasePlatformTopicSessions:
assert adapter.processing_hooks == [
("start", "1"),
("complete", "1", ProcessingOutcome.FAILURE),
("complete", "1", False),
]
@pytest.mark.asyncio
@@ -218,31 +218,5 @@ class TestBasePlatformTopicSessions:
assert adapter.processing_hooks == [
("start", "1"),
("complete", "1", ProcessingOutcome.FAILURE),
]
@pytest.mark.asyncio
async def test_cancel_background_tasks_marks_expected_cancellation_cancelled(self):
adapter = DummyTelegramAdapter()
release = asyncio.Event()
async def handler(_event):
await release.wait()
return "ack"
async def hold_typing(_chat_id, interval=2.0, metadata=None):
await asyncio.Event().wait()
adapter.set_message_handler(handler)
adapter._keep_typing = hold_typing
event = _make_event("-1001", "17585")
await adapter.handle_message(event)
await asyncio.sleep(0)
await adapter.cancel_background_tasks()
assert adapter.processing_hooks == [
("start", "1"),
("complete", "1", ProcessingOutcome.CANCELLED),
("complete", "1", False),
]
@@ -160,22 +160,6 @@ class TestCommandBypassActiveSession:
assert sk not in adapter._pending_messages
assert any("handled:status" in r for r in adapter.sent_responses)
@pytest.mark.asyncio
async def test_background_bypasses_guard(self):
"""/background must bypass so it spawns a parallel task, not an interrupt."""
adapter = _make_adapter()
sk = _session_key()
adapter._active_sessions[sk] = asyncio.Event()
await adapter.handle_message(_make_event("/background summarize HN"))
assert sk not in adapter._pending_messages, (
"/background was queued as a pending message instead of being dispatched"
)
assert any("handled:background" in r for r in adapter.sent_responses), (
"/background response was not sent back to the user"
)
# ---------------------------------------------------------------------------
# Tests: non-bypass messages still get queued
+30 -2
View File
@@ -1,7 +1,7 @@
"""Tests for the delivery routing module."""
from gateway.config import Platform
from gateway.delivery import DeliveryTarget
from gateway.config import Platform, GatewayConfig, PlatformConfig, HomeChannel
from gateway.delivery import DeliveryRouter, DeliveryTarget, parse_deliver_spec
from gateway.session import SessionSource
@@ -41,6 +41,28 @@ class TestParseTargetPlatformChat:
assert target.platform == Platform.LOCAL
class TestParseDeliverSpec:
def test_none_returns_default(self):
result = parse_deliver_spec(None)
assert result == "origin"
def test_empty_string_returns_default(self):
result = parse_deliver_spec("")
assert result == "origin"
def test_custom_default(self):
result = parse_deliver_spec(None, default="local")
assert result == "local"
def test_passthrough_string(self):
result = parse_deliver_spec("telegram")
assert result == "telegram"
def test_passthrough_list(self):
result = parse_deliver_spec(["local", "telegram"])
assert result == ["local", "telegram"]
class TestTargetToStringRoundtrip:
def test_origin_roundtrip(self):
origin = SessionSource(platform=Platform.TELEGRAM, chat_id="111", thread_id="42")
@@ -65,4 +87,10 @@ class TestTargetToStringRoundtrip:
assert reparsed.chat_id == "999"
class TestDeliveryRouter:
def test_resolve_targets_does_not_duplicate_local_when_explicit(self):
router = DeliveryRouter(GatewayConfig(always_log_local=True))
targets = router.resolve_targets(["local"])
assert [target.platform for target in targets] == [Platform.LOCAL]
@@ -1,64 +0,0 @@
"""Tests for Discord channel_skill_bindings auto-skill resolution."""
from unittest.mock import MagicMock
import pytest
def _make_adapter():
"""Create a minimal DiscordAdapter with mocked config."""
from gateway.platforms.discord import DiscordAdapter
adapter = object.__new__(DiscordAdapter)
adapter.config = MagicMock()
adapter.config.extra = {}
return adapter
class TestResolveChannelSkills:
def test_no_bindings_returns_none(self):
adapter = _make_adapter()
assert adapter._resolve_channel_skills("123") is None
def test_match_by_channel_id(self):
adapter = _make_adapter()
adapter.config.extra = {
"channel_skill_bindings": [
{"id": "100", "skills": ["skill-a", "skill-b"]},
]
}
assert adapter._resolve_channel_skills("100") == ["skill-a", "skill-b"]
def test_match_by_parent_id(self):
adapter = _make_adapter()
adapter.config.extra = {
"channel_skill_bindings": [
{"id": "200", "skills": ["forum-skill"]},
]
}
# channel_id doesn't match, but parent_id does (forum thread)
assert adapter._resolve_channel_skills("999", parent_id="200") == ["forum-skill"]
def test_no_match_returns_none(self):
adapter = _make_adapter()
adapter.config.extra = {
"channel_skill_bindings": [
{"id": "100", "skills": ["skill-a"]},
]
}
assert adapter._resolve_channel_skills("999") is None
def test_single_skill_string(self):
adapter = _make_adapter()
adapter.config.extra = {
"channel_skill_bindings": [
{"id": "100", "skill": "solo-skill"},
]
}
assert adapter._resolve_channel_skills("100") == ["solo-skill"]
def test_dedup_preserves_order(self):
adapter = _make_adapter()
adapter.config.extra = {
"channel_skill_bindings": [
{"id": "100", "skills": ["a", "b", "a", "c", "b"]},
]
}
assert adapter._resolve_channel_skills("100") == ["a", "b", "c"]
+2 -16
View File
@@ -8,7 +8,7 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import MessageEvent, MessageType, ProcessingOutcome, SendResult
from gateway.platforms.base import MessageEvent, MessageType, SendResult
from gateway.session import SessionSource, build_session_key
@@ -212,7 +212,7 @@ async def test_reactions_disabled_via_env_zero(adapter, monkeypatch):
event = _make_event("5", raw_message)
await adapter.on_processing_start(event)
await adapter.on_processing_complete(event, ProcessingOutcome.SUCCESS)
await adapter.on_processing_complete(event, success=True)
raw_message.add_reaction.assert_not_awaited()
raw_message.remove_reaction.assert_not_awaited()
@@ -232,17 +232,3 @@ async def test_reactions_enabled_by_default(adapter, monkeypatch):
await adapter.on_processing_start(event)
raw_message.add_reaction.assert_awaited_once_with("👀")
@pytest.mark.asyncio
async def test_on_processing_complete_cancelled_removes_eyes_without_terminal_reaction(adapter):
raw_message = SimpleNamespace(
add_reaction=AsyncMock(),
remove_reaction=AsyncMock(),
)
event = _make_event("7", raw_message)
await adapter.on_processing_complete(event, ProcessingOutcome.CANCELLED)
raw_message.remove_reaction.assert_awaited_once_with("👀", adapter._client.user)
raw_message.add_reaction.assert_not_awaited()
-191
View File
@@ -1,191 +0,0 @@
"""Tests for gateway /fast support and Priority Processing routing."""
import sys
import threading
import types
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
import pytest
import yaml
import gateway.run as gateway_run
from gateway.config import Platform
from gateway.platforms.base import MessageEvent
from gateway.session import SessionSource
class _CapturingAgent:
last_init = None
last_run = None
def __init__(self, *args, **kwargs):
type(self).last_init = dict(kwargs)
self.tools = []
def run_conversation(self, user_message, conversation_history=None, task_id=None, persist_user_message=None):
type(self).last_run = {
"user_message": user_message,
"conversation_history": conversation_history,
"task_id": task_id,
"persist_user_message": persist_user_message,
}
return {
"final_response": "ok",
"messages": [],
"api_calls": 1,
"completed": True,
}
def _install_fake_agent(monkeypatch):
fake_run_agent = types.ModuleType("run_agent")
fake_run_agent.AIAgent = _CapturingAgent
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
def _make_runner():
runner = object.__new__(gateway_run.GatewayRunner)
runner.adapters = {}
runner._ephemeral_system_prompt = ""
runner._prefill_messages = []
runner._reasoning_config = None
runner._service_tier = None
runner._provider_routing = {}
runner._fallback_model = None
runner._smart_model_routing = {}
runner._running_agents = {}
runner._pending_model_notes = {}
runner._session_db = None
runner._agent_cache = {}
runner._agent_cache_lock = threading.Lock()
runner._session_model_overrides = {}
runner.hooks = SimpleNamespace(loaded_hooks=False)
runner.config = SimpleNamespace(streaming=None)
runner.session_store = SimpleNamespace(
get_or_create_session=lambda source: SimpleNamespace(session_id="session-1"),
load_transcript=lambda session_id: [],
)
runner._get_or_create_gateway_honcho = lambda session_key: (None, None)
runner._enrich_message_with_vision = AsyncMock(return_value="ENRICHED")
return runner
def _make_source() -> SessionSource:
return SessionSource(
platform=Platform.TELEGRAM,
chat_id="12345",
chat_type="dm",
user_id="user-1",
)
def _make_event(text: str) -> MessageEvent:
return MessageEvent(text=text, source=_make_source(), message_id="m1")
def test_turn_route_injects_priority_processing_without_changing_runtime():
runner = _make_runner()
runner._service_tier = "priority"
runtime_kwargs = {
"api_key": "***",
"base_url": "https://openrouter.ai/api/v1",
"provider": "openrouter",
"api_mode": "chat_completions",
"command": None,
"args": [],
"credential_pool": None,
}
with patch("agent.smart_model_routing.resolve_turn_route", return_value={
"model": "gpt-5.4",
"runtime": dict(runtime_kwargs),
"label": None,
"signature": ("gpt-5.4", "openrouter", "https://openrouter.ai/api/v1", "chat_completions", None, ()),
}):
route = gateway_run.GatewayRunner._resolve_turn_agent_config(runner, "hi", "gpt-5.4", runtime_kwargs)
assert route["runtime"]["provider"] == "openrouter"
assert route["runtime"]["api_mode"] == "chat_completions"
assert route["request_overrides"] == {"service_tier": "priority"}
def test_turn_route_skips_priority_processing_for_unsupported_models():
runner = _make_runner()
runner._service_tier = "priority"
runtime_kwargs = {
"api_key": "***",
"base_url": "https://openrouter.ai/api/v1",
"provider": "openrouter",
"api_mode": "chat_completions",
"command": None,
"args": [],
"credential_pool": None,
}
with patch("agent.smart_model_routing.resolve_turn_route", return_value={
"model": "gpt-5.3-codex",
"runtime": dict(runtime_kwargs),
"label": None,
"signature": ("gpt-5.3-codex", "openrouter", "https://openrouter.ai/api/v1", "chat_completions", None, ()),
}):
route = gateway_run.GatewayRunner._resolve_turn_agent_config(runner, "hi", "gpt-5.3-codex", runtime_kwargs)
assert route["request_overrides"] is None
@pytest.mark.asyncio
async def test_handle_fast_command_persists_config(monkeypatch, tmp_path):
runner = _make_runner()
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.setattr(gateway_run, "_load_gateway_config", lambda: {})
monkeypatch.setattr(gateway_run, "_resolve_gateway_model", lambda config=None: "gpt-5.4")
response = await runner._handle_fast_command(_make_event("/fast fast"))
assert "FAST" in response
assert runner._service_tier == "priority"
saved = yaml.safe_load((tmp_path / "config.yaml").read_text(encoding="utf-8"))
assert saved["agent"]["service_tier"] == "fast"
@pytest.mark.asyncio
async def test_run_agent_passes_priority_processing_to_gateway_agent(monkeypatch, tmp_path):
_install_fake_agent(monkeypatch)
runner = _make_runner()
(tmp_path / "config.yaml").write_text("agent:\n service_tier: fast\n", encoding="utf-8")
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.setattr(gateway_run, "_env_path", tmp_path / ".env")
monkeypatch.setattr(gateway_run, "load_dotenv", lambda *args, **kwargs: None)
monkeypatch.setattr(gateway_run, "_load_gateway_config", lambda: {})
monkeypatch.setattr(gateway_run, "_resolve_gateway_model", lambda config=None: "gpt-5.4")
monkeypatch.setattr(
gateway_run,
"_resolve_runtime_agent_kwargs",
lambda: {
"provider": "openrouter",
"api_mode": "chat_completions",
"base_url": "https://openrouter.ai/api/v1",
"api_key": "***",
},
)
import hermes_cli.tools_config as tools_config
monkeypatch.setattr(tools_config, "_get_platform_tools", lambda user_config, platform_key: {"core"})
_CapturingAgent.last_init = None
result = await runner._run_agent(
message="hi",
context_prompt="",
history=[],
source=_make_source(),
session_id="session-1",
session_key="agent:main:telegram:dm:12345",
)
assert result["final_response"] == "ok"
assert _CapturingAgent.last_init["service_tier"] == "priority"
assert _CapturingAgent.last_init["request_overrides"] == {"service_tier": "priority"}
@@ -128,16 +128,12 @@ async def test_internal_event_bypasses_authorization(monkeypatch, tmp_path):
monkeypatch.setattr(GatewayRunner, "_is_user_authorized", tracking_auth)
# Stop execution before the agent runner so the test doesn't block in
# run_in_executor. Auth check happens before _handle_message_with_agent.
async def _raise(*_a, **_kw):
raise RuntimeError("sentinel — stop here")
monkeypatch.setattr(GatewayRunner, "_handle_message_with_agent", _raise)
# _handle_message will proceed past auth check and eventually fail on
# downstream logic. We just need to verify auth is skipped.
try:
await runner._handle_message(event)
except RuntimeError:
pass # Expected sentinel
except Exception:
pass # Expected — downstream code needs more setup
assert not auth_called, (
"_is_user_authorized should NOT be called for internal events"
@@ -179,16 +175,10 @@ async def test_internal_event_does_not_trigger_pairing(monkeypatch, tmp_path):
runner.pairing_store.generate_code = tracking_generate
# Stop execution before the agent runner so the test doesn't block in
# run_in_executor. Pairing check happens before _handle_message_with_agent.
async def _raise(*_a, **_kw):
raise RuntimeError("sentinel — stop here")
monkeypatch.setattr(GatewayRunner, "_handle_message_with_agent", _raise)
try:
await runner._handle_message(event)
except RuntimeError:
pass # Expected sentinel
except Exception:
pass # Expected — downstream code needs more setup
assert not generate_called, (
"Pairing code should NOT be generated for internal events"
+8 -76
View File
@@ -1943,7 +1943,7 @@ class TestMatrixReactions:
with patch.dict("sys.modules", {"nio": fake_nio}):
result = await self.adapter._send_reaction("!room:ex", "$event1", "👍")
assert result == "$reaction1"
assert result is True
mock_client.room_send.assert_called_once()
args = mock_client.room_send.call_args
assert args[0][1] == "m.reaction"
@@ -1956,77 +1956,13 @@ class TestMatrixReactions:
self.adapter._client = None
with patch.dict("sys.modules", {"nio": _make_fake_nio()}):
result = await self.adapter._send_reaction("!room:ex", "$ev", "👍")
assert result is None
assert result is False
@pytest.mark.asyncio
async def test_on_processing_start_sends_eyes(self):
"""on_processing_start should send 👀 reaction."""
from gateway.platforms.base import MessageEvent, MessageType
self.adapter._reactions_enabled = True
self.adapter._send_reaction = AsyncMock(return_value="$reaction_event_123")
source = MagicMock()
source.chat_id = "!room:ex"
event = MessageEvent(
text="hello",
message_type=MessageType.TEXT,
source=source,
raw_message={},
message_id="$msg1",
)
await self.adapter.on_processing_start(event)
self.adapter._send_reaction.assert_called_once_with("!room:ex", "$msg1", "👀")
assert self.adapter._pending_reactions == {("!room:ex", "$msg1"): "$reaction_event_123"}
@pytest.mark.asyncio
async def test_on_processing_complete_sends_check(self):
from gateway.platforms.base import MessageEvent, MessageType, ProcessingOutcome
self.adapter._reactions_enabled = True
self.adapter._pending_reactions = {("!room:ex", "$msg1"): "$eyes_reaction_123"}
self.adapter._redact_reaction = AsyncMock(return_value=True)
self.adapter._send_reaction = AsyncMock(return_value="$check_reaction_456")
source = MagicMock()
source.chat_id = "!room:ex"
event = MessageEvent(
text="hello",
message_type=MessageType.TEXT,
source=source,
raw_message={},
message_id="$msg1",
)
await self.adapter.on_processing_complete(event, ProcessingOutcome.SUCCESS)
self.adapter._redact_reaction.assert_called_once_with("!room:ex", "$eyes_reaction_123")
self.adapter._send_reaction.assert_called_once_with("!room:ex", "$msg1", "")
@pytest.mark.asyncio
async def test_on_processing_complete_sends_cross_on_failure(self):
from gateway.platforms.base import MessageEvent, MessageType, ProcessingOutcome
self.adapter._reactions_enabled = True
self.adapter._pending_reactions = {("!room:ex", "$msg1"): "$eyes_reaction_123"}
self.adapter._redact_reaction = AsyncMock(return_value=True)
self.adapter._send_reaction = AsyncMock(return_value="$cross_reaction_456")
source = MagicMock()
source.chat_id = "!room:ex"
event = MessageEvent(
text="hello",
message_type=MessageType.TEXT,
source=source,
raw_message={},
message_id="$msg1",
)
await self.adapter.on_processing_complete(event, ProcessingOutcome.FAILURE)
self.adapter._redact_reaction.assert_called_once_with("!room:ex", "$eyes_reaction_123")
self.adapter._send_reaction.assert_called_once_with("!room:ex", "$msg1", "")
@pytest.mark.asyncio
async def test_on_processing_complete_cancelled_sends_no_terminal_reaction(self):
from gateway.platforms.base import MessageEvent, MessageType, ProcessingOutcome
self.adapter._reactions_enabled = True
self.adapter._send_reaction = AsyncMock(return_value=True)
@@ -2039,18 +1975,15 @@ class TestMatrixReactions:
raw_message={},
message_id="$msg1",
)
await self.adapter.on_processing_complete(event, ProcessingOutcome.CANCELLED)
self.adapter._send_reaction.assert_not_called()
await self.adapter.on_processing_start(event)
self.adapter._send_reaction.assert_called_once_with("!room:ex", "$msg1", "👀")
@pytest.mark.asyncio
async def test_on_processing_complete_no_pending_reaction(self):
"""on_processing_complete should skip redaction if no eyes reaction was tracked."""
from gateway.platforms.base import MessageEvent, MessageType, ProcessingOutcome
async def test_on_processing_complete_sends_check(self):
from gateway.platforms.base import MessageEvent, MessageType
self.adapter._reactions_enabled = True
self.adapter._pending_reactions = {}
self.adapter._redact_reaction = AsyncMock()
self.adapter._send_reaction = AsyncMock(return_value="$check_reaction_789")
self.adapter._send_reaction = AsyncMock(return_value=True)
source = MagicMock()
source.chat_id = "!room:ex"
@@ -2061,8 +1994,7 @@ class TestMatrixReactions:
raw_message={},
message_id="$msg1",
)
await self.adapter.on_processing_complete(event, ProcessingOutcome.SUCCESS)
self.adapter._redact_reaction.assert_not_called()
await self.adapter.on_processing_complete(event, success=True)
self.adapter._send_reaction.assert_called_once_with("!room:ex", "$msg1", "")
@pytest.mark.asyncio
-108
View File
@@ -436,95 +436,6 @@ class TestThreadPersistence:
assert len(data) == 5
# ---------------------------------------------------------------------------
# DM mention-thread feature
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_dm_mention_thread_disabled_by_default(monkeypatch):
"""Default (dm_mention_threads=false): DM with mention should NOT create a thread."""
monkeypatch.delenv("MATRIX_DM_MENTION_THREADS", raising=False)
monkeypatch.setenv("MATRIX_AUTO_THREAD", "false")
adapter = _make_adapter()
room = _make_room(member_count=2)
event = _make_event("@hermes:example.org help me", event_id="$dm1")
await adapter._on_room_message(room, event)
adapter.handle_message.assert_awaited_once()
msg = adapter.handle_message.await_args.args[0]
assert msg.source.thread_id is None
@pytest.mark.asyncio
async def test_dm_mention_thread_creates_thread(monkeypatch):
"""MATRIX_DM_MENTION_THREADS=true: DM with @mention creates a thread."""
monkeypatch.setenv("MATRIX_DM_MENTION_THREADS", "true")
monkeypatch.setenv("MATRIX_AUTO_THREAD", "false")
adapter = _make_adapter()
room = _make_room(member_count=2)
event = _make_event("@hermes:example.org help me", event_id="$dm1")
with patch.object(adapter, "_save_participated_threads"):
await adapter._on_room_message(room, event)
adapter.handle_message.assert_awaited_once()
msg = adapter.handle_message.await_args.args[0]
assert msg.source.thread_id == "$dm1"
assert msg.text == "help me"
@pytest.mark.asyncio
async def test_dm_mention_thread_no_mention_no_thread(monkeypatch):
"""MATRIX_DM_MENTION_THREADS=true: DM without mention does NOT create a thread."""
monkeypatch.setenv("MATRIX_DM_MENTION_THREADS", "true")
monkeypatch.setenv("MATRIX_AUTO_THREAD", "false")
adapter = _make_adapter()
room = _make_room(member_count=2)
event = _make_event("hello without mention", event_id="$dm1")
await adapter._on_room_message(room, event)
adapter.handle_message.assert_awaited_once()
msg = adapter.handle_message.await_args.args[0]
assert msg.source.thread_id is None
@pytest.mark.asyncio
async def test_dm_mention_thread_preserves_existing_thread(monkeypatch):
"""MATRIX_DM_MENTION_THREADS=true: DM already in a thread keeps that thread_id."""
monkeypatch.setenv("MATRIX_DM_MENTION_THREADS", "true")
monkeypatch.setenv("MATRIX_AUTO_THREAD", "false")
adapter = _make_adapter()
adapter._bot_participated_threads.add("$existing_thread")
room = _make_room(member_count=2)
event = _make_event("@hermes:example.org help me", thread_id="$existing_thread")
await adapter._on_room_message(room, event)
adapter.handle_message.assert_awaited_once()
msg = adapter.handle_message.await_args.args[0]
assert msg.source.thread_id == "$existing_thread"
@pytest.mark.asyncio
async def test_dm_mention_thread_tracks_participation(monkeypatch):
"""DM mention-thread tracks the thread in _bot_participated_threads."""
monkeypatch.setenv("MATRIX_DM_MENTION_THREADS", "true")
monkeypatch.setenv("MATRIX_AUTO_THREAD", "false")
adapter = _make_adapter()
room = _make_room(member_count=2)
event = _make_event("@hermes:example.org help", event_id="$dm1")
with patch.object(adapter, "_save_participated_threads"):
await adapter._on_room_message(room, event)
assert "$dm1" in adapter._bot_participated_threads
# ---------------------------------------------------------------------------
# YAML config bridge
# ---------------------------------------------------------------------------
@@ -569,25 +480,6 @@ class TestMatrixConfigBridge:
assert os.getenv("MATRIX_FREE_RESPONSE_ROOMS") == "!room1:example.org,!room2:example.org"
assert os.getenv("MATRIX_AUTO_THREAD") == "false"
def test_yaml_bridge_sets_dm_mention_threads(self, monkeypatch, tmp_path):
"""Matrix YAML dm_mention_threads should bridge to env var."""
monkeypatch.delenv("MATRIX_DM_MENTION_THREADS", raising=False)
import os
import yaml
yaml_content = {"matrix": {"dm_mention_threads": True}}
config_file = tmp_path / "config.yaml"
config_file.write_text(yaml.dump(yaml_content))
yaml_cfg = yaml.safe_load(config_file.read_text())
matrix_cfg = yaml_cfg.get("matrix", {})
if isinstance(matrix_cfg, dict):
if "dm_mention_threads" in matrix_cfg and not os.getenv("MATRIX_DM_MENTION_THREADS"):
monkeypatch.setenv("MATRIX_DM_MENTION_THREADS", str(matrix_cfg["dm_mention_threads"]).lower())
assert os.getenv("MATRIX_DM_MENTION_THREADS") == "true"
def test_env_vars_take_precedence_over_yaml(self, monkeypatch):
"""Env vars should not be overwritten by YAML values."""
monkeypatch.setenv("MATRIX_REQUIRE_MENTION", "true")
+4 -202
View File
@@ -34,45 +34,6 @@ def _make_timeout_error() -> httpx.TimeoutException:
return httpx.TimeoutException("timed out")
# ---------------------------------------------------------------------------
# cache_image_from_bytes (base.py)
# ---------------------------------------------------------------------------
class TestCacheImageFromBytes:
"""Tests for gateway.platforms.base.cache_image_from_bytes"""
def test_caches_valid_jpeg(self, tmp_path, monkeypatch):
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
from gateway.platforms.base import cache_image_from_bytes
path = cache_image_from_bytes(b"\xff\xd8\xff fake jpeg data", ".jpg")
assert path.endswith(".jpg")
def test_caches_valid_png(self, tmp_path, monkeypatch):
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
from gateway.platforms.base import cache_image_from_bytes
path = cache_image_from_bytes(b"\x89PNG\r\n\x1a\n fake png data", ".png")
assert path.endswith(".png")
def test_rejects_html_content(self, tmp_path, monkeypatch):
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
from gateway.platforms.base import cache_image_from_bytes
with pytest.raises(ValueError, match="non-image data"):
cache_image_from_bytes(b"<!DOCTYPE html><html><title>Slack</title></html>", ".png")
def test_rejects_empty_data(self, tmp_path, monkeypatch):
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
from gateway.platforms.base import cache_image_from_bytes
with pytest.raises(ValueError, match="non-image data"):
cache_image_from_bytes(b"", ".jpg")
def test_rejects_plain_text(self, tmp_path, monkeypatch):
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
from gateway.platforms.base import cache_image_from_bytes
with pytest.raises(ValueError, match="non-image data"):
cache_image_from_bytes(b"just some text, not an image", ".jpg")
# ---------------------------------------------------------------------------
# cache_image_from_url (base.py)
# ---------------------------------------------------------------------------
@@ -110,7 +71,7 @@ class TestCacheImageFromUrl:
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
fake_response = MagicMock()
fake_response.content = b"\xff\xd8\xff image data"
fake_response.content = b"image data"
fake_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
@@ -140,7 +101,7 @@ class TestCacheImageFromUrl:
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
ok_response = MagicMock()
ok_response.content = b"\xff\xd8\xff image data"
ok_response.content = b"image data"
ok_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
@@ -376,134 +337,6 @@ class TestCacheAudioFromUrl:
mock_sleep.assert_not_called()
# ---------------------------------------------------------------------------
# SSRF redirect guard tests (base.py)
# ---------------------------------------------------------------------------
class TestSSRFRedirectGuard:
"""cache_image_from_url / cache_audio_from_url must reject redirects
that land on private/internal hosts (e.g. cloud metadata endpoint)."""
def _make_redirect_response(self, target_url: str):
"""Build a mock httpx response that looks like a redirect."""
resp = MagicMock()
resp.is_redirect = True
resp.next_request = MagicMock(url=target_url)
return resp
def _make_client_capturing_hooks(self):
"""Return (mock_client, captured_kwargs dict) where captured_kwargs
will contain the kwargs passed to httpx.AsyncClient()."""
captured = {}
mock_client = AsyncMock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
def factory(*args, **kwargs):
captured.update(kwargs)
return mock_client
return mock_client, captured, factory
def test_image_blocks_private_redirect(self, tmp_path, monkeypatch):
"""cache_image_from_url rejects a redirect to a private IP."""
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
redirect_resp = self._make_redirect_response(
"http://169.254.169.254/latest/meta-data"
)
mock_client, captured, factory = self._make_client_capturing_hooks()
async def fake_get(_url, **kwargs):
# Simulate httpx calling the response event hooks
for hook in captured["event_hooks"]["response"]:
await hook(redirect_resp)
mock_client.get = AsyncMock(side_effect=fake_get)
def fake_safe(url):
return url == "https://public.example.com/image.png"
async def run():
with patch("tools.url_safety.is_safe_url", side_effect=fake_safe), \
patch("httpx.AsyncClient", side_effect=factory):
from gateway.platforms.base import cache_image_from_url
await cache_image_from_url(
"https://public.example.com/image.png", ext=".png"
)
with pytest.raises(ValueError, match="Blocked redirect"):
asyncio.run(run())
def test_audio_blocks_private_redirect(self, tmp_path, monkeypatch):
"""cache_audio_from_url rejects a redirect to a private IP."""
monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio")
redirect_resp = self._make_redirect_response(
"http://10.0.0.1/internal/secrets"
)
mock_client, captured, factory = self._make_client_capturing_hooks()
async def fake_get(_url, **kwargs):
for hook in captured["event_hooks"]["response"]:
await hook(redirect_resp)
mock_client.get = AsyncMock(side_effect=fake_get)
def fake_safe(url):
return url == "https://public.example.com/voice.ogg"
async def run():
with patch("tools.url_safety.is_safe_url", side_effect=fake_safe), \
patch("httpx.AsyncClient", side_effect=factory):
from gateway.platforms.base import cache_audio_from_url
await cache_audio_from_url(
"https://public.example.com/voice.ogg", ext=".ogg"
)
with pytest.raises(ValueError, match="Blocked redirect"):
asyncio.run(run())
def test_safe_redirect_allowed(self, tmp_path, monkeypatch):
"""A redirect to a public IP is allowed through."""
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
redirect_resp = self._make_redirect_response(
"https://cdn.example.com/real-image.png"
)
ok_response = MagicMock()
ok_response.content = b"\xff\xd8\xff fake jpeg"
ok_response.raise_for_status = MagicMock()
ok_response.is_redirect = False
mock_client, captured, factory = self._make_client_capturing_hooks()
call_count = 0
async def fake_get(_url, **kwargs):
nonlocal call_count
call_count += 1
# First call triggers redirect hook, second returns data
for hook in captured["event_hooks"]["response"]:
await hook(redirect_resp if call_count == 1 else ok_response)
return ok_response
mock_client.get = AsyncMock(side_effect=fake_get)
async def run():
with patch("tools.url_safety.is_safe_url", return_value=True), \
patch("httpx.AsyncClient", side_effect=factory):
from gateway.platforms.base import cache_image_from_url
return await cache_image_from_url(
"https://public.example.com/image.png", ext=".jpg"
)
path = asyncio.run(run())
assert path.endswith(".jpg")
# ---------------------------------------------------------------------------
# Slack mock setup (mirrors existing test_slack.py approach)
# ---------------------------------------------------------------------------
@@ -562,9 +395,8 @@ class TestSlackDownloadSlackFile:
adapter = _make_slack_adapter()
fake_response = MagicMock()
fake_response.content = b"\x89PNG\r\n\x1a\n fake png"
fake_response.content = b"fake image bytes"
fake_response.raise_for_status = MagicMock()
fake_response.headers = {"content-type": "image/png"}
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=fake_response)
@@ -581,44 +413,14 @@ class TestSlackDownloadSlackFile:
assert path.endswith(".jpg")
mock_client.get.assert_called_once()
def test_rejects_html_response(self, tmp_path, monkeypatch):
"""An HTML sign-in page from Slack is rejected, not cached as image."""
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
adapter = _make_slack_adapter()
fake_response = MagicMock()
fake_response.content = b"<!DOCTYPE html><html><title>Slack</title></html>"
fake_response.raise_for_status = MagicMock()
fake_response.headers = {"content-type": "text/html; charset=utf-8"}
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=fake_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
async def run():
with patch("httpx.AsyncClient", return_value=mock_client):
await adapter._download_slack_file(
"https://files.slack.com/img.jpg", ext=".jpg"
)
with pytest.raises(ValueError, match="HTML instead of media"):
asyncio.run(run())
# Verify nothing was cached
img_dir = tmp_path / "img"
if img_dir.exists():
assert list(img_dir.iterdir()) == []
def test_retries_on_timeout_then_succeeds(self, tmp_path, monkeypatch):
"""Timeout on first attempt triggers retry; success on second."""
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
adapter = _make_slack_adapter()
fake_response = MagicMock()
fake_response.content = b"\x89PNG\r\n\x1a\n image bytes"
fake_response.content = b"image bytes"
fake_response.raise_for_status = MagicMock()
fake_response.headers = {"content-type": "image/png"}
mock_client = AsyncMock()
mock_client.get = AsyncMock(
+9
View File
@@ -7,6 +7,7 @@ from gateway.session import (
_hash_id,
_hash_sender_id,
_hash_chat_id,
_looks_like_phone,
)
from gateway.config import Platform, HomeChannel
@@ -38,6 +39,14 @@ class TestHashHelpers:
assert len(result) == 12
assert "12345" not in result
def test_looks_like_phone(self):
assert _looks_like_phone("+15551234567")
assert _looks_like_phone("15551234567")
assert _looks_like_phone("+1-555-123-4567")
assert not _looks_like_phone("alice")
assert not _looks_like_phone("user-123")
assert not _looks_like_phone("")
# ---------------------------------------------------------------------------
# Integration: build_session_context_prompt

Some files were not shown because too many files have changed in this diff Show More