Compare commits

..

1 Commits

Author SHA1 Message Date
cokemine fe2bf02b57 fix(models): correct probed_url selection logic
Updated the logic for determining the probed_url in the probe_api_models function to use the first tried URL instead of the last. This change ensures that the most relevant URL is returned when probing for models. Additionally, improved the output message in the _model_flow_custom function to provide clearer guidance based on the suggested_base_url.
2026-04-09 02:26:38 -07:00
145 changed files with 3913 additions and 6414 deletions
+2 -2
View File
@@ -27,8 +27,8 @@ jobs:
with:
python-version: '3.11'
- name: Install ascii-guard
run: python -m pip install ascii-guard==2.3.0 pyyaml==6.0.3
- name: Install Python dependencies
run: python -m pip install ascii-guard pyyaml
- name: Extract skill metadata for dashboard
run: python3 website/scripts/extract-skills.py
+2 -2
View File
@@ -27,8 +27,8 @@ jobs:
timeout-minutes: 30
steps:
- uses: actions/checkout@v4
- uses: DeterminateSystems/nix-installer-action@ef8a148080ab6020fd15196c2084a2eea5ff2d25 # v22
- uses: DeterminateSystems/magic-nix-cache-action@565684385bcd71bad329742eefe8d12f2e765b39 # v13
- uses: DeterminateSystems/nix-installer-action@main
- uses: DeterminateSystems/magic-nix-cache-action@main
- name: Check flake
if: runner.os == 'Linux'
run: nix flake check --print-build-logs
-3
View File
@@ -1,8 +1,5 @@
FROM debian:13.4
# Disable Python stdout buffering to ensure logs are printed immediately
ENV PYTHONUNBUFFERED=1
# Install system dependencies in one layer, clear APT cache
RUN apt-get update && \
apt-get install -y --no-install-recommends \
+83 -27
View File
@@ -485,6 +485,35 @@ def _prefer_refreshable_claude_code_token(env_token: str, creds: Optional[Dict[s
return None
def get_anthropic_token_source(token: Optional[str] = None) -> str:
"""Best-effort source classification for an Anthropic credential token."""
token = (token or "").strip()
if not token:
return "none"
env_token = os.getenv("ANTHROPIC_TOKEN", "").strip()
if env_token and env_token == token:
return "anthropic_token_env"
cc_env_token = os.getenv("CLAUDE_CODE_OAUTH_TOKEN", "").strip()
if cc_env_token and cc_env_token == token:
return "claude_code_oauth_token_env"
creds = read_claude_code_credentials()
if creds and creds.get("accessToken") == token:
return str(creds.get("source") or "claude_code_credentials")
managed_key = read_claude_managed_key()
if managed_key and managed_key == token:
return "claude_json_primary_api_key"
api_key = os.getenv("ANTHROPIC_API_KEY", "").strip()
if api_key and api_key == token:
return "anthropic_api_key_env"
return "unknown"
def resolve_anthropic_token() -> Optional[str]:
"""Resolve an Anthropic token from all available sources.
@@ -691,6 +720,21 @@ def run_hermes_oauth_login_pure() -> Optional[Dict[str, Any]]:
}
def _save_hermes_oauth_credentials(access_token: str, refresh_token: str, expires_at_ms: int) -> None:
"""Save OAuth credentials to ~/.hermes/.anthropic_oauth.json."""
data = {
"accessToken": access_token,
"refreshToken": refresh_token,
"expiresAt": expires_at_ms,
}
try:
_HERMES_OAUTH_FILE.parent.mkdir(parents=True, exist_ok=True)
_HERMES_OAUTH_FILE.write_text(json.dumps(data, indent=2), encoding="utf-8")
_HERMES_OAUTH_FILE.chmod(0o600)
except (OSError, IOError) as e:
logger.debug("Failed to save Hermes OAuth credentials: %s", e)
def read_hermes_oauth_credentials() -> Optional[Dict[str, Any]]:
"""Read Hermes-managed OAuth credentials from ~/.hermes/.anthropic_oauth.json."""
if _HERMES_OAUTH_FILE.exists():
@@ -739,6 +783,39 @@ def _sanitize_tool_id(tool_id: str) -> str:
return sanitized or "tool_0"
def _convert_openai_image_part_to_anthropic(part: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""Convert an OpenAI-style image block to Anthropic's image source format."""
image_data = part.get("image_url", {})
url = image_data.get("url", "") if isinstance(image_data, dict) else str(image_data)
if not isinstance(url, str) or not url.strip():
return None
url = url.strip()
if url.startswith("data:"):
header, sep, data = url.partition(",")
if sep and ";base64" in header:
media_type = header[5:].split(";", 1)[0] or "image/png"
return {
"type": "image",
"source": {
"type": "base64",
"media_type": media_type,
"data": data,
},
}
if url.startswith(("http://", "https://")):
return {
"type": "image",
"source": {
"type": "url",
"url": url,
},
}
return None
def convert_tools_to_anthropic(tools: List[Dict]) -> List[Dict]:
"""Convert OpenAI tool definitions to Anthropic format."""
if not tools:
@@ -1161,27 +1238,10 @@ def build_anthropic_kwargs(
) -> Dict[str, Any]:
"""Build kwargs for anthropic.messages.create().
Naming note — two distinct concepts, easily confused:
max_tokens = OUTPUT token cap for a single response.
Anthropic's API calls this "max_tokens" but it only
limits the *output*. Anthropic's own native SDK
renamed it "max_output_tokens" for clarity.
context_length = TOTAL context window (input tokens + output tokens).
The API enforces: input_tokens + max_tokens ≤ context_length.
Stored on the ContextCompressor; reduced on overflow errors.
When *max_tokens* is None the model's native output ceiling is used
(e.g. 128K for Opus 4.6, 64K for Sonnet 4.6).
When *context_length* is provided and the model's native output ceiling
exceeds it (e.g. a local endpoint with an 8K window), the output cap is
clamped to context_length 1. This only kicks in for unusually small
context windows; for full-size models the native output cap is always
smaller than the context window so no clamping happens.
NOTE: this clamping does not account for prompt size — if the prompt is
large, Anthropic may still reject the request. The caller must detect
"max_tokens too large given prompt" errors and retry with a smaller cap
(see parse_available_output_tokens_from_error + _ephemeral_max_output_tokens).
When *max_tokens* is None, the model's native output limit is used
(e.g. 128K for Opus 4.6, 64K for Sonnet 4.6). If *context_length*
is provided, the effective limit is clamped so it doesn't exceed
the context window.
When *is_oauth* is True, applies Claude Code compatibility transforms:
system prompt prefix, tool name prefixing, and prompt sanitization.
@@ -1196,14 +1256,10 @@ def build_anthropic_kwargs(
anthropic_tools = convert_tools_to_anthropic(tools) if tools else []
model = normalize_model_name(model, preserve_dots=preserve_dots)
# effective_max_tokens = output cap for this call (≠ total context window)
effective_max_tokens = max_tokens or _get_anthropic_max_output(model)
# Clamp output cap to fit inside the total context window.
# Only matters for small custom endpoints where context_length < native
# output ceiling. For standard Anthropic models context_length (e.g.
# 200K) is always larger than the output ceiling (e.g. 128K), so this
# branch is not taken.
# Clamp to context window if the user set a lower context_length
# (e.g. custom endpoint with limited capacity).
if context_length and effective_max_tokens > context_length:
effective_max_tokens = max(context_length - 1, 1)
+69 -50
View File
@@ -702,7 +702,7 @@ def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
logger.debug("Auxiliary text client: %s (%s) via pool", pconfig.name, model)
extra = {}
if "api.kimi.com" in base_url.lower():
extra["default_headers"] = {"User-Agent": "KimiCLI/1.3"}
extra["default_headers"] = {"User-Agent": "KimiCLI/1.0"}
elif "api.githubcopilot.com" in base_url.lower():
from hermes_cli.models import copilot_default_headers
@@ -721,7 +721,7 @@ def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
logger.debug("Auxiliary text client: %s (%s)", pconfig.name, model)
extra = {}
if "api.kimi.com" in base_url.lower():
extra["default_headers"] = {"User-Agent": "KimiCLI/1.3"}
extra["default_headers"] = {"User-Agent": "KimiCLI/1.0"}
elif "api.githubcopilot.com" in base_url.lower():
from hermes_cli.models import copilot_default_headers
@@ -967,6 +967,40 @@ def _try_anthropic() -> Tuple[Optional[Any], Optional[str]]:
return AnthropicAuxiliaryClient(real_client, model, token, base_url, is_oauth=is_oauth), model
def _resolve_forced_provider(forced: str) -> Tuple[Optional[OpenAI], Optional[str]]:
"""Resolve a specific forced provider. Returns (None, None) if creds missing."""
if forced == "openrouter":
client, model = _try_openrouter()
if client is None:
logger.warning("auxiliary.provider=openrouter but OPENROUTER_API_KEY not set")
return client, model
if forced == "nous":
client, model = _try_nous()
if client is None:
logger.warning("auxiliary.provider=nous but Nous Portal not configured (run: hermes auth)")
return client, model
if forced == "codex":
client, model = _try_codex()
if client is None:
logger.warning("auxiliary.provider=codex but no Codex OAuth token found (run: hermes model)")
return client, model
if forced == "main":
# "main" = skip OpenRouter/Nous, use the main chat model's credentials.
for try_fn in (_try_custom_endpoint, _try_codex, _resolve_api_key_provider):
client, model = try_fn()
if client is not None:
return client, model
logger.warning("auxiliary.provider=main but no main endpoint credentials found")
return None, None
# Unknown provider name — fall through to auto
logger.warning("Unknown auxiliary.provider=%r, falling back to auto", forced)
return None, None
_AUTO_PROVIDER_LABELS = {
"_try_openrouter": "openrouter",
"_try_nous": "nous",
@@ -1013,32 +1047,6 @@ def _is_payment_error(exc: Exception) -> bool:
return False
def _is_connection_error(exc: Exception) -> bool:
"""Detect connection/network errors that warrant provider fallback.
Returns True for errors indicating the provider endpoint is unreachable
(DNS failure, connection refused, TLS errors, timeouts). These are
distinct from API errors (4xx/5xx) which indicate the provider IS
reachable but returned an error.
"""
from openai import APIConnectionError, APITimeoutError
if isinstance(exc, (APIConnectionError, APITimeoutError)):
return True
# urllib3 / httpx / httpcore connection errors
err_type = type(exc).__name__
if any(kw in err_type for kw in ("Connection", "Timeout", "DNS", "SSL")):
return True
err_lower = str(exc).lower()
if any(kw in err_lower for kw in (
"connection refused", "name or service not known",
"no route to host", "network is unreachable",
"timed out", "connection reset",
)):
return True
return False
def _try_payment_fallback(
failed_provider: str,
task: str = None,
@@ -1103,7 +1111,7 @@ def _resolve_auto() -> Tuple[Optional[OpenAI], Optional[str]]:
main_model = _read_main_model()
if (main_provider and main_model
and main_provider not in _AGGREGATOR_PROVIDERS
and main_provider not in ("auto", "")):
and main_provider not in ("auto", "custom", "")):
client, resolved = resolve_provider_client(main_provider, main_model)
if client is not None:
logger.info("Auxiliary auto-detect: using main provider %s (%s)",
@@ -1161,7 +1169,7 @@ def _to_async_client(sync_client, model: str):
async_kwargs["default_headers"] = copilot_default_headers()
elif "api.kimi.com" in base_lower:
async_kwargs["default_headers"] = {"User-Agent": "KimiCLI/1.3"}
async_kwargs["default_headers"] = {"User-Agent": "KimiCLI/1.0"}
return AsyncOpenAI(**async_kwargs), model
@@ -1281,13 +1289,7 @@ def resolve_provider_client(
)
return None, None
final_model = model or _read_main_model() or "gpt-4o-mini"
extra = {}
if "api.kimi.com" in custom_base.lower():
extra["default_headers"] = {"User-Agent": "KimiCLI/1.3"}
elif "api.githubcopilot.com" in custom_base.lower():
from hermes_cli.models import copilot_default_headers
extra["default_headers"] = copilot_default_headers()
client = OpenAI(api_key=custom_key, base_url=custom_base, **extra)
client = OpenAI(api_key=custom_key, base_url=custom_base)
return (_to_async_client(client, final_model) if async_mode
else (client, final_model))
# Try custom first, then codex, then API-key providers
@@ -1366,7 +1368,7 @@ def resolve_provider_client(
# Provider-specific headers
headers = {}
if "api.kimi.com" in base_url.lower():
headers["User-Agent"] = "KimiCLI/1.3"
headers["User-Agent"] = "KimiCLI/1.0"
elif "api.githubcopilot.com" in base_url.lower():
from hermes_cli.models import copilot_default_headers
@@ -1461,6 +1463,22 @@ def _strict_vision_backend_available(provider: str) -> bool:
return _resolve_strict_vision_backend(provider)[0] is not None
def _preferred_main_vision_provider() -> Optional[str]:
"""Return the selected main provider when it is also a supported vision backend."""
try:
from hermes_cli.config import load_config
config = load_config()
model_cfg = config.get("model", {})
if isinstance(model_cfg, dict):
provider = _normalize_vision_provider(model_cfg.get("provider", ""))
if provider in _VISION_AUTO_PROVIDER_ORDER:
return provider
except Exception:
pass
return None
def get_available_vision_backends() -> List[str]:
"""Return the currently available vision backends in auto-selection order.
@@ -1574,6 +1592,18 @@ def resolve_vision_provider_client(
return requested, client, final_model
def get_vision_auxiliary_client() -> Tuple[Optional[OpenAI], Optional[str]]:
"""Return (client, default_model_slug) for vision/multimodal auxiliary tasks."""
_, client, final_model = resolve_vision_provider_client(async_mode=False)
return client, final_model
def get_async_vision_auxiliary_client():
"""Return (async_client, model_slug) for async vision consumers."""
_, client, final_model = resolve_vision_provider_client(async_mode=True)
return client, final_model
def get_auxiliary_extra_body() -> dict:
"""Return extra_body kwargs for auxiliary API calls.
@@ -2063,18 +2093,7 @@ def call_llm(
# try alternative providers instead of giving up. This handles the
# common case where a user runs out of OpenRouter credits but has
# Codex OAuth or another provider available.
#
# ── Connection error fallback ────────────────────────────────
# When a provider endpoint is unreachable (DNS failure, connection
# refused, timeout), try alternative providers. This handles stale
# Codex/OAuth tokens that authenticate but whose endpoint is down,
# and providers the user never configured that got picked up by
# the auto-detection chain.
should_fallback = _is_payment_error(first_err) or _is_connection_error(first_err)
if should_fallback:
reason = "payment error" if _is_payment_error(first_err) else "connection error"
logger.info("Auxiliary %s: %s on %s (%s), trying fallback",
task or "call", reason, resolved_provider, first_err)
if _is_payment_error(first_err):
fb_client, fb_model, fb_label = _try_payment_fallback(
resolved_provider, task)
if fb_client is not None:
+114
View File
@@ -0,0 +1,114 @@
"""BuiltinMemoryProvider — wraps MEMORY.md / USER.md as a MemoryProvider.
Always registered as the first provider. Cannot be disabled or removed.
This is the existing Hermes memory system exposed through the provider
interface for compatibility with the MemoryManager.
The actual storage logic lives in tools/memory_tool.py (MemoryStore).
This provider is a thin adapter that delegates to MemoryStore and
exposes the memory tool schema.
"""
from __future__ import annotations
import json
import logging
from typing import Any, Dict, List
from agent.memory_provider import MemoryProvider
from tools.registry import tool_error
logger = logging.getLogger(__name__)
class BuiltinMemoryProvider(MemoryProvider):
"""Built-in file-backed memory (MEMORY.md + USER.md).
Always active, never disabled by other providers. The `memory` tool
is handled by run_agent.py's agent-level tool interception (not through
the normal registry), so get_tool_schemas() returns an empty list —
the memory tool is already wired separately.
"""
def __init__(
self,
memory_store=None,
memory_enabled: bool = False,
user_profile_enabled: bool = False,
):
self._store = memory_store
self._memory_enabled = memory_enabled
self._user_profile_enabled = user_profile_enabled
@property
def name(self) -> str:
return "builtin"
def is_available(self) -> bool:
"""Built-in memory is always available."""
return True
def initialize(self, session_id: str, **kwargs) -> None:
"""Load memory from disk if not already loaded."""
if self._store is not None:
self._store.load_from_disk()
def system_prompt_block(self) -> str:
"""Return MEMORY.md and USER.md content for the system prompt.
Uses the frozen snapshot captured at load time. This ensures the
system prompt stays stable throughout a session (preserving the
prompt cache), even though the live entries may change via tool calls.
"""
if not self._store:
return ""
parts = []
if self._memory_enabled:
mem_block = self._store.format_for_system_prompt("memory")
if mem_block:
parts.append(mem_block)
if self._user_profile_enabled:
user_block = self._store.format_for_system_prompt("user")
if user_block:
parts.append(user_block)
return "\n\n".join(parts)
def prefetch(self, query: str, *, session_id: str = "") -> str:
"""Built-in memory doesn't do query-based recall — it's injected via system_prompt_block."""
return ""
def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None:
"""Built-in memory doesn't auto-sync turns — writes happen via the memory tool."""
def get_tool_schemas(self) -> List[Dict[str, Any]]:
"""Return empty list.
The `memory` tool is an agent-level intercepted tool, handled
specially in run_agent.py before normal tool dispatch. It's not
part of the standard tool registry. We don't duplicate it here.
"""
return []
def handle_tool_call(self, tool_name: str, args: Dict[str, Any], **kwargs) -> str:
"""Not used — the memory tool is intercepted in run_agent.py."""
return tool_error("Built-in memory tool is handled by the agent loop")
def shutdown(self) -> None:
"""No cleanup needed — files are saved on every write."""
# -- Property access for backward compatibility --------------------------
@property
def store(self):
"""Access the underlying MemoryStore for legacy code paths."""
return self._store
@property
def memory_enabled(self) -> bool:
return self._memory_enabled
@property
def user_profile_enabled(self) -> bool:
return self._user_profile_enabled
+42 -35
View File
@@ -114,6 +114,7 @@ class ContextCompressor:
self.last_prompt_tokens = 0
self.last_completion_tokens = 0
self.last_total_tokens = 0
self.summary_model = summary_model_override or ""
@@ -125,12 +126,28 @@ class ContextCompressor:
"""Update tracked token usage from API response."""
self.last_prompt_tokens = usage.get("prompt_tokens", 0)
self.last_completion_tokens = usage.get("completion_tokens", 0)
self.last_total_tokens = usage.get("total_tokens", 0)
def should_compress(self, prompt_tokens: int = None) -> bool:
"""Check if context exceeds the compression threshold."""
tokens = prompt_tokens if prompt_tokens is not None else self.last_prompt_tokens
return tokens >= self.threshold_tokens
def should_compress_preflight(self, messages: List[Dict[str, Any]]) -> bool:
"""Quick pre-flight check using rough estimate (before API call)."""
rough_estimate = estimate_messages_tokens_rough(messages)
return rough_estimate >= self.threshold_tokens
def get_status(self) -> Dict[str, Any]:
"""Get current compression status for display/logging."""
return {
"last_prompt_tokens": self.last_prompt_tokens,
"threshold_tokens": self.threshold_tokens,
"context_length": self.context_length,
"usage_percent": min(100, (self.last_prompt_tokens / self.context_length * 100)) if self.context_length else 0,
"compression_count": self.compression_count,
}
# ------------------------------------------------------------------
# Tool output pruning (cheap pre-pass, no LLM call)
# ------------------------------------------------------------------
@@ -674,43 +691,33 @@ Write only the summary body. Do not include any preamble or prefix."""
)
compressed.append(msg)
# If LLM summary failed, insert a static fallback so the model
# knows context was lost rather than silently dropping everything.
if not summary:
if not self.quiet_mode:
logger.warning("Summary generation failed — inserting static fallback context marker")
n_dropped = compress_end - compress_start
summary = (
f"{SUMMARY_PREFIX}\n"
f"Summary generation was unavailable. {n_dropped} conversation turns were "
f"removed to free context space but could not be summarized. The removed "
f"turns contained earlier work in this session. Continue based on the "
f"recent messages below and the current state of any files or resources."
)
_merge_summary_into_tail = False
last_head_role = messages[compress_start - 1].get("role", "user") if compress_start > 0 else "user"
first_tail_role = messages[compress_end].get("role", "user") if compress_end < n_messages else "user"
# Pick a role that avoids consecutive same-role with both neighbors.
# Priority: avoid colliding with head (already committed), then tail.
if last_head_role in ("assistant", "tool"):
summary_role = "user"
else:
summary_role = "assistant"
# If the chosen role collides with the tail AND flipping wouldn't
# collide with the head, flip it.
if summary_role == first_tail_role:
flipped = "assistant" if summary_role == "user" else "user"
if flipped != last_head_role:
summary_role = flipped
if summary:
last_head_role = messages[compress_start - 1].get("role", "user") if compress_start > 0 else "user"
first_tail_role = messages[compress_end].get("role", "user") if compress_end < n_messages else "user"
# Pick a role that avoids consecutive same-role with both neighbors.
# Priority: avoid colliding with head (already committed), then tail.
if last_head_role in ("assistant", "tool"):
summary_role = "user"
else:
# Both roles would create consecutive same-role messages
# (e.g. head=assistant, tail=user — neither role works).
# Merge the summary into the first tail message instead
# of inserting a standalone message that breaks alternation.
_merge_summary_into_tail = True
if not _merge_summary_into_tail:
compressed.append({"role": summary_role, "content": summary})
summary_role = "assistant"
# If the chosen role collides with the tail AND flipping wouldn't
# collide with the head, flip it.
if summary_role == first_tail_role:
flipped = "assistant" if summary_role == "user" else "user"
if flipped != last_head_role:
summary_role = flipped
else:
# Both roles would create consecutive same-role messages
# (e.g. head=assistant, tail=user — neither role works).
# Merge the summary into the first tail message instead
# of inserting a standalone message that breaks alternation.
_merge_summary_into_tail = True
if not _merge_summary_into_tail:
compressed.append({"role": summary_role, "content": summary})
else:
if not self.quiet_mode:
logger.debug("No summary model available — middle turns dropped without summary")
for i in range(compress_end, n_messages):
msg = messages[i].copy()
+20 -8
View File
@@ -18,14 +18,12 @@ import hermes_cli.auth as auth_mod
from hermes_cli.auth import (
CODEX_ACCESS_TOKEN_REFRESH_SKEW_SECONDS,
DEFAULT_AGENT_KEY_MIN_TTL_SECONDS,
KIMI_CODE_BASE_URL,
PROVIDER_REGISTRY,
_codex_access_token_is_expiring,
_decode_jwt_claims,
_import_codex_cli_tokens,
_load_auth_store,
_load_provider_state,
_resolve_kimi_base_url,
_resolve_zai_base_url,
read_credential_pool,
write_credential_pool,
@@ -66,10 +64,10 @@ SUPPORTED_POOL_STRATEGIES = {
}
# Cooldown before retrying an exhausted credential.
# 429 (rate-limited) and 402 (billing/quota) both cool down after 1 hour.
# Provider-supplied reset_at timestamps override these defaults.
# 429 (rate-limited) cools down faster since quotas reset frequently.
# 402 (billing/quota) and other codes use a longer default.
EXHAUSTED_TTL_429_SECONDS = 60 * 60 # 1 hour
EXHAUSTED_TTL_DEFAULT_SECONDS = 60 * 60 # 1 hour
EXHAUSTED_TTL_DEFAULT_SECONDS = 24 * 60 * 60 # 24 hours
# Pool key prefix for custom OpenAI-compatible endpoints.
# Custom endpoints all share provider='custom' but are keyed by their
@@ -633,6 +631,17 @@ class CredentialPool:
return False
return False
def mark_used(self, entry_id: Optional[str] = None) -> None:
"""Increment request_count for tracking. Used by least_used strategy."""
target_id = entry_id or self._current_id
if not target_id:
return
with self._lock:
for idx, entry in enumerate(self._entries):
if entry.id == target_id:
self._entries[idx] = replace(entry, request_count=entry.request_count + 1)
return
def select(self) -> Optional[PooledCredential]:
with self._lock:
return self._select_unlocked()
@@ -794,6 +803,11 @@ class CredentialPool:
else:
self._active_leases[credential_id] = count - 1
def active_lease_count(self, credential_id: str) -> int:
"""Return the number of active leases for a credential."""
with self._lock:
return self._active_leases.get(credential_id, 0)
def try_refresh_current(self) -> Optional[PooledCredential]:
with self._lock:
return self._try_refresh_current_unlocked()
@@ -1070,9 +1084,7 @@ def _seed_from_env(provider: str, entries: List[PooledCredential]) -> Tuple[bool
active_sources.add(source)
auth_type = AUTH_TYPE_OAUTH if provider == "anthropic" and not token.startswith("sk-ant-api") else AUTH_TYPE_API_KEY
base_url = env_url or pconfig.inference_base_url
if provider == "kimi-coding":
base_url = _resolve_kimi_base_url(token, pconfig.inference_base_url, env_url)
elif provider == "zai":
if provider == "zai":
base_url = _resolve_zai_base_url(token, pconfig.inference_base_url, env_url)
changed |= _upsert_entry(
entries,
+76
View File
@@ -67,6 +67,26 @@ def _get_skin():
return None
def get_skin_faces(key: str, default: list) -> list:
"""Get spinner face list from active skin, falling back to default."""
skin = _get_skin()
if skin:
faces = skin.get_spinner_list(key)
if faces:
return faces
return default
def get_skin_verbs() -> list:
"""Get thinking verbs from active skin."""
skin = _get_skin()
if skin:
verbs = skin.get_spinner_list("thinking_verbs")
if verbs:
return verbs
return KawaiiSpinner.THINKING_VERBS
def get_skin_tool_prefix() -> str:
"""Get tool output prefix character from active skin."""
skin = _get_skin()
@@ -703,6 +723,46 @@ class KawaiiSpinner:
return False
# =========================================================================
# Kawaii face arrays (used by AIAgent._execute_tool_calls for spinner text)
# =========================================================================
KAWAII_SEARCH = [
"♪(´ε` )", "(。◕‿◕。)", "ヾ(^∇^)", "(◕ᴗ◕✿)", "( ˘▽˘)っ",
"٩(◕‿◕。)۶", "(✿◠‿◠)", "♪~(´ε` )", "(ノ´ヮ`)*:・゚✧", "(◎o◎)",
]
KAWAII_READ = [
"φ(゜▽゜*)♪", "( ˘▽˘)っ", "(⌐■_■)", "٩(。•́‿•̀。)۶", "(◕‿◕✿)",
"ヾ(@⌒ー⌒@)", "(✧ω✧)", "♪(๑ᴖ◡ᴖ๑)♪", "(≧◡≦)", "( ´ ▽ ` )",
]
KAWAII_TERMINAL = [
"ヽ(>∀<☆)", "(ノ°∀°)", "٩(^ᴗ^)۶", "ヾ(⌐■_■)ノ♪", "(•̀ᴗ•́)و",
"┗(0)┓", "(`・ω・´)", "( ̄▽ ̄)", "(ง •̀_•́)ง", "ヽ(´▽`)/",
]
KAWAII_BROWSER = [
"(ノ°∀°)", "(☞゚ヮ゚)☞", "( ͡° ͜ʖ ͡°)", "┌( ಠ_ಠ)┘", "(⊙_⊙)",
"ヾ(•ω•`)o", "( ̄ω ̄)", "( ˇωˇ )", "(ᵔᴥᵔ)", "(◎o◎)",
]
KAWAII_CREATE = [
"✧*。٩(ˊᗜˋ*)و✧", "(ノ◕ヮ◕)ノ*:・゚✧", "ヽ(>∀<☆)", "٩(♡ε♡)۶", "(◕‿◕)♡",
"✿◕ ‿ ◕✿", "(*≧▽≦)", "ヾ(-)", "(☆▽☆)", "°˖✧◝(⁰▿⁰)◜✧˖°",
]
KAWAII_SKILL = [
"ヾ(@⌒ー⌒@)", "(๑˃ᴗ˂)ﻭ", "٩(◕‿◕。)۶", "(✿╹◡╹)", "ヽ(・∀・)",
"(ノ´ヮ`)*:・゚✧", "♪(๑ᴖ◡ᴖ๑)♪", "(◠‿◠)", "٩(ˊᗜˋ*)و", "(^▽^)",
"ヾ(^∇^)", "(★ω★)/", "٩(。•́‿•̀。)۶", "(◕ᴗ◕✿)", "(◎o◎)",
"(✧ω✧)", "ヽ(>∀<☆)", "( ˘▽˘)っ", "(≧◡≦) ♡", "ヾ( ̄▽ ̄)",
]
KAWAII_THINK = [
"(っ°Д°;)っ", "(;′⌒`)", "(・_・ヾ", "( ´_ゝ`)", "( ̄ヘ ̄)",
"(。-`ω´-)", "( ˘︹˘ )", "(¬_¬)", "ヽ(ー_ー )", "(;一_一)",
]
KAWAII_GENERIC = [
"♪(´ε` )", "(◕‿◕✿)", "ヾ(^∇^)", "٩(◕‿◕。)۶", "(✿◠‿◠)",
"(ノ´ヮ`)*:・゚✧", "ヽ(>∀<☆)", "(☆▽☆)", "( ˘▽˘)っ", "(≧◡≦)",
]
# =========================================================================
# Cute tool message (completion line that replaces the spinner)
# =========================================================================
@@ -910,6 +970,22 @@ _SKY_BLUE = "\033[38;5;117m"
_ANSI_RESET = "\033[0m"
def honcho_session_url(workspace: str, session_name: str) -> str:
"""Build a Honcho app URL for a session."""
from urllib.parse import quote
return (
f"https://app.honcho.dev/explore"
f"?workspace={quote(workspace, safe='')}"
f"&view=sessions"
f"&session={quote(session_name, safe='')}"
)
def _osc8_link(url: str, text: str) -> str:
"""OSC 8 terminal hyperlink (clickable in iTerm2, Ghostty, WezTerm, etc.)."""
return f"\033]8;;{url}\033\\{text}\033]8;;\033\\"
# =========================================================================
# Context pressure display (CLI user-facing warnings)
# =========================================================================
-782
View File
@@ -1,782 +0,0 @@
"""API error classification for smart failover and recovery.
Provides a structured taxonomy of API errors and a priority-ordered
classification pipeline that determines the correct recovery action
(retry, rotate credential, fallback to another provider, compress
context, or abort).
Replaces scattered inline string-matching with a centralized classifier
that the main retry loop in run_agent.py consults for every API failure.
"""
from __future__ import annotations
import enum
import logging
import re
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
logger = logging.getLogger(__name__)
# ── Error taxonomy ──────────────────────────────────────────────────────
class FailoverReason(enum.Enum):
"""Why an API call failed — determines recovery strategy."""
# Authentication / authorization
auth = "auth" # Transient auth (401/403) — refresh/rotate
auth_permanent = "auth_permanent" # Auth failed after refresh — abort
# Billing / quota
billing = "billing" # 402 or confirmed credit exhaustion — rotate immediately
rate_limit = "rate_limit" # 429 or quota-based throttling — backoff then rotate
# Server-side
overloaded = "overloaded" # 503/529 — provider overloaded, backoff
server_error = "server_error" # 500/502 — internal server error, retry
# Transport
timeout = "timeout" # Connection/read timeout — rebuild client + retry
# Context / payload
context_overflow = "context_overflow" # Context too large — compress, not failover
payload_too_large = "payload_too_large" # 413 — compress payload
# Model
model_not_found = "model_not_found" # 404 or invalid model — fallback to different model
# Request format
format_error = "format_error" # 400 bad request — abort or strip + retry
# Provider-specific
thinking_signature = "thinking_signature" # Anthropic thinking block sig invalid
long_context_tier = "long_context_tier" # Anthropic "extra usage" tier gate
# Catch-all
unknown = "unknown" # Unclassifiable — retry with backoff
# ── Classification result ───────────────────────────────────────────────
@dataclass
class ClassifiedError:
"""Structured classification of an API error with recovery hints."""
reason: FailoverReason
status_code: Optional[int] = None
provider: Optional[str] = None
model: Optional[str] = None
message: str = ""
error_context: Dict[str, Any] = field(default_factory=dict)
# Recovery action hints — the retry loop checks these instead of
# re-classifying the error itself.
retryable: bool = True
should_compress: bool = False
should_rotate_credential: bool = False
should_fallback: bool = False
@property
def is_auth(self) -> bool:
return self.reason in (FailoverReason.auth, FailoverReason.auth_permanent)
# ── Provider-specific patterns ──────────────────────────────────────────
# Patterns that indicate billing exhaustion (not transient rate limit)
_BILLING_PATTERNS = [
"insufficient credits",
"insufficient_quota",
"credit balance",
"credits have been exhausted",
"top up your credits",
"payment required",
"billing hard limit",
"exceeded your current quota",
"account is deactivated",
"plan does not include",
]
# Patterns that indicate rate limiting (transient, will resolve)
_RATE_LIMIT_PATTERNS = [
"rate limit",
"rate_limit",
"too many requests",
"throttled",
"requests per minute",
"tokens per minute",
"requests per day",
"try again in",
"please retry after",
"resource_exhausted",
]
# Usage-limit patterns that need disambiguation (could be billing OR rate_limit)
_USAGE_LIMIT_PATTERNS = [
"usage limit",
"quota",
"limit exceeded",
"key limit exceeded",
]
# Patterns confirming usage limit is transient (not billing)
_USAGE_LIMIT_TRANSIENT_SIGNALS = [
"try again",
"retry",
"resets at",
"reset in",
"wait",
"requests remaining",
"periodic",
"window",
]
# Payload-too-large patterns detected from message text (no status_code attr).
# Proxies and some backends embed the HTTP status in the error message.
_PAYLOAD_TOO_LARGE_PATTERNS = [
"request entity too large",
"payload too large",
"error code: 413",
]
# Context overflow patterns
_CONTEXT_OVERFLOW_PATTERNS = [
"context length",
"context size",
"maximum context",
"token limit",
"too many tokens",
"reduce the length",
"exceeds the limit",
"context window",
"prompt is too long",
"prompt exceeds max length",
"max_tokens",
"maximum number of tokens",
# Chinese error messages (some providers return these)
"超过最大长度",
"上下文长度",
]
# Model not found patterns
_MODEL_NOT_FOUND_PATTERNS = [
"is not a valid model",
"invalid model",
"model not found",
"model_not_found",
"does not exist",
"no such model",
"unknown model",
"unsupported model",
]
# Auth patterns (non-status-code signals)
_AUTH_PATTERNS = [
"invalid api key",
"invalid_api_key",
"authentication",
"unauthorized",
"forbidden",
"invalid token",
"token expired",
"token revoked",
"access denied",
]
# Anthropic thinking block signature patterns
_THINKING_SIG_PATTERNS = [
"signature", # Combined with "thinking" check
]
# Transport error type names
_TRANSPORT_ERROR_TYPES = frozenset({
"ReadTimeout", "ConnectTimeout", "PoolTimeout",
"ConnectError", "RemoteProtocolError",
"ConnectionError", "ConnectionResetError",
"ConnectionAbortedError", "BrokenPipeError",
"TimeoutError", "ReadError",
"ServerDisconnectedError",
# OpenAI SDK errors (not subclasses of Python builtins)
"APIConnectionError",
"APITimeoutError",
})
# Server disconnect patterns (no status code, but transport-level)
_SERVER_DISCONNECT_PATTERNS = [
"server disconnected",
"peer closed connection",
"connection reset by peer",
"connection was closed",
"network connection lost",
"unexpected eof",
"incomplete chunked read",
]
# ── Classification pipeline ─────────────────────────────────────────────
def classify_api_error(
error: Exception,
*,
provider: str = "",
model: str = "",
approx_tokens: int = 0,
context_length: int = 200000,
num_messages: int = 0,
) -> ClassifiedError:
"""Classify an API error into a structured recovery recommendation.
Priority-ordered pipeline:
1. Special-case provider-specific patterns (thinking sigs, tier gates)
2. HTTP status code + message-aware refinement
3. Error code classification (from body)
4. Message pattern matching (billing vs rate_limit vs context vs auth)
5. Transport error heuristics
6. Server disconnect + large session → context overflow
7. Fallback: unknown (retryable with backoff)
Args:
error: The exception from the API call.
provider: Current provider name (e.g. "openrouter", "anthropic").
model: Current model slug.
approx_tokens: Approximate token count of the current context.
context_length: Maximum context length for the current model.
Returns:
ClassifiedError with reason and recovery action hints.
"""
status_code = _extract_status_code(error)
error_type = type(error).__name__
body = _extract_error_body(error)
error_code = _extract_error_code(body)
# Build a comprehensive error message string for pattern matching.
# str(error) alone may not include the body message (e.g. OpenAI SDK's
# APIStatusError.__str__ returns the first arg, not the body). Append
# the body message so patterns like "try again" in 402 disambiguation
# are detected even when only present in the structured body.
#
# Also extract metadata.raw — OpenRouter wraps upstream provider errors
# inside {"error": {"message": "Provider returned error", "metadata":
# {"raw": "<actual error JSON>"}}} and the real error message (e.g.
# "context length exceeded") is only in the inner JSON.
_raw_msg = str(error).lower()
_body_msg = ""
_metadata_msg = ""
if isinstance(body, dict):
_err_obj = body.get("error", {})
if isinstance(_err_obj, dict):
_body_msg = (_err_obj.get("message") or "").lower()
# Parse metadata.raw for wrapped provider errors
_metadata = _err_obj.get("metadata", {})
if isinstance(_metadata, dict):
_raw_json = _metadata.get("raw") or ""
if isinstance(_raw_json, str) and _raw_json.strip():
try:
import json
_inner = json.loads(_raw_json)
if isinstance(_inner, dict):
_inner_err = _inner.get("error", {})
if isinstance(_inner_err, dict):
_metadata_msg = (_inner_err.get("message") or "").lower()
except (json.JSONDecodeError, TypeError):
pass
if not _body_msg:
_body_msg = (body.get("message") or "").lower()
# Combine all message sources for pattern matching
parts = [_raw_msg]
if _body_msg and _body_msg not in _raw_msg:
parts.append(_body_msg)
if _metadata_msg and _metadata_msg not in _raw_msg and _metadata_msg not in _body_msg:
parts.append(_metadata_msg)
error_msg = " ".join(parts)
provider_lower = (provider or "").strip().lower()
model_lower = (model or "").strip().lower()
def _result(reason: FailoverReason, **overrides) -> ClassifiedError:
defaults = {
"reason": reason,
"status_code": status_code,
"provider": provider,
"model": model,
"message": _extract_message(error, body),
}
defaults.update(overrides)
return ClassifiedError(**defaults)
# ── 1. Provider-specific patterns (highest priority) ────────────
# Anthropic thinking block signature invalid (400).
# Don't gate on provider — OpenRouter proxies Anthropic errors, so the
# provider may be "openrouter" even though the error is Anthropic-specific.
# The message pattern ("signature" + "thinking") is unique enough.
if (
status_code == 400
and "signature" in error_msg
and "thinking" in error_msg
):
return _result(
FailoverReason.thinking_signature,
retryable=True,
should_compress=False,
)
# Anthropic long-context tier gate (429 "extra usage" + "long context")
if (
status_code == 429
and "extra usage" in error_msg
and "long context" in error_msg
):
return _result(
FailoverReason.long_context_tier,
retryable=True,
should_compress=True,
)
# ── 2. HTTP status code classification ──────────────────────────
if status_code is not None:
classified = _classify_by_status(
status_code, error_msg, error_code, body,
provider=provider_lower, model=model_lower,
approx_tokens=approx_tokens, context_length=context_length,
num_messages=num_messages,
result_fn=_result,
)
if classified is not None:
return classified
# ── 3. Error code classification ────────────────────────────────
if error_code:
classified = _classify_by_error_code(error_code, error_msg, _result)
if classified is not None:
return classified
# ── 4. Message pattern matching (no status code) ────────────────
classified = _classify_by_message(
error_msg, error_type,
approx_tokens=approx_tokens,
context_length=context_length,
result_fn=_result,
)
if classified is not None:
return classified
# ── 5. Server disconnect + large session → context overflow ─────
# Must come BEFORE generic transport error catch — a disconnect on
# a large session is more likely context overflow than a transient
# transport hiccup. Without this ordering, RemoteProtocolError
# always maps to timeout regardless of session size.
is_disconnect = any(p in error_msg for p in _SERVER_DISCONNECT_PATTERNS)
if is_disconnect and not status_code:
is_large = approx_tokens > context_length * 0.6 or approx_tokens > 120000 or num_messages > 200
if is_large:
return _result(
FailoverReason.context_overflow,
retryable=True,
should_compress=True,
)
return _result(FailoverReason.timeout, retryable=True)
# ── 6. Transport / timeout heuristics ───────────────────────────
if error_type in _TRANSPORT_ERROR_TYPES or isinstance(error, (TimeoutError, ConnectionError, OSError)):
return _result(FailoverReason.timeout, retryable=True)
# ── 7. Fallback: unknown ────────────────────────────────────────
return _result(FailoverReason.unknown, retryable=True)
# ── Status code classification ──────────────────────────────────────────
def _classify_by_status(
status_code: int,
error_msg: str,
error_code: str,
body: dict,
*,
provider: str,
model: str,
approx_tokens: int,
context_length: int,
num_messages: int = 0,
result_fn,
) -> Optional[ClassifiedError]:
"""Classify based on HTTP status code with message-aware refinement."""
if status_code == 401:
# Not retryable on its own — credential pool rotation and
# provider-specific refresh (Codex, Anthropic, Nous) run before
# the retryability check in run_agent.py. If those succeed, the
# loop `continue`s. If they fail, retryable=False ensures we
# hit the client-error abort path (which tries fallback first).
return result_fn(
FailoverReason.auth,
retryable=False,
should_rotate_credential=True,
should_fallback=True,
)
if status_code == 403:
# OpenRouter 403 "key limit exceeded" is actually billing
if "key limit exceeded" in error_msg or "spending limit" in error_msg:
return result_fn(
FailoverReason.billing,
retryable=False,
should_rotate_credential=True,
should_fallback=True,
)
return result_fn(
FailoverReason.auth,
retryable=False,
should_fallback=True,
)
if status_code == 402:
return _classify_402(error_msg, result_fn)
if status_code == 404:
if any(p in error_msg for p in _MODEL_NOT_FOUND_PATTERNS):
return result_fn(
FailoverReason.model_not_found,
retryable=False,
should_fallback=True,
)
# Generic 404 — could be model or endpoint
return result_fn(
FailoverReason.model_not_found,
retryable=False,
should_fallback=True,
)
if status_code == 413:
return result_fn(
FailoverReason.payload_too_large,
retryable=True,
should_compress=True,
)
if status_code == 429:
# Already checked long_context_tier above; this is a normal rate limit
return result_fn(
FailoverReason.rate_limit,
retryable=True,
should_rotate_credential=True,
should_fallback=True,
)
if status_code == 400:
return _classify_400(
error_msg, error_code, body,
provider=provider, model=model,
approx_tokens=approx_tokens,
context_length=context_length,
num_messages=num_messages,
result_fn=result_fn,
)
if status_code in (500, 502):
return result_fn(FailoverReason.server_error, retryable=True)
if status_code in (503, 529):
return result_fn(FailoverReason.overloaded, retryable=True)
# Other 4xx — non-retryable
if 400 <= status_code < 500:
return result_fn(
FailoverReason.format_error,
retryable=False,
should_fallback=True,
)
# Other 5xx — retryable
if 500 <= status_code < 600:
return result_fn(FailoverReason.server_error, retryable=True)
return None
def _classify_402(error_msg: str, result_fn) -> ClassifiedError:
"""Disambiguate 402: billing exhaustion vs transient usage limit.
The key insight from OpenClaw: some 402s are transient rate limits
disguised as payment errors. "Usage limit, try again in 5 minutes"
is NOT a billing problem — it's a periodic quota that resets.
"""
# Check for transient usage-limit signals first
has_usage_limit = any(p in error_msg for p in _USAGE_LIMIT_PATTERNS)
has_transient_signal = any(p in error_msg for p in _USAGE_LIMIT_TRANSIENT_SIGNALS)
if has_usage_limit and has_transient_signal:
# Transient quota — treat as rate limit, not billing
return result_fn(
FailoverReason.rate_limit,
retryable=True,
should_rotate_credential=True,
should_fallback=True,
)
# Confirmed billing exhaustion
return result_fn(
FailoverReason.billing,
retryable=False,
should_rotate_credential=True,
should_fallback=True,
)
def _classify_400(
error_msg: str,
error_code: str,
body: dict,
*,
provider: str,
model: str,
approx_tokens: int,
context_length: int,
num_messages: int = 0,
result_fn,
) -> ClassifiedError:
"""Classify 400 Bad Request — context overflow, format error, or generic."""
# Context overflow from 400
if any(p in error_msg for p in _CONTEXT_OVERFLOW_PATTERNS):
return result_fn(
FailoverReason.context_overflow,
retryable=True,
should_compress=True,
)
# Some providers return model-not-found as 400 instead of 404 (e.g. OpenRouter).
if any(p in error_msg for p in _MODEL_NOT_FOUND_PATTERNS):
return result_fn(
FailoverReason.model_not_found,
retryable=False,
should_fallback=True,
)
# Some providers return rate limit / billing errors as 400 instead of 429/402.
# Check these patterns before falling through to format_error.
if any(p in error_msg for p in _RATE_LIMIT_PATTERNS):
return result_fn(
FailoverReason.rate_limit,
retryable=True,
should_rotate_credential=True,
should_fallback=True,
)
if any(p in error_msg for p in _BILLING_PATTERNS):
return result_fn(
FailoverReason.billing,
retryable=False,
should_rotate_credential=True,
should_fallback=True,
)
# Generic 400 + large session → probable context overflow
# Anthropic sometimes returns a bare "Error" message when context is too large
err_body_msg = ""
if isinstance(body, dict):
err_obj = body.get("error", {})
if isinstance(err_obj, dict):
err_body_msg = (err_obj.get("message") or "").strip().lower()
# Responses API (and some providers) use flat body: {"message": "..."}
if not err_body_msg:
err_body_msg = (body.get("message") or "").strip().lower()
is_generic = len(err_body_msg) < 30 or err_body_msg in ("error", "")
is_large = approx_tokens > context_length * 0.4 or approx_tokens > 80000 or num_messages > 80
if is_generic and is_large:
return result_fn(
FailoverReason.context_overflow,
retryable=True,
should_compress=True,
)
# Non-retryable format error
return result_fn(
FailoverReason.format_error,
retryable=False,
should_fallback=True,
)
# ── Error code classification ───────────────────────────────────────────
def _classify_by_error_code(
error_code: str, error_msg: str, result_fn,
) -> Optional[ClassifiedError]:
"""Classify by structured error codes from the response body."""
code_lower = error_code.lower()
if code_lower in ("resource_exhausted", "throttled", "rate_limit_exceeded"):
return result_fn(
FailoverReason.rate_limit,
retryable=True,
should_rotate_credential=True,
)
if code_lower in ("insufficient_quota", "billing_not_active", "payment_required"):
return result_fn(
FailoverReason.billing,
retryable=False,
should_rotate_credential=True,
should_fallback=True,
)
if code_lower in ("model_not_found", "model_not_available", "invalid_model"):
return result_fn(
FailoverReason.model_not_found,
retryable=False,
should_fallback=True,
)
if code_lower in ("context_length_exceeded", "max_tokens_exceeded"):
return result_fn(
FailoverReason.context_overflow,
retryable=True,
should_compress=True,
)
return None
# ── Message pattern classification ──────────────────────────────────────
def _classify_by_message(
error_msg: str,
error_type: str,
*,
approx_tokens: int,
context_length: int,
result_fn,
) -> Optional[ClassifiedError]:
"""Classify based on error message patterns when no status code is available."""
# Payload-too-large patterns (from message text when no status_code)
if any(p in error_msg for p in _PAYLOAD_TOO_LARGE_PATTERNS):
return result_fn(
FailoverReason.payload_too_large,
retryable=True,
should_compress=True,
)
# Billing patterns
if any(p in error_msg for p in _BILLING_PATTERNS):
return result_fn(
FailoverReason.billing,
retryable=False,
should_rotate_credential=True,
should_fallback=True,
)
# Rate limit patterns
if any(p in error_msg for p in _RATE_LIMIT_PATTERNS):
return result_fn(
FailoverReason.rate_limit,
retryable=True,
should_rotate_credential=True,
should_fallback=True,
)
# Context overflow patterns
if any(p in error_msg for p in _CONTEXT_OVERFLOW_PATTERNS):
return result_fn(
FailoverReason.context_overflow,
retryable=True,
should_compress=True,
)
# Auth patterns
if any(p in error_msg for p in _AUTH_PATTERNS):
return result_fn(
FailoverReason.auth,
retryable=True,
should_rotate_credential=True,
)
# Model not found patterns
if any(p in error_msg for p in _MODEL_NOT_FOUND_PATTERNS):
return result_fn(
FailoverReason.model_not_found,
retryable=False,
should_fallback=True,
)
return None
# ── Helpers ─────────────────────────────────────────────────────────────
def _extract_status_code(error: Exception) -> Optional[int]:
"""Walk the error and its cause chain to find an HTTP status code."""
current = error
for _ in range(5): # Max depth to prevent infinite loops
code = getattr(current, "status_code", None)
if isinstance(code, int):
return code
# Some SDKs use .status instead of .status_code
code = getattr(current, "status", None)
if isinstance(code, int) and 100 <= code < 600:
return code
# Walk cause chain
cause = getattr(current, "__cause__", None) or getattr(current, "__context__", None)
if cause is None or cause is current:
break
current = cause
return None
def _extract_error_body(error: Exception) -> dict:
"""Extract the structured error body from an SDK exception."""
body = getattr(error, "body", None)
if isinstance(body, dict):
return body
# Some errors have .response.json()
response = getattr(error, "response", None)
if response is not None:
try:
json_body = response.json()
if isinstance(json_body, dict):
return json_body
except Exception:
pass
return {}
def _extract_error_code(body: dict) -> str:
"""Extract an error code string from the response body."""
if not body:
return ""
error_obj = body.get("error", {})
if isinstance(error_obj, dict):
code = error_obj.get("code") or error_obj.get("type") or ""
if isinstance(code, str) and code.strip():
return code.strip()
# Top-level code
code = body.get("code") or body.get("error_code") or ""
if isinstance(code, (str, int)):
return str(code).strip()
return ""
def _extract_message(error: Exception, body: dict) -> str:
"""Extract the most informative error message."""
# Try structured body first
if body:
error_obj = body.get("error", {})
if isinstance(error_obj, dict):
msg = error_obj.get("message", "")
if isinstance(msg, str) and msg.strip():
return msg.strip()[:500]
msg = body.get("message", "")
if isinstance(msg, str) and msg.strip():
return msg.strip()[:500]
# Fallback to str(error)
return str(error)[:500]
+9
View File
@@ -39,6 +39,15 @@ def _has_known_pricing(model_name: str, provider: str = None, base_url: str = No
return has_known_pricing(model_name, provider=provider, base_url=base_url)
def _get_pricing(model_name: str) -> Dict[str, float]:
"""Look up pricing for a model. Uses fuzzy matching on model name.
Returns _DEFAULT_PRICING (zero cost) for unknown/custom models —
we can't assume costs for self-hosted endpoints, local inference, etc.
"""
return get_pricing(model_name)
def _estimate_cost(
session_or_model: Dict[str, Any] | str,
input_tokens: int = 0,
+5
View File
@@ -134,6 +134,11 @@ class MemoryManager:
"""All registered providers in order."""
return list(self._providers)
@property
def provider_names(self) -> List[str]:
"""Names of all registered providers."""
return [p.name for p in self._providers]
def get_provider(self, name: str) -> Optional[MemoryProvider]:
"""Get a provider by name, or None if not registered."""
for p in self._providers:
-43
View File
@@ -603,49 +603,6 @@ def parse_context_limit_from_error(error_msg: str) -> Optional[int]:
return None
def parse_available_output_tokens_from_error(error_msg: str) -> Optional[int]:
"""Detect an "output cap too large" error and return how many output tokens are available.
Background — two distinct context errors exist:
1. "Prompt too long" — the INPUT itself exceeds the context window.
Fix: compress history and/or halve context_length.
2. "max_tokens too large" — input is fine, but input + requested_output > window.
Fix: reduce max_tokens (the output cap) for this call.
Do NOT touch context_length — the window hasn't shrunk.
Anthropic's API returns errors like:
"max_tokens: 32768 > context_window: 200000 - input_tokens: 190000 = available_tokens: 10000"
Returns the number of output tokens that would fit (e.g. 10000 above), or None if
the error does not look like a max_tokens-too-large error.
"""
error_lower = error_msg.lower()
# Must look like an output-cap error, not a prompt-length error.
is_output_cap_error = (
"max_tokens" in error_lower
and ("available_tokens" in error_lower or "available tokens" in error_lower)
)
if not is_output_cap_error:
return None
# Extract the available_tokens figure.
# Anthropic format: "… = available_tokens: 10000"
patterns = [
r'available_tokens[:\s]+(\d+)',
r'available\s+tokens[:\s]+(\d+)',
# fallback: last number after "=" in expressions like "200000 - 190000 = 10000"
r'=\s*(\d+)\s*$',
]
for pattern in patterns:
match = re.search(pattern, error_lower)
if match:
tokens = int(match.group(1))
if tokens >= 1:
return tokens
return None
def _model_id_matches(candidate_id: str, lookup_model: str) -> bool:
"""Return True if *candidate_id* (from server) matches *lookup_model* (configured).
+111
View File
@@ -135,6 +135,9 @@ class ProviderInfo:
doc: str = "" # documentation URL
model_count: int = 0
def has_api_url(self) -> bool:
return bool(self.api)
# ---------------------------------------------------------------------------
# Provider ID mapping: Hermes ↔ models.dev
@@ -631,6 +634,43 @@ def get_provider_info(provider_id: str) -> Optional[ProviderInfo]:
return _parse_provider_info(mdev_id, raw)
def list_all_providers() -> Dict[str, ProviderInfo]:
"""Return all providers from models.dev as {provider_id: ProviderInfo}.
Returns the full catalog — 109+ providers. For providers that have
a Hermes alias, both the models.dev ID and the Hermes ID are included.
"""
data = fetch_models_dev()
result: Dict[str, ProviderInfo] = {}
for pid, pdata in data.items():
if isinstance(pdata, dict):
info = _parse_provider_info(pid, pdata)
result[pid] = info
return result
def get_providers_for_env_var(env_var: str) -> List[str]:
"""Reverse lookup: find all providers that use a given env var.
Useful for auto-detection: "user has ANTHROPIC_API_KEY set, which
providers does that enable?"
Returns list of models.dev provider IDs.
"""
data = fetch_models_dev()
matches: List[str] = []
for pid, pdata in data.items():
if isinstance(pdata, dict):
env = pdata.get("env", [])
if isinstance(env, list) and env_var in env:
matches.append(pid)
return matches
# ---------------------------------------------------------------------------
# Model-level queries (rich ModelInfo)
# ---------------------------------------------------------------------------
@@ -668,3 +708,74 @@ def get_model_info(
return None
def get_model_info_any_provider(model_id: str) -> Optional[ModelInfo]:
"""Search all providers for a model by ID.
Useful when you have a full slug like "anthropic/claude-sonnet-4.6" or
a bare name and want to find it anywhere. Checks Hermes-mapped providers
first, then falls back to all models.dev providers.
"""
data = fetch_models_dev()
# Try Hermes-mapped providers first (more likely what the user wants)
for hermes_id, mdev_id in PROVIDER_TO_MODELS_DEV.items():
pdata = data.get(mdev_id)
if not isinstance(pdata, dict):
continue
models = pdata.get("models", {})
if not isinstance(models, dict):
continue
raw = models.get(model_id)
if isinstance(raw, dict):
return _parse_model_info(model_id, raw, mdev_id)
# Case-insensitive
model_lower = model_id.lower()
for mid, mdata in models.items():
if mid.lower() == model_lower and isinstance(mdata, dict):
return _parse_model_info(mid, mdata, mdev_id)
# Fall back to ALL providers
for pid, pdata in data.items():
if pid in _get_reverse_mapping():
continue # already checked
if not isinstance(pdata, dict):
continue
models = pdata.get("models", {})
if not isinstance(models, dict):
continue
raw = models.get(model_id)
if isinstance(raw, dict):
return _parse_model_info(model_id, raw, pid)
return None
def list_provider_model_infos(provider_id: str) -> List[ModelInfo]:
"""Return all models for a provider as ModelInfo objects.
Filters out deprecated models by default.
"""
mdev_id = PROVIDER_TO_MODELS_DEV.get(provider_id, provider_id)
data = fetch_models_dev()
pdata = data.get(mdev_id)
if not isinstance(pdata, dict):
return []
models = pdata.get("models", {})
if not isinstance(models, dict):
return []
result: List[ModelInfo] = []
for mid, mdata in models.items():
if not isinstance(mdata, dict):
continue
status = mdata.get("status", "")
if status == "deprecated":
continue
result.append(_parse_model_info(mid, mdata, mdev_id))
return result
+11
View File
@@ -491,6 +491,17 @@ def _parse_skill_file(skill_file: Path) -> tuple[bool, dict, str]:
return True, {}, ""
def _read_skill_conditions(skill_file: Path) -> dict:
"""Extract conditional activation fields from SKILL.md frontmatter."""
try:
raw = skill_file.read_text(encoding="utf-8")[:2000]
frontmatter, _ = parse_frontmatter(raw)
return extract_skill_conditions(frontmatter)
except Exception as e:
logger.debug("Failed to read skill conditions from %s: %s", skill_file, e)
return {}
def _skill_should_show(
conditions: dict,
available_tools: "set[str] | None",
-242
View File
@@ -1,242 +0,0 @@
"""Rate limit tracking for inference API responses.
Captures x-ratelimit-* headers from provider responses and provides
formatted display for the /usage slash command. Currently supports
the Nous Portal header format (also used by OpenRouter and OpenAI-compatible
APIs that follow the same convention).
Header schema (12 headers total):
x-ratelimit-limit-requests RPM cap
x-ratelimit-limit-requests-1h RPH cap
x-ratelimit-limit-tokens TPM cap
x-ratelimit-limit-tokens-1h TPH cap
x-ratelimit-remaining-requests requests left in minute window
x-ratelimit-remaining-requests-1h requests left in hour window
x-ratelimit-remaining-tokens tokens left in minute window
x-ratelimit-remaining-tokens-1h tokens left in hour window
x-ratelimit-reset-requests seconds until minute request window resets
x-ratelimit-reset-requests-1h seconds until hour request window resets
x-ratelimit-reset-tokens seconds until minute token window resets
x-ratelimit-reset-tokens-1h seconds until hour token window resets
"""
from __future__ import annotations
import time
from dataclasses import dataclass, field
from typing import Any, Dict, Mapping, Optional
@dataclass
class RateLimitBucket:
"""One rate-limit window (e.g. requests per minute)."""
limit: int = 0
remaining: int = 0
reset_seconds: float = 0.0
captured_at: float = 0.0 # time.time() when this was captured
@property
def used(self) -> int:
return max(0, self.limit - self.remaining)
@property
def usage_pct(self) -> float:
if self.limit <= 0:
return 0.0
return (self.used / self.limit) * 100.0
@property
def remaining_seconds_now(self) -> float:
"""Estimated seconds remaining until reset, adjusted for elapsed time."""
elapsed = time.time() - self.captured_at
return max(0.0, self.reset_seconds - elapsed)
@dataclass
class RateLimitState:
"""Full rate-limit state parsed from response headers."""
requests_min: RateLimitBucket = field(default_factory=RateLimitBucket)
requests_hour: RateLimitBucket = field(default_factory=RateLimitBucket)
tokens_min: RateLimitBucket = field(default_factory=RateLimitBucket)
tokens_hour: RateLimitBucket = field(default_factory=RateLimitBucket)
captured_at: float = 0.0 # when the headers were captured
provider: str = ""
@property
def has_data(self) -> bool:
return self.captured_at > 0
@property
def age_seconds(self) -> float:
if not self.has_data:
return float("inf")
return time.time() - self.captured_at
def _safe_int(value: Any, default: int = 0) -> int:
try:
return int(float(value))
except (TypeError, ValueError):
return default
def _safe_float(value: Any, default: float = 0.0) -> float:
try:
return float(value)
except (TypeError, ValueError):
return default
def parse_rate_limit_headers(
headers: Mapping[str, str],
provider: str = "",
) -> Optional[RateLimitState]:
"""Parse x-ratelimit-* headers into a RateLimitState.
Returns None if no rate limit headers are present.
"""
# Quick check: at least one rate limit header must exist
has_any = any(k.lower().startswith("x-ratelimit-") for k in headers)
if not has_any:
return None
now = time.time()
def _bucket(resource: str, suffix: str = "") -> RateLimitBucket:
# e.g. resource="requests", suffix="" -> per-minute
# resource="tokens", suffix="-1h" -> per-hour
tag = f"{resource}{suffix}"
return RateLimitBucket(
limit=_safe_int(headers.get(f"x-ratelimit-limit-{tag}")),
remaining=_safe_int(headers.get(f"x-ratelimit-remaining-{tag}")),
reset_seconds=_safe_float(headers.get(f"x-ratelimit-reset-{tag}")),
captured_at=now,
)
return RateLimitState(
requests_min=_bucket("requests"),
requests_hour=_bucket("requests", "-1h"),
tokens_min=_bucket("tokens"),
tokens_hour=_bucket("tokens", "-1h"),
captured_at=now,
provider=provider,
)
# ── Formatting ──────────────────────────────────────────────────────────
def _fmt_count(n: int) -> str:
"""Human-friendly number: 7999856 -> '8.0M', 33599 -> '33.6K', 799 -> '799'."""
if n >= 1_000_000:
return f"{n / 1_000_000:.1f}M"
if n >= 10_000:
return f"{n / 1_000:.1f}K"
if n >= 1_000:
return f"{n / 1_000:.1f}K"
return str(n)
def _fmt_seconds(seconds: float) -> str:
"""Seconds -> human-friendly duration: '58s', '2m 14s', '58m 57s', '1h 2m'."""
s = max(0, int(seconds))
if s < 60:
return f"{s}s"
if s < 3600:
m, sec = divmod(s, 60)
return f"{m}m {sec}s" if sec else f"{m}m"
h, remainder = divmod(s, 3600)
m = remainder // 60
return f"{h}h {m}m" if m else f"{h}h"
def _bar(pct: float, width: int = 20) -> str:
"""ASCII progress bar: [████████░░░░░░░░░░░░] 40%."""
filled = int(pct / 100.0 * width)
filled = max(0, min(width, filled))
empty = width - filled
return f"[{'' * filled}{'' * empty}]"
def _bucket_line(label: str, bucket: RateLimitBucket, label_width: int = 14) -> str:
"""Format one bucket as a single line."""
if bucket.limit <= 0:
return f" {label:<{label_width}} (no data)"
pct = bucket.usage_pct
used = _fmt_count(bucket.used)
limit = _fmt_count(bucket.limit)
remaining = _fmt_count(bucket.remaining)
reset = _fmt_seconds(bucket.remaining_seconds_now)
bar = _bar(pct)
return f" {label:<{label_width}} {bar} {pct:5.1f}% {used}/{limit} used ({remaining} left, resets in {reset})"
def format_rate_limit_display(state: RateLimitState) -> str:
"""Format rate limit state for terminal/chat display."""
if not state.has_data:
return "No rate limit data yet — make an API request first."
age = state.age_seconds
if age < 5:
freshness = "just now"
elif age < 60:
freshness = f"{int(age)}s ago"
else:
freshness = f"{_fmt_seconds(age)} ago"
provider_label = state.provider.title() if state.provider else "Provider"
lines = [
f"{provider_label} Rate Limits (captured {freshness}):",
"",
_bucket_line("Requests/min", state.requests_min),
_bucket_line("Requests/hr", state.requests_hour),
"",
_bucket_line("Tokens/min", state.tokens_min),
_bucket_line("Tokens/hr", state.tokens_hour),
]
# Add warnings if any bucket is getting hot
warnings = []
for label, bucket in [
("requests/min", state.requests_min),
("requests/hr", state.requests_hour),
("tokens/min", state.tokens_min),
("tokens/hr", state.tokens_hour),
]:
if bucket.limit > 0 and bucket.usage_pct >= 80:
reset = _fmt_seconds(bucket.remaining_seconds_now)
warnings.append(f"{label} at {bucket.usage_pct:.0f}% — resets in {reset}")
if warnings:
lines.append("")
lines.extend(warnings)
return "\n".join(lines)
def format_rate_limit_compact(state: RateLimitState) -> str:
"""One-line compact summary for status bars / gateway messages."""
if not state.has_data:
return "No rate limit data."
rm = state.requests_min
tm = state.tokens_min
rh = state.requests_hour
th = state.tokens_hour
parts = []
if rm.limit > 0:
parts.append(f"RPM: {rm.remaining}/{rm.limit}")
if rh.limit > 0:
parts.append(f"RPH: {_fmt_count(rh.remaining)}/{_fmt_count(rh.limit)} (resets {_fmt_seconds(rh.remaining_seconds_now)})")
if tm.limit > 0:
parts.append(f"TPM: {_fmt_count(tm.remaining)}/{_fmt_count(tm.limit)}")
if th.limit > 0:
parts.append(f"TPH: {_fmt_count(th.remaining)}/{_fmt_count(th.limit)} (resets {_fmt_seconds(th.remaining_seconds_now)})")
return " | ".join(parts)
+2 -8
View File
@@ -159,10 +159,7 @@ class SubdirectoryHintTracker:
def _is_valid_subdir(self, path: Path) -> bool:
"""Check if path is a valid directory to scan for hints."""
try:
if not path.is_dir():
return False
except OSError:
if not path.is_dir():
return False
if path in self._loaded_dirs:
return False
@@ -175,10 +172,7 @@ class SubdirectoryHintTracker:
found_hints = []
for filename in _HINT_FILENAMES:
hint_path = directory / filename
try:
if not hint_path.is_file():
continue
except OSError:
if not hint_path.is_file():
continue
try:
content = hint_path.read_text(encoding="utf-8").strip()
+24
View File
@@ -595,6 +595,30 @@ def get_pricing(
}
def estimate_cost_usd(
model: str,
input_tokens: int,
output_tokens: int,
*,
provider: Optional[str] = None,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
) -> float:
"""Backward-compatible helper for legacy callers.
This uses non-cached input/output only. New code should call
`estimate_usage_cost()` with canonical usage buckets.
"""
result = estimate_usage_cost(
model,
CanonicalUsage(input_tokens=input_tokens, output_tokens=output_tokens),
provider=provider,
base_url=base_url,
api_key=api_key,
)
return float(result.amount_usd or _ZERO)
def format_duration_compact(seconds: float) -> str:
if seconds < 60:
return f"{seconds:.0f}s"
+2 -2
View File
@@ -1158,7 +1158,7 @@ def main(
providers_order (str): Comma-separated list of OpenRouter providers to try in order (e.g. "anthropic,openai,google")
provider_sort (str): Sort providers by "price", "throughput", or "latency" (OpenRouter only)
max_tokens (int): Maximum tokens for model responses (optional, uses model default if not set)
reasoning_effort (str): OpenRouter reasoning effort level: "none", "minimal", "low", "medium", "high", "xhigh" (default: "medium")
reasoning_effort (str): OpenRouter reasoning effort level: "xhigh", "high", "medium", "low", "minimal", "none" (default: "medium")
reasoning_disabled (bool): Completely disable reasoning/thinking tokens (default: False)
prefill_messages_file (str): Path to JSON file containing prefill messages (list of {role, content} dicts)
max_samples (int): Only process the first N samples from the dataset (optional, processes all if not set)
@@ -1227,7 +1227,7 @@ def main(
print("🧠 Reasoning: DISABLED (effort=none)")
elif reasoning_effort:
# Use specified effort level
valid_efforts = ["none", "minimal", "low", "medium", "high", "xhigh"]
valid_efforts = ["xhigh", "high", "medium", "low", "minimal", "none"]
if reasoning_effort not in valid_efforts:
print(f"❌ Error: --reasoning_effort must be one of: {', '.join(valid_efforts)}")
return
+2 -27
View File
@@ -48,25 +48,6 @@ model:
# api_key: "your-key-here" # Uncomment to set here instead of .env
base_url: "https://openrouter.ai/api/v1"
# ── Token limits — two settings, easy to confuse ──────────────────────────
#
# context_length: TOTAL context window (input + output tokens combined).
# Controls when Hermes compresses history and validates requests.
# Leave unset — Hermes auto-detects the correct value from the provider.
# Set manually only when auto-detection is wrong (e.g. a local server with
# a custom num_ctx, or a proxy that doesn't expose /v1/models).
#
# context_length: 131072
#
# max_tokens: OUTPUT cap — maximum tokens the model may generate per response.
# Unrelated to how long your conversation history can be.
# The OpenAI-standard name "max_tokens" is a misnomer; Anthropic's native
# API has since renamed it "max_output_tokens" for clarity.
# Leave unset to use the model's native output ceiling (recommended).
# Set only if you want to deliberately limit individual response length.
#
# max_tokens: 8192
# =============================================================================
# OpenRouter Provider Routing (only applies when using OpenRouter)
# =============================================================================
@@ -136,8 +117,7 @@ terminal:
timeout: 180
docker_mount_cwd_to_workspace: false # SECURITY: off by default. Opt in to mount the launch cwd into Docker /workspace.
lifetime_seconds: 300
# sudo_password: "hunter2" # Optional: pipe a sudo password via sudo -S. SECURITY WARNING: plaintext.
# sudo_password: "" # Explicit empty password: try empty and never open the interactive sudo prompt.
# sudo_password: "" # Enable sudo commands (pipes via sudo -S) - SECURITY WARNING: plaintext!
# -----------------------------------------------------------------------------
# OPTION 2: SSH remote execution
@@ -228,18 +208,13 @@ terminal:
#
# SECURITY WARNING: Password stored in plaintext!
#
# INTERACTIVE PROMPT: If sudo_password is unset and the CLI is running,
# INTERACTIVE PROMPT: If no sudo_password is set and the CLI is running,
# you'll be prompted to enter your password when sudo is needed:
# - 45-second timeout (auto-skips if no input)
# - Press Enter to skip (command fails gracefully)
# - Password is hidden while typing
# - Password is cached for the session
#
# EMPTY PASSWORDS: Setting sudo_password to an explicit empty string is different
# from leaving it unset. Hermes will try an empty password via `sudo -S` and
# will not open the interactive prompt. This is useful for passwordless sudo,
# Touch ID sudo setups, and environments where prompting is just noise.
#
# ALTERNATIVES:
# - SSH backend: Configure passwordless sudo on the remote server
# - Containers: Run as root inside the container (no sudo needed)
+90 -94
View File
@@ -1118,6 +1118,14 @@ HERMES_CADUCEUS = """[#CD7F32]⠀⢀⣀⡀⠀⣀⣀
[#B8860B]⠀⠈⠳⠈⣡⠞⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[/]
[#B8860B]⠀⠈⠀[/]"""
# Compact banner for smaller terminals (fallback)
# Note: built dynamically by _build_compact_banner() to fit terminal width
COMPACT_BANNER = """
[bold #FFD700]╔══════════════════════════════════════════════════════════════╗[/]
[bold #FFD700]║[/] [#FFBF00]⚕ NOUS HERMES[/] [dim #B8860B]- AI Agent Framework[/] [bold #FFD700]║[/]
[bold #FFD700]║[/] [#CD7F32]Messenger of the Digital Gods[/] [dim #B8860B]Nous Research[/] [bold #FFD700]║[/]
[bold #FFD700]╚══════════════════════════════════════════════════════════════╝[/]
"""
def _build_compact_banner() -> str:
@@ -1363,6 +1371,7 @@ class HermesCLI:
self._stream_buf = "" # Partial line buffer for line-buffered rendering
self._stream_started = False # True once first delta arrives
self._stream_box_opened = False # True once the response box header is printed
self._reasoning_stream_started = False # True once live reasoning starts streaming
self._reasoning_preview_buf = "" # Coalesce tiny reasoning chunks for [thinking] output
self._pending_edit_snapshots = {}
@@ -1420,6 +1429,8 @@ class HermesCLI:
self.api_key = api_key or os.getenv("OPENROUTER_API_KEY") or os.getenv("OPENAI_API_KEY")
else:
self.api_key = api_key or os.getenv("OPENAI_API_KEY") or os.getenv("OPENROUTER_API_KEY")
self._nous_key_expires_at: Optional[str] = None
self._nous_key_source: Optional[str] = None
# Max turns priority: CLI arg > config file > env var > default
if max_turns is not None: # CLI arg was explicitly set
self.max_turns = max_turns
@@ -1535,7 +1546,6 @@ class HermesCLI:
self._clarify_deadline = 0
self._sudo_state = None
self._sudo_deadline = 0
self._modal_input_snapshot = None
self._approval_state = None
self._approval_deadline = 0
self._approval_lock = threading.Lock()
@@ -1592,12 +1602,7 @@ class HermesCLI:
return f"[{('' * filled) + ('' * max(0, width - filled))}]"
def _get_status_bar_snapshot(self) -> Dict[str, Any]:
# Prefer the agent's model name — it updates on fallback.
# self.model reflects the originally configured model and never
# changes mid-session, so the TUI would show a stale name after
# _try_activate_fallback() switches provider/model.
agent = getattr(self, "agent", None)
model_name = (getattr(agent, "model", None) or self.model or "unknown")
model_name = self.model or "unknown"
model_short = model_name.split("/")[-1] if "/" in model_name else model_name
if model_short.endswith(".gguf"):
model_short = model_short[:-5]
@@ -1623,6 +1628,7 @@ class HermesCLI:
"compressions": 0,
}
agent = getattr(self, "agent", None)
if not agent:
return snapshot
@@ -1995,6 +2001,7 @@ class HermesCLI:
"""
if not text:
return
self._reasoning_stream_started = True
self._reasoning_shown_this_turn = True
if getattr(self, "_stream_box_opened", False):
return
@@ -2204,6 +2211,7 @@ class HermesCLI:
self._stream_buf = ""
self._stream_started = False
self._stream_box_opened = False
self._reasoning_stream_started = False
self._stream_text_ansi = ""
self._stream_prefilt = ""
self._in_reasoning_block = False
@@ -3995,7 +4003,59 @@ class HermesCLI:
print(" To change model or provider, use: hermes model")
def _handle_prompt_command(self, cmd: str):
"""Handle the /prompt command to view or set system prompt."""
parts = cmd.split(maxsplit=1)
if len(parts) > 1:
# Set new prompt
new_prompt = parts[1].strip()
if new_prompt.lower() == "clear":
self.system_prompt = ""
self.agent = None # Force re-init
if save_config_value("agent.system_prompt", ""):
print("(^_^)b System prompt cleared (saved to config)")
else:
print("(^_^) System prompt cleared (session only)")
else:
self.system_prompt = new_prompt
self.agent = None # Force re-init
if save_config_value("agent.system_prompt", new_prompt):
print("(^_^)b System prompt set (saved to config)")
else:
print("(^_^) System prompt set (session only)")
print(f" \"{new_prompt[:60]}{'...' if len(new_prompt) > 60 else ''}\"")
else:
# Show current prompt
print()
print("+" + "-" * 50 + "+")
print("|" + " " * 15 + "(^_^) System Prompt" + " " * 15 + "|")
print("+" + "-" * 50 + "+")
print()
if self.system_prompt:
# Word wrap the prompt for display
words = self.system_prompt.split()
lines = []
current_line = ""
for word in words:
if len(current_line) + len(word) + 1 <= 50:
current_line += (" " if current_line else "") + word
else:
lines.append(current_line)
current_line = word
if current_line:
lines.append(current_line)
for line in lines:
print(f" {line}")
else:
print(" (no custom prompt set - using default)")
print()
print(" Usage:")
print(" /prompt <text> - Set a custom system prompt")
print(" /prompt clear - Remove custom prompt")
print(" /personality - Use a predefined personality")
print()
@staticmethod
@@ -4495,7 +4555,9 @@ class HermesCLI:
self._handle_model_switch(cmd_original)
elif canonical == "provider":
self._show_model_and_providers()
elif canonical == "prompt":
# Use original case so prompt text isn't lowercased
self._handle_prompt_command(cmd_original)
elif canonical == "personality":
# Use original case (handler lowercases the personality name itself)
self._handle_personality_command(cmd_original)
@@ -4975,9 +5037,6 @@ class HermesCLI:
def _try_launch_chrome_debug(port: int, system: str) -> bool:
"""Try to launch Chrome/Chromium with remote debugging enabled.
Uses a dedicated user-data-dir so the debug instance doesn't conflict
with an already-running Chrome using the default profile.
Returns True if a launch command was executed (doesn't guarantee success).
"""
import subprocess as _sp
@@ -4987,20 +5046,10 @@ class HermesCLI:
if not candidates:
return False
# Dedicated profile dir so debug Chrome won't collide with normal Chrome
data_dir = str(_hermes_home / "chrome-debug")
os.makedirs(data_dir, exist_ok=True)
chrome = candidates[0]
try:
_sp.Popen(
[
chrome,
f"--remote-debugging-port={port}",
f"--user-data-dir={data_dir}",
"--no-first-run",
"--no-default-browser-check",
],
[chrome, f"--remote-debugging-port={port}"],
stdout=_sp.DEVNULL,
stderr=_sp.DEVNULL,
start_new_session=True, # detach from terminal
@@ -5075,33 +5124,18 @@ class HermesCLI:
print(f" ✓ Chrome launched and listening on port {_port}")
else:
print(f" ⚠ Chrome launched but port {_port} isn't responding yet")
print(" Try again in a few seconds — the debug instance may still be starting")
print(" You may need to close existing Chrome windows first and retry")
else:
print(" ⚠ Could not auto-launch Chrome")
# Show manual instructions as fallback
_data_dir = str(_hermes_home / "chrome-debug")
sys_name = _plat.system()
if sys_name == "Darwin":
chrome_cmd = (
'open -a "Google Chrome" --args'
f" --remote-debugging-port=9222"
f' --user-data-dir="{_data_dir}"'
" --no-first-run --no-default-browser-check"
)
chrome_cmd = 'open -a "Google Chrome" --args --remote-debugging-port=9222'
elif sys_name == "Windows":
chrome_cmd = (
f'chrome.exe --remote-debugging-port=9222'
f' --user-data-dir="{_data_dir}"'
f" --no-first-run --no-default-browser-check"
)
chrome_cmd = 'chrome.exe --remote-debugging-port=9222'
else:
chrome_cmd = (
f"google-chrome --remote-debugging-port=9222"
f' --user-data-dir="{_data_dir}"'
f" --no-first-run --no-default-browser-check"
)
print(f" Launch Chrome manually:")
print(f" {chrome_cmd}")
chrome_cmd = "google-chrome --remote-debugging-port=9222"
print(f" Launch Chrome manually: {chrome_cmd}")
else:
print(f" ⚠ Port {_port} is not reachable at {cdp_url}")
@@ -5274,7 +5308,7 @@ class HermesCLI:
Usage:
/reasoning Show current effort level and display state
/reasoning <level> Set reasoning effort (none, minimal, low, medium, high, xhigh)
/reasoning <level> Set reasoning effort (none, low, medium, high, xhigh)
/reasoning show|on Show model thinking/reasoning in output
/reasoning hide|off Hide model thinking/reasoning from output
"""
@@ -5292,7 +5326,7 @@ class HermesCLI:
display_state = "on ✓" if self.show_reasoning else "off"
_cprint(f" {_GOLD}Reasoning effort: {level}{_RST}")
_cprint(f" {_GOLD}Reasoning display: {display_state}{_RST}")
_cprint(f" {_DIM}Usage: /reasoning <none|minimal|low|medium|high|xhigh|show|hide>{_RST}")
_cprint(f" {_DIM}Usage: /reasoning <none|low|medium|high|xhigh|show|hide>{_RST}")
return
arg = parts[1].strip().lower()
@@ -5318,7 +5352,7 @@ class HermesCLI:
parsed = _parse_reasoning_config(arg)
if parsed is None:
_cprint(f" {_DIM}(._.) Unknown argument: {arg}{_RST}")
_cprint(f" {_DIM}Valid levels: none, minimal, low, medium, high, xhigh{_RST}")
_cprint(f" {_DIM}Valid levels: none, low, minimal, medium, high, xhigh{_RST}")
_cprint(f" {_DIM}Display: show, hide{_RST}")
return
@@ -5357,7 +5391,7 @@ class HermesCLI:
approx_tokens = estimate_messages_tokens_rough(self.conversation_history)
print(f"🗜️ Compressing {original_count} messages (~{approx_tokens:,} tokens)...")
compressed, _new_system = self.agent._compress_context(
compressed, new_system = self.agent._compress_context(
self.conversation_history,
self.agent._cached_system_prompt or "",
approx_tokens=approx_tokens,
@@ -5374,27 +5408,12 @@ class HermesCLI:
print(f" ❌ Compression failed: {e}")
def _show_usage(self):
"""Show rate limits (if available) and session token usage."""
"""Show cumulative token usage for the current session."""
if not self.agent:
print("(._.) No active agent -- send a message first.")
return
agent = self.agent
calls = agent.session_api_calls
if calls == 0:
print("(._.) No API calls made yet in this session.")
return
# ── Rate limits (shown first when available) ────────────────
rl_state = agent.get_rate_limit_state()
if rl_state and rl_state.has_data:
from agent.rate_limit_tracker import format_rate_limit_display
print()
print(format_rate_limit_display(rl_state))
print()
# ── Session token usage ─────────────────────────────────────
input_tokens = getattr(agent, "session_input_tokens", 0) or 0
output_tokens = getattr(agent, "session_output_tokens", 0) or 0
cache_read_tokens = getattr(agent, "session_cache_read_tokens", 0) or 0
@@ -5402,7 +5421,13 @@ class HermesCLI:
prompt = agent.session_prompt_tokens
completion = agent.session_completion_tokens
total = agent.session_total_tokens
calls = agent.session_api_calls
if calls == 0:
print("(._.) No API calls made yet in this session.")
return
# Current context window state
compressor = agent.context_compressor
last_prompt = compressor.last_prompt_tokens
ctx_len = compressor.context_length
@@ -6180,7 +6205,6 @@ class HermesCLI:
timeout = 45
response_queue = queue.Queue()
self._capture_modal_input_snapshot()
self._sudo_state = {
"response_queue": response_queue,
}
@@ -6193,7 +6217,6 @@ class HermesCLI:
result = response_queue.get(timeout=1)
self._sudo_state = None
self._sudo_deadline = 0
self._restore_modal_input_snapshot()
self._invalidate()
if result:
_cprint(f"\n{_DIM} ✓ Password received (cached for session){_RST}")
@@ -6208,7 +6231,6 @@ class HermesCLI:
self._sudo_state = None
self._sudo_deadline = 0
self._restore_modal_input_snapshot()
self._invalidate()
_cprint(f"\n{_DIM} ⏱ Timeout — continuing without sudo{_RST}")
return ""
@@ -6381,33 +6403,6 @@ class HermesCLI:
def _secret_capture_callback(self, var_name: str, prompt: str, metadata=None) -> dict:
return prompt_for_secret(self, var_name, prompt, metadata)
def _capture_modal_input_snapshot(self) -> None:
"""Temporarily clear the input buffer and save the user's in-progress draft."""
if self._modal_input_snapshot is not None or not getattr(self, "_app", None):
return
try:
buf = self._app.current_buffer
self._modal_input_snapshot = {
"text": buf.text,
"cursor_position": buf.cursor_position,
}
buf.reset()
except Exception:
self._modal_input_snapshot = None
def _restore_modal_input_snapshot(self) -> None:
"""Restore any draft text that was present before a modal prompt opened."""
snapshot = self._modal_input_snapshot
self._modal_input_snapshot = None
if not snapshot or not getattr(self, "_app", None):
return
try:
buf = self._app.current_buffer
buf.text = snapshot.get("text", "")
buf.cursor_position = min(snapshot.get("cursor_position", 0), len(buf.text))
except Exception:
pass
def _submit_secret_response(self, value: str) -> None:
if not self._secret_state:
return
@@ -7135,7 +7130,6 @@ class HermesCLI:
# Sudo password prompt state (similar mechanism to clarify)
self._sudo_state = None # dict with response_queue when active
self._sudo_deadline = 0
self._modal_input_snapshot = None
# Dangerous command approval state (similar mechanism to clarify)
self._approval_state = None # dict with command, description, choices, selected, response_queue
@@ -7207,6 +7201,7 @@ class HermesCLI:
text = event.app.current_buffer.text
self._sudo_state["response_queue"].put(text)
self._sudo_state = None
event.app.current_buffer.reset()
event.app.invalidate()
return
@@ -7411,6 +7406,7 @@ class HermesCLI:
if self._sudo_state:
self._sudo_state["response_queue"].put("")
self._sudo_state = None
event.app.current_buffer.reset()
event.app.invalidate()
return
Generated
+4 -4
View File
@@ -22,16 +22,16 @@
},
"nixpkgs": {
"locked": {
"lastModified": 1775036866,
"narHash": "sha256-ZojAnPuCdy657PbTq5V0Y+AHKhZAIwSIT2cb8UgAz/U=",
"lastModified": 1751274312,
"narHash": "sha256-/bVBlRpECLVzjV19t5KMdMFWSwKLtb5RyXdjz3LJT+g=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "6201e203d09599479a3b3450ed24fa81537ebc4e",
"rev": "50ab793786d9de88ee30ec4e4c24fb4236fc2674",
"type": "github"
},
"original": {
"owner": "NixOS",
"ref": "nixos-unstable",
"ref": "nixos-24.11",
"repo": "nixpkgs",
"type": "github"
}
+1 -1
View File
@@ -2,7 +2,7 @@
description = "Hermes Agent - AI agent framework by Nous Research";
inputs = {
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
nixpkgs.url = "github:NixOS/nixpkgs/nixos-24.11";
flake-parts = {
url = "github:hercules-ci/flake-parts";
inputs.nixpkgs-lib.follows = "nixpkgs";
-15
View File
@@ -532,8 +532,6 @@ def load_gateway_config() -> GatewayConfig:
bridged["reply_prefix"] = platform_cfg["reply_prefix"]
if "require_mention" in platform_cfg:
bridged["require_mention"] = platform_cfg["require_mention"]
if "free_response_channels" in platform_cfg:
bridged["free_response_channels"] = platform_cfg["free_response_channels"]
if "mention_patterns" in platform_cfg:
bridged["mention_patterns"] = platform_cfg["mention_patterns"]
if not bridged:
@@ -548,19 +546,6 @@ def load_gateway_config() -> GatewayConfig:
plat_data["extra"] = extra
extra.update(bridged)
# Slack settings → env vars (env vars take precedence)
slack_cfg = yaml_cfg.get("slack", {})
if isinstance(slack_cfg, dict):
if "require_mention" in slack_cfg and not os.getenv("SLACK_REQUIRE_MENTION"):
os.environ["SLACK_REQUIRE_MENTION"] = str(slack_cfg["require_mention"]).lower()
if "allow_bots" in slack_cfg and not os.getenv("SLACK_ALLOW_BOTS"):
os.environ["SLACK_ALLOW_BOTS"] = str(slack_cfg["allow_bots"]).lower()
frc = slack_cfg.get("free_response_channels")
if frc is not None and not os.getenv("SLACK_FREE_RESPONSE_CHANNELS"):
if isinstance(frc, list):
frc = ",".join(str(v) for v in frc)
os.environ["SLACK_FREE_RESPONSE_CHANNELS"] = str(frc)
# Discord settings → env vars (env vars take precedence)
discord_cfg = yaml_cfg.get("discord", {})
if isinstance(discord_cfg, dict):
+61
View File
@@ -124,6 +124,53 @@ class DeliveryRouter:
self.adapters = adapters or {}
self.output_dir = get_hermes_home() / "cron" / "output"
def resolve_targets(
self,
deliver: Union[str, List[str]],
origin: Optional[SessionSource] = None
) -> List[DeliveryTarget]:
"""
Resolve delivery specification to concrete targets.
Args:
deliver: Delivery spec - "origin", "telegram", ["local", "discord"], etc.
origin: The source where the request originated (for "origin" target)
Returns:
List of resolved delivery targets
"""
if isinstance(deliver, str):
deliver = [deliver]
targets = []
seen_platforms = set()
for target_str in deliver:
target = DeliveryTarget.parse(target_str, origin)
# Resolve home channel if needed
if target.chat_id is None and target.platform != Platform.LOCAL:
home = self.config.get_home_channel(target.platform)
if home:
target.chat_id = home.chat_id
else:
# No home channel configured, skip this platform
continue
# Deduplicate
key = (target.platform, target.chat_id, target.thread_id)
if key not in seen_platforms:
seen_platforms.add(key)
targets.append(target)
# Always include local if configured
if self.config.always_log_local:
local_key = (Platform.LOCAL, None, None)
if local_key not in seen_platforms:
targets.append(DeliveryTarget(platform=Platform.LOCAL))
return targets
async def deliver(
self,
content: str,
@@ -252,5 +299,19 @@ class DeliveryRouter:
return await adapter.send(target.chat_id, content, metadata=send_metadata or None)
def parse_deliver_spec(
deliver: Optional[Union[str, List[str]]],
origin: Optional[SessionSource] = None,
default: str = "origin"
) -> Union[str, List[str]]:
"""
Normalize a delivery specification.
If None or empty, returns the default.
"""
if not deliver:
return default
return deliver
+1 -126
View File
@@ -10,142 +10,18 @@ import logging
import os
import random
import re
import subprocess
import sys
import uuid
from abc import ABC, abstractmethod
from urllib.parse import urlsplit
logger = logging.getLogger(__name__)
def _detect_macos_system_proxy() -> str | None:
"""Read the macOS system HTTP(S) proxy via ``scutil --proxy``.
Returns an ``http://host:port`` URL string if an HTTP or HTTPS proxy is
enabled, otherwise *None*. Falls back silently on non-macOS or on any
subprocess error.
"""
if sys.platform != "darwin":
return None
try:
out = subprocess.check_output(
["scutil", "--proxy"], timeout=3, text=True, stderr=subprocess.DEVNULL,
)
except Exception:
return None
props: dict[str, str] = {}
for line in out.splitlines():
line = line.strip()
if " : " in line:
key, _, val = line.partition(" : ")
props[key.strip()] = val.strip()
# Prefer HTTPS, fall back to HTTP
for enable_key, host_key, port_key in (
("HTTPSEnable", "HTTPSProxy", "HTTPSPort"),
("HTTPEnable", "HTTPProxy", "HTTPPort"),
):
if props.get(enable_key) == "1":
host = props.get(host_key)
port = props.get(port_key)
if host and port:
return f"http://{host}:{port}"
return None
def resolve_proxy_url(platform_env_var: str | None = None) -> str | None:
"""Return a proxy URL from env vars, or macOS system proxy.
Check order:
0. *platform_env_var* (e.g. ``DISCORD_PROXY``) highest priority
1. HTTPS_PROXY / HTTP_PROXY / ALL_PROXY (and lowercase variants)
2. macOS system proxy via ``scutil --proxy`` (auto-detect)
Returns *None* if no proxy is found.
"""
if platform_env_var:
value = (os.environ.get(platform_env_var) or "").strip()
if value:
return value
for key in ("HTTPS_PROXY", "HTTP_PROXY", "ALL_PROXY",
"https_proxy", "http_proxy", "all_proxy"):
value = (os.environ.get(key) or "").strip()
if value:
return value
return _detect_macos_system_proxy()
def proxy_kwargs_for_bot(proxy_url: str | None) -> dict:
"""Build kwargs for ``commands.Bot()`` / ``discord.Client()`` with proxy.
Returns:
- SOCKS URL ``{"connector": ProxyConnector(..., rdns=True)}``
- HTTP URL ``{"proxy": url}``
- *None* ``{}``
``rdns=True`` forces remote DNS resolution through the proxy required
by many SOCKS implementations (Shadowrocket, Clash) and essential for
bypassing DNS pollution behind the GFW.
"""
if not proxy_url:
return {}
if proxy_url.lower().startswith("socks"):
try:
from aiohttp_socks import ProxyConnector
connector = ProxyConnector.from_url(proxy_url, rdns=True)
return {"connector": connector}
except ImportError:
logger.warning(
"aiohttp_socks not installed — SOCKS proxy %s ignored. "
"Run: pip install aiohttp-socks",
proxy_url,
)
return {}
return {"proxy": proxy_url}
def proxy_kwargs_for_aiohttp(proxy_url: str | None) -> tuple[dict, dict]:
"""Build kwargs for standalone ``aiohttp.ClientSession`` with proxy.
Returns ``(session_kwargs, request_kwargs)`` where:
- SOCKS ``({"connector": ProxyConnector(...)}, {})``
- HTTP ``({}, {"proxy": url})``
- None ``({}, {})``
Usage::
sess_kw, req_kw = proxy_kwargs_for_aiohttp(proxy_url)
async with aiohttp.ClientSession(**sess_kw) as session:
async with session.get(url, **req_kw) as resp:
...
"""
if not proxy_url:
return {}, {}
if proxy_url.lower().startswith("socks"):
try:
from aiohttp_socks import ProxyConnector
connector = ProxyConnector.from_url(proxy_url, rdns=True)
return {"connector": connector}, {}
except ImportError:
logger.warning(
"aiohttp_socks not installed — SOCKS proxy %s ignored. "
"Run: pip install aiohttp-socks",
proxy_url,
)
return {}, {}
return {}, {"proxy": proxy_url}
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Any, Callable, Awaitable, Tuple
from enum import Enum
import sys
from pathlib import Path as _Path
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
@@ -422,7 +298,6 @@ SUPPORTED_DOCUMENT_TYPES = {
".pdf": "application/pdf",
".md": "text/markdown",
".txt": "text/plain",
".log": "text/plain",
".zip": "application/zip",
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
+8 -22
View File
@@ -529,17 +529,10 @@ class DiscordAdapter(BasePlatformAdapter):
intents.members = any(not entry.isdigit() for entry in self._allowed_user_ids)
intents.voice_states = True
# Resolve proxy (DISCORD_PROXY > generic env vars > macOS system proxy)
from gateway.platforms.base import resolve_proxy_url, proxy_kwargs_for_bot
proxy_url = resolve_proxy_url(platform_env_var="DISCORD_PROXY")
if proxy_url:
logger.info("[%s] Using proxy for Discord: %s", self.name, proxy_url)
# Create bot — proxy= for HTTP, connector= for SOCKS
# Create bot
self._client = commands.Bot(
command_prefix="!", # Not really used, we handle raw messages
intents=intents,
**proxy_kwargs_for_bot(proxy_url),
)
adapter_self = self # capture for closure
@@ -1314,11 +1307,8 @@ class DiscordAdapter(BasePlatformAdapter):
# Download the image and send as a Discord file attachment
# (Discord renders attachments inline, unlike plain URLs)
from gateway.platforms.base import resolve_proxy_url, proxy_kwargs_for_aiohttp
_proxy = resolve_proxy_url(platform_env_var="DISCORD_PROXY")
_sess_kw, _req_kw = proxy_kwargs_for_aiohttp(_proxy)
async with aiohttp.ClientSession(**_sess_kw) as session:
async with session.get(image_url, timeout=aiohttp.ClientTimeout(total=30), **_req_kw) as resp:
async with aiohttp.ClientSession() as session:
async with session.get(image_url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
if resp.status != 200:
raise Exception(f"Failed to download image: HTTP {resp.status}")
@@ -1595,7 +1585,7 @@ class DiscordAdapter(BasePlatformAdapter):
await self._run_simple_slash(interaction, f"/model {name}".strip())
@tree.command(name="reasoning", description="Show or change reasoning effort")
@discord.app_commands.describe(effort="Reasoning effort: none, minimal, low, medium, high, or xhigh.")
@discord.app_commands.describe(effort="Reasoning effort: xhigh, high, medium, low, minimal, or none.")
async def slash_reasoning(interaction: discord.Interaction, effort: str = ""):
await self._run_simple_slash(interaction, f"/reasoning {effort}".strip())
@@ -2392,7 +2382,7 @@ class DiscordAdapter(BasePlatformAdapter):
ext or "unknown", content_type,
)
else:
MAX_DOC_BYTES = 32 * 1024 * 1024
MAX_DOC_BYTES = 20 * 1024 * 1024
if att.size and att.size > MAX_DOC_BYTES:
logger.warning(
"[Discord] Document too large (%s bytes), skipping: %s",
@@ -2401,14 +2391,10 @@ class DiscordAdapter(BasePlatformAdapter):
else:
try:
import aiohttp
from gateway.platforms.base import resolve_proxy_url, proxy_kwargs_for_aiohttp
_proxy = resolve_proxy_url(platform_env_var="DISCORD_PROXY")
_sess_kw, _req_kw = proxy_kwargs_for_aiohttp(_proxy)
async with aiohttp.ClientSession(**_sess_kw) as session:
async with aiohttp.ClientSession() as session:
async with session.get(
att.url,
timeout=aiohttp.ClientTimeout(total=30),
**_req_kw,
) as resp:
if resp.status != 200:
raise Exception(f"HTTP {resp.status}")
@@ -2420,9 +2406,9 @@ class DiscordAdapter(BasePlatformAdapter):
media_urls.append(cached_path)
media_types.append(doc_mime)
logger.info("[Discord] Cached user document: %s", cached_path)
# Inject text content for plain-text documents (capped at 100 KB)
# Inject text content for .txt/.md files (capped at 100 KB)
MAX_TEXT_INJECT_BYTES = 100 * 1024
if ext in (".md", ".txt", ".log") and len(raw_bytes) <= MAX_TEXT_INJECT_BYTES:
if ext in (".md", ".txt") and len(raw_bytes) <= MAX_TEXT_INJECT_BYTES:
try:
text_content = raw_bytes.decode("utf-8")
display_name = att.filename or f"document{ext}"
+65 -219
View File
@@ -14,7 +14,6 @@ import logging
import os
import re
import time
from dataclasses import dataclass, field
from typing import Dict, Optional, Any, Tuple
try:
@@ -46,14 +45,6 @@ from gateway.platforms.base import (
logger = logging.getLogger(__name__)
@dataclass
class _ThreadContextCache:
"""Cache entry for fetched thread context."""
content: str
fetched_at: float = field(default_factory=time.monotonic)
message_count: int = 0
def check_slack_requirements() -> bool:
"""Check if Slack dependencies are available."""
return SLACK_AVAILABLE
@@ -110,9 +101,6 @@ class SlackAdapter(BasePlatformAdapter):
# session + memory scoping.
self._assistant_threads: Dict[Tuple[str, str], Dict[str, str]] = {}
self._ASSISTANT_THREADS_MAX = 5000
# Cache for _fetch_thread_context results: cache_key → _ThreadContextCache
self._thread_context_cache: Dict[str, _ThreadContextCache] = {}
self._THREAD_CACHE_TTL = 60.0
async def connect(self) -> bool:
"""Connect to Slack via Socket Mode."""
@@ -293,7 +281,6 @@ class SlackAdapter(BasePlatformAdapter):
kwargs = {
"channel": chat_id,
"text": chunk,
"mrkdwn": True,
}
if thread_ts:
kwargs["thread_ts"] = thread_ts
@@ -336,7 +323,9 @@ class SlackAdapter(BasePlatformAdapter):
if not self._app:
return SendResult(success=False, error="Not connected")
try:
# Convert standard markdown → Slack mrkdwn
formatted = self.format_message(content)
await self._get_client(chat_id).chat_update(
channel=chat_id,
ts=message_id,
@@ -468,36 +457,13 @@ class SlackAdapter(BasePlatformAdapter):
text = re.sub(r'(`[^`]+`)', lambda m: _ph(m.group(0)), text)
# 3) Convert markdown links [text](url) → <url|text>
def _convert_markdown_link(m):
label = m.group(1)
url = m.group(2).strip()
if url.startswith('<') and url.endswith('>'):
url = url[1:-1].strip()
return _ph(f'<{url}|{label}>')
text = re.sub(
r'\[([^\]]+)\]\(([^()]*(?:\([^()]*\)[^()]*)*)\)',
_convert_markdown_link,
r'\[([^\]]+)\]\(([^)]+)\)',
lambda m: _ph(f'<{m.group(2)}|{m.group(1)}>'),
text,
)
# 4) Protect existing Slack entities/manual links so escaping and later
# formatting passes don't break them.
text = re.sub(
r'(<(?:[@#!]|(?:https?|mailto|tel):)[^>\n]+>)',
lambda m: _ph(m.group(1)),
text,
)
# 5) Protect blockquote markers before escaping
text = re.sub(r'^(>+\s)', lambda m: _ph(m.group(0)), text, flags=re.MULTILINE)
# 6) Escape Slack control characters in remaining plain text.
# Unescape first so already-escaped input doesn't get double-escaped.
text = text.replace('&amp;', '&').replace('&lt;', '<').replace('&gt;', '>')
text = text.replace('&', '&amp;').replace('<', '&lt;').replace('>', '&gt;')
# 7) Convert headers (## Title) → *Title* (bold)
# 4) Convert headers (## Title) → *Title* (bold)
def _convert_header(m):
inner = m.group(1).strip()
# Strip redundant bold markers inside a header
@@ -508,39 +474,34 @@ class SlackAdapter(BasePlatformAdapter):
r'^#{1,6}\s+(.+)$', _convert_header, text, flags=re.MULTILINE
)
# 8) Convert bold+italic: ***text*** → *_text_* (Slack bold wrapping italic)
text = re.sub(
r'\*\*\*(.+?)\*\*\*',
lambda m: _ph(f'*_{m.group(1)}_*'),
text,
)
# 9) Convert bold: **text** → *text* (Slack bold)
# 5) Convert bold: **text** → *text* (Slack bold)
text = re.sub(
r'\*\*(.+?)\*\*',
lambda m: _ph(f'*{m.group(1)}*'),
text,
)
# 10) Convert italic: _text_ stays as _text_ (already Slack italic)
# Single *text* → _text_ (Slack italic)
# 6) Convert italic: _text_ stays as _text_ (already Slack italic)
# Single *text* → _text_ (Slack italic)
text = re.sub(
r'(?<!\*)\*([^*\n]+)\*(?!\*)',
lambda m: _ph(f'_{m.group(1)}_'),
text,
)
# 11) Convert strikethrough: ~~text~~ → ~text~
# 7) Convert strikethrough: ~~text~~ → ~text~
text = re.sub(
r'~~(.+?)~~',
lambda m: _ph(f'~{m.group(1)}~'),
text,
)
# 12) Blockquotes: > prefix is already protected by step 5 above.
# 8) Convert blockquotes: > text → > text (same syntax, just ensure
# no extra escaping happens to the > character)
# Slack uses the same > prefix, so this is a no-op for content.
# 13) Restore placeholders in reverse order
for key in reversed(placeholders):
# 9) Restore placeholders in reverse order
for key in reversed(list(placeholders.keys())):
text = text.replace(key, placeholders[key])
return text
@@ -953,26 +914,9 @@ class SlackAdapter(BasePlatformAdapter):
if v > cutoff
}
# Bot message filtering (SLACK_ALLOW_BOTS / config allow_bots):
# "none" — ignore all bot messages (default, backward-compatible)
# "mentions" — accept bot messages only when they @mention us
# "all" — accept all bot messages (except our own)
# Ignore bot messages (including our own)
if event.get("bot_id") or event.get("subtype") == "bot_message":
allow_bots = self.config.extra.get("allow_bots", "")
if not allow_bots:
allow_bots = os.getenv("SLACK_ALLOW_BOTS", "none")
allow_bots = str(allow_bots).lower().strip()
if allow_bots == "none":
return
elif allow_bots == "mentions":
text_check = event.get("text", "")
if self._bot_user_id and f"<@{self._bot_user_id}>" not in text_check:
return
# "all" falls through to process the message
# Always ignore our own messages to prevent echo loops
msg_user = event.get("user", "")
if msg_user and self._bot_user_id and msg_user == self._bot_user_id:
return
return
# Ignore message edits and deletions
subtype = event.get("subtype")
@@ -1004,7 +948,7 @@ class SlackAdapter(BasePlatformAdapter):
channel_type = event.get("channel_type", "")
if not channel_type and channel_id.startswith("D"):
channel_type = "im"
is_dm = channel_type in ("im", "mpim") # Both 1:1 and group DMs
is_dm = channel_type == "im"
# Build thread_ts for session keying.
# In channels: fall back to ts so each top-level @mention starts a
@@ -1017,8 +961,6 @@ class SlackAdapter(BasePlatformAdapter):
thread_ts = event.get("thread_ts") or ts # ts fallback for channels
# In channels, respond if:
# 0. Channel is in free_response_channels, OR require_mention is
# disabled — always process regardless of mention.
# 1. The bot is @mentioned in this message, OR
# 2. The message is a reply in a thread the bot started/participated in, OR
# 3. The message is in a thread where the bot was previously @mentioned, OR
@@ -1028,29 +970,24 @@ class SlackAdapter(BasePlatformAdapter):
event_thread_ts = event.get("thread_ts")
is_thread_reply = bool(event_thread_ts and event_thread_ts != ts)
if not is_dm and bot_uid:
if channel_id in self._slack_free_response_channels():
pass # Free-response channel — always process
elif not self._slack_require_mention():
pass # Mention requirement disabled globally for Slack
elif not is_mentioned:
reply_to_bot_thread = (
is_thread_reply and event_thread_ts in self._bot_message_ts
if not is_dm and bot_uid and not is_mentioned:
reply_to_bot_thread = (
is_thread_reply and event_thread_ts in self._bot_message_ts
)
in_mentioned_thread = (
event_thread_ts is not None
and event_thread_ts in self._mentioned_threads
)
has_session = (
is_thread_reply
and self._has_active_session_for_thread(
channel_id=channel_id,
thread_ts=event_thread_ts,
user_id=user_id,
)
in_mentioned_thread = (
event_thread_ts is not None
and event_thread_ts in self._mentioned_threads
)
has_session = (
is_thread_reply
and self._has_active_session_for_thread(
channel_id=channel_id,
thread_ts=event_thread_ts,
user_id=user_id,
)
)
if not reply_to_bot_thread and not in_mentioned_thread and not has_session:
return
)
if not reply_to_bot_thread and not in_mentioned_thread and not has_session:
return
if is_mentioned:
# Strip the bot mention from the text
@@ -1191,19 +1128,14 @@ class SlackAdapter(BasePlatformAdapter):
reply_to_message_id=thread_ts if thread_ts != ts else None,
)
# Only react when bot is directly addressed (DM or @mention).
# In listen-all channels (require_mention=false), reacting to every
# casual message would be noisy.
_should_react = is_dm or is_mentioned
if _should_react:
await self._add_reaction(channel_id, ts, "eyes")
# Add 👀 reaction to acknowledge receipt
await self._add_reaction(channel_id, ts, "eyes")
await self.handle_message(msg_event)
if _should_react:
await self._remove_reaction(channel_id, ts, "eyes")
await self._add_reaction(channel_id, ts, "white_check_mark")
# Replace 👀 with ✅ when done
await self._remove_reaction(channel_id, ts, "eyes")
await self._add_reaction(channel_id, ts, "white_check_mark")
# ----- Approval button support (Block Kit) -----
@@ -1297,20 +1229,6 @@ class SlackAdapter(BasePlatformAdapter):
msg_ts = message.get("ts", "")
channel_id = body.get("channel", {}).get("id", "")
user_name = body.get("user", {}).get("name", "unknown")
user_id = body.get("user", {}).get("id", "")
# Only authorized users may click approval buttons. Button clicks
# bypass the normal message auth flow in gateway/run.py, so we must
# check here as well.
allowed_csv = os.getenv("SLACK_ALLOWED_USERS", "").strip()
if allowed_csv:
allowed_ids = {uid.strip() for uid in allowed_csv.split(",") if uid.strip()}
if "*" not in allowed_ids and user_id not in allowed_ids:
logger.warning(
"[Slack] Unauthorized approval click by %s (%s) — ignoring",
user_name, user_id,
)
return
# Map action_id to approval choice
choice_map = {
@@ -1321,9 +1239,10 @@ class SlackAdapter(BasePlatformAdapter):
}
choice = choice_map.get(action_id, "deny")
# Prevent double-clicks — atomic pop; first caller gets False, others get True (default)
if self._approval_resolved.pop(msg_ts, True):
# Prevent double-clicks
if self._approval_resolved.get(msg_ts, False):
return
self._approval_resolved[msg_ts] = True
# Update the message to show the decision and remove buttons
label_map = {
@@ -1378,7 +1297,8 @@ class SlackAdapter(BasePlatformAdapter):
except Exception as exc:
logger.error("Failed to resolve gateway approval from Slack button: %s", exc)
# (approval state already consumed by atomic pop above)
# Clean up stale approval state
self._approval_resolved.pop(msg_ts, None)
# ----- Thread context fetching -----
@@ -1389,104 +1309,57 @@ class SlackAdapter(BasePlatformAdapter):
"""Fetch recent thread messages to provide context when the bot is
mentioned mid-thread for the first time.
This method is only called when there is NO active session for the
thread (guarded at the call site by _has_active_session_for_thread).
That guard ensures thread messages are prepended only on the very
first turn after that the session history already holds them, so
there is no duplication across subsequent turns.
Results are cached for _THREAD_CACHE_TTL seconds per thread to avoid
hammering conversations.replies (Tier 3, ~50 req/min).
Returns a formatted string with prior thread history, or empty string
on failure or if the thread has no prior messages.
Returns a formatted string with thread history, or empty string on
failure or if the thread is empty (just the parent message).
"""
cache_key = f"{channel_id}:{thread_ts}"
now = time.monotonic()
cached = self._thread_context_cache.get(cache_key)
if cached and (now - cached.fetched_at) < self._THREAD_CACHE_TTL:
return cached.content
try:
client = self._get_client(channel_id)
# Retry with exponential backoff for Tier-3 rate limits (429).
result = None
for attempt in range(3):
try:
result = await client.conversations_replies(
channel=channel_id,
ts=thread_ts,
limit=limit + 1, # +1 because it includes the current message
inclusive=True,
)
break
except Exception as exc:
# Check for rate-limit error from slack_sdk
err_str = str(exc).lower()
is_rate_limit = (
"ratelimited" in err_str
or "429" in err_str
or "rate_limited" in err_str
)
if is_rate_limit and attempt < 2:
retry_after = 1.0 * (2 ** attempt) # 1s, 2s
logger.warning(
"[Slack] conversations.replies rate limited; retrying in %.1fs (attempt %d/3)",
retry_after, attempt + 1,
)
await asyncio.sleep(retry_after)
continue
raise
if result is None:
return ""
result = await client.conversations_replies(
channel=channel_id,
ts=thread_ts,
limit=limit + 1, # +1 because it includes the current message
inclusive=True,
)
messages = result.get("messages", [])
if not messages:
return ""
bot_uid = self._team_bot_user_ids.get(team_id, self._bot_user_id)
context_parts = []
for msg in messages:
msg_ts = msg.get("ts", "")
# Exclude the current triggering message — it will be delivered
# as the user message itself, so including it here would duplicate it.
# Skip the current message (the one that triggered this fetch)
if msg_ts == current_ts:
continue
# Exclude our own bot messages to avoid circular context.
# Skip bot messages from ourselves
if msg.get("bot_id") or msg.get("subtype") == "bot_message":
continue
msg_user = msg.get("user", "unknown")
msg_text = msg.get("text", "").strip()
if not msg_text:
continue
# Strip bot mentions from context messages
bot_uid = self._team_bot_user_ids.get(team_id, self._bot_user_id)
if bot_uid:
msg_text = msg_text.replace(f"<@{bot_uid}>", "").strip()
msg_user = msg.get("user", "unknown")
# Mark the thread parent
is_parent = msg_ts == thread_ts
prefix = "[thread parent] " if is_parent else ""
# Resolve user name (cached)
name = await self._resolve_user_name(msg_user, chat_id=channel_id)
context_parts.append(f"{prefix}{name}: {msg_text}")
content = ""
if context_parts:
content = (
"[Thread context — prior messages in this thread (not yet in conversation history):]\n"
+ "\n".join(context_parts)
+ "\n[End of thread context]\n\n"
)
if not context_parts:
return ""
self._thread_context_cache[cache_key] = _ThreadContextCache(
content=content,
fetched_at=now,
message_count=len(context_parts),
return (
"[Thread context — previous messages in this thread:]\n"
+ "\n".join(context_parts)
+ "\n[End of thread context]\n\n"
)
return content
except Exception as e:
logger.warning("[Slack] Failed to fetch thread context: %s", e)
return ""
@@ -1642,30 +1515,3 @@ class SlackAdapter(BasePlatformAdapter):
continue
raise
raise last_exc
# ── Channel mention gating ─────────────────────────────────────────────
def _slack_require_mention(self) -> bool:
"""Return whether channel messages require an explicit bot mention.
Uses explicit-false parsing (like Discord/Matrix) rather than
truthy parsing, since the safe default is True (gating on).
Unrecognised or empty values keep gating enabled.
"""
configured = self.config.extra.get("require_mention")
if configured is not None:
if isinstance(configured, str):
return configured.lower() not in ("false", "0", "no", "off")
return bool(configured)
return os.getenv("SLACK_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no", "off")
def _slack_free_response_channels(self) -> set:
"""Return channel IDs where no @mention is required."""
raw = self.config.extra.get("free_response_channels")
if raw is None:
raw = os.getenv("SLACK_FREE_RESPONSE_CHANNELS", "")
if isinstance(raw, list):
return {str(part).strip() for part in raw if str(part).strip()}
if isinstance(raw, str) and raw.strip():
return {part.strip() for part in raw.split(",") if part.strip()}
return set()
-9
View File
@@ -1398,15 +1398,6 @@ class TelegramAdapter(BasePlatformAdapter):
await query.answer(text="Invalid approval data.")
return
# Only authorized users may click approval buttons.
caller_id = str(getattr(query.from_user, "id", ""))
allowed_csv = os.getenv("TELEGRAM_ALLOWED_USERS", "").strip()
if allowed_csv:
allowed_ids = {uid.strip() for uid in allowed_csv.split(",") if uid.strip()}
if "*" not in allowed_ids and caller_id not in allowed_ids:
await query.answer(text="⛔ You are not authorized to approve commands.")
return
session_key = self._approval_state.pop(approval_id, None)
if not session_key:
await query.answer(text="This approval has already been resolved.")
+5 -3
View File
@@ -45,9 +45,11 @@ _SEED_FALLBACK_IPS: list[str] = ["149.154.167.220"]
def _resolve_proxy_url() -> str | None:
# Delegate to shared implementation (env vars + macOS system proxy detection)
from gateway.platforms.base import resolve_proxy_url
return resolve_proxy_url()
for key in ("HTTPS_PROXY", "HTTP_PROXY", "ALL_PROXY", "https_proxy", "http_proxy", "all_proxy"):
value = (os.environ.get(key) or "").strip()
if value:
return value
return None
class TelegramFallbackTransport(httpx.AsyncBaseTransport):
+25 -22
View File
@@ -514,6 +514,12 @@ class GatewayRunner:
self._agent_cache: Dict[str, tuple] = {}
self._agent_cache_lock = _threading.Lock()
# Track active fallback model/provider when primary is rate-limited.
# Set after an agent run where fallback was activated; cleared when
# the primary model succeeds again or the user switches via /model.
self._effective_model: Optional[str] = None
self._effective_provider: Optional[str] = None
# Per-session model overrides from /model command.
# Key: session_key, Value: dict with model/provider/api_key/base_url/api_mode
self._session_model_overrides: Dict[str, Dict[str, str]] = {}
@@ -919,8 +925,8 @@ class GatewayRunner:
def _load_reasoning_config() -> dict | None:
"""Load reasoning effort from config.yaml.
Reads agent.reasoning_effort from config.yaml. Valid: "none",
"minimal", "low", "medium", "high", "xhigh". Returns None to use
Reads agent.reasoning_effort from config.yaml. Valid: "xhigh",
"high", "medium", "low", "minimal", "none". Returns None to use
default (medium).
"""
from hermes_constants import parse_reasoning_effort
@@ -4834,7 +4840,7 @@ class GatewayRunner:
Usage:
/reasoning Show current effort level and display state
/reasoning <level> Set reasoning effort (none, minimal, low, medium, high, xhigh)
/reasoning <level> Set reasoning effort (none, low, medium, high, xhigh)
/reasoning show|on Show model reasoning in responses
/reasoning hide|off Hide model reasoning from responses
"""
@@ -4879,7 +4885,7 @@ class GatewayRunner:
"🧠 **Reasoning Settings**\n\n"
f"**Effort:** `{level}`\n"
f"**Display:** {display_state}\n\n"
"_Usage:_ `/reasoning <none|minimal|low|medium|high|xhigh|show|hide>`"
"_Usage:_ `/reasoning <none|low|medium|high|xhigh|show|hide>`"
)
# Display toggle
@@ -4897,12 +4903,12 @@ class GatewayRunner:
effort = args.strip()
if effort == "none":
parsed = {"enabled": False}
elif effort in ("minimal", "low", "medium", "high", "xhigh"):
elif effort in ("xhigh", "high", "medium", "low", "minimal"):
parsed = {"enabled": True, "effort": effort}
else:
return (
f"⚠️ Unknown argument: `{effort}`\n\n"
"**Valid levels:** none, minimal, low, medium, high, xhigh\n"
"**Valid levels:** none, low, minimal, medium, high, xhigh\n"
"**Display:** show, hide"
)
@@ -5274,28 +5280,19 @@ class GatewayRunner:
agent = self._running_agents.get(session_key)
if agent and hasattr(agent, "session_total_tokens") and agent.session_api_calls > 0:
lines = []
# Rate limits first (when available from provider headers)
rl_state = agent.get_rate_limit_state()
if rl_state and rl_state.has_data:
from agent.rate_limit_tracker import format_rate_limit_compact
lines.append(f"⏱️ **Rate Limits:** {format_rate_limit_compact(rl_state)}")
lines.append("")
# Session token usage
lines.append("📊 **Session Token Usage**")
lines.append(f"Prompt (input): {agent.session_prompt_tokens:,}")
lines.append(f"Completion (output): {agent.session_completion_tokens:,}")
lines.append(f"Total: {agent.session_total_tokens:,}")
lines.append(f"API calls: {agent.session_api_calls}")
lines = [
"📊 **Session Token Usage**",
f"Prompt (input): {agent.session_prompt_tokens:,}",
f"Completion (output): {agent.session_completion_tokens:,}",
f"Total: {agent.session_total_tokens:,}",
f"API calls: {agent.session_api_calls}",
]
ctx = agent.context_compressor
if ctx.last_prompt_tokens:
pct = min(100, ctx.last_prompt_tokens / ctx.context_length * 100) if ctx.context_length else 0
lines.append(f"Context: {ctx.last_prompt_tokens:,} / {ctx.context_length:,} ({pct:.0f}%)")
if ctx.compression_count:
lines.append(f"Compressions: {ctx.compression_count}")
return "\n".join(lines)
# No running agent -- check session history for a rough count
@@ -7274,9 +7271,15 @@ class GatewayRunner:
if _agent is not None and hasattr(_agent, 'model'):
_cfg_model = _resolve_gateway_model()
if _agent.model != _cfg_model:
self._effective_model = _agent.model
self._effective_provider = getattr(_agent, 'provider', None)
# Fallback activated — evict cached agent so the next
# message starts fresh and retries the primary model.
self._evict_cached_agent(session_key)
else:
# Primary model worked — clear any stale fallback state
self._effective_model = None
self._effective_provider = None
# Check if we were interrupted OR have a queued message (/queue).
result = result_holder[0]
+18 -1
View File
@@ -32,6 +32,9 @@ def _now() -> datetime:
# PII redaction helpers
# ---------------------------------------------------------------------------
_PHONE_RE = re.compile(r"^\+?\d[\d\-\s]{6,}$")
def _hash_id(value: str) -> str:
"""Deterministic 12-char hex hash of an identifier."""
return hashlib.sha256(value.encode("utf-8")).hexdigest()[:12]
@@ -55,6 +58,10 @@ def _hash_chat_id(value: str) -> str:
return _hash_id(value)
def _looks_like_phone(value: str) -> bool:
"""Return True if *value* looks like a phone number (E.164 or similar)."""
return bool(_PHONE_RE.match(value.strip()))
from .config import (
Platform,
GatewayConfig,
@@ -137,6 +144,15 @@ class SessionSource:
chat_id_alt=data.get("chat_id_alt"),
)
@classmethod
def local_cli(cls) -> "SessionSource":
"""Create a source representing the local CLI."""
return cls(
platform=Platform.LOCAL,
chat_id="cli",
chat_name="CLI terminal",
chat_type="dm",
)
@dataclass
@@ -494,7 +510,8 @@ class SessionStore:
"""
def __init__(self, sessions_dir: Path, config: GatewayConfig,
has_active_processes_fn=None):
has_active_processes_fn=None,
on_auto_reset=None):
self.sessions_dir = sessions_dir
self.config = config
self._entries: Dict[str, SessionEntry] = {}
+1 -56
View File
@@ -136,34 +136,7 @@ class GatewayStreamConsumer:
if should_edit and self._accumulated:
# Split overflow: if accumulated text exceeds the platform
# limit, split into properly sized chunks.
if (
len(self._accumulated) > _safe_limit
and self._message_id is None
):
# No existing message to edit (first message or after a
# segment break). Use truncate_message — the same
# helper the non-streaming path uses — to split with
# proper word/code-fence boundaries and chunk
# indicators like "(1/2)".
chunks = self.adapter.truncate_message(
self._accumulated, _safe_limit
)
for chunk in chunks:
await self._send_new_chunk(chunk, self._message_id)
self._accumulated = ""
self._last_sent_text = ""
self._last_edit_time = time.monotonic()
if got_done:
return
if got_segment_break:
self._message_id = None
self._fallback_final_send = False
self._fallback_prefix = ""
continue
# Existing message: edit it with the first chunk, then
# start a new message for the overflow remainder.
# limit, finalize the current message and start a new one.
while (
len(self._accumulated) > _safe_limit
and self._message_id is not None
@@ -253,34 +226,6 @@ class GatewayStreamConsumer:
# Strip trailing whitespace/newlines but preserve leading content
return cleaned.rstrip()
async def _send_new_chunk(self, text: str, reply_to_id: Optional[str]) -> Optional[str]:
"""Send a new message chunk, optionally threaded to a previous message.
Returns the message_id so callers can thread subsequent chunks.
"""
text = self._clean_for_display(text)
if not text.strip():
return reply_to_id
try:
meta = dict(self.metadata) if self.metadata else {}
result = await self.adapter.send(
chat_id=self.chat_id,
content=text,
reply_to=reply_to_id,
metadata=meta,
)
if result.success and result.message_id:
self._message_id = str(result.message_id)
self._already_sent = True
self._last_sent_text = text
return str(result.message_id)
else:
self._edit_supported = False
return reply_to_id
except Exception as e:
logger.error("Stream send chunk error: %s", e)
return reply_to_id
def _visible_prefix(self) -> str:
"""Return the visible text already shown in the streamed message."""
prefix = self._last_sent_text or ""
+34 -15
View File
@@ -70,6 +70,7 @@ DEFAULT_CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex"
DEFAULT_QWEN_BASE_URL = "https://portal.qwen.ai/v1"
DEFAULT_GITHUB_MODELS_BASE_URL = "https://api.githubcopilot.com"
DEFAULT_COPILOT_ACP_BASE_URL = "acp://copilot"
DEFAULT_GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai"
CODEX_OAUTH_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
CODEX_OAUTH_TOKEN_URL = "https://auth.openai.com/oauth/token"
CODEX_ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120
@@ -249,7 +250,7 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
# Kimi Code Endpoint Detection
# =============================================================================
# Kimi Code (kimi.com/code) issues keys prefixed "sk-kimi-" that only work
# Kimi Code (platform.kimi.ai) issues keys prefixed "sk-kimi-" that only work
# on api.kimi.com/coding/v1. Legacy keys from platform.moonshot.ai work on
# api.moonshot.ai/v1 (the default). Auto-detect when user hasn't set
# KIMI_BASE_URL explicitly.
@@ -2341,6 +2342,33 @@ def resolve_external_process_provider_credentials(provider_id: str) -> Dict[str,
}
# =============================================================================
# External credential detection
# =============================================================================
def detect_external_credentials() -> List[Dict[str, Any]]:
"""Scan for credentials from other CLI tools that Hermes can reuse.
Returns a list of dicts, each with:
- provider: str -- Hermes provider id (e.g. "openai-codex")
- path: str -- filesystem path where creds were found
- label: str -- human-friendly description for the setup UI
"""
found: List[Dict[str, Any]] = []
# Codex CLI: ~/.codex/auth.json (importable, not shared)
cli_tokens = _import_codex_cli_tokens()
if cli_tokens:
codex_path = Path.home() / ".codex" / "auth.json"
found.append({
"provider": "openai-codex",
"path": str(codex_path),
"label": f"Codex CLI credentials found ({codex_path}) — run `hermes auth` to create a separate session",
})
return found
# =============================================================================
# CLI Commands — login / logout
# =============================================================================
@@ -2989,15 +3017,12 @@ def _login_nous(args, pconfig: ProviderConfig) -> None:
_save_provider_state(auth_store, "nous", auth_state)
saved_to = _save_auth_store(auth_store)
config_path = _update_config_for_provider("nous", inference_base_url)
print()
print("Login successful!")
print(f" Auth state: {saved_to}")
print(f" Config updated: {config_path} (model.provider=nous)")
# Resolve model BEFORE writing provider to config.yaml so we never
# leave the config in a half-updated state (provider=nous but model
# still set to the previous provider's model, e.g. opus from
# OpenRouter). The auth.json active_provider was already set above.
selected_model = None
try:
runtime_key = auth_state.get("agent_key") or auth_state.get("access_token")
if not isinstance(runtime_key, str) or not runtime_key:
@@ -3031,6 +3056,9 @@ def _login_nous(args, pconfig: ProviderConfig) -> None:
unavailable_models=unavailable_models,
portal_url=_portal,
)
if selected_model:
_save_model_choice(selected_model)
print(f"Default model set to: {selected_model}")
elif unavailable_models:
_url = (_portal or DEFAULT_NOUS_PORTAL_URL).rstrip("/")
print("No free models currently available.")
@@ -3042,15 +3070,6 @@ def _login_nous(args, pconfig: ProviderConfig) -> None:
print()
print(f"Login succeeded, but could not fetch available models. Reason: {message}")
# Write provider + model atomically so config is never mismatched.
config_path = _update_config_for_provider(
"nous", inference_base_url, default_model=selected_model,
)
if selected_model:
_save_model_choice(selected_model)
print(f"Default model set to: {selected_model}")
print(f" Config updated: {config_path} (model.provider=nous)")
except KeyboardInterrupt:
print("\nLogin cancelled.")
raise SystemExit(130)
+6
View File
@@ -90,6 +90,12 @@ HERMES_CADUCEUS = """[#CD7F32]⠀⢀⣀⡀⠀⣀⣀
[#B8860B]⠀⠈⠳⠈⣡⠞⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[/]
[#B8860B]⠀⠈⠀[/]"""
COMPACT_BANNER = """
[bold #FFD700]╔══════════════════════════════════════════════════════════════╗[/]
[bold #FFD700]║[/] [#FFBF00]⚕ NOUS HERMES[/] [dim #B8860B]- AI Agent Framework[/] [bold #FFD700]║[/]
[bold #FFD700]║[/] [#CD7F32]Messenger of the Digital Gods[/] [dim #B8860B]Nous Research[/] [bold #FFD700]║[/]
[bold #FFD700]╚══════════════════════════════════════════════════════════════╝[/]
"""
# =========================================================================
+140
View File
@@ -0,0 +1,140 @@
"""Shared curses-based multi-select checklist for Hermes CLI.
Used by both ``hermes tools`` and ``hermes skills`` to present a
toggleable list of items. Falls back to a numbered text UI when
curses is unavailable (Windows without curses, piped stdin, etc.).
"""
import sys
from typing import List, Set
from hermes_cli.colors import Colors, color
def curses_checklist(
title: str,
items: List[str],
pre_selected: Set[int],
) -> Set[int]:
"""Multi-select checklist. Returns set of **selected** indices.
Args:
title: Header text shown at the top of the checklist.
items: Display labels for each row.
pre_selected: Indices that start checked.
Returns:
The indices the user confirmed as checked. On cancel (ESC/q),
returns ``pre_selected`` unchanged.
"""
# Safety: return defaults when stdin is not a terminal.
if not sys.stdin.isatty():
return set(pre_selected)
try:
import curses
selected = set(pre_selected)
result = [None]
def _ui(stdscr):
curses.curs_set(0)
if curses.has_colors():
curses.start_color()
curses.use_default_colors()
curses.init_pair(1, curses.COLOR_GREEN, -1)
curses.init_pair(2, curses.COLOR_YELLOW, -1)
curses.init_pair(3, 8, -1) # dim gray
cursor = 0
scroll_offset = 0
while True:
stdscr.clear()
max_y, max_x = stdscr.getmaxyx()
# Header
try:
hattr = curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0)
stdscr.addnstr(0, 0, title, max_x - 1, hattr)
stdscr.addnstr(
1, 0,
" ↑↓ navigate SPACE toggle ENTER confirm ESC cancel",
max_x - 1, curses.A_DIM,
)
except curses.error:
pass
# Scrollable item list
visible_rows = max_y - 3
if cursor < scroll_offset:
scroll_offset = cursor
elif cursor >= scroll_offset + visible_rows:
scroll_offset = cursor - visible_rows + 1
for draw_i, i in enumerate(
range(scroll_offset, min(len(items), scroll_offset + visible_rows))
):
y = draw_i + 3
if y >= max_y - 1:
break
check = "" if i in selected else " "
arrow = "" if i == cursor else " "
line = f" {arrow} [{check}] {items[i]}"
attr = curses.A_NORMAL
if i == cursor:
attr = curses.A_BOLD
if curses.has_colors():
attr |= curses.color_pair(1)
try:
stdscr.addnstr(y, 0, line, max_x - 1, attr)
except curses.error:
pass
stdscr.refresh()
key = stdscr.getch()
if key in (curses.KEY_UP, ord("k")):
cursor = (cursor - 1) % len(items)
elif key in (curses.KEY_DOWN, ord("j")):
cursor = (cursor + 1) % len(items)
elif key == ord(" "):
selected.symmetric_difference_update({cursor})
elif key in (curses.KEY_ENTER, 10, 13):
result[0] = set(selected)
return
elif key in (27, ord("q")):
result[0] = set(pre_selected)
return
curses.wrapper(_ui)
return result[0] if result[0] is not None else set(pre_selected)
except Exception:
pass # fall through to numbered fallback
# ── Numbered text fallback ────────────────────────────────────────────
selected = set(pre_selected)
print(color(f"\n {title}", Colors.YELLOW))
print(color(" Toggle by number, Enter to confirm.\n", Colors.DIM))
while True:
for i, label in enumerate(items):
check = "" if i in selected else " "
print(f" {i + 1:3}. [{check}] {label}")
print()
try:
raw = input(color(" Number to toggle, 's' to save, 'q' to cancel: ", Colors.DIM)).strip()
except (KeyboardInterrupt, EOFError):
return set(pre_selected)
if raw.lower() == "s" or raw == "":
return selected
if raw.lower() == "q":
return set(pre_selected)
try:
idx = int(raw) - 1
if 0 <= idx < len(items):
selected.symmetric_difference_update({idx})
except ValueError:
print(color(" Invalid input", Colors.DIM))
+10 -3
View File
@@ -87,7 +87,8 @@ COMMAND_REGISTRY: list[CommandDef] = [
CommandDef("model", "Switch model for this session", "Configuration", args_hint="[model] [--global]"),
CommandDef("provider", "Show available providers and current provider",
"Configuration"),
CommandDef("prompt", "View/set custom system prompt", "Configuration",
cli_only=True, args_hint="[text]", subcommands=("clear",)),
CommandDef("personality", "Set a predefined personality", "Configuration",
args_hint="[name]"),
CommandDef("statusbar", "Toggle the context/model status bar", "Configuration",
@@ -99,7 +100,7 @@ COMMAND_REGISTRY: list[CommandDef] = [
"Configuration"),
CommandDef("reasoning", "Manage reasoning effort and display", "Configuration",
args_hint="[level|show|hide]",
subcommands=("none", "minimal", "low", "medium", "high", "xhigh", "show", "hide", "on", "off")),
subcommands=("none", "low", "minimal", "medium", "high", "xhigh", "show", "hide", "on", "off")),
CommandDef("skin", "Show or change the display skin/theme", "Configuration",
cli_only=True, args_hint="[name]"),
CommandDef("voice", "Toggle voice mode", "Configuration",
@@ -128,7 +129,7 @@ COMMAND_REGISTRY: list[CommandDef] = [
CommandDef("commands", "Browse all commands and skills (paginated)", "Info",
gateway_only=True, args_hint="[page]"),
CommandDef("help", "Show available commands", "Info"),
CommandDef("usage", "Show token usage and rate limits for the current session", "Info"),
CommandDef("usage", "Show token usage for the current session", "Info"),
CommandDef("insights", "Show usage insights and analytics", "Info",
args_hint="[days]"),
CommandDef("platforms", "Show gateway/messaging platform status", "Info",
@@ -169,6 +170,12 @@ def resolve_command(name: str) -> CommandDef | None:
return _COMMAND_LOOKUP.get(name.lower().lstrip("/"))
def register_plugin_command(cmd: CommandDef) -> None:
"""Append a plugin-defined command to the registry and refresh lookups."""
COMMAND_REGISTRY.append(cmd)
rebuild_lookups()
def rebuild_lookups() -> None:
"""Rebuild all derived lookup dicts from the current COMMAND_REGISTRY.
+7 -52
View File
@@ -197,44 +197,14 @@ def _ensure_default_soul_md(home: Path) -> None:
def ensure_hermes_home():
"""Ensure ~/.hermes directory structure exists with secure permissions.
In managed mode (NixOS), dirs are created by the activation script with
setgid + group-writable (2770). We skip mkdir and set umask(0o007) so
any files created (e.g. SOUL.md) are group-writable (0660).
"""
"""Ensure ~/.hermes directory structure exists with secure permissions."""
home = get_hermes_home()
if is_managed():
old_umask = os.umask(0o007)
try:
_ensure_hermes_home_managed(home)
finally:
os.umask(old_umask)
else:
home.mkdir(parents=True, exist_ok=True)
_secure_dir(home)
for subdir in ("cron", "sessions", "logs", "memories"):
d = home / subdir
d.mkdir(parents=True, exist_ok=True)
_secure_dir(d)
_ensure_default_soul_md(home)
def _ensure_hermes_home_managed(home: Path):
"""Managed-mode variant: verify dirs exist (activation creates them), seed SOUL.md."""
if not home.is_dir():
raise RuntimeError(
f"HERMES_HOME {home} does not exist. "
"Run 'sudo nixos-rebuild switch' first."
)
home.mkdir(parents=True, exist_ok=True)
_secure_dir(home)
for subdir in ("cron", "sessions", "logs", "memories"):
d = home / subdir
if not d.is_dir():
raise RuntimeError(
f"{d} does not exist. "
"Run 'sudo nixos-rebuild switch' first."
)
# Inside umask(0o007) scope — SOUL.md will be created as 0660
d.mkdir(parents=True, exist_ok=True)
_secure_dir(d)
_ensure_default_soul_md(home)
@@ -599,7 +569,7 @@ DEFAULT_CONFIG = {
},
# Config schema version - bump this when adding new required fields
"_config_version": 13,
"_config_version": 12,
}
# =============================================================================
@@ -1247,7 +1217,7 @@ OPTIONAL_ENV_VARS = {
"category": "setting",
},
"SUDO_PASSWORD": {
"description": "Sudo password for terminal commands requiring root access; set to an explicit empty string to try empty without prompting",
"description": "Sudo password for terminal commands requiring root access",
"prompt": "Sudo password",
"url": None,
"password": True,
@@ -1731,21 +1701,6 @@ def migrate_config(interactive: bool = True, quiet: bool = False) -> Dict[str, A
ep = providers_dict[key]
print(f"{key}: {ep.get('api', '')}")
# ── Version 12 → 13: clear dead LLM_MODEL / OPENAI_MODEL from .env ──
# These env vars were written by the old setup wizard but nothing reads
# them anymore (config.yaml is the sole source of truth since March 2026).
# Stale entries cause user confusion — see issue report.
if current_ver < 13:
for dead_var in ("LLM_MODEL", "OPENAI_MODEL"):
try:
old_val = get_env_value(dead_var)
if old_val:
save_env_value(dead_var, "")
if not quiet:
print(f" ✓ Cleared {dead_var} from .env (no longer used — config.yaml is source of truth)")
except Exception:
pass
if current_ver < latest_ver and not quiet:
print(f"Config version: {current_ver}{latest_ver}")
+12
View File
@@ -31,6 +31,13 @@ logger = logging.getLogger(__name__)
# OAuth device code flow constants (same client ID as opencode/Copilot CLI)
COPILOT_OAUTH_CLIENT_ID = "Ov23li8tweQw6odWQebz"
COPILOT_DEVICE_CODE_URL = "https://github.com/login/device/code"
COPILOT_ACCESS_TOKEN_URL = "https://github.com/login/oauth/access_token"
# Copilot API constants
COPILOT_TOKEN_EXCHANGE_URL = "https://api.github.com/copilot_internal/v2/token"
COPILOT_API_BASE_URL = "https://api.githubcopilot.com"
# Token type prefixes
_CLASSIC_PAT_PREFIX = "ghp_"
_SUPPORTED_PREFIXES = ("gho_", "github_pat_", "ghu_")
@@ -43,6 +50,11 @@ _DEVICE_CODE_POLL_INTERVAL = 5 # seconds
_DEVICE_CODE_POLL_SAFETY_MARGIN = 3 # seconds
def is_classic_pat(token: str) -> bool:
"""Check if a token is a classic PAT (ghp_*), which Copilot doesn't support."""
return token.strip().startswith(_CLASSIC_PAT_PREFIX)
def validate_copilot_token(token: str) -> tuple[bool, str]:
"""Validate that a token is usable with the Copilot API.
-332
View File
@@ -1,332 +0,0 @@
"""
Dump command for hermes CLI.
Outputs a compact, plain-text summary of the user's Hermes setup
that can be copy-pasted into Discord/GitHub/Telegram for support context.
No ANSI colors, no checkmarks just data.
"""
import json
import os
import platform
import subprocess
import sys
from pathlib import Path
from hermes_cli.config import get_hermes_home, get_env_path, get_project_root, load_config
from hermes_constants import display_hermes_home
def _get_git_commit(project_root: Path) -> str:
"""Return short git commit hash, or '(unknown)'."""
try:
result = subprocess.run(
["git", "rev-parse", "--short=8", "HEAD"],
capture_output=True, text=True, timeout=5,
cwd=str(project_root),
)
if result.returncode == 0:
return result.stdout.strip()
except Exception:
pass
return "(unknown)"
def _redact(value: str) -> str:
"""Redact all but first 4 and last 4 chars."""
if not value:
return ""
if len(value) < 12:
return "***"
return value[:4] + "..." + value[-4:]
def _gateway_status() -> str:
"""Return a short gateway status string."""
if sys.platform.startswith("linux"):
try:
from hermes_cli.gateway import get_service_name
svc = get_service_name()
except Exception:
svc = "hermes-gateway"
try:
r = subprocess.run(
["systemctl", "--user", "is-active", svc],
capture_output=True, text=True, timeout=5,
)
return "running (systemd)" if r.stdout.strip() == "active" else "stopped"
except Exception:
return "unknown"
elif sys.platform == "darwin":
try:
from hermes_cli.gateway import get_launchd_label
r = subprocess.run(
["launchctl", "list", get_launchd_label()],
capture_output=True, text=True, timeout=5,
)
return "loaded (launchd)" if r.returncode == 0 else "not loaded"
except Exception:
return "unknown"
return "N/A"
def _count_skills(hermes_home: Path) -> int:
"""Count installed skills."""
skills_dir = hermes_home / "skills"
if not skills_dir.is_dir():
return 0
count = 0
for item in skills_dir.rglob("SKILL.md"):
count += 1
return count
def _count_mcp_servers(config: dict) -> int:
"""Count configured MCP servers."""
mcp = config.get("mcp", {})
servers = mcp.get("servers", {})
return len(servers)
def _cron_summary(hermes_home: Path) -> str:
"""Return cron jobs summary."""
jobs_file = hermes_home / "cron" / "jobs.json"
if not jobs_file.exists():
return "0"
try:
with open(jobs_file, encoding="utf-8") as f:
data = json.load(f)
jobs = data.get("jobs", [])
active = sum(1 for j in jobs if j.get("enabled", True))
return f"{active} active / {len(jobs)} total"
except Exception:
return "(error reading)"
def _configured_platforms() -> list[str]:
"""Return list of configured messaging platform names."""
checks = {
"telegram": "TELEGRAM_BOT_TOKEN",
"discord": "DISCORD_BOT_TOKEN",
"slack": "SLACK_BOT_TOKEN",
"whatsapp": "WHATSAPP_ENABLED",
"signal": "SIGNAL_HTTP_URL",
"email": "EMAIL_ADDRESS",
"sms": "TWILIO_ACCOUNT_SID",
"matrix": "MATRIX_HOMESERVER_URL",
"mattermost": "MATTERMOST_URL",
"homeassistant": "HASS_TOKEN",
"dingtalk": "DINGTALK_CLIENT_ID",
"feishu": "FEISHU_APP_ID",
"wecom": "WECOM_BOT_ID",
}
return [name for name, env in checks.items() if os.getenv(env)]
def _memory_provider(config: dict) -> str:
"""Return the active memory provider name."""
mem = config.get("memory", {})
provider = mem.get("provider", "")
return provider if provider else "built-in"
def _get_model_and_provider(config: dict) -> tuple[str, str]:
"""Extract model and provider from config."""
model_cfg = config.get("model", "")
if isinstance(model_cfg, dict):
model = model_cfg.get("default") or model_cfg.get("model") or model_cfg.get("name") or "(not set)"
provider = model_cfg.get("provider") or "(auto)"
elif isinstance(model_cfg, str):
model = model_cfg or "(not set)"
provider = "(auto)"
else:
model = "(not set)"
provider = "(auto)"
return model, provider
def _config_overrides(config: dict) -> dict[str, str]:
"""Find non-default config values worth reporting.
Returns a flat dict of dotpath -> value for interesting overrides.
"""
from hermes_cli.config import DEFAULT_CONFIG
overrides = {}
# Sections with interesting user-facing overrides
interesting_paths = [
("agent", "max_turns"),
("agent", "gateway_timeout"),
("agent", "tool_use_enforcement"),
("terminal", "backend"),
("terminal", "docker_image"),
("terminal", "persistent_shell"),
("browser", "allow_private_urls"),
("compression", "enabled"),
("compression", "threshold"),
("display", "streaming"),
("display", "skin"),
("display", "show_reasoning"),
("smart_model_routing", "enabled"),
("privacy", "redact_pii"),
("tts", "provider"),
]
for section, key in interesting_paths:
default_section = DEFAULT_CONFIG.get(section, {})
user_section = config.get(section, {})
if not isinstance(default_section, dict) or not isinstance(user_section, dict):
continue
default_val = default_section.get(key)
user_val = user_section.get(key)
if user_val is not None and user_val != default_val:
overrides[f"{section}.{key}"] = str(user_val)
# Toolsets (if different from default)
default_toolsets = DEFAULT_CONFIG.get("toolsets", [])
user_toolsets = config.get("toolsets", [])
if user_toolsets != default_toolsets:
overrides["toolsets"] = str(user_toolsets)
# Fallback providers
fallbacks = config.get("fallback_providers", [])
if fallbacks:
overrides["fallback_providers"] = str(fallbacks)
return overrides
def run_dump(args):
"""Output a compact, copy-pasteable setup summary."""
show_keys = getattr(args, "show_keys", False)
# Load env from .env file so key checks work
from dotenv import load_dotenv
env_path = get_env_path()
if env_path.exists():
try:
load_dotenv(env_path, encoding="utf-8")
except UnicodeDecodeError:
load_dotenv(env_path, encoding="latin-1")
# Also try project .env as dev fallback
load_dotenv(get_project_root() / ".env", override=False, encoding="utf-8")
project_root = get_project_root()
hermes_home = get_hermes_home()
try:
from hermes_cli import __version__, __release_date__
except ImportError:
__version__ = "(unknown)"
__release_date__ = ""
commit = _get_git_commit(project_root)
try:
config = load_config()
except Exception:
config = {}
model, provider = _get_model_and_provider(config)
# Profile
try:
from hermes_cli.profiles import get_active_profile_name
profile = get_active_profile_name() or "(default)"
except Exception:
profile = "(default)"
# Terminal backend
terminal_cfg = config.get("terminal", {})
backend = terminal_cfg.get("backend", "local")
# OpenAI SDK version
try:
import openai
openai_ver = openai.__version__
except ImportError:
openai_ver = "not installed"
# OS info
os_info = f"{platform.system()} {platform.release()} {platform.machine()}"
lines = []
lines.append("--- hermes dump ---")
ver_str = f"{__version__}"
if __release_date__:
ver_str += f" ({__release_date__})"
ver_str += f" [{commit}]"
lines.append(f"version: {ver_str}")
lines.append(f"os: {os_info}")
lines.append(f"python: {sys.version.split()[0]}")
lines.append(f"openai_sdk: {openai_ver}")
lines.append(f"profile: {profile}")
lines.append(f"hermes_home: {display_hermes_home()}")
lines.append(f"model: {model}")
lines.append(f"provider: {provider}")
lines.append(f"terminal: {backend}")
# API keys
lines.append("")
lines.append("api_keys:")
api_keys = [
("OPENROUTER_API_KEY", "openrouter"),
("OPENAI_API_KEY", "openai"),
("ANTHROPIC_API_KEY", "anthropic"),
("ANTHROPIC_TOKEN", "anthropic_token"),
("NOUS_API_KEY", "nous"),
("GLM_API_KEY", "glm/zai"),
("ZAI_API_KEY", "zai"),
("KIMI_API_KEY", "kimi"),
("MINIMAX_API_KEY", "minimax"),
("DEEPSEEK_API_KEY", "deepseek"),
("DASHSCOPE_API_KEY", "dashscope"),
("HF_TOKEN", "huggingface"),
("AI_GATEWAY_API_KEY", "ai_gateway"),
("OPENCODE_ZEN_API_KEY", "opencode_zen"),
("OPENCODE_GO_API_KEY", "opencode_go"),
("KILOCODE_API_KEY", "kilocode"),
("FIRECRAWL_API_KEY", "firecrawl"),
("TAVILY_API_KEY", "tavily"),
("BROWSERBASE_API_KEY", "browserbase"),
("FAL_KEY", "fal"),
("ELEVENLABS_API_KEY", "elevenlabs"),
("GITHUB_TOKEN", "github"),
]
for env_var, label in api_keys:
val = os.getenv(env_var, "")
if show_keys and val:
display = _redact(val)
else:
display = "set" if val else "not set"
lines.append(f" {label:<20} {display}")
# Features summary
lines.append("")
lines.append("features:")
toolsets = config.get("toolsets", ["hermes-cli"])
lines.append(f" toolsets: {', '.join(toolsets) if toolsets else '(default)'}")
lines.append(f" mcp_servers: {_count_mcp_servers(config)}")
lines.append(f" memory_provider: {_memory_provider(config)}")
lines.append(f" gateway: {_gateway_status()}")
platforms = _configured_platforms()
lines.append(f" platforms: {', '.join(platforms) if platforms else 'none'}")
lines.append(f" cron_jobs: {_cron_summary(hermes_home)}")
lines.append(f" skills: {_count_skills(hermes_home)}")
# Config overrides (non-default values)
overrides = _config_overrides(config)
if overrides:
lines.append("")
lines.append("config_overrides:")
for key, val in overrides.items():
lines.append(f" {key}: {val}")
lines.append("--- end dump ---")
output = "\n".join(lines)
print(output)
+13
View File
@@ -308,6 +308,8 @@ def get_service_name() -> str:
return f"{_SERVICE_BASE}-{suffix}"
SERVICE_NAME = _SERVICE_BASE # backward-compat for external importers; prefer get_service_name()
def get_systemd_unit_path(system: bool = False) -> Path:
name = get_service_name()
@@ -579,6 +581,17 @@ def get_python_path() -> str:
return str(venv_python)
return sys.executable
def get_hermes_cli_path() -> str:
"""Get the path to the hermes CLI."""
# Check if installed via pip
import shutil
hermes_bin = shutil.which("hermes")
if hermes_bin:
return hermes_bin
# Fallback to direct module execution
return f"{get_python_path()} -m hermes_cli.main"
# =============================================================================
# Systemd (Linux)
+1 -26
View File
@@ -1811,10 +1811,7 @@ def _set_reasoning_effort(config, effort: str) -> None:
def _prompt_reasoning_effort_selection(efforts, current_effort=""):
"""Prompt for a reasoning effort. Returns effort, 'none', or None to keep current."""
deduped = list(dict.fromkeys(str(effort).strip().lower() for effort in efforts if str(effort).strip()))
canonical_order = ("minimal", "low", "medium", "high", "xhigh")
ordered = [effort for effort in canonical_order if effort in deduped]
ordered.extend(effort for effort in deduped if effort not in canonical_order)
ordered = list(dict.fromkeys(str(effort).strip().lower() for effort in efforts if str(effort).strip()))
if not ordered:
return None
@@ -2646,12 +2643,6 @@ def cmd_doctor(args):
run_doctor(args)
def cmd_dump(args):
"""Dump setup summary for support/debugging."""
from hermes_cli.dump import run_dump
run_dump(args)
def cmd_config(args):
"""Configuration management."""
from hermes_cli.config import config_command
@@ -4733,22 +4724,6 @@ For more help on a command:
help="Attempt to fix issues automatically"
)
doctor_parser.set_defaults(func=cmd_doctor)
# =========================================================================
# dump command
# =========================================================================
dump_parser = subparsers.add_parser(
"dump",
help="Dump setup summary for support/debugging",
description="Output a compact, plain-text summary of your Hermes setup "
"that can be copy-pasted into Discord/GitHub for support context"
)
dump_parser.add_argument(
"--show-keys",
action="store_true",
help="Show redacted API key prefixes (first/last 4 chars) instead of just set/not set"
)
dump_parser.set_defaults(func=cmd_dump)
# =========================================================================
# config command
+28
View File
@@ -332,3 +332,31 @@ def normalize_model_for_provider(model_input: str, target_provider: str) -> str:
# Batch / convenience helpers
# ---------------------------------------------------------------------------
def model_display_name(model_id: str) -> str:
"""Return a short, human-readable display name for a model id.
Strips the vendor prefix (if any) for a cleaner display in menus
and status bars, while preserving dots for readability.
Examples::
>>> model_display_name("anthropic/claude-sonnet-4.6")
'claude-sonnet-4.6'
>>> model_display_name("claude-sonnet-4-6")
'claude-sonnet-4-6'
"""
return _strip_vendor_prefix((model_id or "").strip())
def is_aggregator_provider(provider: str) -> bool:
"""Check if a provider is an aggregator that needs vendor/model format."""
return (provider or "").strip().lower() in _AGGREGATOR_PROVIDERS
def vendor_for_model(model_name: str) -> str:
"""Return the vendor slug for a model, or ``""`` if unknown.
Convenience wrapper around :func:`detect_vendor` that never returns
``None``.
"""
return detect_vendor(model_name) or ""
+74 -11
View File
@@ -733,7 +733,6 @@ def list_authenticated_providers(
fetch_models_dev,
get_provider_info as _mdev_pinfo,
)
from hermes_cli.auth import PROVIDER_REGISTRY
from hermes_cli.models import OPENROUTER_MODELS, _PROVIDER_MODELS
results: List[dict] = []
@@ -754,16 +753,9 @@ def list_authenticated_providers(
if not isinstance(pdata, dict):
continue
# Prefer auth.py PROVIDER_REGISTRY for env var names — it's our
# source of truth. models.dev can have wrong mappings (e.g.
# minimax-cn → MINIMAX_API_KEY instead of MINIMAX_CN_API_KEY).
pconfig = PROVIDER_REGISTRY.get(hermes_id)
if pconfig and pconfig.api_key_env_vars:
env_vars = list(pconfig.api_key_env_vars)
else:
env_vars = pdata.get("env", [])
if not isinstance(env_vars, list):
continue
env_vars = pdata.get("env", [])
if not isinstance(env_vars, list):
continue
# Check if any env var is set
has_creds = any(os.environ.get(ev) for ev in env_vars)
@@ -859,3 +851,74 @@ def list_authenticated_providers(
return results
# ---------------------------------------------------------------------------
# Fuzzy suggestions
# ---------------------------------------------------------------------------
def suggest_models(raw_input: str, limit: int = 3) -> List[str]:
"""Return fuzzy model suggestions for a (possibly misspelled) input."""
query = raw_input.strip()
if not query:
return []
results = search_models_dev(query, limit=limit)
suggestions: list[str] = []
for r in results:
mid = r.get("model_id", "")
if mid:
suggestions.append(mid)
return suggestions[:limit]
# ---------------------------------------------------------------------------
# Custom provider switch
# ---------------------------------------------------------------------------
def switch_to_custom_provider() -> CustomAutoResult:
"""Handle bare '/model --provider custom' — resolve endpoint and auto-detect model."""
from hermes_cli.runtime_provider import (
resolve_runtime_provider,
_auto_detect_local_model,
)
try:
runtime = resolve_runtime_provider(requested="custom")
except Exception as e:
return CustomAutoResult(
success=False,
error_message=f"Could not resolve custom endpoint: {e}",
)
cust_base = runtime.get("base_url", "")
cust_key = runtime.get("api_key", "")
if not cust_base or "openrouter.ai" in cust_base:
return CustomAutoResult(
success=False,
error_message=(
"No custom endpoint configured. "
"Set model.base_url in config.yaml, or set OPENAI_BASE_URL "
"in .env, or run: hermes setup -> Custom OpenAI-compatible endpoint"
),
)
detected_model = _auto_detect_local_model(cust_base)
if not detected_model:
return CustomAutoResult(
success=False,
base_url=cust_base,
api_key=cust_key,
error_message=(
f"Custom endpoint at {cust_base} is reachable but no single "
f"model was auto-detected. Specify the model explicitly: "
f"/model <model-name> --provider custom"
),
)
return CustomAutoResult(
success=True,
model=detected_model,
base_url=cust_base,
api_key=cust_key,
)
+43
View File
@@ -20,6 +20,10 @@ COPILOT_EDITOR_VERSION = "vscode/1.104.1"
COPILOT_REASONING_EFFORTS_GPT5 = ["minimal", "low", "medium", "high"]
COPILOT_REASONING_EFFORTS_O_SERIES = ["low", "medium", "high"]
# Backward-compatible aliases for the earlier GitHub Models-backed Copilot work.
GITHUB_MODELS_BASE_URL = COPILOT_BASE_URL
GITHUB_MODELS_CATALOG_URL = COPILOT_MODELS_URL
# (model_id, display description shown in menus)
OPENROUTER_MODELS: list[tuple[str, str]] = [
("anthropic/claude-opus-4.6", "recommended"),
@@ -412,6 +416,12 @@ _FREE_TIER_CACHE_TTL: int = 180 # seconds (3 minutes)
_free_tier_cache: tuple[bool, float] | None = None # (result, timestamp)
def clear_nous_free_tier_cache() -> None:
"""Invalidate the cached free-tier result (e.g. after login/logout)."""
global _free_tier_cache
_free_tier_cache = None
def check_nous_free_tier() -> bool:
"""Check if the current Nous Portal user is on a free (unpaid) tier.
@@ -525,6 +535,14 @@ def model_ids() -> list[str]:
return [mid for mid, _ in OPENROUTER_MODELS]
def menu_labels() -> list[str]:
"""Return display labels like 'anthropic/claude-opus-4.6 (recommended)'."""
labels = []
for mid, desc in OPENROUTER_MODELS:
labels.append(f"{mid} ({desc})" if desc else mid)
return labels
# ---------------------------------------------------------------------------
# Pricing helpers — fetch live pricing from OpenRouter-compatible /v1/models
# ---------------------------------------------------------------------------
@@ -557,6 +575,31 @@ def _format_price_per_mtok(per_token_str: str) -> str:
return f"${per_m:.2f}"
def format_pricing_label(pricing: dict[str, str] | None) -> str:
"""Build a compact pricing label like 'in $3 · out $15 · cache $0.30/Mtok'.
Returns empty string when pricing is unavailable.
"""
if not pricing:
return ""
prompt_price = pricing.get("prompt", "")
completion_price = pricing.get("completion", "")
if not prompt_price and not completion_price:
return ""
inp = _format_price_per_mtok(prompt_price)
out = _format_price_per_mtok(completion_price)
if inp == "free" and out == "free":
return "free"
cache_read = pricing.get("input_cache_read", "")
cache_str = _format_price_per_mtok(cache_read) if cache_read else ""
if inp == out and not cache_str:
return f"{inp}/Mtok"
parts = [f"in {inp}", f"out {out}"]
if cache_str and cache_str != "?" and cache_str != inp:
parts.append(f"cache {cache_str}")
return " · ".join(parts) + "/Mtok"
def format_model_pricing_table(
models: list[tuple[str, str]],
pricing_map: dict[str, dict[str, str]],
+3 -3
View File
@@ -102,7 +102,7 @@ _RESERVED_NAMES = frozenset({
# Hermes subcommands that cannot be used as profile names/aliases
_HERMES_SUBCOMMANDS = frozenset({
"chat", "model", "gateway", "setup", "whatsapp", "login", "logout",
"status", "cron", "doctor", "dump", "config", "pairing", "skills", "tools",
"status", "cron", "doctor", "config", "pairing", "skills", "tools",
"mcp", "sessions", "insights", "version", "update", "uninstall",
"profile", "plugins", "honcho", "acp",
})
@@ -1007,7 +1007,7 @@ _hermes_completion() {
# Top-level subcommands
if [[ "$COMP_CWORD" == 1 ]]; then
local commands="chat model gateway setup status cron doctor dump config skills tools mcp sessions profile update version"
local commands="chat model gateway setup status cron doctor config skills tools mcp sessions profile update version"
COMPREPLY=($(compgen -W "$commands" -- "$cur"))
fi
}
@@ -1032,7 +1032,7 @@ _hermes() {
_arguments \\
'-p[Profile name]:profile:($profiles)' \\
'--profile[Profile name]:profile:($profiles)' \\
'1:command:(chat model gateway setup status cron doctor dump config skills tools mcp sessions profile update version)' \\
'1:command:(chat model gateway setup status cron doctor config skills tools mcp sessions profile update version)' \\
'*::arg:->args'
case $words[1] in
+41
View File
@@ -148,6 +148,10 @@ class ProviderDef:
doc: str = ""
source: str = "" # "models.dev", "hermes", "user-config"
@property
def is_user_defined(self) -> bool:
return self.source == "user-config"
# -- Aliases ------------------------------------------------------------------
# Maps human-friendly / legacy names to canonical provider IDs.
@@ -258,6 +262,12 @@ def normalize_provider(name: str) -> str:
return ALIASES.get(key, key)
def get_overlay(provider_id: str) -> Optional[HermesOverlay]:
"""Get Hermes overlay for a provider, if one exists."""
canonical = normalize_provider(provider_id)
return HERMES_OVERLAYS.get(canonical)
def get_provider(name: str) -> Optional[ProviderDef]:
"""Look up a provider by id or alias, merging all data sources.
@@ -340,6 +350,37 @@ def get_label(provider_id: str) -> str:
return canonical
# For direct import compat, expose as module-level dict
# Built on demand by get_label() calls
LABELS: Dict[str, str] = {
# Static entries for backward compat — get_label() is the proper API
"openrouter": "OpenRouter",
"nous": "Nous Portal",
"openai-codex": "OpenAI Codex",
"copilot-acp": "GitHub Copilot ACP",
"github-copilot": "GitHub Copilot",
"anthropic": "Anthropic",
"zai": "Z.AI / GLM",
"kimi-for-coding": "Kimi / Moonshot",
"minimax": "MiniMax",
"minimax-cn": "MiniMax (China)",
"deepseek": "DeepSeek",
"alibaba": "Alibaba Cloud (DashScope)",
"vercel": "Vercel AI Gateway",
"opencode": "OpenCode Zen",
"opencode-go": "OpenCode Go",
"kilo": "Kilo Gateway",
"huggingface": "Hugging Face",
"local": "Local endpoint",
"custom": "Custom endpoint",
# Legacy Hermes IDs (point to same providers)
"ai-gateway": "Vercel AI Gateway",
"kilocode": "Kilo Gateway",
"copilot": "GitHub Copilot",
"kimi-coding": "Kimi / Moonshot",
"opencode-zen": "OpenCode Zen",
}
def is_aggregator(provider: str) -> bool:
"""Return True when the provider is a multi-model aggregator."""
+164 -171
View File
@@ -172,6 +172,147 @@ def _setup_copilot_reasoning_selection(
_set_reasoning_effort(config, "none")
def _setup_provider_model_selection(config, provider_id, current_model, prompt_choice, prompt_fn):
"""Model selection for API-key providers with live /models detection.
Tries the provider's /models endpoint first. Falls back to a
hardcoded default list with a warning if the endpoint is unreachable.
Always offers a 'Custom model' escape hatch.
"""
from hermes_cli.auth import PROVIDER_REGISTRY, resolve_api_key_provider_credentials
from hermes_cli.config import get_env_value
from hermes_cli.models import (
copilot_model_api_mode,
fetch_api_models,
fetch_github_model_catalog,
normalize_copilot_model_id,
normalize_opencode_model_id,
opencode_model_api_mode,
)
pconfig = PROVIDER_REGISTRY[provider_id]
is_copilot_catalog_provider = provider_id in {"copilot", "copilot-acp"}
# Resolve API key and base URL for the probe
if is_copilot_catalog_provider:
api_key = ""
if provider_id == "copilot":
creds = resolve_api_key_provider_credentials(provider_id)
api_key = creds.get("api_key", "")
base_url = creds.get("base_url", "") or pconfig.inference_base_url
else:
try:
creds = resolve_api_key_provider_credentials("copilot")
api_key = creds.get("api_key", "")
except Exception:
pass
base_url = pconfig.inference_base_url
catalog = fetch_github_model_catalog(api_key)
current_model = normalize_copilot_model_id(
current_model,
catalog=catalog,
api_key=api_key,
) or current_model
else:
api_key = ""
for ev in pconfig.api_key_env_vars:
api_key = get_env_value(ev) or os.getenv(ev, "")
if api_key:
break
base_url_env = pconfig.base_url_env_var or ""
base_url = (get_env_value(base_url_env) if base_url_env else "") or pconfig.inference_base_url
catalog = None
# Try live /models endpoint
if is_copilot_catalog_provider and catalog:
live_models = [item.get("id", "") for item in catalog if item.get("id")]
else:
live_models = fetch_api_models(api_key, base_url)
if live_models:
provider_models = live_models
print_info(f"Found {len(live_models)} model(s) from {pconfig.name} API")
else:
fallback_provider_id = "copilot" if provider_id == "copilot-acp" else provider_id
provider_models = _DEFAULT_PROVIDER_MODELS.get(fallback_provider_id, [])
if provider_models:
print_warning(
f"Could not auto-detect models from {pconfig.name} API — showing defaults.\n"
f" Use \"Custom model\" if the model you expect isn't listed."
)
if provider_id in {"opencode-zen", "opencode-go"}:
provider_models = [normalize_opencode_model_id(provider_id, mid) for mid in provider_models]
current_model = normalize_opencode_model_id(provider_id, current_model)
provider_models = list(dict.fromkeys(mid for mid in provider_models if mid))
model_choices = list(provider_models)
model_choices.append("Custom model")
model_choices.append(f"Keep current ({current_model})")
keep_idx = len(model_choices) - 1
model_idx = prompt_choice("Select default model:", model_choices, keep_idx)
selected_model = current_model
if model_idx < len(provider_models):
selected_model = provider_models[model_idx]
if is_copilot_catalog_provider:
selected_model = normalize_copilot_model_id(
selected_model,
catalog=catalog,
api_key=api_key,
) or selected_model
elif provider_id in {"opencode-zen", "opencode-go"}:
selected_model = normalize_opencode_model_id(provider_id, selected_model)
_set_default_model(config, selected_model)
elif model_idx == len(provider_models):
custom = prompt_fn("Enter model name")
if custom:
if is_copilot_catalog_provider:
selected_model = normalize_copilot_model_id(
custom,
catalog=catalog,
api_key=api_key,
) or custom
elif provider_id in {"opencode-zen", "opencode-go"}:
selected_model = normalize_opencode_model_id(provider_id, custom)
else:
selected_model = custom
_set_default_model(config, selected_model)
else:
# "Keep current" selected — validate it's compatible with the new
# provider. OpenRouter-formatted names (containing "/") won't work
# on direct-API providers and would silently break the gateway.
if "/" in (current_model or "") and provider_models:
print_warning(
f"Current model \"{current_model}\" looks like an OpenRouter model "
f"and won't work with {pconfig.name}. "
f"Switching to {provider_models[0]}."
)
selected_model = provider_models[0]
_set_default_model(config, provider_models[0])
if provider_id == "copilot" and selected_model:
model_cfg = _model_config_dict(config)
model_cfg["api_mode"] = copilot_model_api_mode(
selected_model,
catalog=catalog,
api_key=api_key,
)
config["model"] = model_cfg
_setup_copilot_reasoning_selection(
config,
selected_model,
prompt_choice,
catalog=catalog,
api_key=api_key,
)
elif provider_id in {"opencode-zen", "opencode-go"} and selected_model:
model_cfg = _model_config_dict(config)
model_cfg["api_mode"] = opencode_model_api_mode(provider_id, selected_model)
config["model"] = model_cfg
# Import config helpers
from hermes_cli.config import (
@@ -2431,120 +2572,9 @@ _OPENCLAW_SCRIPT = (
)
def _load_openclaw_migration_module():
"""Load the openclaw_to_hermes migration script as a module.
Returns the loaded module, or None if the script can't be loaded.
"""
if not _OPENCLAW_SCRIPT.exists():
return None
spec = importlib.util.spec_from_file_location(
"openclaw_to_hermes", _OPENCLAW_SCRIPT
)
if spec is None or spec.loader is None:
return None
mod = importlib.util.module_from_spec(spec)
# Register in sys.modules so @dataclass can resolve the module
# (Python 3.11+ requires this for dynamically loaded modules)
import sys as _sys
_sys.modules[spec.name] = mod
try:
spec.loader.exec_module(mod)
except Exception:
_sys.modules.pop(spec.name, None)
raise
return mod
# Item kinds that represent high-impact changes warranting explicit warnings.
# Gateway tokens/channels can hijack messaging platforms from the old agent.
# Config values may have different semantics between OpenClaw and Hermes.
# Instruction/context files (.md) can contain incompatible setup procedures.
_HIGH_IMPACT_KIND_KEYWORDS = {
"gateway": "⚠ Gateway/messaging — this will configure Hermes to use your OpenClaw messaging channels",
"telegram": "⚠ Telegram — this will point Hermes at your OpenClaw Telegram bot",
"slack": "⚠ Slack — this will point Hermes at your OpenClaw Slack workspace",
"discord": "⚠ Discord — this will point Hermes at your OpenClaw Discord bot",
"whatsapp": "⚠ WhatsApp — this will point Hermes at your OpenClaw WhatsApp connection",
"config": "⚠ Config values — OpenClaw settings may not map 1:1 to Hermes equivalents",
"soul": "⚠ Instruction file — may contain OpenClaw-specific setup/restart procedures",
"memory": "⚠ Memory/context file — may reference OpenClaw-specific infrastructure",
"context": "⚠ Context file — may contain OpenClaw-specific instructions",
}
def _print_migration_preview(report: dict):
"""Print a detailed dry-run preview of what migration would do.
Groups items by category and adds explicit warnings for high-impact
changes like gateway token takeover and config value differences.
"""
items = report.get("items", [])
if not items:
print_info("Nothing to migrate.")
return
migrated_items = [i for i in items if i.get("status") == "migrated"]
conflict_items = [i for i in items if i.get("status") == "conflict"]
skipped_items = [i for i in items if i.get("status") == "skipped"]
warnings_shown = set()
if migrated_items:
print(color(" Would import:", Colors.GREEN))
for item in migrated_items:
kind = item.get("kind", "unknown")
dest = item.get("destination", "")
if dest:
dest_short = str(dest).replace(str(Path.home()), "~")
print(f" {kind:<22s}{dest_short}")
else:
print(f" {kind}")
# Check for high-impact items and collect warnings
kind_lower = kind.lower()
dest_lower = str(dest).lower()
for keyword, warning in _HIGH_IMPACT_KIND_KEYWORDS.items():
if keyword in kind_lower or keyword in dest_lower:
warnings_shown.add(warning)
print()
if conflict_items:
print(color(" Would overwrite (conflicts with existing Hermes config):", Colors.YELLOW))
for item in conflict_items:
kind = item.get("kind", "unknown")
reason = item.get("reason", "already exists")
print(f" {kind:<22s} {reason}")
print()
if skipped_items:
print(color(" Would skip:", Colors.DIM))
for item in skipped_items:
kind = item.get("kind", "unknown")
reason = item.get("reason", "")
print(f" {kind:<22s} {reason}")
print()
# Print collected warnings
if warnings_shown:
print(color(" ── Warnings ──", Colors.YELLOW))
for warning in sorted(warnings_shown):
print(color(f" {warning}", Colors.YELLOW))
print()
print(color(" Note: OpenClaw config values may have different semantics in Hermes.", Colors.YELLOW))
print(color(" For example, OpenClaw's tool_call_execution: \"auto\" ≠ Hermes's yolo mode.", Colors.YELLOW))
print(color(" Instruction files (.md) from OpenClaw may contain incompatible procedures.", Colors.YELLOW))
print()
def _offer_openclaw_migration(hermes_home: Path) -> bool:
"""Detect ~/.openclaw and offer to migrate during first-time setup.
Runs a dry-run first to show the user exactly what would be imported,
overwritten, or taken over. Only executes after explicit confirmation.
Returns True if migration ran successfully, False otherwise.
"""
openclaw_dir = Path.home() / ".openclaw"
@@ -2557,12 +2587,12 @@ def _offer_openclaw_migration(hermes_home: Path) -> bool:
print()
print_header("OpenClaw Installation Detected")
print_info(f"Found OpenClaw data at {openclaw_dir}")
print_info("Hermes can preview what would be imported before making any changes.")
print_info("Hermes can import your settings, memories, skills, and API keys.")
print()
if not prompt_yes_no("Would you like to see what can be imported?", default=True):
if not prompt_yes_no("Would you like to import from OpenClaw?", default=True):
print_info(
"Skipping migration. You can run it later with: hermes claw migrate --dry-run"
"Skipping migration. You can run it later via the openclaw-migration skill."
)
return False
@@ -2571,71 +2601,34 @@ def _offer_openclaw_migration(hermes_home: Path) -> bool:
if not config_path.exists():
save_config(load_config())
# Load the migration module
# Dynamically load the migration script
try:
mod = _load_openclaw_migration_module()
if mod is None:
spec = importlib.util.spec_from_file_location(
"openclaw_to_hermes", _OPENCLAW_SCRIPT
)
if spec is None or spec.loader is None:
print_warning("Could not load migration script.")
return False
except Exception as e:
print_warning(f"Could not load migration script: {e}")
logger.debug("OpenClaw migration module load error", exc_info=True)
return False
# ── Phase 1: Dry-run preview ──
try:
mod = importlib.util.module_from_spec(spec)
# Register in sys.modules so @dataclass can resolve the module
# (Python 3.11+ requires this for dynamically loaded modules)
import sys as _sys
_sys.modules[spec.name] = mod
try:
spec.loader.exec_module(mod)
except Exception:
_sys.modules.pop(spec.name, None)
raise
# Run migration with the "full" preset, execute mode, no overwrite
selected = mod.resolve_selected_options(None, None, preset="full")
dry_migrator = mod.Migrator(
source_root=openclaw_dir.resolve(),
target_root=hermes_home.resolve(),
execute=False, # dry-run — no files modified
workspace_target=None,
overwrite=True, # show everything including conflicts
migrate_secrets=True,
output_dir=None,
selected_options=selected,
preset_name="full",
)
preview_report = dry_migrator.migrate()
except Exception as e:
print_warning(f"Migration preview failed: {e}")
logger.debug("OpenClaw migration preview error", exc_info=True)
return False
# Display the full preview
preview_summary = preview_report.get("summary", {})
preview_count = preview_summary.get("migrated", 0)
if preview_count == 0:
print()
print_info("Nothing to import from OpenClaw.")
return False
print()
print_header(f"Migration Preview — {preview_count} item(s) would be imported")
print_info("No changes have been made yet. Review the list below:")
print()
_print_migration_preview(preview_report)
# ── Phase 2: Confirm and execute ──
if not prompt_yes_no("Proceed with migration?", default=False):
print_info(
"Migration cancelled. You can run it later with: hermes claw migrate"
)
print_info(
"Use --dry-run to preview again, or --preset minimal for a lighter import."
)
return False
# Execute the migration — overwrite=False so existing Hermes configs are
# preserved. The user saw the preview; conflicts are skipped by default.
try:
migrator = mod.Migrator(
source_root=openclaw_dir.resolve(),
target_root=hermes_home.resolve(),
execute=True,
workspace_target=None,
overwrite=False, # preserve existing Hermes config
overwrite=True,
migrate_secrets=True,
output_dir=None,
selected_options=selected,
@@ -2647,7 +2640,7 @@ def _offer_openclaw_migration(hermes_home: Path) -> bool:
logger.debug("OpenClaw migration error", exc_info=True)
return False
# Print final summary
# Print summary
summary = report.get("summary", {})
migrated = summary.get("migrated", 0)
skipped = summary.get("skipped", 0)
@@ -2658,7 +2651,7 @@ def _offer_openclaw_migration(hermes_home: Path) -> bool:
if migrated:
print_success(f"Imported {migrated} item(s) from OpenClaw.")
if conflicts:
print_info(f"Skipped {conflicts} item(s) that already exist in Hermes (use hermes claw migrate --overwrite to force).")
print_info(f"Skipped {conflicts} item(s) that already exist in Hermes.")
if skipped:
print_info(f"Skipped {skipped} item(s) (not found or unchanged).")
if errors:
+6 -2
View File
@@ -72,13 +72,13 @@ def display_hermes_home() -> str:
return str(home)
VALID_REASONING_EFFORTS = ("minimal", "low", "medium", "high", "xhigh")
VALID_REASONING_EFFORTS = ("xhigh", "high", "medium", "low", "minimal")
def parse_reasoning_effort(effort: str) -> dict | None:
"""Parse a reasoning effort level into a config dict.
Valid levels: "none", "minimal", "low", "medium", "high", "xhigh".
Valid levels: "xhigh", "high", "medium", "low", "minimal", "none".
Returns None when the input is empty or unrecognized (caller uses default).
Returns {"enabled": False} for "none".
Returns {"enabled": True, "effort": <level>} for valid effort levels.
@@ -95,7 +95,11 @@ def parse_reasoning_effort(effort: str) -> dict | None:
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
OPENROUTER_MODELS_URL = f"{OPENROUTER_BASE_URL}/models"
OPENROUTER_CHAT_URL = f"{OPENROUTER_BASE_URL}/chat/completions"
AI_GATEWAY_BASE_URL = "https://ai-gateway.vercel.sh/v1"
AI_GATEWAY_MODELS_URL = f"{AI_GATEWAY_BASE_URL}/models"
AI_GATEWAY_CHAT_URL = f"{AI_GATEWAY_BASE_URL}/chat/completions"
NOUS_API_BASE_URL = "https://inference-api.nousresearch.com/v1"
NOUS_API_CHAT_URL = f"{NOUS_API_BASE_URL}/chat/completions"
+1 -34
View File
@@ -13,7 +13,6 @@ secrets are never written to disk.
"""
import logging
import os
from logging.handlers import RotatingFileHandler
from pathlib import Path
from typing import Optional
@@ -178,38 +177,6 @@ def setup_verbose_logging() -> None:
# Internal helpers
# ---------------------------------------------------------------------------
class _ManagedRotatingFileHandler(RotatingFileHandler):
"""RotatingFileHandler that ensures group-writable perms in managed mode.
In managed mode (NixOS), the stateDir uses setgid (2770) so new files
inherit the hermes group. However, both _open() (initial creation) and
doRollover() create files via open(), which uses the process umask
typically 0022, producing 0644. This subclass applies chmod 0660 after
both operations so the gateway and interactive users can share log files.
"""
def __init__(self, *args, **kwargs):
from hermes_cli.config import is_managed
self._managed = is_managed()
super().__init__(*args, **kwargs)
def _chmod_if_managed(self):
if self._managed:
try:
os.chmod(self.baseFilename, 0o660)
except OSError:
pass
def _open(self):
stream = super()._open()
self._chmod_if_managed()
return stream
def doRollover(self):
super().doRollover()
self._chmod_if_managed()
def _add_rotating_handler(
logger: logging.Logger,
path: Path,
@@ -231,7 +198,7 @@ def _add_rotating_handler(
return # already attached
path.parent.mkdir(parents=True, exist_ok=True)
handler = _ManagedRotatingFileHandler(
handler = RotatingFileHandler(
str(path), maxBytes=max_bytes, backupCount=backup_count,
)
handler.setLevel(level)
+95 -29
View File
@@ -520,6 +520,72 @@ class SessionDB:
)
self._execute_write(_do)
def set_token_counts(
self,
session_id: str,
input_tokens: int = 0,
output_tokens: int = 0,
model: str = None,
cache_read_tokens: int = 0,
cache_write_tokens: int = 0,
reasoning_tokens: int = 0,
estimated_cost_usd: Optional[float] = None,
actual_cost_usd: Optional[float] = None,
cost_status: Optional[str] = None,
cost_source: Optional[str] = None,
pricing_version: Optional[str] = None,
billing_provider: Optional[str] = None,
billing_base_url: Optional[str] = None,
billing_mode: Optional[str] = None,
) -> None:
"""Set token counters to absolute values (not increment).
Use this when the caller provides cumulative totals from a completed
conversation run (e.g. the gateway, where the cached agent's
session_prompt_tokens already reflects the running total).
"""
def _do(conn):
conn.execute(
"""UPDATE sessions SET
input_tokens = ?,
output_tokens = ?,
cache_read_tokens = ?,
cache_write_tokens = ?,
reasoning_tokens = ?,
estimated_cost_usd = ?,
actual_cost_usd = CASE
WHEN ? IS NULL THEN actual_cost_usd
ELSE ?
END,
cost_status = COALESCE(?, cost_status),
cost_source = COALESCE(?, cost_source),
pricing_version = COALESCE(?, pricing_version),
billing_provider = COALESCE(billing_provider, ?),
billing_base_url = COALESCE(billing_base_url, ?),
billing_mode = COALESCE(billing_mode, ?),
model = COALESCE(model, ?)
WHERE id = ?""",
(
input_tokens,
output_tokens,
cache_read_tokens,
cache_write_tokens,
reasoning_tokens,
estimated_cost_usd,
actual_cost_usd,
actual_cost_usd,
cost_status,
cost_source,
pricing_version,
billing_provider,
billing_base_url,
billing_mode,
model,
session_id,
),
)
self._execute_write(_do)
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
"""Get a session by ID."""
with self._lock:
@@ -878,8 +944,7 @@ class SessionDB:
try:
msg["tool_calls"] = json.loads(msg["tool_calls"])
except (json.JSONDecodeError, TypeError):
logger.warning("Failed to deserialize tool_calls in get_messages, falling back to []")
msg["tool_calls"] = []
pass
result.append(msg)
return result
@@ -907,8 +972,7 @@ class SessionDB:
try:
msg["tool_calls"] = json.loads(row["tool_calls"])
except (json.JSONDecodeError, TypeError):
logger.warning("Failed to deserialize tool_calls in conversation replay, falling back to []")
msg["tool_calls"] = []
pass
# Restore reasoning fields on assistant messages so providers
# that replay reasoning (OpenRouter, OpenAI, Nous) receive
# coherent multi-turn reasoning context.
@@ -919,14 +983,12 @@ class SessionDB:
try:
msg["reasoning_details"] = json.loads(row["reasoning_details"])
except (json.JSONDecodeError, TypeError):
logger.warning("Failed to deserialize reasoning_details, falling back to None")
msg["reasoning_details"] = None
pass
if row["codex_reasoning_items"]:
try:
msg["codex_reasoning_items"] = json.loads(row["codex_reasoning_items"])
except (json.JSONDecodeError, TypeError):
logger.warning("Failed to deserialize codex_reasoning_items, falling back to None")
msg["codex_reasoning_items"] = None
pass
messages.append(msg)
return messages
@@ -1173,10 +1235,10 @@ class SessionDB:
self._execute_write(_do)
def delete_session(self, session_id: str) -> bool:
"""Delete a session and all its messages.
"""Delete a session, its child sessions, and all their messages.
Child sessions are orphaned (parent_session_id set to NULL) rather
than cascade-deleted, so they remain accessible independently.
Child sessions (subagent runs, compression continuations) are deleted
first to satisfy the ``parent_session_id`` foreign key constraint.
Returns True if the session was found and deleted.
"""
def _do(conn):
@@ -1185,12 +1247,15 @@ class SessionDB:
)
if cursor.fetchone()[0] == 0:
return False
# Orphan child sessions so FK constraint is satisfied
conn.execute(
"UPDATE sessions SET parent_session_id = NULL "
"WHERE parent_session_id = ?",
# Delete child sessions first (FK constraint)
child_ids = [r[0] for r in conn.execute(
"SELECT id FROM sessions WHERE parent_session_id = ?",
(session_id,),
)
).fetchall()]
for cid in child_ids:
conn.execute("DELETE FROM messages WHERE session_id = ?", (cid,))
conn.execute("DELETE FROM sessions WHERE id = ?", (cid,))
# Delete the session itself
conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,))
return True
@@ -1199,9 +1264,9 @@ class SessionDB:
def prune_sessions(self, older_than_days: int = 90, source: str = None) -> int:
"""Delete sessions older than N days. Returns count of deleted sessions.
Only prunes ended sessions (not active ones). Child sessions outside
the prune window are orphaned (parent_session_id set to NULL) rather
than cascade-deleted.
Only prunes ended sessions (not active ones). Child sessions whose
parents are being pruned are deleted first to satisfy the
``parent_session_id`` foreign key constraint.
"""
cutoff = time.time() - (older_than_days * 86400)
@@ -1219,16 +1284,17 @@ class SessionDB:
)
session_ids = set(row["id"] for row in cursor.fetchall())
if not session_ids:
return 0
# Orphan any sessions whose parent is about to be deleted
placeholders = ",".join("?" * len(session_ids))
conn.execute(
f"UPDATE sessions SET parent_session_id = NULL "
f"WHERE parent_session_id IN ({placeholders})",
list(session_ids),
)
# Delete children first whose parents are in the prune set
# (avoids FK constraint errors)
for sid in list(session_ids):
child_ids = [r[0] for r in conn.execute(
"SELECT id FROM sessions WHERE parent_session_id = ?",
(sid,),
).fetchall()]
for cid in child_ids:
conn.execute("DELETE FROM messages WHERE session_id = ?", (cid,))
conn.execute("DELETE FROM sessions WHERE id = ?", (cid,))
session_ids.discard(cid) # don't double-delete
for sid in session_ids:
conn.execute("DELETE FROM messages WHERE session_id = ?", (sid,))
+13
View File
@@ -89,6 +89,13 @@ def get_timezone() -> Optional[ZoneInfo]:
return _cached_tz
def get_timezone_name() -> str:
"""Return the IANA name of the configured timezone, or empty string."""
if not _cache_resolved:
get_timezone() # populates cache
return _cached_tz_name or ""
def now() -> datetime:
"""
Return the current time as a timezone-aware datetime.
@@ -103,3 +110,9 @@ def now() -> datetime:
return datetime.now().astimezone()
def reset_cache() -> None:
"""Clear the cached timezone. Used by tests and after config changes."""
global _cached_tz, _cached_tz_name, _cache_resolved
_cached_tz = None
_cached_tz_name = None
_cache_resolved = False
+5 -27
View File
@@ -560,40 +560,22 @@
# ── Directories ───────────────────────────────────────────────────
{
systemd.tmpfiles.rules = [
"d ${cfg.stateDir} 2770 ${cfg.user} ${cfg.group} - -"
"d ${cfg.stateDir}/.hermes 2770 ${cfg.user} ${cfg.group} - -"
"d ${cfg.stateDir}/.hermes/cron 2770 ${cfg.user} ${cfg.group} - -"
"d ${cfg.stateDir}/.hermes/sessions 2770 ${cfg.user} ${cfg.group} - -"
"d ${cfg.stateDir}/.hermes/logs 2770 ${cfg.user} ${cfg.group} - -"
"d ${cfg.stateDir}/.hermes/memories 2770 ${cfg.user} ${cfg.group} - -"
"d ${cfg.stateDir} 0750 ${cfg.user} ${cfg.group} - -"
"d ${cfg.stateDir}/.hermes 0750 ${cfg.user} ${cfg.group} - -"
"d ${cfg.stateDir}/home 0750 ${cfg.user} ${cfg.group} - -"
"d ${cfg.workingDirectory} 2770 ${cfg.user} ${cfg.group} - -"
"d ${cfg.workingDirectory} 0750 ${cfg.user} ${cfg.group} - -"
];
}
# ── Activation: link config + auth + documents ────────────────────
{
system.activationScripts."hermes-agent-setup" = lib.stringAfter ([ "users" ] ++ lib.optional (config.system.activationScripts ? setupSecrets) "setupSecrets") ''
system.activationScripts."hermes-agent-setup" = lib.stringAfter [ "users" "setupSecrets" ] ''
# Ensure directories exist (activation runs before tmpfiles)
mkdir -p ${cfg.stateDir}/.hermes
mkdir -p ${cfg.stateDir}/home
mkdir -p ${cfg.workingDirectory}
chown ${cfg.user}:${cfg.group} ${cfg.stateDir} ${cfg.stateDir}/.hermes ${cfg.stateDir}/home ${cfg.workingDirectory}
chmod 2770 ${cfg.stateDir} ${cfg.stateDir}/.hermes ${cfg.workingDirectory}
chmod 0750 ${cfg.stateDir}/home
# Create subdirs, set setgid + group-writable, migrate existing files.
# Nix-managed files (config.yaml, .env, .managed) stay 0640/0644.
find ${cfg.stateDir}/.hermes -maxdepth 1 \
\( -name "*.db" -o -name "*.db-wal" -o -name "*.db-shm" -o -name "SOUL.md" \) \
-exec chmod g+rw {} + 2>/dev/null || true
for _subdir in cron sessions logs memories; do
mkdir -p "${cfg.stateDir}/.hermes/$_subdir"
chown ${cfg.user}:${cfg.group} "${cfg.stateDir}/.hermes/$_subdir"
chmod 2770 "${cfg.stateDir}/.hermes/$_subdir"
find "${cfg.stateDir}/.hermes/$_subdir" -type f \
-exec chmod g+rw {} + 2>/dev/null || true
done
chmod 0750 ${cfg.stateDir} ${cfg.stateDir}/.hermes ${cfg.stateDir}/home ${cfg.workingDirectory}
# Merge Nix settings into existing config.yaml.
# Preserves user-added keys (skills, streaming, etc.); Nix keys win.
@@ -680,10 +662,6 @@ HERMES_NIX_ENV_EOF
Restart = cfg.restart;
RestartSec = cfg.restartSec;
# Shared-state: files created by the gateway should be group-writable
# so interactive users in the hermes group can read/write them.
UMask = "0007";
# Hardening
NoNewPrivileges = true;
ProtectSystem = "strict";
+1 -1
View File
@@ -14,7 +14,7 @@
};
runtimeDeps = with pkgs; [
nodejs_20 ripgrep git openssh ffmpeg tirith
nodejs_20 ripgrep git openssh ffmpeg
];
runtimePath = pkgs.lib.makeBinPath runtimeDeps;
@@ -1803,34 +1803,30 @@ class Migrator:
def migrate_cron_jobs(self, config: Optional[Dict[str, Any]] = None) -> None:
config = config or self.load_openclaw_config()
cron = config.get("cron") or {}
if not cron:
self.record("cron-jobs", None, None, "skipped", "No cron configuration found")
return
# Archive the full cron config
if self.archive_dir and self.execute:
self.archive_dir.mkdir(parents=True, exist_ok=True)
dest = self.archive_dir / "cron-config.json"
dest.write_text(json.dumps(cron, indent=2, ensure_ascii=False) + "\n", encoding="utf-8")
self.record("cron-jobs", "openclaw.json cron.*", str(dest), "archived",
"Cron config archived. Use 'hermes cron' to recreate jobs manually.")
else:
self.record("cron-jobs", "openclaw.json cron.*", "archive/cron-config.json",
"archived", "Would archive cron config")
# Also check for cron store files
cron_store = self.source_root / "cron"
found_any = False
# Archive the full cron config when present
if cron:
found_any = True
if self.archive_dir and self.execute:
self.archive_dir.mkdir(parents=True, exist_ok=True)
dest = self.archive_dir / "cron-config.json"
dest.write_text(json.dumps(cron, indent=2, ensure_ascii=False) + "\n", encoding="utf-8")
self.record("cron-jobs", "openclaw.json cron.*", str(dest), "archived",
"Cron config archived. Use 'hermes cron' to recreate jobs manually.")
else:
self.record("cron-jobs", "openclaw.json cron.*", "archive/cron-config.json",
"archived", "Would archive cron config")
# Also check for cron store files even when config.cron is missing
if cron_store.is_dir() and self.archive_dir:
found_any = True
dest_cron = self.archive_dir / "cron-store"
if self.execute:
shutil.copytree(cron_store, dest_cron, dirs_exist_ok=True)
self.record("cron-jobs", str(cron_store), str(dest_cron), "archived",
"Cron job store archived")
if not found_any:
self.record("cron-jobs", None, None, "skipped", "No cron configuration found")
# ── Hooks ─────────────────────────────────────────────────
def migrate_hooks_config(self, config: Optional[Dict[str, Any]] = None) -> None:
config = config or self.load_openclaw_config()
@@ -2458,15 +2454,6 @@ class Migrator:
notes.append(f"- **{item.kind}**: {item.reason}")
notes.append("")
has_cron_config_archive = any(
i.kind == "cron-jobs" and i.status == "archived" and i.destination and i.destination.endswith("cron-config.json")
for i in self.items
)
has_cron_store_archive = any(
i.kind == "cron-jobs" and i.status == "archived" and i.destination and i.destination.endswith("cron-store")
for i in self.items
)
notes.extend([
"## IMPORTANT: Archive the OpenClaw Directory",
"",
@@ -2488,14 +2475,7 @@ class Migrator:
"- Run `hermes claw cleanup` to archive the OpenClaw directory (prevents state confusion)",
"- Run `hermes setup` to configure any remaining settings",
"- Run `hermes mcp list` to verify MCP servers were imported correctly",
])
if has_cron_config_archive:
notes.append("- Run `hermes cron` to recreate scheduled tasks (see archive/cron-config.json)")
elif has_cron_store_archive:
notes.append("- Run `hermes cron` to recreate scheduled tasks (see archived cron-store)")
notes.extend([
"- Run `hermes cron` to recreate scheduled tasks (see archive/cron-config.json)",
"- Run `hermes gateway install` if you need the gateway service",
"- Review `~/.hermes/config.yaml` for any adjustments",
"",
+136 -218
View File
@@ -77,7 +77,6 @@ from hermes_constants import OPENROUTER_BASE_URL
# Agent internals extracted to agent/ package for modularity
from agent.memory_manager import build_memory_context_block
from agent.retry_utils import jittered_backoff
from agent.error_classifier import classify_api_error, FailoverReason
from agent.prompt_builder import (
DEFAULT_AGENT_IDENTITY, PLATFORM_HINTS,
MEMORY_GUIDANCE, SESSION_SEARCH_GUIDANCE, SKILLS_GUIDANCE,
@@ -87,7 +86,6 @@ from agent.model_metadata import (
fetch_model_metadata,
estimate_tokens_rough, estimate_messages_tokens_rough, estimate_request_tokens_rough,
get_next_probe_tier, parse_context_limit_from_error,
parse_available_output_tokens_from_error,
save_context_length, is_local_endpoint,
query_ollama_num_ctx,
)
@@ -624,6 +622,7 @@ class AIAgent:
self.tool_complete_callback = tool_complete_callback
self.thinking_callback = thinking_callback
self.reasoning_callback = reasoning_callback
self._reasoning_deltas_fired = False # Set by _fire_reasoning_delta, reset per API call
self.clarify_callback = clarify_callback
self.step_callback = step_callback
self.stream_delta_callback = stream_delta_callback
@@ -693,10 +692,6 @@ class AIAgent:
self._current_tool: str | None = None
self._api_call_count: int = 0
# Rate limit tracking — updated from x-ratelimit-* response headers
# after each API call. Accessed by /usage slash command.
self._rate_limit_state: Optional["RateLimitState"] = None
# Centralized logging — agent.log (INFO+) and errors.log (WARNING+)
# both live under ~/.hermes/logs/. Idempotent, so gateway mode
# (which creates a new AIAgent per message) won't duplicate handlers.
@@ -1298,6 +1293,7 @@ class AIAgent:
if hasattr(self, "context_compressor") and self.context_compressor:
self.context_compressor.last_prompt_tokens = 0
self.context_compressor.last_completion_tokens = 0
self.context_compressor.last_total_tokens = 0
self.context_compressor.compression_count = 0
self.context_compressor._context_probed = False
self.context_compressor._context_probe_persistable = False
@@ -2549,29 +2545,6 @@ class AIAgent:
self._last_activity_ts = time.time()
self._last_activity_desc = desc
def _capture_rate_limits(self, http_response: Any) -> None:
"""Parse x-ratelimit-* headers from an HTTP response and cache the state.
Called after each streaming API call. The httpx Response object is
available on the OpenAI SDK Stream via ``stream.response``.
"""
if http_response is None:
return
headers = getattr(http_response, "headers", None)
if not headers:
return
try:
from agent.rate_limit_tracker import parse_rate_limit_headers
state = parse_rate_limit_headers(headers, provider=self.provider)
if state is not None:
self._rate_limit_state = state
except Exception:
pass # Never let header parsing break the agent loop
def get_rate_limit_state(self):
"""Return the last captured RateLimitState, or None."""
return self._rate_limit_state
def get_activity_summary(self) -> dict:
"""Return a snapshot of the agent's current activity for diagnostics.
@@ -3847,6 +3820,7 @@ class AIAgent:
max_stream_retries = 1
has_tool_calls = False
first_delta_fired = False
self._reasoning_deltas_fired = False
# Accumulate streamed text so we can recover if get_final_response()
# returns empty output (e.g. chatgpt.com backend-api sends
# response.incomplete instead of response.completed).
@@ -4324,6 +4298,7 @@ class AIAgent:
def _fire_reasoning_delta(self, text: str) -> None:
"""Fire reasoning callback if registered."""
self._reasoning_deltas_fired = True
cb = self.reasoning_callback
if cb is not None:
try:
@@ -4424,11 +4399,6 @@ class AIAgent:
self._touch_activity("waiting for provider response (streaming)")
stream = request_client_holder["client"].chat.completions.create(**stream_kwargs)
# Capture rate limit headers from the initial HTTP response.
# The OpenAI SDK Stream object exposes the underlying httpx
# response via .response before any chunks are consumed.
self._capture_rate_limits(getattr(stream, "response", None))
content_parts: list = []
tool_calls_acc: dict = {}
tool_gen_notified: set = set()
@@ -4443,6 +4413,10 @@ class AIAgent:
role = "assistant"
reasoning_parts: list = []
usage_obj = None
# Reset per-call reasoning tracking so _build_assistant_message
# knows whether reasoning was already displayed during streaming.
self._reasoning_deltas_fired = False
_first_chunk_seen = False
for chunk in stream:
last_chunk_time["t"] = time.time()
@@ -4599,6 +4573,7 @@ class AIAgent:
works unchanged.
"""
has_tool_use = False
self._reasoning_deltas_fired = False
# Reset stale-stream timer for this attempt
last_chunk_time["t"] = time.time()
@@ -4960,21 +4935,9 @@ class AIAgent:
# Swap OpenAI client and config in-place
self.api_key = fb_client.api_key
self.client = fb_client
# Preserve provider-specific headers that
# resolve_provider_client() may have baked into
# fb_client via the default_headers kwarg. The OpenAI
# SDK stores these in _custom_headers. Without this,
# subsequent request-client rebuilds (via
# _create_request_openai_client) drop the headers,
# causing 403s from providers like Kimi Coding that
# require a User-Agent sentinel.
fb_headers = getattr(fb_client, "_custom_headers", None)
if not fb_headers:
fb_headers = getattr(fb_client, "default_headers", None)
self._client_kwargs = {
"api_key": fb_client.api_key,
"base_url": fb_base_url,
**({"default_headers": dict(fb_headers)} if fb_headers else {}),
}
# Re-evaluate prompt caching for the new provider/model
@@ -5389,22 +5352,15 @@ class AIAgent:
if self.api_mode == "anthropic_messages":
from agent.anthropic_adapter import build_anthropic_kwargs
anthropic_messages = self._prepare_anthropic_messages_for_api(api_messages)
# Pass context_length (total input+output window) so the adapter can
# clamp max_tokens (output cap) when the user configured a smaller
# context window than the model's native output limit.
# Pass context_length so the adapter can clamp max_tokens if the
# user configured a smaller context window than the model's output limit.
ctx_len = getattr(self, "context_compressor", None)
ctx_len = ctx_len.context_length if ctx_len else None
# _ephemeral_max_output_tokens is set for one call when the API
# returns "max_tokens too large given prompt" — it caps output to
# the available window space without touching context_length.
ephemeral_out = getattr(self, "_ephemeral_max_output_tokens", None)
if ephemeral_out is not None:
self._ephemeral_max_output_tokens = None # consume immediately
return build_anthropic_kwargs(
model=self.model,
messages=anthropic_messages,
tools=self.tools,
max_tokens=ephemeral_out if ephemeral_out is not None else self.max_tokens,
max_tokens=self.max_tokens,
reasoning_config=self.reasoning_config,
is_oauth=self._is_anthropic_oauth,
preserve_dots=self._anthropic_preserve_dots(),
@@ -7293,7 +7249,6 @@ class AIAgent:
length_continue_retries = 0
truncated_response_prefix = ""
compression_attempts = 0
_turn_exit_reason = "unknown" # Diagnostic: why the loop ended
# Clear any stale interrupt state at start
self.clear_interrupt()
@@ -7318,7 +7273,6 @@ class AIAgent:
# Check for interrupt request (e.g., user sent new message)
if self._interrupt_requested:
interrupted = True
_turn_exit_reason = "interrupted_by_user"
if not self.quiet_mode:
self._safe_print("\n⚡ Breaking out of tool loop due to interrupt...")
break
@@ -7327,7 +7281,6 @@ class AIAgent:
self._api_call_count = api_call_count
self._touch_activity(f"starting API call #{api_call_count}")
if not self.iteration_budget.consume():
_turn_exit_reason = "budget_exhausted"
if not self.quiet_mode:
self._safe_print(f"\n⚠️ Iteration budget exhausted ({self.iteration_budget.used}/{self.iteration_budget.max_total} iterations used)")
break
@@ -8032,25 +7985,6 @@ class AIAgent:
status_code = getattr(api_error, "status_code", None)
error_context = self._extract_api_error_context(api_error)
# ── Classify the error for structured recovery decisions ──
_compressor = getattr(self, "context_compressor", None)
_ctx_len = getattr(_compressor, "context_length", 200000) if _compressor else 200000
classified = classify_api_error(
api_error,
provider=getattr(self, "provider", "") or "",
model=getattr(self, "model", "") or "",
approx_tokens=approx_tokens,
context_length=_ctx_len,
num_messages=len(api_messages) if api_messages else 0,
)
logger.debug(
"Error classified: reason=%s status=%s retryable=%s compress=%s rotate=%s fallback=%s",
classified.reason.value, classified.status_code,
classified.retryable, classified.should_compress,
classified.should_rotate_credential, classified.should_fallback,
)
recovered_with_pool, has_retried_429 = self._recover_with_credential_pool(
status_code=status_code,
has_retried_429=has_retried_429,
@@ -8113,24 +8047,27 @@ class AIAgent:
# from all messages so the next retry sends no thinking
# blocks at all. One-shot — don't retry infinitely.
if (
classified.reason == FailoverReason.thinking_signature
self.api_mode == "anthropic_messages"
and status_code == 400
and not thinking_sig_retry_attempted
):
thinking_sig_retry_attempted = True
for _m in messages:
if isinstance(_m, dict):
_m.pop("reasoning_details", None)
self._vprint(
f"{self.log_prefix}⚠️ Thinking block signature invalid — "
f"stripped all thinking blocks, retrying...",
force=True,
)
logging.warning(
"%sThinking block signature recovery: stripped "
"reasoning_details from %d messages",
self.log_prefix, len(messages),
)
continue
_err_msg_lower = str(api_error).lower()
if "signature" in _err_msg_lower and "thinking" in _err_msg_lower:
thinking_sig_retry_attempted = True
for _m in messages:
if isinstance(_m, dict):
_m.pop("reasoning_details", None)
self._vprint(
f"{self.log_prefix}⚠️ Thinking block signature invalid — "
f"stripped all thinking blocks, retrying...",
force=True,
)
logging.warning(
"%sThinking block signature recovery: stripped "
"reasoning_details from %d messages",
self.log_prefix, len(messages),
)
continue
retry_count += 1
elapsed_time = time.time() - api_start_time
@@ -8187,7 +8124,14 @@ class AIAgent:
# is NOT a transient rate limit — retrying or switching
# credentials won't help. Reduce context to 200k (the
# standard tier) and compress.
if classified.reason == FailoverReason.long_context_tier:
# Only applies to Sonnet — Opus 1M is general access.
_is_long_context_tier_error = (
status_code == 429
and "extra usage" in error_msg
and "long context" in error_msg
and "sonnet" in self.model.lower()
)
if _is_long_context_tier_error:
_reduced_ctx = 200000
compressor = self.context_compressor
old_ctx = compressor.context_length
@@ -8232,9 +8176,13 @@ class AIAgent:
# When a fallback model is configured, switch immediately instead
# of burning through retries with exponential backoff -- the
# primary provider won't recover within the retry window.
is_rate_limited = classified.reason in (
FailoverReason.rate_limit,
FailoverReason.billing,
is_rate_limited = (
status_code == 429
or "rate limit" in error_msg
or "too many requests" in error_msg
or "rate_limit" in error_msg
or "usage limit" in error_msg
or "quota" in error_msg
)
if is_rate_limited and self._fallback_index < len(self._fallback_chain):
# Don't eagerly fallback if credential pool rotation may
@@ -8250,7 +8198,10 @@ class AIAgent:
continue
is_payload_too_large = (
classified.reason == FailoverReason.payload_too_large
status_code == 413
or 'request entity too large' in error_msg
or 'payload too large' in error_msg
or 'error code: 413' in error_msg
)
if is_payload_too_large:
@@ -8294,59 +8245,69 @@ class AIAgent:
}
# Check for context-length errors BEFORE generic 4xx handler.
# The classifier detects context overflow from: explicit error
# messages, generic 400 + large session heuristic (#1630), and
# server disconnect + large session pattern (#2153).
is_context_length_error = (
classified.reason == FailoverReason.context_overflow
)
# Local backends (LM Studio, Ollama, llama.cpp) often return
# HTTP 400 with messages like "Context size has been exceeded"
# which must trigger compression, not an immediate abort.
is_context_length_error = any(phrase in error_msg for phrase in [
'context length', 'context size', 'maximum context',
'token limit', 'too many tokens', 'reduce the length',
'exceeds the limit', 'context window',
'request entity too large', # OpenRouter/Nous 413 safety net
'prompt is too long', # Anthropic: "prompt is too long: N tokens > M maximum"
'prompt exceeds max length', # Z.AI / GLM: generic 400 overflow wording
])
# Fallback heuristic: Anthropic sometimes returns a generic
# 400 invalid_request_error with just "Error" as the message
# when the context is too large. If the error message is very
# short/generic AND the session is large, treat it as a
# probable context-length error and attempt compression rather
# than aborting. This prevents an infinite failure loop where
# each failed message gets persisted, making the session even
# larger. (#1630)
if not is_context_length_error and status_code == 400:
ctx_len = getattr(getattr(self, 'context_compressor', None), 'context_length', 200000)
is_large_session = approx_tokens > ctx_len * 0.4 or len(api_messages) > 80
is_generic_error = len(error_msg.strip()) < 30 # e.g. just "error"
if is_large_session and is_generic_error:
is_context_length_error = True
self._vprint(
f"{self.log_prefix}⚠️ Generic 400 with large session "
f"(~{approx_tokens:,} tokens, {len(api_messages)} msgs) — "
f"treating as probable context overflow.",
force=True,
)
# Server disconnects on large sessions are often caused by
# the request exceeding the provider's context/payload limit
# without a proper HTTP error response. Treat these as
# context-length errors to trigger compression rather than
# burning through retries that will all fail the same way.
# This breaks the death spiral: disconnect → no token data
# → no compression → bigger session → more disconnects.
# (#2153)
if not is_context_length_error and not status_code:
_is_server_disconnect = (
'server disconnected' in error_msg
or 'peer closed connection' in error_msg
or error_type in ('ReadError', 'RemoteProtocolError', 'ServerDisconnectedError')
)
if _is_server_disconnect:
ctx_len = getattr(getattr(self, 'context_compressor', None), 'context_length', 200000)
_is_large = approx_tokens > ctx_len * 0.6 or len(api_messages) > 200
if _is_large:
is_context_length_error = True
self._vprint(
f"{self.log_prefix}⚠️ Server disconnected with large session "
f"(~{approx_tokens:,} tokens, {len(api_messages)} msgs) — "
f"treating as context-length error, attempting compression.",
force=True,
)
if is_context_length_error:
compressor = self.context_compressor
old_ctx = compressor.context_length
# ── Distinguish two very different errors ───────────
# 1. "Prompt too long": the INPUT exceeds the context window.
# Fix: reduce context_length + compress history.
# 2. "max_tokens too large": input is fine, but
# input_tokens + requested max_tokens > context_window.
# Fix: reduce max_tokens (the OUTPUT cap) for this call.
# Do NOT shrink context_length — the window is unchanged.
#
# Note: max_tokens = output token cap (one response).
# context_length = total window (input + output combined).
available_out = parse_available_output_tokens_from_error(error_msg)
if available_out is not None:
# Error is purely about the output cap being too large.
# Cap output to the available space and retry without
# touching context_length or triggering compression.
safe_out = max(1, available_out - 64) # small safety margin
self._ephemeral_max_output_tokens = safe_out
self._vprint(
f"{self.log_prefix}⚠️ Output cap too large for current prompt — "
f"retrying with max_tokens={safe_out:,} "
f"(available_tokens={available_out:,}; context_length unchanged at {old_ctx:,})",
force=True,
)
# Still count against compression_attempts so we don't
# loop forever if the error keeps recurring.
compression_attempts += 1
if compression_attempts > max_compression_attempts:
self._vprint(f"{self.log_prefix}❌ Max compression attempts ({max_compression_attempts}) reached.", force=True)
self._vprint(f"{self.log_prefix} 💡 Try /new to start a fresh conversation, or /compress to retry compression.", force=True)
logging.error(f"{self.log_prefix}Context compression failed after {max_compression_attempts} attempts.")
self._persist_session(messages, conversation_history)
return {
"messages": messages,
"completed": False,
"api_calls": api_call_count,
"error": f"Context length exceeded: max compression attempts ({max_compression_attempts}) reached.",
"partial": True
}
restart_with_compressed_messages = True
break
# Error is about the INPUT being too large — reduce context_length.
# Try to parse the actual limit from the error message
parsed_limit = parse_context_limit_from_error(error_msg)
if parsed_limit and parsed_limit < old_ctx:
@@ -8413,30 +8374,35 @@ class AIAgent:
"partial": True
}
# Check for non-retryable client errors. The classifier
# already accounts for 413, 429, 529 (transient), context
# overflow, and generic-400 heuristics. Local validation
# errors (ValueError, TypeError) are programming bugs.
# Check for non-retryable client errors (4xx HTTP status codes).
# These indicate a problem with the request itself (bad model ID,
# invalid API key, forbidden, etc.) and will never succeed on retry.
# Note: 413 and context-length errors are excluded — handled above.
# 429 (rate limit) is transient and MUST be retried with backoff.
# 529 (Anthropic overloaded) is also transient.
# Also catch local validation errors (ValueError, TypeError) — these
# are programming bugs, not transient failures.
# Exclude UnicodeEncodeError — it's a ValueError subclass but is
# handled separately by the surrogate sanitization path above.
_RETRYABLE_STATUS_CODES = {413, 429, 529}
is_local_validation_error = (
isinstance(api_error, (ValueError, TypeError))
and not isinstance(api_error, UnicodeEncodeError)
)
is_client_error = (
is_local_validation_error
or (
not classified.retryable
and not classified.should_compress
and classified.reason not in (
FailoverReason.rate_limit,
FailoverReason.billing,
FailoverReason.overloaded,
FailoverReason.context_overflow,
FailoverReason.payload_too_large,
FailoverReason.long_context_tier,
FailoverReason.thinking_signature,
)
)
) and not is_context_length_error
# Detect generic 400s from Anthropic OAuth (transient server-side failures).
# Real invalid_request_error responses include a descriptive message;
# transient ones contain only "Error" or are empty. (ref: issue #1608)
_err_body = getattr(api_error, "body", None) or {}
_err_message = (_err_body.get("error", {}).get("message", "") if isinstance(_err_body, dict) else "")
_is_generic_400 = (status_code == 400 and _err_message.strip().lower() in ("error", ""))
is_client_status_error = isinstance(status_code, int) and 400 <= status_code < 500 and status_code not in _RETRYABLE_STATUS_CODES and not _is_generic_400
is_client_error = (is_local_validation_error or is_client_status_error or any(phrase in error_msg for phrase in [
'error code: 401', 'error code: 403',
'error code: 404', 'error code: 422',
'is not a valid model', 'invalid model', 'model not found',
'invalid api key', 'invalid_api_key', 'authentication',
'unauthorized', 'forbidden', 'not found',
])) and not is_context_length_error
if is_client_error:
# Try fallback before aborting — a different provider
@@ -8456,7 +8422,7 @@ class AIAgent:
self._vprint(f"{self.log_prefix} 🔌 Provider: {_provider} Model: {_model}", force=True)
self._vprint(f"{self.log_prefix} 🌐 Endpoint: {_base}", force=True)
# Actionable guidance for common auth errors
if classified.is_auth or classified.reason == FailoverReason.billing:
if status_code in (401, 403) or "unauthorized" in error_msg or "forbidden" in error_msg or "permission" in error_msg:
if _provider == "openai-codex" and status_code == 401:
self._vprint(f"{self.log_prefix} 💡 Codex OAuth token was rejected (HTTP 401). Your token may have been", force=True)
self._vprint(f"{self.log_prefix} refreshed by another client (Codex CLI, VS Code). To fix:", force=True)
@@ -8616,7 +8582,6 @@ class AIAgent:
# If the API call was interrupted, skip response processing
if interrupted:
_turn_exit_reason = "interrupted_during_api_call"
break
if restart_with_compressed_messages:
@@ -8636,7 +8601,6 @@ class AIAgent:
# (e.g. repeated context-length errors that exhausted retry_count),
# the `response` variable is still None. Break out cleanly.
if response is None:
_turn_exit_reason = "all_retries_exhausted_no_response"
print(f"{self.log_prefix}❌ All API retries exhausted with no successful response.")
self._persist_session(messages, conversation_history)
break
@@ -9100,7 +9064,6 @@ class AIAgent:
# instead of wasting API calls on retries that won't help.
fallback = getattr(self, '_last_content_with_tools', None)
if fallback:
_turn_exit_reason = "fallback_prior_turn_content"
logger.debug("Empty follow-up after tool calls — using prior turn content as final response")
self._last_content_with_tools = None
self._empty_content_retries = 0
@@ -9167,7 +9130,6 @@ class AIAgent:
# Exhausted prefill attempts, empty retries, or
# structured reasoning with no content —
# fall through to "(empty)" terminal.
_turn_exit_reason = "empty_response_exhausted"
reasoning_text = self._extract_reasoning(assistant_message)
assistant_msg = self._build_assistant_message(assistant_message, finish_reason)
assistant_msg["content"] = "(empty)"
@@ -9185,6 +9147,7 @@ class AIAgent:
# Reset retry counter/signature on successful content
if hasattr(self, '_empty_content_retries'):
self._empty_content_retries = 0
self._last_empty_content_signature = None
self._thinking_prefill_retries = 0
if (
@@ -9238,7 +9201,6 @@ class AIAgent:
messages.append(final_msg)
_turn_exit_reason = f"text_response(finish_reason={finish_reason})"
if not self.quiet_mode:
self._safe_print(f"🎉 Conversation completed after {api_call_count} OpenAI-compatible API call(s)")
break
@@ -9256,6 +9218,7 @@ class AIAgent:
# If an assistant message with tool_calls was already appended,
# the API expects a role="tool" result for every tool_call_id.
# Fill in error results for any that weren't answered yet.
pending_handled = False
for idx in range(len(messages) - 1, -1, -1):
msg = messages[idx]
if not isinstance(msg, dict):
@@ -9287,7 +9250,6 @@ class AIAgent:
# If we're near the limit, break to avoid infinite loops
if api_call_count >= self.max_iterations - 1:
_turn_exit_reason = f"error_near_max_iterations({error_msg[:80]})"
final_response = f"I apologize, but I encountered repeated errors: {error_msg}"
# Append as assistant so the history stays valid for
# session resume (avoids consecutive user messages).
@@ -9298,7 +9260,6 @@ class AIAgent:
api_call_count >= self.max_iterations
or self.iteration_budget.remaining <= 0
):
_turn_exit_reason = f"max_iterations_reached({api_call_count}/{self.max_iterations})"
if self.iteration_budget.remaining <= 0 and not self.quiet_mode:
print(f"\n⚠️ Iteration budget exhausted ({self.iteration_budget.used}/{self.iteration_budget.max_total} iterations used)")
final_response = self._handle_max_iterations(messages, api_call_count)
@@ -9315,49 +9276,6 @@ class AIAgent:
# Persist session to both JSON log and SQLite
self._persist_session(messages, conversation_history)
# ── Turn-exit diagnostic log ─────────────────────────────────────
# Always logged at INFO so agent.log captures WHY every turn ended.
# When the last message is a tool result (agent was mid-work), log
# at WARNING — this is the "just stops" scenario users report.
_last_msg_role = messages[-1].get("role") if messages else None
_last_tool_name = None
if _last_msg_role == "tool":
# Walk back to find the assistant message with the tool call
for _m in reversed(messages):
if _m.get("role") == "assistant" and _m.get("tool_calls"):
_tcs = _m["tool_calls"]
if _tcs and isinstance(_tcs[0], dict):
_last_tool_name = _tcs[-1].get("function", {}).get("name")
break
_turn_tool_count = sum(
1 for m in messages
if isinstance(m, dict) and m.get("role") == "assistant" and m.get("tool_calls")
)
_resp_len = len(final_response) if final_response else 0
_budget_used = self.iteration_budget.used if self.iteration_budget else 0
_budget_max = self.iteration_budget.max_total if self.iteration_budget else 0
_diag_msg = (
"Turn ended: reason=%s model=%s api_calls=%d/%d budget=%d/%d "
"tool_turns=%d last_msg_role=%s response_len=%d session=%s"
)
_diag_args = (
_turn_exit_reason, self.model, api_call_count, self.max_iterations,
_budget_used, _budget_max,
_turn_tool_count, _last_msg_role, _resp_len,
self.session_id or "none",
)
if _last_msg_role == "tool" and not interrupted:
# Agent was mid-work — this is the "just stops" case.
logger.warning(
"Turn ended with pending tool result (agent may appear stuck). "
+ _diag_msg + " last_tool=%s",
*_diag_args, _last_tool_name,
)
else:
logger.info(_diag_msg, *_diag_args)
# Plugin hook: post_llm_call
# Fired once per turn after the tool-calling loop completes.
@@ -249,8 +249,9 @@ Type these during an interactive chat session.
/config Show config (CLI)
/model [name] Show or change model
/provider Show provider info
/prompt [text] View/set system prompt (CLI)
/personality [name] Set personality
/reasoning [level] Set reasoning (none|minimal|low|medium|high|xhigh|show|hide)
/reasoning [level] Set reasoning (none|low|medium|high|xhigh|show|hide)
/verbose Cycle: off → new → all → verbose
/voice [on|off|tts] Voice mode
/yolo Toggle approval bypass
+95 -82
View File
@@ -1,7 +1,7 @@
---
name: google-workspace
description: Gmail, Calendar, Drive, Contacts, Sheets, and Docs integration via gws CLI (googleworkspace/cli). Uses OAuth2 with automatic token refresh via bridge script. Requires gws binary.
version: 2.0.0
description: Gmail, Calendar, Drive, Contacts, Sheets, and Docs integration via Python. Uses OAuth2 with automatic token refresh. No external binaries needed — runs entirely with Google's Python client libraries in the Hermes venv.
version: 1.0.0
author: Nous Research
license: MIT
required_credential_files:
@@ -11,25 +11,14 @@ required_credential_files:
description: Google OAuth2 client credentials (downloaded from Google Cloud Console)
metadata:
hermes:
tags: [Google, Gmail, Calendar, Drive, Sheets, Docs, Contacts, Email, OAuth, gws]
tags: [Google, Gmail, Calendar, Drive, Sheets, Docs, Contacts, Email, OAuth]
homepage: https://github.com/NousResearch/hermes-agent
related_skills: [himalaya]
---
# Google Workspace
Gmail, Calendar, Drive, Contacts, Sheets, and Docs — powered by `gws` (Google's official Rust CLI). The skill provides a backward-compatible Python wrapper that handles OAuth token refresh and delegates to `gws`.
## Architecture
```
google_api.py → gws_bridge.py → gws CLI
(argparse compat) (token refresh) (Google APIs)
```
- `setup.py` handles OAuth2 (headless-compatible, works on CLI/Telegram/Discord)
- `gws_bridge.py` refreshes the Hermes token and injects it into `gws` via `GOOGLE_WORKSPACE_CLI_TOKEN`
- `google_api.py` provides the same CLI interface as v1 but delegates to `gws`
Gmail, Calendar, Drive, Contacts, Sheets, and Docs — all through Python scripts in this skill. No external binaries to install.
## References
@@ -38,22 +27,7 @@ google_api.py → gws_bridge.py → gws CLI
## Scripts
- `scripts/setup.py` — OAuth2 setup (run once to authorize)
- `scripts/gws_bridge.py` — Token refresh bridge to gws CLI
- `scripts/google_api.py` — Backward-compatible API wrapper (delegates to gws)
## Prerequisites
Install `gws`:
```bash
cargo install google-workspace-cli
# or via npm (recommended, downloads prebuilt binary):
npm install -g @googleworkspace/cli
# or via Homebrew:
brew install googleworkspace-cli
```
Verify: `gws --version`
- `scripts/google_api.py` — API wrapper CLI (agent uses this for all operations)
## First-Time Setup
@@ -82,29 +56,42 @@ If it prints `AUTHENTICATED`, skip to Usage — setup is already done.
### Step 1: Triage — ask the user what they need
Before starting OAuth setup, ask the user TWO questions:
**Question 1: "What Google services do you need? Just email, or also
Calendar/Drive/Sheets/Docs?"**
- **Email only** → Use the `himalaya` skill instead — simpler setup.
- **Calendar, Drive, Sheets, Docs (or email + these)** → Continue below.
- **Email only** They don't need this skill at all. Use the `himalaya` skill
instead — it works with a Gmail App Password (Settings → Security → App
Passwords) and takes 2 minutes to set up. No Google Cloud project needed.
Load the himalaya skill and follow its setup instructions.
**Partial scopes**: Users can authorize only a subset of services. The setup
script accepts partial scopes and warns about missing ones.
- **Calendar, Drive, Sheets, Docs (or email + these)** → Continue with this
skill's OAuth setup below.
**Question 2: "Does your Google account use Advanced Protection?"**
**Question 2: "Does your Google account use Advanced Protection (hardware
security keys required to sign in)? If you're not sure, you probably don't
— it's something you would have explicitly enrolled in."**
- **No / Not sure** → Normal setup.
- **Yes** → Workspace admin must add the OAuth client ID to allowed apps first.
- **No / Not sure** → Normal setup. Continue below.
- **Yes** Their Workspace admin must add the OAuth client ID to the org's
allowed apps list before Step 4 will work. Let them know upfront.
### Step 2: Create OAuth credentials (one-time, ~5 minutes)
Tell the user:
> You need a Google Cloud OAuth client. This is a one-time setup:
>
> 1. Go to https://console.cloud.google.com/apis/credentials
> 2. Create a project (or use an existing one)
> 3. Enable the APIs you need (Gmail, Calendar, Drive, Sheets, Docs, People)
> 4. Credentials → Create Credentials → OAuth 2.0 Client ID → Desktop app
> 5. Download JSON and tell me the file path
> 3. Click "Enable APIs" and enable: Gmail API, Google Calendar API,
> Google Drive API, Google Sheets API, Google Docs API, People API
> 4. Go to Credentials → Create Credentials → OAuth 2.0 Client ID
> 5. Application type: "Desktop app" → Create
> 6. Click "Download JSON" and tell me the file path
Once they provide the path:
```bash
$GSETUP --client-secret /path/to/client_secret.json
@@ -116,10 +103,20 @@ $GSETUP --client-secret /path/to/client_secret.json
$GSETUP --auth-url
```
Send the URL to the user. After authorizing, they paste back the redirect URL or code.
This prints a URL. **Send the URL to the user** and tell them:
> Open this link in your browser, sign in with your Google account, and
> authorize access. After authorizing, you'll be redirected to a page that
> may show an error — that's expected. Copy the ENTIRE URL from your
> browser's address bar and paste it back to me.
### Step 4: Exchange the code
The user will paste back either a URL like `http://localhost:1/?code=4/0A...&scope=...`
or just the code string. Either works. The `--auth-url` step stores a temporary
pending OAuth session locally so `--auth-code` can complete the PKCE exchange
later, even on headless systems:
```bash
$GSETUP --auth-code "THE_URL_OR_CODE_THE_USER_PASTED"
```
@@ -130,11 +127,18 @@ $GSETUP --auth-code "THE_URL_OR_CODE_THE_USER_PASTED"
$GSETUP --check
```
Should print `AUTHENTICATED`. Token refreshes automatically from now on.
Should print `AUTHENTICATED`. Setup is complete — token refreshes automatically from now on.
### Notes
- Token is stored at `google_token.json` under the active profile's `HERMES_HOME` and auto-refreshes.
- Pending OAuth session state/verifier are stored temporarily at `google_oauth_pending.json` under the active profile's `HERMES_HOME` until exchange completes.
- Hermes now refuses to overwrite a full Google Workspace token with a narrower re-auth token missing Gmail scopes, so one profile's partial consent cannot silently break email actions later.
- To revoke: `$GSETUP --revoke`
## Usage
All commands go through the API script:
All commands go through the API script. Set `GAPI` as a shorthand:
```bash
HERMES_HOME="${HERMES_HOME:-$HOME/.hermes}"
@@ -149,21 +153,40 @@ GAPI="$PYTHON_BIN $GWORKSPACE_SKILL_DIR/scripts/google_api.py"
### Gmail
```bash
# Search (returns JSON array with id, from, subject, date, snippet)
$GAPI gmail search "is:unread" --max 10
$GAPI gmail search "from:boss@company.com newer_than:1d"
$GAPI gmail search "has:attachment filename:pdf newer_than:7d"
# Read full message (returns JSON with body text)
$GAPI gmail get MESSAGE_ID
# Send
$GAPI gmail send --to user@example.com --subject "Hello" --body "Message text"
$GAPI gmail send --to user@example.com --subject "Report" --body "<h1>Q4</h1>" --html
$GAPI gmail send --to user@example.com --subject "Report" --body "<h1>Q4</h1><p>Details...</p>" --html
# Reply (automatically threads and sets In-Reply-To)
$GAPI gmail reply MESSAGE_ID --body "Thanks, that works for me."
# Labels
$GAPI gmail labels
$GAPI gmail modify MESSAGE_ID --add-labels LABEL_ID
$GAPI gmail modify MESSAGE_ID --remove-labels UNREAD
```
### Calendar
```bash
# List events (defaults to next 7 days)
$GAPI calendar list
$GAPI calendar create --summary "Standup" --start 2026-03-01T10:00:00+01:00 --end 2026-03-01T10:30:00+01:00
$GAPI calendar create --summary "Review" --start ... --end ... --attendees "alice@co.com,bob@co.com"
$GAPI calendar list --start 2026-03-01T00:00:00Z --end 2026-03-07T23:59:59Z
# Create event (ISO 8601 with timezone required)
$GAPI calendar create --summary "Team Standup" --start 2026-03-01T10:00:00-06:00 --end 2026-03-01T10:30:00-06:00
$GAPI calendar create --summary "Lunch" --start 2026-03-01T12:00:00Z --end 2026-03-01T13:00:00Z --location "Cafe"
$GAPI calendar create --summary "Review" --start 2026-03-01T14:00:00Z --end 2026-03-01T15:00:00Z --attendees "alice@co.com,bob@co.com"
# Delete event
$GAPI calendar delete EVENT_ID
```
@@ -183,8 +206,13 @@ $GAPI contacts list --max 20
### Sheets
```bash
# Read
$GAPI sheets get SHEET_ID "Sheet1!A1:D10"
# Write
$GAPI sheets update SHEET_ID "Sheet1!A1:B2" --values '[["Name","Score"],["Alice","95"]]'
# Append rows
$GAPI sheets append SHEET_ID "Sheet1!A:C" --values '[["new","row","data"]]'
```
@@ -194,52 +222,37 @@ $GAPI sheets append SHEET_ID "Sheet1!A:C" --values '[["new","row","data"]]'
$GAPI docs get DOC_ID
```
### Direct gws access (advanced)
For operations not covered by the wrapper, use `gws_bridge.py` directly:
```bash
GBRIDGE="$PYTHON_BIN $GWORKSPACE_SKILL_DIR/scripts/gws_bridge.py"
$GBRIDGE calendar +agenda --today --format table
$GBRIDGE gmail +triage --labels --format json
$GBRIDGE drive +upload ./report.pdf
$GBRIDGE sheets +read --spreadsheet SHEET_ID --range "Sheet1!A1:D10"
```
## Output Format
All commands return JSON via `gws --format json`. Key output shapes:
All commands return JSON. Parse with `jq` or read directly. Key fields:
- **Gmail search/triage**: Array of message summaries (sender, subject, date, snippet)
- **Gmail get/read**: Message object with headers and body text
- **Gmail send/reply**: Confirmation with message ID
- **Calendar list/agenda**: Array of event objects (summary, start, end, location)
- **Calendar create**: Confirmation with event ID and htmlLink
- **Drive search**: Array of file objects (id, name, mimeType, webViewLink)
- **Sheets get/read**: 2D array of cell values
- **Docs get**: Full document JSON (use `body.content` for text extraction)
- **Contacts list**: Array of person objects with names, emails, phones
Parse output with `jq` or read JSON directly.
- **Gmail search**: `[{id, threadId, from, to, subject, date, snippet, labels}]`
- **Gmail get**: `{id, threadId, from, to, subject, date, labels, body}`
- **Gmail send/reply**: `{status: "sent", id, threadId}`
- **Calendar list**: `[{id, summary, start, end, location, description, htmlLink}]`
- **Calendar create**: `{status: "created", id, summary, htmlLink}`
- **Drive search**: `[{id, name, mimeType, modifiedTime, webViewLink}]`
- **Contacts list**: `[{name, emails: [...], phones: [...]}]`
- **Sheets get**: `[[cell, cell, ...], ...]`
## Rules
1. **Never send email or create/delete events without confirming with the user first.**
2. **Check auth before first use** — run `setup.py --check`.
3. **Use the Gmail search syntax reference** for complex queries.
4. **Calendar times must include timezone** — ISO 8601 with offset or UTC.
5. **Respect rate limits** — avoid rapid-fire sequential API calls.
1. **Never send email or create/delete events without confirming with the user first.** Show the draft content and ask for approval.
2. **Check auth before first use** — run `setup.py --check`. If it fails, guide the user through setup.
3. **Use the Gmail search syntax reference** for complex queries — load it with `skill_view("google-workspace", file_path="references/gmail-search-syntax.md")`.
4. **Calendar times must include timezone** always use ISO 8601 with offset (e.g., `2026-03-01T10:00:00-06:00`) or UTC (`Z`).
5. **Respect rate limits** — avoid rapid-fire sequential API calls. Batch reads when possible.
## Troubleshooting
| Problem | Fix |
|---------|-----|
| `NOT_AUTHENTICATED` | Run setup Steps 2-5 |
| `REFRESH_FAILED` | Token revoked — redo Steps 3-5 |
| `gws: command not found` | Install: `npm install -g @googleworkspace/cli` |
| `HttpError 403` | Missing scope — `$GSETUP --revoke` then redo Steps 3-5 |
| `HttpError 403: Access Not Configured` | Enable API in Google Cloud Console |
| Advanced Protection blocks auth | Admin must allowlist the OAuth client ID |
| `NOT_AUTHENTICATED` | Run setup Steps 2-5 above |
| `REFRESH_FAILED` | Token revoked or expired — redo Steps 3-5 |
| `HttpError 403: Insufficient Permission` | Missing API scope — `$GSETUP --revoke` then redo Steps 3-5 |
| `HttpError 403: Access Not Configured` | API not enabled — user needs to enable it in Google Cloud Console |
| `ModuleNotFoundError` | Run `$GSETUP --install-deps` |
| Advanced Protection blocks auth | Workspace admin must allowlist the OAuth client ID |
## Revoking Access
@@ -1,17 +1,16 @@
#!/usr/bin/env python3
"""Google Workspace API CLI for Hermes Agent.
Thin wrapper that delegates to gws (googleworkspace/cli) via gws_bridge.py.
Maintains the same CLI interface for backward compatibility with Hermes skills.
A thin CLI wrapper around Google's Python client libraries.
Authenticates using the token stored by setup.py.
Usage:
python google_api.py gmail search "is:unread" [--max 10]
python google_api.py gmail get MESSAGE_ID
python google_api.py gmail send --to user@example.com --subject "Hi" --body "Hello"
python google_api.py gmail reply MESSAGE_ID --body "Thanks"
python google_api.py calendar list [--start DATE] [--end DATE] [--calendar primary]
python google_api.py calendar list [--from DATE] [--to DATE] [--calendar primary]
python google_api.py calendar create --summary "Meeting" --start DATETIME --end DATETIME
python google_api.py calendar delete EVENT_ID
python google_api.py drive search "budget report" [--max 10]
python google_api.py contacts list [--max 20]
python google_api.py sheets get SHEET_ID RANGE
@@ -21,193 +20,386 @@ Usage:
"""
import argparse
import base64
import json
import os
import subprocess
import sys
from datetime import datetime, timedelta, timezone
from email.mime.text import MIMEText
from pathlib import Path
BRIDGE = Path(__file__).parent / "gws_bridge.py"
PYTHON = sys.executable
try:
from hermes_constants import display_hermes_home, get_hermes_home
except ModuleNotFoundError:
HERMES_AGENT_ROOT = Path(__file__).resolve().parents[4]
if HERMES_AGENT_ROOT.exists():
sys.path.insert(0, str(HERMES_AGENT_ROOT))
from hermes_constants import display_hermes_home, get_hermes_home
HERMES_HOME = get_hermes_home()
TOKEN_PATH = HERMES_HOME / "google_token.json"
SCOPES = [
"https://www.googleapis.com/auth/gmail.readonly",
"https://www.googleapis.com/auth/gmail.send",
"https://www.googleapis.com/auth/gmail.modify",
"https://www.googleapis.com/auth/calendar",
"https://www.googleapis.com/auth/drive.readonly",
"https://www.googleapis.com/auth/contacts.readonly",
"https://www.googleapis.com/auth/spreadsheets",
"https://www.googleapis.com/auth/documents.readonly",
]
def gws(*args: str) -> None:
"""Call gws via the bridge and exit with its return code."""
result = subprocess.run(
[PYTHON, str(BRIDGE)] + list(args),
env={**os.environ, "HERMES_HOME": os.environ.get("HERMES_HOME", str(Path.home() / ".hermes"))},
)
sys.exit(result.returncode)
def _missing_scopes() -> list[str]:
try:
payload = json.loads(TOKEN_PATH.read_text())
except Exception:
return []
raw = payload.get("scopes") or payload.get("scope")
if not raw:
return []
granted = {s.strip() for s in (raw.split() if isinstance(raw, str) else raw) if s.strip()}
return sorted(scope for scope in SCOPES if scope not in granted)
# -- Gmail --
def get_credentials():
"""Load and refresh credentials from token file."""
if not TOKEN_PATH.exists():
print("Not authenticated. Run the setup script first:", file=sys.stderr)
print(f" python {Path(__file__).parent / 'setup.py'}", file=sys.stderr)
sys.exit(1)
from google.oauth2.credentials import Credentials
from google.auth.transport.requests import Request
creds = Credentials.from_authorized_user_file(str(TOKEN_PATH), SCOPES)
if creds.expired and creds.refresh_token:
creds.refresh(Request())
TOKEN_PATH.write_text(creds.to_json())
if not creds.valid:
print("Token is invalid. Re-run setup.", file=sys.stderr)
sys.exit(1)
missing_scopes = _missing_scopes()
if missing_scopes:
print(
"Token is valid but missing Google Workspace scopes required by this skill.",
file=sys.stderr,
)
for scope in missing_scopes:
print(f" - {scope}", file=sys.stderr)
print(
f"Re-run setup.py from the active Hermes profile ({display_hermes_home()}) to restore full access.",
file=sys.stderr,
)
sys.exit(1)
return creds
def build_service(api, version):
from googleapiclient.discovery import build
return build(api, version, credentials=get_credentials())
# =========================================================================
# Gmail
# =========================================================================
def gmail_search(args):
cmd = ["gmail", "+triage", "--query", args.query, "--max", str(args.max), "--format", "json"]
gws(*cmd)
service = build_service("gmail", "v1")
results = service.users().messages().list(
userId="me", q=args.query, maxResults=args.max
).execute()
messages = results.get("messages", [])
if not messages:
print("No messages found.")
return
output = []
for msg_meta in messages:
msg = service.users().messages().get(
userId="me", id=msg_meta["id"], format="metadata",
metadataHeaders=["From", "To", "Subject", "Date"],
).execute()
headers = {h["name"]: h["value"] for h in msg.get("payload", {}).get("headers", [])}
output.append({
"id": msg["id"],
"threadId": msg["threadId"],
"from": headers.get("From", ""),
"to": headers.get("To", ""),
"subject": headers.get("Subject", ""),
"date": headers.get("Date", ""),
"snippet": msg.get("snippet", ""),
"labels": msg.get("labelIds", []),
})
print(json.dumps(output, indent=2, ensure_ascii=False))
def gmail_get(args):
gws("gmail", "+read", "--id", args.message_id, "--headers", "--format", "json")
service = build_service("gmail", "v1")
msg = service.users().messages().get(
userId="me", id=args.message_id, format="full"
).execute()
headers = {h["name"]: h["value"] for h in msg.get("payload", {}).get("headers", [])}
# Extract body text
body = ""
payload = msg.get("payload", {})
if payload.get("body", {}).get("data"):
body = base64.urlsafe_b64decode(payload["body"]["data"]).decode("utf-8", errors="replace")
elif payload.get("parts"):
for part in payload["parts"]:
if part.get("mimeType") == "text/plain" and part.get("body", {}).get("data"):
body = base64.urlsafe_b64decode(part["body"]["data"]).decode("utf-8", errors="replace")
break
if not body:
for part in payload["parts"]:
if part.get("mimeType") == "text/html" and part.get("body", {}).get("data"):
body = base64.urlsafe_b64decode(part["body"]["data"]).decode("utf-8", errors="replace")
break
result = {
"id": msg["id"],
"threadId": msg["threadId"],
"from": headers.get("From", ""),
"to": headers.get("To", ""),
"subject": headers.get("Subject", ""),
"date": headers.get("Date", ""),
"labels": msg.get("labelIds", []),
"body": body,
}
print(json.dumps(result, indent=2, ensure_ascii=False))
def gmail_send(args):
cmd = ["gmail", "+send", "--to", args.to, "--subject", args.subject, "--body", args.body, "--format", "json"]
service = build_service("gmail", "v1")
message = MIMEText(args.body, "html" if args.html else "plain")
message["to"] = args.to
message["subject"] = args.subject
if args.cc:
cmd += ["--cc", args.cc]
if args.html:
cmd.append("--html")
gws(*cmd)
message["cc"] = args.cc
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
body = {"raw": raw}
if args.thread_id:
body["threadId"] = args.thread_id
result = service.users().messages().send(userId="me", body=body).execute()
print(json.dumps({"status": "sent", "id": result["id"], "threadId": result.get("threadId", "")}, indent=2))
def gmail_reply(args):
gws("gmail", "+reply", "--message-id", args.message_id, "--body", args.body, "--format", "json")
service = build_service("gmail", "v1")
# Fetch original to get thread ID and headers
original = service.users().messages().get(
userId="me", id=args.message_id, format="metadata",
metadataHeaders=["From", "Subject", "Message-ID"],
).execute()
headers = {h["name"]: h["value"] for h in original.get("payload", {}).get("headers", [])}
subject = headers.get("Subject", "")
if not subject.startswith("Re:"):
subject = f"Re: {subject}"
message = MIMEText(args.body)
message["to"] = headers.get("From", "")
message["subject"] = subject
if headers.get("Message-ID"):
message["In-Reply-To"] = headers["Message-ID"]
message["References"] = headers["Message-ID"]
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
body = {"raw": raw, "threadId": original["threadId"]}
result = service.users().messages().send(userId="me", body=body).execute()
print(json.dumps({"status": "sent", "id": result["id"], "threadId": result.get("threadId", "")}, indent=2))
def gmail_labels(args):
gws("gmail", "users", "labels", "list", "--params", json.dumps({"userId": "me"}), "--format", "json")
service = build_service("gmail", "v1")
results = service.users().labels().list(userId="me").execute()
labels = [{"id": l["id"], "name": l["name"], "type": l.get("type", "")} for l in results.get("labels", [])]
print(json.dumps(labels, indent=2))
def gmail_modify(args):
service = build_service("gmail", "v1")
body = {}
if args.add_labels:
body["addLabelIds"] = args.add_labels.split(",")
if args.remove_labels:
body["removeLabelIds"] = args.remove_labels.split(",")
gws(
"gmail", "users", "messages", "modify",
"--params", json.dumps({"userId": "me", "id": args.message_id}),
"--json", json.dumps(body),
"--format", "json",
)
result = service.users().messages().modify(userId="me", id=args.message_id, body=body).execute()
print(json.dumps({"id": result["id"], "labels": result.get("labelIds", [])}, indent=2))
# -- Calendar --
# =========================================================================
# Calendar
# =========================================================================
def calendar_list(args):
if args.start or args.end:
# Specific date range — use raw Calendar API for precise timeMin/timeMax
from datetime import datetime, timedelta, timezone as tz
now = datetime.now(tz.utc)
time_min = args.start or now.isoformat()
time_max = args.end or (now + timedelta(days=7)).isoformat()
gws(
"calendar", "events", "list",
"--params", json.dumps({
"calendarId": args.calendar,
"timeMin": time_min,
"timeMax": time_max,
"maxResults": args.max,
"singleEvents": True,
"orderBy": "startTime",
}),
"--format", "json",
)
else:
# No date range — use +agenda helper (defaults to 7 days)
cmd = ["calendar", "+agenda", "--days", "7", "--format", "json"]
if args.calendar != "primary":
cmd += ["--calendar", args.calendar]
gws(*cmd)
service = build_service("calendar", "v3")
now = datetime.now(timezone.utc)
time_min = args.start or now.isoformat()
time_max = args.end or (now + timedelta(days=7)).isoformat()
# Ensure timezone info
for val in [time_min, time_max]:
if "T" in val and "Z" not in val and "+" not in val and "-" not in val[11:]:
val += "Z"
results = service.events().list(
calendarId=args.calendar, timeMin=time_min, timeMax=time_max,
maxResults=args.max, singleEvents=True, orderBy="startTime",
).execute()
events = []
for e in results.get("items", []):
events.append({
"id": e["id"],
"summary": e.get("summary", "(no title)"),
"start": e.get("start", {}).get("dateTime", e.get("start", {}).get("date", "")),
"end": e.get("end", {}).get("dateTime", e.get("end", {}).get("date", "")),
"location": e.get("location", ""),
"description": e.get("description", ""),
"status": e.get("status", ""),
"htmlLink": e.get("htmlLink", ""),
})
print(json.dumps(events, indent=2, ensure_ascii=False))
def calendar_create(args):
cmd = [
"calendar", "+insert",
"--summary", args.summary,
"--start", args.start,
"--end", args.end,
"--format", "json",
]
service = build_service("calendar", "v3")
event = {
"summary": args.summary,
"start": {"dateTime": args.start},
"end": {"dateTime": args.end},
}
if args.location:
cmd += ["--location", args.location]
event["location"] = args.location
if args.description:
cmd += ["--description", args.description]
event["description"] = args.description
if args.attendees:
for email in args.attendees.split(","):
cmd += ["--attendee", email.strip()]
if args.calendar != "primary":
cmd += ["--calendar", args.calendar]
gws(*cmd)
event["attendees"] = [{"email": e.strip()} for e in args.attendees.split(",")]
result = service.events().insert(calendarId=args.calendar, body=event).execute()
print(json.dumps({
"status": "created",
"id": result["id"],
"summary": result.get("summary", ""),
"htmlLink": result.get("htmlLink", ""),
}, indent=2))
def calendar_delete(args):
gws(
"calendar", "events", "delete",
"--params", json.dumps({"calendarId": args.calendar, "eventId": args.event_id}),
"--format", "json",
)
service = build_service("calendar", "v3")
service.events().delete(calendarId=args.calendar, eventId=args.event_id).execute()
print(json.dumps({"status": "deleted", "eventId": args.event_id}))
# -- Drive --
# =========================================================================
# Drive
# =========================================================================
def drive_search(args):
query = args.query if args.raw_query else f"fullText contains '{args.query}'"
gws(
"drive", "files", "list",
"--params", json.dumps({
"q": query,
"pageSize": args.max,
"fields": "files(id,name,mimeType,modifiedTime,webViewLink)",
}),
"--format", "json",
)
service = build_service("drive", "v3")
query = f"fullText contains '{args.query}'" if not args.raw_query else args.query
results = service.files().list(
q=query, pageSize=args.max, fields="files(id, name, mimeType, modifiedTime, webViewLink)",
).execute()
files = results.get("files", [])
print(json.dumps(files, indent=2, ensure_ascii=False))
# -- Contacts --
# =========================================================================
# Contacts
# =========================================================================
def contacts_list(args):
gws(
"people", "people", "connections", "list",
"--params", json.dumps({
"resourceName": "people/me",
"pageSize": args.max,
"personFields": "names,emailAddresses,phoneNumbers",
}),
"--format", "json",
)
service = build_service("people", "v1")
results = service.people().connections().list(
resourceName="people/me",
pageSize=args.max,
personFields="names,emailAddresses,phoneNumbers",
).execute()
contacts = []
for person in results.get("connections", []):
names = person.get("names", [{}])
emails = person.get("emailAddresses", [])
phones = person.get("phoneNumbers", [])
contacts.append({
"name": names[0].get("displayName", "") if names else "",
"emails": [e.get("value", "") for e in emails],
"phones": [p.get("value", "") for p in phones],
})
print(json.dumps(contacts, indent=2, ensure_ascii=False))
# -- Sheets --
# =========================================================================
# Sheets
# =========================================================================
def sheets_get(args):
gws(
"sheets", "+read",
"--spreadsheet", args.sheet_id,
"--range", args.range,
"--format", "json",
)
service = build_service("sheets", "v4")
result = service.spreadsheets().values().get(
spreadsheetId=args.sheet_id, range=args.range,
).execute()
print(json.dumps(result.get("values", []), indent=2, ensure_ascii=False))
def sheets_update(args):
service = build_service("sheets", "v4")
values = json.loads(args.values)
gws(
"sheets", "spreadsheets", "values", "update",
"--params", json.dumps({
"spreadsheetId": args.sheet_id,
"range": args.range,
"valueInputOption": "USER_ENTERED",
}),
"--json", json.dumps({"values": values}),
"--format", "json",
)
body = {"values": values}
result = service.spreadsheets().values().update(
spreadsheetId=args.sheet_id, range=args.range,
valueInputOption="USER_ENTERED", body=body,
).execute()
print(json.dumps({"updatedCells": result.get("updatedCells", 0), "updatedRange": result.get("updatedRange", "")}, indent=2))
def sheets_append(args):
service = build_service("sheets", "v4")
values = json.loads(args.values)
gws(
"sheets", "+append",
"--spreadsheet", args.sheet_id,
"--json-values", json.dumps(values),
"--format", "json",
)
body = {"values": values}
result = service.spreadsheets().values().append(
spreadsheetId=args.sheet_id, range=args.range,
valueInputOption="USER_ENTERED", insertDataOption="INSERT_ROWS", body=body,
).execute()
print(json.dumps({"updatedCells": result.get("updates", {}).get("updatedCells", 0)}, indent=2))
# -- Docs --
# =========================================================================
# Docs
# =========================================================================
def docs_get(args):
gws(
"docs", "documents", "get",
"--params", json.dumps({"documentId": args.doc_id}),
"--format", "json",
)
service = build_service("docs", "v1")
doc = service.documents().get(documentId=args.doc_id).execute()
# Extract plain text from the document structure
text_parts = []
for element in doc.get("body", {}).get("content", []):
paragraph = element.get("paragraph", {})
for pe in paragraph.get("elements", []):
text_run = pe.get("textRun", {})
if text_run.get("content"):
text_parts.append(text_run["content"])
result = {
"title": doc.get("title", ""),
"documentId": doc.get("documentId", ""),
"body": "".join(text_parts),
}
print(json.dumps(result, indent=2, ensure_ascii=False))
# -- CLI parser (backward-compatible interface) --
# =========================================================================
# CLI parser
# =========================================================================
def main():
parser = argparse.ArgumentParser(description="Google Workspace API for Hermes Agent (gws backend)")
parser = argparse.ArgumentParser(description="Google Workspace API for Hermes Agent")
sub = parser.add_subparsers(dest="service", required=True)
# --- Gmail ---
@@ -229,7 +421,7 @@ def main():
p.add_argument("--body", required=True)
p.add_argument("--cc", default="")
p.add_argument("--html", action="store_true", help="Send body as HTML")
p.add_argument("--thread-id", default="", help="Thread ID (unused with gws, kept for compat)")
p.add_argument("--thread-id", default="", help="Thread ID for threading")
p.set_defaults(func=gmail_send)
p = gmail_sub.add_parser("reply")
@@ -1,89 +0,0 @@
#!/usr/bin/env python3
"""Bridge between Hermes OAuth token and gws CLI.
Refreshes the token if expired, then executes gws with the valid access token.
"""
import json
import os
import subprocess
import sys
from datetime import datetime, timezone
from pathlib import Path
def get_hermes_home() -> Path:
return Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
def get_token_path() -> Path:
return get_hermes_home() / "google_token.json"
def refresh_token(token_data: dict) -> dict:
"""Refresh the access token using the refresh token."""
import urllib.error
import urllib.parse
import urllib.request
params = urllib.parse.urlencode({
"client_id": token_data["client_id"],
"client_secret": token_data["client_secret"],
"refresh_token": token_data["refresh_token"],
"grant_type": "refresh_token",
}).encode()
req = urllib.request.Request(token_data["token_uri"], data=params)
try:
with urllib.request.urlopen(req) as resp:
result = json.loads(resp.read())
except urllib.error.HTTPError as e:
body = e.read().decode("utf-8", errors="replace")
print(f"ERROR: Token refresh failed (HTTP {e.code}): {body}", file=sys.stderr)
print("Re-run setup.py to re-authenticate.", file=sys.stderr)
sys.exit(1)
token_data["token"] = result["access_token"]
token_data["expiry"] = datetime.fromtimestamp(
datetime.now(timezone.utc).timestamp() + result["expires_in"],
tz=timezone.utc,
).isoformat()
get_token_path().write_text(json.dumps(token_data, indent=2))
return token_data
def get_valid_token() -> str:
"""Return a valid access token, refreshing if needed."""
token_path = get_token_path()
if not token_path.exists():
print("ERROR: No Google token found. Run setup.py --auth-url first.", file=sys.stderr)
sys.exit(1)
token_data = json.loads(token_path.read_text())
expiry = token_data.get("expiry", "")
if expiry:
exp_dt = datetime.fromisoformat(expiry.replace("Z", "+00:00"))
now = datetime.now(timezone.utc)
if now >= exp_dt:
token_data = refresh_token(token_data)
return token_data["token"]
def main():
"""Refresh token if needed, then exec gws with remaining args."""
if len(sys.argv) < 2:
print("Usage: gws_bridge.py <gws args...>", file=sys.stderr)
sys.exit(1)
access_token = get_valid_token()
env = os.environ.copy()
env["GOOGLE_WORKSPACE_CLI_TOKEN"] = access_token
result = subprocess.run(["gws"] + sys.argv[1:], env=env)
sys.exit(result.returncode)
if __name__ == "__main__":
main()
@@ -23,7 +23,6 @@ Agent workflow:
import argparse
import json
import os
import subprocess
import sys
from pathlib import Path
@@ -129,11 +128,7 @@ def check_auth():
from google.auth.transport.requests import Request
try:
# Don't pass scopes — user may have authorized only a subset.
# Passing scopes forces google-auth to validate them on refresh,
# which fails with invalid_scope if the token has fewer scopes
# than requested.
creds = Credentials.from_authorized_user_file(str(TOKEN_PATH))
creds = Credentials.from_authorized_user_file(str(TOKEN_PATH), SCOPES)
except Exception as e:
print(f"TOKEN_CORRUPT: {e}")
return False
@@ -142,9 +137,8 @@ def check_auth():
if creds.valid:
missing_scopes = _missing_scopes_from_payload(payload)
if missing_scopes:
print(f"AUTHENTICATED (partial): Token valid but missing {len(missing_scopes)} scopes:")
for s in missing_scopes:
print(f" - {s}")
print(f"AUTH_SCOPE_MISMATCH: {_format_missing_scopes(missing_scopes)}")
return False
print(f"AUTHENTICATED: Token valid at {TOKEN_PATH}")
return True
@@ -154,9 +148,8 @@ def check_auth():
TOKEN_PATH.write_text(creds.to_json())
missing_scopes = _missing_scopes_from_payload(_load_token_payload(TOKEN_PATH))
if missing_scopes:
print(f"AUTHENTICATED (partial): Token refreshed but missing {len(missing_scopes)} scopes:")
for s in missing_scopes:
print(f" - {s}")
print(f"AUTH_SCOPE_MISMATCH: {_format_missing_scopes(missing_scopes)}")
return False
print(f"AUTHENTICATED: Token refreshed at {TOKEN_PATH}")
return True
except Exception as e:
@@ -279,33 +272,16 @@ def exchange_auth_code(code: str):
_ensure_deps()
from google_auth_oauthlib.flow import Flow
from urllib.parse import parse_qs, urlparse
# Extract granted scopes from the callback URL if present
if returned_state and "scope" in parse_qs(urlparse(code).query if isinstance(code, str) and code.startswith("http") else {}):
granted_scopes = parse_qs(urlparse(code).query)["scope"][0].split()
else:
# Try to extract from code_or_url parameter
if isinstance(code, str) and code.startswith("http"):
params = parse_qs(urlparse(code).query)
if "scope" in params:
granted_scopes = params["scope"][0].split()
else:
granted_scopes = SCOPES
else:
granted_scopes = SCOPES
flow = Flow.from_client_secrets_file(
str(CLIENT_SECRET_PATH),
scopes=granted_scopes,
scopes=SCOPES,
redirect_uri=pending_auth.get("redirect_uri", REDIRECT_URI),
state=pending_auth["state"],
code_verifier=pending_auth["code_verifier"],
)
try:
# Accept partial scopes — user may deselect some permissions in the consent screen
os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1"
flow.fetch_token(code=code)
except Exception as e:
print(f"ERROR: Token exchange failed: {e}")
@@ -314,21 +290,11 @@ def exchange_auth_code(code: str):
creds = flow.credentials
token_payload = json.loads(creds.to_json())
# Store only the scopes actually granted by the user, not what was requested.
# creds.to_json() writes the requested scopes, which causes refresh to fail
# with invalid_scope if the user only authorized a subset.
actually_granted = list(creds.granted_scopes or []) if hasattr(creds, "granted_scopes") and creds.granted_scopes else []
if actually_granted:
token_payload["scopes"] = actually_granted
elif granted_scopes != SCOPES:
# granted_scopes was extracted from the callback URL
token_payload["scopes"] = granted_scopes
missing_scopes = _missing_scopes_from_payload(token_payload)
if missing_scopes:
print(f"WARNING: Token missing some Google Workspace scopes: {', '.join(missing_scopes)}")
print("Some services may not be available.")
print(f"ERROR: Refusing to save incomplete Google Workspace token. {_format_missing_scopes(missing_scopes)}")
print(f"Existing token at {TOKEN_PATH} was left unchanged.")
sys.exit(1)
TOKEN_PATH.write_text(json.dumps(token_payload, indent=2))
PENDING_AUTH_PATH.unlink(missing_ok=True)
+10
View File
@@ -17,6 +17,7 @@ from agent.anthropic_adapter import (
build_anthropic_kwargs,
convert_messages_to_anthropic,
convert_tools_to_anthropic,
get_anthropic_token_source,
is_claude_code_token_valid,
normalize_anthropic_response,
normalize_model_name,
@@ -164,6 +165,15 @@ class TestResolveAnthropicToken:
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
assert resolve_anthropic_token() == "sk-ant-oat01-mytoken"
def test_reports_claude_json_primary_key_source(self, monkeypatch, tmp_path):
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
(tmp_path / ".claude.json").write_text(json.dumps({"primaryApiKey": "sk-ant-api03-primary"}))
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
assert get_anthropic_token_source("sk-ant-api03-primary") == "claude_json_primary_api_key"
def test_does_not_resolve_primary_api_key_as_native_anthropic_token(self, monkeypatch, tmp_path):
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
+226
View File
@@ -9,6 +9,7 @@ import pytest
from agent.auxiliary_client import (
get_text_auxiliary_client,
get_vision_auxiliary_client,
get_available_vision_backends,
resolve_vision_provider_client,
resolve_provider_client,
@@ -19,6 +20,7 @@ from agent.auxiliary_client import (
_get_provider_chain,
_is_payment_error,
_try_payment_fallback,
_resolve_forced_provider,
_resolve_auto,
)
@@ -662,6 +664,15 @@ class TestGetTextAuxiliaryClient:
class TestVisionClientFallback:
"""Vision client auto mode resolves known-good multimodal backends."""
def test_vision_returns_none_without_any_credentials(self):
with (
patch("agent.auxiliary_client._read_nous_auth", return_value=None),
patch("agent.auxiliary_client._try_anthropic", return_value=(None, None)),
):
client, model = get_vision_auxiliary_client()
assert client is None
assert model is None
def test_vision_auto_includes_active_provider_when_configured(self, monkeypatch):
"""Active provider appears in available backends when credentials exist."""
monkeypatch.setenv("ANTHROPIC_API_KEY", "***")
@@ -743,6 +754,21 @@ class TestAuxiliaryPoolAwareness:
assert call_kwargs["base_url"] == "https://api.githubcopilot.com"
assert call_kwargs["default_headers"]["Editor-Version"]
def test_vision_auto_uses_active_provider_as_fallback(self, monkeypatch):
"""When no OpenRouter/Nous available, vision auto falls back to active provider."""
monkeypatch.setenv("ANTHROPIC_API_KEY", "***")
with (
patch("agent.auxiliary_client._read_nous_auth", return_value=None),
patch("agent.auxiliary_client._read_main_provider", return_value="anthropic"),
patch("agent.auxiliary_client._read_main_model", return_value="claude-sonnet-4"),
patch("agent.anthropic_adapter.build_anthropic_client", return_value=MagicMock()),
patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="***"),
):
client, model = get_vision_auxiliary_client()
assert client is not None
assert client.__class__.__name__ == "AnthropicAuxiliaryClient"
def test_vision_auto_prefers_active_provider_over_openrouter(self, monkeypatch):
"""Active provider is tried before OpenRouter in vision auto."""
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
@@ -774,6 +800,43 @@ class TestAuxiliaryPoolAwareness:
assert client is not None
assert provider == "custom:local"
def test_vision_direct_endpoint_override(self, monkeypatch):
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
monkeypatch.setenv("AUXILIARY_VISION_BASE_URL", "http://localhost:4567/v1")
monkeypatch.setenv("AUXILIARY_VISION_API_KEY", "vision-key")
monkeypatch.setenv("AUXILIARY_VISION_MODEL", "vision-model")
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
client, model = get_vision_auxiliary_client()
assert model == "vision-model"
assert mock_openai.call_args.kwargs["base_url"] == "http://localhost:4567/v1"
assert mock_openai.call_args.kwargs["api_key"] == "vision-key"
def test_vision_direct_endpoint_without_key_uses_placeholder(self, monkeypatch):
"""Vision endpoint without API key should use 'no-key-required' placeholder."""
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
monkeypatch.setenv("AUXILIARY_VISION_BASE_URL", "http://localhost:4567/v1")
monkeypatch.setenv("AUXILIARY_VISION_MODEL", "vision-model")
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
client, model = get_vision_auxiliary_client()
assert client is not None
assert model == "vision-model"
assert mock_openai.call_args.kwargs["api_key"] == "no-key-required"
def test_vision_uses_openrouter_when_available(self, monkeypatch):
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
client, model = get_vision_auxiliary_client()
assert model == "google/gemini-3-flash-preview"
assert client is not None
def test_vision_uses_nous_when_available(self, monkeypatch):
with patch("agent.auxiliary_client._read_nous_auth") as mock_nous, \
patch("agent.auxiliary_client.OpenAI"):
mock_nous.return_value = {"access_token": "nous-tok"}
client, model = get_vision_auxiliary_client()
assert model == "google/gemini-3-flash-preview"
assert client is not None
def test_vision_config_google_provider_uses_gemini_credentials(self, monkeypatch):
config = {
"auxiliary": {
@@ -799,6 +862,53 @@ class TestAuxiliaryPoolAwareness:
assert mock_openai.call_args.kwargs["api_key"] == "gemini-key"
assert mock_openai.call_args.kwargs["base_url"] == "https://generativelanguage.googleapis.com/v1beta/openai"
def test_vision_forced_main_uses_custom_endpoint(self, monkeypatch):
"""When explicitly forced to 'main', vision CAN use custom endpoint."""
config = {
"model": {
"provider": "custom",
"base_url": "http://localhost:1234/v1",
"default": "my-local-model",
}
}
monkeypatch.setenv("AUXILIARY_VISION_PROVIDER", "main")
monkeypatch.setenv("OPENAI_API_KEY", "local-key")
monkeypatch.setattr("hermes_cli.config.load_config", lambda: config)
monkeypatch.setattr("hermes_cli.runtime_provider.load_config", lambda: config)
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
patch("agent.auxiliary_client.OpenAI") as mock_openai:
client, model = get_vision_auxiliary_client()
assert client is not None
assert model == "my-local-model"
def test_vision_forced_main_returns_none_without_creds(self, monkeypatch):
"""Forced main with no credentials still returns None."""
monkeypatch.setenv("AUXILIARY_VISION_PROVIDER", "main")
monkeypatch.delenv("OPENAI_BASE_URL", raising=False)
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
# Clear client cache to avoid stale entries from previous tests
from agent.auxiliary_client import _client_cache
_client_cache.clear()
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
patch("agent.auxiliary_client._read_main_provider", return_value=""), \
patch("agent.auxiliary_client._read_main_model", return_value=""), \
patch("agent.auxiliary_client._select_pool_entry", return_value=(False, None)), \
patch("agent.auxiliary_client._resolve_custom_runtime", return_value=(None, None)), \
patch("agent.auxiliary_client._read_codex_access_token", return_value=None), \
patch("agent.auxiliary_client._resolve_api_key_provider", return_value=(None, None)):
client, model = get_vision_auxiliary_client()
assert client is None
assert model is None
def test_vision_forced_codex(self, monkeypatch, codex_auth_dir):
"""When forced to 'codex', vision uses Codex OAuth."""
monkeypatch.setenv("AUXILIARY_VISION_PROVIDER", "codex")
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
patch("agent.auxiliary_client.OpenAI"):
client, model = get_vision_auxiliary_client()
from agent.auxiliary_client import CodexAuxiliaryClient
assert isinstance(client, CodexAuxiliaryClient)
assert model == "gpt-5.2-codex"
class TestGetAuxiliaryProvider:
@@ -838,6 +948,122 @@ class TestGetAuxiliaryProvider:
assert _get_auxiliary_provider("web_extract") == "main"
class TestResolveForcedProvider:
"""Tests for _resolve_forced_provider with explicit provider selection."""
def test_forced_openrouter(self, monkeypatch):
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
client, model = _resolve_forced_provider("openrouter")
assert model == "google/gemini-3-flash-preview"
assert client is not None
def test_forced_openrouter_no_key(self, monkeypatch):
with patch("agent.auxiliary_client._read_nous_auth", return_value=None):
client, model = _resolve_forced_provider("openrouter")
assert client is None
assert model is None
def test_forced_nous(self, monkeypatch):
with patch("agent.auxiliary_client._read_nous_auth") as mock_nous, \
patch("agent.auxiliary_client.OpenAI"):
mock_nous.return_value = {"access_token": "nous-tok"}
client, model = _resolve_forced_provider("nous")
assert model == "google/gemini-3-flash-preview"
assert client is not None
def test_forced_nous_not_configured(self, monkeypatch):
with patch("agent.auxiliary_client._read_nous_auth", return_value=None):
client, model = _resolve_forced_provider("nous")
assert client is None
assert model is None
def test_forced_main_uses_custom(self, monkeypatch):
config = {
"model": {
"provider": "custom",
"base_url": "http://local:8080/v1",
"default": "my-local-model",
}
}
monkeypatch.setenv("OPENAI_API_KEY", "local-key")
monkeypatch.setattr("hermes_cli.config.load_config", lambda: config)
monkeypatch.setattr("hermes_cli.runtime_provider.load_config", lambda: config)
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
patch("agent.auxiliary_client.OpenAI") as mock_openai:
client, model = _resolve_forced_provider("main")
assert model == "my-local-model"
def test_forced_main_uses_config_saved_custom_endpoint(self, monkeypatch):
config = {
"model": {
"provider": "custom",
"base_url": "http://local:8080/v1",
"default": "my-local-model",
}
}
monkeypatch.setenv("OPENAI_API_KEY", "local-key")
monkeypatch.setattr("hermes_cli.config.load_config", lambda: config)
monkeypatch.setattr("hermes_cli.runtime_provider.load_config", lambda: config)
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
patch("agent.auxiliary_client._read_codex_access_token", return_value=None), \
patch("agent.auxiliary_client._resolve_api_key_provider", return_value=(None, None)), \
patch("agent.auxiliary_client.OpenAI") as mock_openai:
client, model = _resolve_forced_provider("main")
assert client is not None
assert model == "my-local-model"
call_kwargs = mock_openai.call_args
assert call_kwargs.kwargs["base_url"] == "http://local:8080/v1"
def test_forced_main_skips_openrouter_nous(self, monkeypatch):
"""Even if OpenRouter key is set, 'main' skips it."""
config = {
"model": {
"provider": "custom",
"base_url": "http://local:8080/v1",
"default": "my-local-model",
}
}
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
monkeypatch.setenv("OPENAI_API_KEY", "local-key")
monkeypatch.setattr("hermes_cli.config.load_config", lambda: config)
monkeypatch.setattr("hermes_cli.runtime_provider.load_config", lambda: config)
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
patch("agent.auxiliary_client.OpenAI") as mock_openai:
client, model = _resolve_forced_provider("main")
# Should use custom endpoint, not OpenRouter
assert model == "my-local-model"
def test_forced_main_falls_to_codex(self, codex_auth_dir, monkeypatch):
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
patch("agent.auxiliary_client.OpenAI"):
client, model = _resolve_forced_provider("main")
from agent.auxiliary_client import CodexAuxiliaryClient
assert isinstance(client, CodexAuxiliaryClient)
assert model == "gpt-5.2-codex"
def test_forced_codex(self, codex_auth_dir, monkeypatch):
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
patch("agent.auxiliary_client.OpenAI"):
client, model = _resolve_forced_provider("codex")
from agent.auxiliary_client import CodexAuxiliaryClient
assert isinstance(client, CodexAuxiliaryClient)
assert model == "gpt-5.2-codex"
def test_forced_codex_no_token(self, monkeypatch):
with patch("agent.auxiliary_client._read_codex_access_token", return_value=None):
client, model = _resolve_forced_provider("codex")
assert client is None
assert model is None
def test_forced_unknown_returns_none(self, monkeypatch):
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
patch("agent.auxiliary_client._read_codex_access_token", return_value=None):
client, model = _resolve_forced_provider("invalid-provider")
assert client is None
assert model is None
class TestTaskSpecificOverrides:
"""Integration tests for per-task provider routing via get_text_auxiliary_client(task=...)."""
+25
View File
@@ -38,6 +38,16 @@ class TestShouldCompress:
assert compressor.should_compress(prompt_tokens=50000) is False
class TestShouldCompressPreflight:
def test_short_messages(self, compressor):
msgs = [{"role": "user", "content": "short"}]
assert compressor.should_compress_preflight(msgs) is False
def test_long_messages(self, compressor):
# Each message ~100k chars / 4 = 25k tokens, need >85k threshold
msgs = [{"role": "user", "content": "x" * 400000}]
assert compressor.should_compress_preflight(msgs) is True
class TestUpdateFromResponse:
def test_updates_fields(self, compressor):
@@ -48,12 +58,27 @@ class TestUpdateFromResponse:
})
assert compressor.last_prompt_tokens == 5000
assert compressor.last_completion_tokens == 1000
assert compressor.last_total_tokens == 6000
def test_missing_fields_default_zero(self, compressor):
compressor.update_from_response({})
assert compressor.last_prompt_tokens == 0
class TestGetStatus:
def test_returns_expected_keys(self, compressor):
status = compressor.get_status()
assert "last_prompt_tokens" in status
assert "threshold_tokens" in status
assert "context_length" in status
assert "usage_percent" in status
assert "compression_count" in status
def test_usage_percent_calculation(self, compressor):
compressor.last_prompt_tokens = 50000
status = compressor.get_status()
assert status["usage_percent"] == 50.0
class TestCompress:
def _make_messages(self, n):
-36
View File
@@ -214,42 +214,6 @@ def test_exhausted_entry_resets_after_ttl(tmp_path, monkeypatch):
assert entry.last_status == "ok"
def test_exhausted_402_entry_resets_after_one_hour(tmp_path, monkeypatch):
"""402-exhausted credentials recover after 1 hour, not 24."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
_write_auth_store(
tmp_path,
{
"version": 1,
"credential_pool": {
"openrouter": [
{
"id": "cred-1",
"label": "primary",
"auth_type": "api_key",
"priority": 0,
"source": "manual",
"access_token": "***",
"base_url": "https://openrouter.ai/api/v1",
"last_status": "exhausted",
"last_status_at": time.time() - 3700, # ~1h2m ago
"last_error_code": 402,
}
]
},
},
)
from agent.credential_pool import load_pool
pool = load_pool("openrouter")
entry = pool.select()
assert entry is not None
assert entry.id == "cred-1"
assert entry.last_status == "ok"
def test_explicit_reset_timestamp_overrides_default_429_ttl(tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
_write_auth_store(
-782
View File
@@ -1,782 +0,0 @@
"""Tests for agent.error_classifier — structured API error classification."""
import pytest
from agent.error_classifier import (
ClassifiedError,
FailoverReason,
classify_api_error,
_extract_status_code,
_extract_error_body,
_extract_error_code,
_classify_402,
)
# ── Helper: mock API errors ────────────────────────────────────────────
class MockAPIError(Exception):
"""Simulates an OpenAI SDK APIStatusError."""
def __init__(self, message, status_code=None, body=None):
super().__init__(message)
self.status_code = status_code
self.body = body or {}
class MockTransportError(Exception):
"""Simulates a transport-level error with a specific type name."""
pass
class ReadTimeout(MockTransportError):
pass
class ConnectError(MockTransportError):
pass
class RemoteProtocolError(MockTransportError):
pass
class ServerDisconnectedError(MockTransportError):
pass
# ── Test: FailoverReason enum ──────────────────────────────────────────
class TestFailoverReason:
def test_all_reasons_have_string_values(self):
for reason in FailoverReason:
assert isinstance(reason.value, str)
def test_enum_members_exist(self):
expected = {
"auth", "auth_permanent", "billing", "rate_limit",
"overloaded", "server_error", "timeout",
"context_overflow", "payload_too_large",
"model_not_found", "format_error",
"thinking_signature", "long_context_tier", "unknown",
}
actual = {r.value for r in FailoverReason}
assert expected == actual
# ── Test: ClassifiedError ──────────────────────────────────────────────
class TestClassifiedError:
def test_is_auth_property(self):
e1 = ClassifiedError(reason=FailoverReason.auth)
assert e1.is_auth is True
e2 = ClassifiedError(reason=FailoverReason.auth_permanent)
assert e2.is_auth is True
e3 = ClassifiedError(reason=FailoverReason.billing)
assert e3.is_auth is False
def test_is_transient_property(self):
transient_reasons = [
FailoverReason.rate_limit,
FailoverReason.overloaded,
FailoverReason.server_error,
FailoverReason.timeout,
FailoverReason.unknown,
]
for reason in transient_reasons:
e = ClassifiedError(reason=reason)
assert e.is_transient is True, f"{reason} should be transient"
non_transient = [
FailoverReason.auth,
FailoverReason.billing,
FailoverReason.model_not_found,
FailoverReason.format_error,
]
for reason in non_transient:
e = ClassifiedError(reason=reason)
assert e.is_transient is False, f"{reason} should NOT be transient"
def test_defaults(self):
e = ClassifiedError(reason=FailoverReason.unknown)
assert e.retryable is True
assert e.should_compress is False
assert e.should_rotate_credential is False
assert e.should_fallback is False
assert e.status_code is None
assert e.message == ""
# ── Test: Status code extraction ───────────────────────────────────────
class TestExtractStatusCode:
def test_from_status_code_attr(self):
e = MockAPIError("fail", status_code=429)
assert _extract_status_code(e) == 429
def test_from_status_attr(self):
class ErrWithStatus(Exception):
status = 503
assert _extract_status_code(ErrWithStatus()) == 503
def test_from_cause_chain(self):
inner = MockAPIError("inner", status_code=401)
outer = Exception("outer")
outer.__cause__ = inner
assert _extract_status_code(outer) == 401
def test_none_when_missing(self):
assert _extract_status_code(Exception("generic")) is None
def test_rejects_non_http_status(self):
"""Integers outside 100-599 on .status should be ignored."""
class ErrWeirdStatus(Exception):
status = 42
assert _extract_status_code(ErrWeirdStatus()) is None
# ── Test: Error body extraction ────────────────────────────────────────
class TestExtractErrorBody:
def test_from_body_attr(self):
e = MockAPIError("fail", body={"error": {"message": "bad"}})
assert _extract_error_body(e) == {"error": {"message": "bad"}}
def test_empty_when_no_body(self):
assert _extract_error_body(Exception("generic")) == {}
# ── Test: Error code extraction ────────────────────────────────────────
class TestExtractErrorCode:
def test_from_nested_error_code(self):
body = {"error": {"code": "rate_limit_exceeded"}}
assert _extract_error_code(body) == "rate_limit_exceeded"
def test_from_nested_error_type(self):
body = {"error": {"type": "invalid_request_error"}}
assert _extract_error_code(body) == "invalid_request_error"
def test_from_top_level_code(self):
body = {"code": "model_not_found"}
assert _extract_error_code(body) == "model_not_found"
def test_empty_when_no_code(self):
assert _extract_error_code({}) == ""
assert _extract_error_code({"error": {"message": "oops"}}) == ""
# ── Test: 402 disambiguation ───────────────────────────────────────────
class TestClassify402:
"""The critical 402 billing vs rate_limit disambiguation."""
def test_billing_exhaustion(self):
"""Plain 402 = billing."""
result = _classify_402(
"payment required",
lambda reason, **kw: ClassifiedError(reason=reason, **kw),
)
assert result.reason == FailoverReason.billing
assert result.should_rotate_credential is True
def test_transient_usage_limit(self):
"""402 with 'usage limit' + 'try again' = rate limit, not billing."""
result = _classify_402(
"usage limit exceeded. try again in 5 minutes",
lambda reason, **kw: ClassifiedError(reason=reason, **kw),
)
assert result.reason == FailoverReason.rate_limit
assert result.should_rotate_credential is True
def test_quota_with_retry(self):
"""402 with 'quota' + 'retry' = rate limit."""
result = _classify_402(
"quota exceeded, please retry after the window resets",
lambda reason, **kw: ClassifiedError(reason=reason, **kw),
)
assert result.reason == FailoverReason.rate_limit
def test_quota_without_retry(self):
"""402 with just 'quota' but no transient signal = billing."""
result = _classify_402(
"quota exceeded",
lambda reason, **kw: ClassifiedError(reason=reason, **kw),
)
assert result.reason == FailoverReason.billing
def test_insufficient_credits(self):
result = _classify_402(
"insufficient credits to complete request",
lambda reason, **kw: ClassifiedError(reason=reason, **kw),
)
assert result.reason == FailoverReason.billing
# ── Test: Full classification pipeline ─────────────────────────────────
class TestClassifyApiError:
"""End-to-end classification tests."""
# ── Auth errors ──
def test_401_classified_as_auth(self):
e = MockAPIError("Unauthorized", status_code=401)
result = classify_api_error(e, provider="openrouter")
assert result.reason == FailoverReason.auth
assert result.should_rotate_credential is True
# 401 is non-retryable on its own — credential rotation runs
# before the retryability check in the agent loop.
assert result.retryable is False
assert result.should_fallback is True
def test_403_classified_as_auth(self):
e = MockAPIError("Forbidden", status_code=403)
result = classify_api_error(e, provider="anthropic")
assert result.reason == FailoverReason.auth
assert result.should_fallback is True
def test_403_key_limit_classified_as_billing(self):
"""OpenRouter 403 'key limit exceeded' is billing, not auth."""
e = MockAPIError("Key limit exceeded for this key", status_code=403)
result = classify_api_error(e, provider="openrouter")
assert result.reason == FailoverReason.billing
assert result.should_rotate_credential is True
assert result.should_fallback is True
def test_403_spending_limit_classified_as_billing(self):
e = MockAPIError("spending limit reached", status_code=403)
result = classify_api_error(e, provider="openrouter")
assert result.reason == FailoverReason.billing
# ── Billing ──
def test_402_plain_billing(self):
e = MockAPIError("Payment Required", status_code=402)
result = classify_api_error(e)
assert result.reason == FailoverReason.billing
assert result.retryable is False
def test_402_transient_usage_limit(self):
e = MockAPIError("usage limit exceeded, try again later", status_code=402)
result = classify_api_error(e)
assert result.reason == FailoverReason.rate_limit
assert result.retryable is True
# ── Rate limit ──
def test_429_rate_limit(self):
e = MockAPIError("Too Many Requests", status_code=429)
result = classify_api_error(e)
assert result.reason == FailoverReason.rate_limit
assert result.should_fallback is True
# ── Server errors ──
def test_500_server_error(self):
e = MockAPIError("Internal Server Error", status_code=500)
result = classify_api_error(e)
assert result.reason == FailoverReason.server_error
assert result.retryable is True
def test_502_server_error(self):
e = MockAPIError("Bad Gateway", status_code=502)
result = classify_api_error(e)
assert result.reason == FailoverReason.server_error
def test_503_overloaded(self):
e = MockAPIError("Service Unavailable", status_code=503)
result = classify_api_error(e)
assert result.reason == FailoverReason.overloaded
def test_529_anthropic_overloaded(self):
e = MockAPIError("Overloaded", status_code=529)
result = classify_api_error(e)
assert result.reason == FailoverReason.overloaded
# ── Model not found ──
def test_404_model_not_found(self):
e = MockAPIError("model not found", status_code=404)
result = classify_api_error(e)
assert result.reason == FailoverReason.model_not_found
assert result.should_fallback is True
assert result.retryable is False
def test_404_generic(self):
e = MockAPIError("Not Found", status_code=404)
result = classify_api_error(e)
assert result.reason == FailoverReason.model_not_found
# ── Payload too large ──
def test_413_payload_too_large(self):
e = MockAPIError("Request Entity Too Large", status_code=413)
result = classify_api_error(e)
assert result.reason == FailoverReason.payload_too_large
assert result.should_compress is True
# ── Context overflow ──
def test_400_context_length(self):
e = MockAPIError("context length exceeded: 250000 > 200000", status_code=400)
result = classify_api_error(e)
assert result.reason == FailoverReason.context_overflow
assert result.should_compress is True
def test_400_too_many_tokens(self):
e = MockAPIError("This model's maximum context is 128000 tokens, too many tokens", status_code=400)
result = classify_api_error(e)
assert result.reason == FailoverReason.context_overflow
def test_400_prompt_too_long(self):
e = MockAPIError("prompt is too long: 300000 tokens > 200000 maximum", status_code=400)
result = classify_api_error(e)
assert result.reason == FailoverReason.context_overflow
def test_400_generic_large_session(self):
"""Generic 400 with large session → context overflow heuristic."""
e = MockAPIError(
"Error",
status_code=400,
body={"error": {"message": "Error"}},
)
result = classify_api_error(e, approx_tokens=100000, context_length=200000)
assert result.reason == FailoverReason.context_overflow
def test_400_generic_small_session_is_format_error(self):
"""Generic 400 with small session → format error, not context overflow."""
e = MockAPIError(
"Error",
status_code=400,
body={"error": {"message": "Error"}},
)
result = classify_api_error(e, approx_tokens=1000, context_length=200000)
assert result.reason == FailoverReason.format_error
# ── Server disconnect + large session ──
def test_disconnect_large_session_context_overflow(self):
"""Server disconnect with large session → context overflow."""
e = Exception("server disconnected without sending complete message")
result = classify_api_error(e, approx_tokens=150000, context_length=200000)
assert result.reason == FailoverReason.context_overflow
assert result.should_compress is True
def test_disconnect_small_session_timeout(self):
"""Server disconnect with small session → timeout."""
e = Exception("server disconnected without sending complete message")
result = classify_api_error(e, approx_tokens=5000, context_length=200000)
assert result.reason == FailoverReason.timeout
# ── Provider-specific: Anthropic thinking signature ──
def test_anthropic_thinking_signature(self):
e = MockAPIError(
"thinking block has invalid signature",
status_code=400,
)
result = classify_api_error(e, provider="anthropic")
assert result.reason == FailoverReason.thinking_signature
assert result.retryable is True
def test_non_anthropic_400_with_signature_not_classified_as_thinking(self):
"""400 with 'signature' but from non-Anthropic → format error."""
e = MockAPIError("invalid signature", status_code=400)
result = classify_api_error(e, provider="openrouter", approx_tokens=0)
# Without "thinking" in the message, it shouldn't be thinking_signature
assert result.reason != FailoverReason.thinking_signature
# ── Provider-specific: Anthropic long-context tier ──
def test_anthropic_long_context_tier(self):
e = MockAPIError(
"Extra usage is required for long context requests over 200k tokens",
status_code=429,
)
result = classify_api_error(e, provider="anthropic", model="claude-sonnet-4")
assert result.reason == FailoverReason.long_context_tier
assert result.should_compress is True
def test_normal_429_not_long_context(self):
"""Normal 429 without 'extra usage' + 'long context' → rate_limit."""
e = MockAPIError("Too Many Requests", status_code=429)
result = classify_api_error(e, provider="anthropic")
assert result.reason == FailoverReason.rate_limit
# ── Transport errors ──
def test_read_timeout(self):
e = ReadTimeout("Read timed out")
result = classify_api_error(e)
assert result.reason == FailoverReason.timeout
assert result.retryable is True
def test_connect_error(self):
e = ConnectError("Connection refused")
result = classify_api_error(e)
assert result.reason == FailoverReason.timeout
def test_connection_error_builtin(self):
e = ConnectionError("Connection reset by peer")
result = classify_api_error(e)
assert result.reason == FailoverReason.timeout
def test_timeout_error_builtin(self):
e = TimeoutError("timed out")
result = classify_api_error(e)
assert result.reason == FailoverReason.timeout
# ── Error code classification ──
def test_error_code_resource_exhausted(self):
e = MockAPIError(
"Resource exhausted",
body={"error": {"code": "resource_exhausted", "message": "Too many requests"}},
)
result = classify_api_error(e)
assert result.reason == FailoverReason.rate_limit
def test_error_code_model_not_found(self):
e = MockAPIError(
"Model not available",
body={"error": {"code": "model_not_found"}},
)
result = classify_api_error(e)
assert result.reason == FailoverReason.model_not_found
def test_error_code_context_length_exceeded(self):
e = MockAPIError(
"Context too large",
body={"error": {"code": "context_length_exceeded"}},
)
result = classify_api_error(e)
assert result.reason == FailoverReason.context_overflow
# ── Message-only patterns (no status code) ──
def test_message_billing_pattern(self):
e = Exception("insufficient credits to complete this request")
result = classify_api_error(e)
assert result.reason == FailoverReason.billing
def test_message_rate_limit_pattern(self):
e = Exception("rate limit reached for this model")
result = classify_api_error(e)
assert result.reason == FailoverReason.rate_limit
def test_message_auth_pattern(self):
e = Exception("invalid api key provided")
result = classify_api_error(e)
assert result.reason == FailoverReason.auth
def test_message_model_not_found_pattern(self):
e = Exception("gpt-99 is not a valid model")
result = classify_api_error(e)
assert result.reason == FailoverReason.model_not_found
def test_message_context_overflow_pattern(self):
e = Exception("maximum context length exceeded")
result = classify_api_error(e)
assert result.reason == FailoverReason.context_overflow
# ── Unknown / fallback ──
def test_generic_exception_is_unknown(self):
e = Exception("something weird happened")
result = classify_api_error(e)
assert result.reason == FailoverReason.unknown
assert result.retryable is True
# ── Format error ──
def test_400_descriptive_format_error(self):
"""400 with descriptive message (not context overflow) → format error."""
e = MockAPIError(
"Invalid value for parameter 'temperature': must be between 0 and 2",
status_code=400,
body={"error": {"message": "Invalid value for parameter 'temperature': must be between 0 and 2"}},
)
result = classify_api_error(e, approx_tokens=1000)
assert result.reason == FailoverReason.format_error
assert result.retryable is False
def test_422_format_error(self):
e = MockAPIError("Unprocessable Entity", status_code=422)
result = classify_api_error(e)
assert result.reason == FailoverReason.format_error
assert result.retryable is False
def test_400_flat_body_descriptive_not_context_overflow(self):
"""Responses API flat body with descriptive error + large session → format error.
The Codex Responses API returns errors in flat body format:
{"message": "...", "type": "..."} without an "error" wrapper.
A descriptive 400 must NOT be misclassified as context overflow
just because the session is large.
"""
e = MockAPIError(
"Invalid 'input[index].name': string does not match pattern.",
status_code=400,
body={"message": "Invalid 'input[index].name': string does not match pattern.",
"type": "invalid_request_error"},
)
result = classify_api_error(e, approx_tokens=200000, context_length=400000, num_messages=500)
assert result.reason == FailoverReason.format_error
assert result.retryable is False
def test_400_flat_body_generic_large_session_still_context_overflow(self):
"""Flat body with generic 'Error' message + large session → context overflow.
Regression: the flat-body fallback must not break the existing heuristic
for genuinely generic errors from providers that use flat bodies.
"""
e = MockAPIError(
"Error",
status_code=400,
body={"message": "Error"},
)
result = classify_api_error(e, approx_tokens=100000, context_length=200000)
assert result.reason == FailoverReason.context_overflow
# ── Peer closed + large session ──
def test_peer_closed_large_session(self):
e = Exception("peer closed connection without sending complete message")
result = classify_api_error(e, approx_tokens=130000, context_length=200000)
assert result.reason == FailoverReason.context_overflow
# ── Chinese error messages ──
def test_chinese_context_overflow(self):
e = MockAPIError("超过最大长度限制", status_code=400)
result = classify_api_error(e)
assert result.reason == FailoverReason.context_overflow
# ── Result metadata ──
def test_provider_and_model_in_result(self):
e = MockAPIError("fail", status_code=500)
result = classify_api_error(e, provider="openrouter", model="gpt-5")
assert result.provider == "openrouter"
assert result.model == "gpt-5"
assert result.status_code == 500
def test_message_extracted(self):
e = MockAPIError(
"outer",
status_code=500,
body={"error": {"message": "Internal server error occurred"}},
)
result = classify_api_error(e)
assert result.message == "Internal server error occurred"
# ── Test: Adversarial / edge cases (from live testing) ─────────────────
class TestAdversarialEdgeCases:
"""Edge cases discovered during live testing with real SDK objects."""
def test_empty_exception_message(self):
result = classify_api_error(Exception(""))
assert result.reason == FailoverReason.unknown
assert result.retryable is True
def test_500_with_none_body(self):
e = MockAPIError("fail", status_code=500, body=None)
result = classify_api_error(e)
assert result.reason == FailoverReason.server_error
def test_non_dict_body(self):
"""Some providers return strings instead of JSON."""
class StringBodyError(Exception):
status_code = 400
body = "just a string"
result = classify_api_error(StringBodyError("bad"))
assert result.reason == FailoverReason.format_error
def test_list_body(self):
class ListBodyError(Exception):
status_code = 500
body = [{"error": "something"}]
result = classify_api_error(ListBodyError("server error"))
assert result.reason == FailoverReason.server_error
def test_circular_cause_chain(self):
"""Must not infinite-loop on circular __cause__."""
e = Exception("circular")
e.__cause__ = e
result = classify_api_error(e)
assert result.reason == FailoverReason.unknown
def test_three_level_cause_chain(self):
inner = MockAPIError("inner", status_code=429)
middle = Exception("middle")
middle.__cause__ = inner
outer = RuntimeError("outer")
outer.__cause__ = middle
result = classify_api_error(outer)
assert result.status_code == 429
assert result.reason == FailoverReason.rate_limit
def test_400_with_rate_limit_text(self):
"""Some providers send rate limits as 400 instead of 429."""
e = MockAPIError(
"rate limit policy",
status_code=400,
body={"error": {"message": "rate limit exceeded on this model"}},
)
result = classify_api_error(e, provider="openrouter")
assert result.reason == FailoverReason.rate_limit
def test_400_with_billing_text(self):
"""Some providers send billing errors as 400."""
e = MockAPIError(
"billing",
status_code=400,
body={"error": {"message": "insufficient credits for this request"}},
)
result = classify_api_error(e)
assert result.reason == FailoverReason.billing
def test_200_with_error_body(self):
"""200 status with error in body — should be unknown, not crash."""
class WeirdSuccess(Exception):
status_code = 200
body = {"error": {"message": "loading"}}
result = classify_api_error(WeirdSuccess("model loading"))
assert result.reason == FailoverReason.unknown
def test_ollama_context_size_exceeded(self):
e = MockAPIError(
"Error",
status_code=400,
body={"error": {"message": "context size has been exceeded"}},
)
result = classify_api_error(e, provider="ollama")
assert result.reason == FailoverReason.context_overflow
def test_connection_refused_error(self):
e = ConnectionRefusedError("Connection refused: localhost:11434")
result = classify_api_error(e, provider="ollama")
assert result.reason == FailoverReason.timeout
def test_body_message_enrichment(self):
"""Body message must be included in pattern matching even when
str(error) doesn't contain it (OpenAI SDK APIStatusError)."""
e = MockAPIError(
"Usage limit", # str(e) = "usage limit"
status_code=402,
body={"error": {"message": "Usage limit reached, try again in 5 minutes"}},
)
result = classify_api_error(e)
# "try again" is only in body, not in str(e)
assert result.reason == FailoverReason.rate_limit
def test_disconnect_pattern_ordering(self):
"""Disconnect + large session must beat generic transport catch."""
class FakeRemoteProtocol(Exception):
pass
# Type name isn't in _TRANSPORT_ERROR_TYPES but message has disconnect pattern
e = Exception("peer closed connection without sending complete message")
result = classify_api_error(e, approx_tokens=150000, context_length=200000)
assert result.reason == FailoverReason.context_overflow
assert result.should_compress is True
def test_credit_balance_too_low(self):
e = MockAPIError(
"Credits low",
status_code=402,
body={"error": {"message": "Your credit balance is too low"}},
)
result = classify_api_error(e, provider="anthropic")
assert result.reason == FailoverReason.billing
def test_deepseek_402_chinese(self):
"""Chinese billing message should still match billing patterns."""
# "余额不足" doesn't match English billing patterns, but 402 defaults to billing
e = MockAPIError("余额不足", status_code=402)
result = classify_api_error(e, provider="deepseek")
assert result.reason == FailoverReason.billing
def test_openrouter_wrapped_context_overflow_in_metadata_raw(self):
"""OpenRouter wraps provider errors in metadata.raw JSON string."""
e = MockAPIError(
"Provider returned error",
status_code=400,
body={
"error": {
"message": "Provider returned error",
"code": 400,
"metadata": {
"raw": '{"error":{"message":"context length exceeded: 50000 > 32768"}}'
}
}
},
)
result = classify_api_error(e, provider="openrouter", approx_tokens=10000)
assert result.reason == FailoverReason.context_overflow
assert result.should_compress is True
def test_openrouter_wrapped_rate_limit_in_metadata_raw(self):
e = MockAPIError(
"Provider returned error",
status_code=400,
body={
"error": {
"message": "Provider returned error",
"metadata": {
"raw": '{"error":{"message":"Rate limit exceeded. Please retry after 30s."}}'
}
}
},
)
result = classify_api_error(e, provider="openrouter")
assert result.reason == FailoverReason.rate_limit
def test_thinking_signature_via_openrouter(self):
"""Thinking signature errors proxied through OpenRouter must be caught."""
e = MockAPIError(
"thinking block has invalid signature",
status_code=400,
)
# provider is openrouter, not anthropic — old code missed this
result = classify_api_error(e, provider="openrouter", model="anthropic/claude-sonnet-4")
assert result.reason == FailoverReason.thinking_signature
def test_generic_400_large_by_message_count(self):
"""Many small messages (>80) should trigger context overflow heuristic."""
e = MockAPIError(
"Error",
status_code=400,
body={"error": {"message": "Error"}},
)
# Low token count but high message count
result = classify_api_error(
e, approx_tokens=5000, context_length=200000, num_messages=100,
)
assert result.reason == FailoverReason.context_overflow
def test_disconnect_large_by_message_count(self):
"""Server disconnect with 200+ messages should trigger context overflow."""
e = Exception("server disconnected without sending complete message")
result = classify_api_error(
e, approx_tokens=5000, context_length=200000, num_messages=250,
)
assert result.reason == FailoverReason.context_overflow
def test_openrouter_wrapped_model_not_found_in_metadata_raw(self):
e = MockAPIError(
"Provider returned error",
status_code=400,
body={
"error": {
"message": "Provider returned error",
"metadata": {
"raw": '{"error":{"message":"The model gpt-99 does not exist"}}'
}
}
},
)
result = classify_api_error(e, provider="openrouter")
assert result.reason == FailoverReason.model_not_found
+40
View File
@@ -7,6 +7,7 @@ from pathlib import Path
from hermes_state import SessionDB
from agent.insights import (
InsightsEngine,
_get_pricing,
_estimate_cost,
_format_duration,
_bar_chart,
@@ -117,6 +118,45 @@ def populated_db(db):
return db
# =========================================================================
# Pricing helpers
# =========================================================================
class TestPricing:
def test_provider_prefix_stripped(self):
pricing = _get_pricing("anthropic/claude-sonnet-4-20250514")
assert pricing["input"] == 3.00
assert pricing["output"] == 15.00
def test_unknown_models_do_not_use_heuristics(self):
pricing = _get_pricing("some-new-opus-model")
assert pricing == _DEFAULT_PRICING
pricing = _get_pricing("anthropic/claude-haiku-future")
assert pricing == _DEFAULT_PRICING
def test_unknown_model_returns_zero_cost(self):
"""Unknown/custom models should NOT have fabricated costs."""
pricing = _get_pricing("totally-unknown-model-xyz")
assert pricing == _DEFAULT_PRICING
assert pricing["input"] == 0.0
assert pricing["output"] == 0.0
def test_custom_endpoint_model_zero_cost(self):
"""Self-hosted models should return zero cost."""
for model in ["FP16_Hermes_4.5", "Hermes_4.5_1T_epoch2", "my-local-llama"]:
pricing = _get_pricing(model)
assert pricing["input"] == 0.0, f"{model} should have zero cost"
assert pricing["output"] == 0.0, f"{model} should have zero cost"
def test_none_model(self):
pricing = _get_pricing(None)
assert pricing == _DEFAULT_PRICING
def test_empty_model(self):
pricing = _get_pricing("")
assert pricing == _DEFAULT_PRICING
class TestHasKnownPricing:
def test_known_commercial_model(self):
assert _has_known_pricing("gpt-4o", provider="openai") is True
+299
View File
@@ -0,0 +1,299 @@
"""End-to-end test: a SQLite-backed memory plugin exercising the full interface.
This proves a real plugin can register as a MemoryProvider and get wired
into the agent loop via MemoryManager. Uses SQLite + FTS5 (stdlib, no
external deps, no API keys).
"""
import json
import os
import sqlite3
import tempfile
import pytest
from unittest.mock import patch, MagicMock
from agent.memory_provider import MemoryProvider
from agent.memory_manager import MemoryManager
from agent.builtin_memory_provider import BuiltinMemoryProvider
# ---------------------------------------------------------------------------
# SQLite FTS5 memory provider — a real, minimal plugin implementation
# ---------------------------------------------------------------------------
class SQLiteMemoryProvider(MemoryProvider):
"""Minimal SQLite + FTS5 memory provider for testing.
Demonstrates the full MemoryProvider interface with a real backend.
No external dependencies just stdlib sqlite3.
"""
def __init__(self, db_path: str = ":memory:"):
self._db_path = db_path
self._conn = None
@property
def name(self) -> str:
return "sqlite_memory"
def is_available(self) -> bool:
return True # SQLite is always available
def initialize(self, session_id: str, **kwargs) -> None:
self._conn = sqlite3.connect(self._db_path)
self._conn.execute("PRAGMA journal_mode=WAL")
self._conn.execute("""
CREATE VIRTUAL TABLE IF NOT EXISTS memories
USING fts5(content, context, session_id)
""")
self._session_id = session_id
def system_prompt_block(self) -> str:
if not self._conn:
return ""
count = self._conn.execute("SELECT COUNT(*) FROM memories").fetchone()[0]
if count == 0:
return ""
return (
f"# SQLite Memory Plugin\n"
f"Active. {count} memories stored.\n"
f"Use sqlite_recall to search, sqlite_retain to store."
)
def prefetch(self, query: str, *, session_id: str = "") -> str:
if not self._conn or not query:
return ""
# FTS5 search
try:
rows = self._conn.execute(
"SELECT content FROM memories WHERE memories MATCH ? LIMIT 5",
(query,)
).fetchall()
if not rows:
return ""
results = [row[0] for row in rows]
return "## SQLite Memory\n" + "\n".join(f"- {r}" for r in results)
except sqlite3.OperationalError:
return ""
def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None:
if not self._conn:
return
combined = f"User: {user_content}\nAssistant: {assistant_content}"
self._conn.execute(
"INSERT INTO memories (content, context, session_id) VALUES (?, ?, ?)",
(combined, "conversation", self._session_id),
)
self._conn.commit()
def get_tool_schemas(self):
return [
{
"name": "sqlite_retain",
"description": "Store a fact to SQLite memory.",
"parameters": {
"type": "object",
"properties": {
"content": {"type": "string", "description": "What to remember"},
"context": {"type": "string", "description": "Category/context"},
},
"required": ["content"],
},
},
{
"name": "sqlite_recall",
"description": "Search SQLite memory.",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "Search query"},
},
"required": ["query"],
},
},
]
def handle_tool_call(self, tool_name: str, args: dict, **kwargs) -> str:
if tool_name == "sqlite_retain":
content = args.get("content", "")
context = args.get("context", "explicit")
if not content:
return json.dumps({"error": "content is required"})
self._conn.execute(
"INSERT INTO memories (content, context, session_id) VALUES (?, ?, ?)",
(content, context, self._session_id),
)
self._conn.commit()
return json.dumps({"result": "Stored."})
elif tool_name == "sqlite_recall":
query = args.get("query", "")
if not query:
return json.dumps({"error": "query is required"})
try:
rows = self._conn.execute(
"SELECT content, context FROM memories WHERE memories MATCH ? LIMIT 10",
(query,)
).fetchall()
results = [{"content": r[0], "context": r[1]} for r in rows]
return json.dumps({"results": results})
except sqlite3.OperationalError:
return json.dumps({"results": []})
return json.dumps({"error": f"Unknown tool: {tool_name}"})
def on_memory_write(self, action, target, content):
"""Mirror built-in memory writes to SQLite."""
if action == "add" and self._conn:
self._conn.execute(
"INSERT INTO memories (content, context, session_id) VALUES (?, ?, ?)",
(content, f"builtin_{target}", self._session_id),
)
self._conn.commit()
def shutdown(self):
if self._conn:
self._conn.close()
self._conn = None
# ---------------------------------------------------------------------------
# End-to-end tests
# ---------------------------------------------------------------------------
class TestSQLiteMemoryPlugin:
"""Full lifecycle test with the SQLite provider."""
def test_full_lifecycle(self):
"""Exercise init → store → recall → sync → prefetch → shutdown."""
mgr = MemoryManager()
builtin = BuiltinMemoryProvider()
sqlite_mem = SQLiteMemoryProvider()
mgr.add_provider(builtin)
mgr.add_provider(sqlite_mem)
# Initialize
mgr.initialize_all(session_id="test-session-1", platform="cli")
assert sqlite_mem._conn is not None
# System prompt — empty at first
prompt = mgr.build_system_prompt()
assert "SQLite Memory Plugin" not in prompt
# Store via tool call
result = json.loads(mgr.handle_tool_call(
"sqlite_retain", {"content": "User prefers dark mode", "context": "preference"}
))
assert result["result"] == "Stored."
# System prompt now shows count
prompt = mgr.build_system_prompt()
assert "1 memories stored" in prompt
# Recall via tool call
result = json.loads(mgr.handle_tool_call(
"sqlite_recall", {"query": "dark mode"}
))
assert len(result["results"]) == 1
assert "dark mode" in result["results"][0]["content"]
# Sync a turn (auto-stores conversation)
mgr.sync_all("What's my theme?", "You prefer dark mode.")
count = sqlite_mem._conn.execute("SELECT COUNT(*) FROM memories").fetchone()[0]
assert count == 2 # 1 explicit + 1 synced
# Prefetch for next turn
prefetched = mgr.prefetch_all("dark mode")
assert "dark mode" in prefetched
# Memory bridge — mirroring builtin writes
mgr.on_memory_write("add", "user", "Timezone: US Pacific")
count = sqlite_mem._conn.execute("SELECT COUNT(*) FROM memories").fetchone()[0]
assert count == 3
# Shutdown
mgr.shutdown_all()
assert sqlite_mem._conn is None
def test_tool_routing_with_builtin(self):
"""Verify builtin + plugin tools coexist without conflict."""
mgr = MemoryManager()
builtin = BuiltinMemoryProvider()
sqlite_mem = SQLiteMemoryProvider()
mgr.add_provider(builtin)
mgr.add_provider(sqlite_mem)
mgr.initialize_all(session_id="test-2")
# Builtin has no tools
assert len(builtin.get_tool_schemas()) == 0
# SQLite has 2 tools
schemas = mgr.get_all_tool_schemas()
names = {s["name"] for s in schemas}
assert names == {"sqlite_retain", "sqlite_recall"}
# Routing works
assert mgr.has_tool("sqlite_retain")
assert mgr.has_tool("sqlite_recall")
assert not mgr.has_tool("memory") # builtin doesn't register this
def test_second_external_plugin_rejected(self):
"""Only one external memory provider is allowed at a time."""
mgr = MemoryManager()
p1 = SQLiteMemoryProvider()
p2 = SQLiteMemoryProvider()
# Hack name for p2
p2._name_override = "sqlite_memory_2"
original_name = p2.__class__.name
type(p2).name = property(lambda self: getattr(self, '_name_override', 'sqlite_memory'))
mgr.add_provider(p1)
mgr.add_provider(p2) # should be rejected
# Only p1 was accepted
assert len(mgr.providers) == 1
assert mgr.provider_names == ["sqlite_memory"]
# Restore class
type(p2).name = original_name
mgr.shutdown_all()
def test_provider_failure_isolation(self):
"""Failing external provider doesn't break builtin."""
from agent.builtin_memory_provider import BuiltinMemoryProvider
mgr = MemoryManager()
builtin = BuiltinMemoryProvider() # name="builtin", always accepted
ext = SQLiteMemoryProvider()
mgr.add_provider(builtin)
mgr.add_provider(ext)
mgr.initialize_all(session_id="test-4")
# Break external provider's connection
ext._conn.close()
ext._conn = None
# Sync — external fails silently, builtin (no-op sync) succeeds
mgr.sync_all("user", "assistant") # should not raise
mgr.shutdown_all()
def test_plugin_registration_flow(self):
"""Simulate the full plugin load → agent init path."""
# Simulate what AIAgent.__init__ does via plugins/memory/ discovery
provider = SQLiteMemoryProvider()
mem_mgr = MemoryManager()
mem_mgr.add_provider(BuiltinMemoryProvider())
if provider.is_available():
mem_mgr.add_provider(provider)
mem_mgr.initialize_all(session_id="agent-session")
assert len(mem_mgr.providers) == 2
assert mem_mgr.provider_names == ["builtin", "sqlite_memory"]
assert provider._conn is not None # initialized = connection established
mem_mgr.shutdown_all()
+157 -4
View File
@@ -6,6 +6,8 @@ from unittest.mock import MagicMock, patch
from agent.memory_provider import MemoryProvider
from agent.memory_manager import MemoryManager
from agent.builtin_memory_provider import BuiltinMemoryProvider
# ---------------------------------------------------------------------------
# Concrete test provider
@@ -116,7 +118,7 @@ class TestMemoryManager:
def test_empty_manager(self):
mgr = MemoryManager()
assert mgr.providers == []
assert [p.name for p in mgr.providers] == []
assert mgr.provider_names == []
assert mgr.get_all_tool_schemas() == []
assert mgr.build_system_prompt() == ""
assert mgr.prefetch_all("test") == ""
@@ -126,7 +128,7 @@ class TestMemoryManager:
p = FakeMemoryProvider("test1")
mgr.add_provider(p)
assert len(mgr.providers) == 1
assert [p.name for p in mgr.providers] == ["test1"]
assert mgr.provider_names == ["test1"]
def test_get_provider_by_name(self):
mgr = MemoryManager()
@@ -141,7 +143,7 @@ class TestMemoryManager:
p2 = FakeMemoryProvider("external")
mgr.add_provider(p1)
mgr.add_provider(p2)
assert [p.name for p in mgr.providers] == ["builtin", "external"]
assert mgr.provider_names == ["builtin", "external"]
def test_second_external_rejected(self):
"""Only one non-builtin provider is allowed."""
@@ -152,7 +154,7 @@ class TestMemoryManager:
mgr.add_provider(builtin)
mgr.add_provider(ext1)
mgr.add_provider(ext2) # should be rejected
assert [p.name for p in mgr.providers] == ["builtin", "mem0"]
assert mgr.provider_names == ["builtin", "mem0"]
assert len(mgr.providers) == 2
def test_system_prompt_merges_blocks(self):
@@ -319,6 +321,17 @@ class TestMemoryManager:
mgr.on_pre_compress([{"role": "user", "content": "old"}])
assert p.pre_compress_called
def test_on_memory_write_skips_builtin(self):
"""on_memory_write should skip the builtin provider."""
mgr = MemoryManager()
builtin = BuiltinMemoryProvider()
external = FakeMemoryProvider("external")
mgr.add_provider(builtin)
mgr.add_provider(external)
mgr.on_memory_write("add", "memory", "test fact")
assert external.memory_writes == [("add", "memory", "test fact")]
def test_shutdown_all_reverse_order(self):
mgr = MemoryManager()
order = []
@@ -372,6 +385,146 @@ class TestMemoryManager:
assert result == "works fine"
# ---------------------------------------------------------------------------
# BuiltinMemoryProvider tests
# ---------------------------------------------------------------------------
class TestBuiltinMemoryProvider:
def test_name(self):
p = BuiltinMemoryProvider()
assert p.name == "builtin"
def test_always_available(self):
p = BuiltinMemoryProvider()
assert p.is_available()
def test_no_tools(self):
"""Builtin provider exposes no tools (memory tool is agent-level)."""
p = BuiltinMemoryProvider()
assert p.get_tool_schemas() == []
def test_system_prompt_with_store(self):
store = MagicMock()
store.format_for_system_prompt.side_effect = lambda t: f"BLOCK_{t}" if t == "memory" else f"BLOCK_{t}"
p = BuiltinMemoryProvider(
memory_store=store,
memory_enabled=True,
user_profile_enabled=True,
)
block = p.system_prompt_block()
assert "BLOCK_memory" in block
assert "BLOCK_user" in block
def test_system_prompt_memory_disabled(self):
store = MagicMock()
store.format_for_system_prompt.return_value = "content"
p = BuiltinMemoryProvider(
memory_store=store,
memory_enabled=False,
user_profile_enabled=False,
)
assert p.system_prompt_block() == ""
def test_system_prompt_no_store(self):
p = BuiltinMemoryProvider(memory_store=None, memory_enabled=True)
assert p.system_prompt_block() == ""
def test_prefetch_returns_empty(self):
p = BuiltinMemoryProvider()
assert p.prefetch("anything") == ""
def test_store_property(self):
store = MagicMock()
p = BuiltinMemoryProvider(memory_store=store)
assert p.store is store
def test_initialize_loads_from_disk(self):
store = MagicMock()
p = BuiltinMemoryProvider(memory_store=store)
p.initialize(session_id="test")
store.load_from_disk.assert_called_once()
# ---------------------------------------------------------------------------
# Plugin registration tests
# ---------------------------------------------------------------------------
class TestSingleProviderGating:
"""Only the configured provider should activate."""
def test_no_provider_configured_means_builtin_only(self):
"""When memory.provider is empty, no plugin providers activate."""
mgr = MemoryManager()
builtin = BuiltinMemoryProvider()
mgr.add_provider(builtin)
# Simulate what run_agent.py does when provider=""
configured = ""
available_plugins = [
FakeMemoryProvider("holographic"),
FakeMemoryProvider("mem0"),
]
# With empty config, no plugins should be added
if configured:
for p in available_plugins:
if p.name == configured and p.is_available():
mgr.add_provider(p)
assert mgr.provider_names == ["builtin"]
def test_configured_provider_activates(self):
"""Only the named provider should be added."""
mgr = MemoryManager()
builtin = BuiltinMemoryProvider()
mgr.add_provider(builtin)
configured = "holographic"
p1 = FakeMemoryProvider("holographic")
p2 = FakeMemoryProvider("mem0")
p3 = FakeMemoryProvider("hindsight")
for p in [p1, p2, p3]:
if p.name == configured and p.is_available():
mgr.add_provider(p)
assert mgr.provider_names == ["builtin", "holographic"]
assert p1.initialized is False # not initialized by the gating logic itself
def test_unavailable_provider_skipped(self):
"""If the configured provider is unavailable, it should be skipped."""
mgr = MemoryManager()
builtin = BuiltinMemoryProvider()
mgr.add_provider(builtin)
configured = "holographic"
p1 = FakeMemoryProvider("holographic", available=False)
for p in [p1]:
if p.name == configured and p.is_available():
mgr.add_provider(p)
assert mgr.provider_names == ["builtin"]
def test_nonexistent_provider_results_in_builtin_only(self):
"""If the configured name doesn't match any plugin, only builtin remains."""
mgr = MemoryManager()
builtin = BuiltinMemoryProvider()
mgr.add_provider(builtin)
configured = "nonexistent"
plugins = [FakeMemoryProvider("holographic"), FakeMemoryProvider("mem0")]
for p in plugins:
if p.name == configured and p.is_available():
mgr.add_provider(p)
assert mgr.provider_names == ["builtin"]
class TestPluginMemoryDiscovery:
"""Memory providers are discovered from plugins/memory/ directory."""
+56
View File
@@ -11,6 +11,7 @@ from agent.prompt_builder import (
_scan_context_content,
_truncate_content,
_parse_skill_file,
_read_skill_conditions,
_skill_should_show,
_find_hermes_md,
_find_git_root,
@@ -774,6 +775,61 @@ class TestPromptBuilderConstants:
# Conditional skill activation
# =========================================================================
class TestReadSkillConditions:
def test_no_conditions_returns_empty_lists(self, tmp_path):
skill_file = tmp_path / "SKILL.md"
skill_file.write_text("---\nname: test\ndescription: A skill\n---\n")
conditions = _read_skill_conditions(skill_file)
assert conditions["fallback_for_toolsets"] == []
assert conditions["requires_toolsets"] == []
assert conditions["fallback_for_tools"] == []
assert conditions["requires_tools"] == []
def test_reads_fallback_for_toolsets(self, tmp_path):
skill_file = tmp_path / "SKILL.md"
skill_file.write_text(
"---\nname: ddg\ndescription: DuckDuckGo\nmetadata:\n hermes:\n fallback_for_toolsets: [web]\n---\n"
)
conditions = _read_skill_conditions(skill_file)
assert conditions["fallback_for_toolsets"] == ["web"]
def test_reads_requires_toolsets(self, tmp_path):
skill_file = tmp_path / "SKILL.md"
skill_file.write_text(
"---\nname: openhue\ndescription: Hue lights\nmetadata:\n hermes:\n requires_toolsets: [terminal]\n---\n"
)
conditions = _read_skill_conditions(skill_file)
assert conditions["requires_toolsets"] == ["terminal"]
def test_reads_multiple_conditions(self, tmp_path):
skill_file = tmp_path / "SKILL.md"
skill_file.write_text(
"---\nname: test\ndescription: Test\nmetadata:\n hermes:\n fallback_for_toolsets: [browser]\n requires_tools: [terminal]\n---\n"
)
conditions = _read_skill_conditions(skill_file)
assert conditions["fallback_for_toolsets"] == ["browser"]
assert conditions["requires_tools"] == ["terminal"]
def test_missing_file_returns_empty(self, tmp_path):
conditions = _read_skill_conditions(tmp_path / "missing.md")
assert conditions == {}
def test_logs_condition_read_failures_and_returns_empty(self, tmp_path, monkeypatch, caplog):
skill_file = tmp_path / "SKILL.md"
skill_file.write_text("---\nname: broken\n---\n")
def boom(*args, **kwargs):
raise OSError("read exploded")
monkeypatch.setattr(type(skill_file), "read_text", boom)
with caplog.at_level(logging.DEBUG, logger="agent.prompt_builder"):
conditions = _read_skill_conditions(skill_file)
assert conditions == {}
assert "Failed to read skill conditions" in caplog.text
assert str(skill_file) in caplog.text
class TestSkillShouldShow:
def test_no_filter_info_always_shows(self):
assert _skill_should_show({}, None, None) is True
-212
View File
@@ -1,212 +0,0 @@
"""Tests for agent.rate_limit_tracker — header parsing and formatting."""
import time
import pytest
from agent.rate_limit_tracker import (
RateLimitBucket,
RateLimitState,
parse_rate_limit_headers,
format_rate_limit_display,
format_rate_limit_compact,
_fmt_count,
_fmt_seconds,
_bar,
)
# ── Sample headers from Nous inference API ──────────────────────────────
NOUS_HEADERS = {
"x-ratelimit-limit-requests": "800",
"x-ratelimit-limit-requests-1h": "33600",
"x-ratelimit-limit-tokens": "8000000",
"x-ratelimit-limit-tokens-1h": "336000000",
"x-ratelimit-remaining-requests": "795",
"x-ratelimit-remaining-requests-1h": "33590",
"x-ratelimit-remaining-tokens": "7999500",
"x-ratelimit-remaining-tokens-1h": "335999000",
"x-ratelimit-reset-requests": "45.5",
"x-ratelimit-reset-requests-1h": "3500.0",
"x-ratelimit-reset-tokens": "42.3",
"x-ratelimit-reset-tokens-1h": "3490.0",
}
class TestParseHeaders:
def test_basic_parsing(self):
state = parse_rate_limit_headers(NOUS_HEADERS, provider="nous")
assert state is not None
assert state.provider == "nous"
assert state.has_data
assert state.requests_min.limit == 800
assert state.requests_min.remaining == 795
assert state.requests_min.reset_seconds == 45.5
assert state.requests_hour.limit == 33600
assert state.requests_hour.remaining == 33590
assert state.tokens_min.limit == 8000000
assert state.tokens_min.remaining == 7999500
assert state.tokens_hour.limit == 336000000
assert state.tokens_hour.remaining == 335999000
assert state.tokens_hour.reset_seconds == 3490.0
def test_no_headers(self):
state = parse_rate_limit_headers({})
assert state is None
def test_partial_headers(self):
headers = {
"x-ratelimit-limit-requests": "100",
"x-ratelimit-remaining-requests": "50",
}
state = parse_rate_limit_headers(headers)
assert state is not None
assert state.requests_min.limit == 100
assert state.requests_min.remaining == 50
# Missing fields default to 0
assert state.tokens_min.limit == 0
def test_non_rate_limit_headers_ignored(self):
headers = {
"content-type": "application/json",
"server": "nginx",
}
state = parse_rate_limit_headers(headers)
assert state is None
def test_malformed_values(self):
headers = {
"x-ratelimit-limit-requests": "not-a-number",
"x-ratelimit-remaining-requests": "",
"x-ratelimit-reset-requests": "abc",
}
state = parse_rate_limit_headers(headers)
assert state is not None
assert state.requests_min.limit == 0
assert state.requests_min.remaining == 0
assert state.requests_min.reset_seconds == 0.0
class TestBucket:
def test_used(self):
b = RateLimitBucket(limit=800, remaining=795, reset_seconds=45.0, captured_at=time.time())
assert b.used == 5
def test_usage_pct(self):
b = RateLimitBucket(limit=100, remaining=20, reset_seconds=30.0, captured_at=time.time())
assert b.usage_pct == pytest.approx(80.0)
def test_usage_pct_zero_limit(self):
b = RateLimitBucket(limit=0, remaining=0)
assert b.usage_pct == 0.0
def test_remaining_seconds_now(self):
now = time.time()
b = RateLimitBucket(limit=800, remaining=795, reset_seconds=60.0, captured_at=now - 10)
# ~50 seconds should remain
assert 49 <= b.remaining_seconds_now <= 51
def test_remaining_seconds_expired(self):
b = RateLimitBucket(limit=800, remaining=795, reset_seconds=30.0, captured_at=time.time() - 60)
assert b.remaining_seconds_now == 0.0
class TestFormatting:
def test_fmt_count_millions(self):
assert _fmt_count(8000000) == "8.0M"
assert _fmt_count(336000000) == "336.0M"
def test_fmt_count_thousands(self):
assert _fmt_count(33600) == "33.6K"
assert _fmt_count(1500) == "1.5K"
def test_fmt_count_small(self):
assert _fmt_count(800) == "800"
assert _fmt_count(0) == "0"
def test_fmt_seconds_short(self):
assert _fmt_seconds(45) == "45s"
assert _fmt_seconds(0) == "0s"
def test_fmt_seconds_minutes(self):
assert _fmt_seconds(125) == "2m 5s"
assert _fmt_seconds(120) == "2m"
def test_fmt_seconds_hours(self):
assert _fmt_seconds(3660) == "1h 1m"
assert _fmt_seconds(3600) == "1h"
def test_bar(self):
bar = _bar(50.0, width=10)
assert bar == "[█████░░░░░]"
assert _bar(0.0, width=10) == "[░░░░░░░░░░]"
assert _bar(100.0, width=10) == "[██████████]"
def test_format_display_no_data(self):
state = RateLimitState()
result = format_rate_limit_display(state)
assert "No rate limit data" in result
def test_format_display_with_data(self):
state = parse_rate_limit_headers(NOUS_HEADERS, provider="nous")
result = format_rate_limit_display(state)
assert "Nous" in result
assert "Requests/min" in result
assert "Requests/hr" in result
assert "Tokens/min" in result
assert "Tokens/hr" in result
assert "resets in" in result
def test_format_display_warning_on_high_usage(self):
headers = {
**NOUS_HEADERS,
"x-ratelimit-remaining-requests": "50", # 750/800 used = 93.75%
}
state = parse_rate_limit_headers(headers)
result = format_rate_limit_display(state)
assert "" in result
def test_format_compact(self):
state = parse_rate_limit_headers(NOUS_HEADERS, provider="nous")
result = format_rate_limit_compact(state)
assert "RPM:" in result
assert "RPH:" in result
assert "TPM:" in result
assert "TPH:" in result
assert "resets" in result
def test_format_compact_no_data(self):
state = RateLimitState()
result = format_rate_limit_compact(state)
assert "No rate limit data" in result
class TestAgentIntegration:
"""Test that AIAgent captures rate limit state correctly."""
def test_capture_rate_limits_from_headers(self):
"""Simulate the header capture path without a real API call."""
import sys
import os
# Use a mock httpx-like response
class MockResponse:
headers = NOUS_HEADERS
# Import AIAgent minimally
from unittest.mock import MagicMock, patch
# Test the parsing directly
state = parse_rate_limit_headers(MockResponse.headers, provider="nous")
assert state is not None
assert state.requests_min.limit == 800
assert state.tokens_hour.limit == 336000000
def test_capture_rate_limits_none_response(self):
"""_capture_rate_limits should handle None gracefully."""
from agent.rate_limit_tracker import parse_rate_limit_headers
# None should not crash
result = parse_rate_limit_headers({})
assert result is None
-43
View File
@@ -3,7 +3,6 @@
import os
import pytest
from pathlib import Path
from unittest.mock import patch
from agent.subdirectory_hints import SubdirectoryHintTracker
@@ -190,45 +189,3 @@ class TestSubdirectoryHintTracker:
"terminal", {"command": "curl https://example.com/frontend/api"}
)
assert result is None
class TestPermissionErrorHandling:
"""Regression tests for PermissionError in filesystem checks (ref #6214)."""
def test_is_valid_subdir_permission_error(self, tmp_path):
"""_is_valid_subdir should return False when is_dir() raises PermissionError."""
tracker = SubdirectoryHintTracker(working_dir=str(tmp_path))
restricted = tmp_path / "restricted"
restricted.mkdir()
with patch.object(Path, "is_dir", side_effect=PermissionError("Permission denied")):
assert tracker._is_valid_subdir(restricted) is False
def test_load_hints_permission_error_on_is_file(self, tmp_path):
"""_load_hints_for_directory should skip files when is_file() raises PermissionError."""
tracker = SubdirectoryHintTracker(working_dir=str(tmp_path))
restricted = tmp_path / "restricted"
restricted.mkdir()
original_is_file = Path.is_file
def patched_is_file(self):
if "restricted" in str(self):
raise PermissionError("Permission denied")
return original_is_file(self)
with patch.object(Path, "is_file", patched_is_file):
result = tracker._load_hints_for_directory(restricted)
assert result is None
def test_check_tool_call_survives_inaccessible_path(self, project):
"""Full check_tool_call should not crash when a path is inaccessible."""
tracker = SubdirectoryHintTracker(working_dir=str(project))
original_is_dir = Path.is_dir
def patched_is_dir(self):
if "backend" in str(self) and "src" not in str(self):
raise PermissionError("Permission denied")
return original_is_dir(self)
with patch.object(Path, "is_dir", patched_is_dir):
# Should not raise — gracefully skip the inaccessible directory
result = tracker.check_tool_call(
"read_file", {"path": str(project / "backend" / "src" / "main.py")}
)
# Result may be None (backend skipped) — the key point is no crash
assert result is None or isinstance(result, str)
+2 -45
View File
@@ -2,65 +2,22 @@ import queue
import threading
import time
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock
import cli as cli_module
from cli import HermesCLI
class _FakeBuffer:
def __init__(self, text="", cursor_position=None):
self.text = text
self.cursor_position = len(text) if cursor_position is None else cursor_position
def reset(self, append_to_history=False):
self.text = ""
self.cursor_position = 0
def _make_cli_stub():
cli = HermesCLI.__new__(HermesCLI)
cli._approval_state = None
cli._approval_deadline = 0
cli._approval_lock = threading.Lock()
cli._sudo_state = None
cli._sudo_deadline = 0
cli._modal_input_snapshot = None
cli._invalidate = MagicMock()
cli._app = SimpleNamespace(invalidate=MagicMock(), current_buffer=_FakeBuffer())
cli._app = SimpleNamespace(invalidate=MagicMock())
return cli
class TestCliApprovalUi:
def test_sudo_prompt_restores_existing_draft_after_response(self):
cli = _make_cli_stub()
cli._app.current_buffer = _FakeBuffer("draft command", cursor_position=5)
result = {}
def _run_callback():
result["value"] = cli._sudo_password_callback()
with patch.object(cli_module, "_cprint"):
thread = threading.Thread(target=_run_callback, daemon=True)
thread.start()
deadline = time.time() + 2
while cli._sudo_state is None and time.time() < deadline:
time.sleep(0.01)
assert cli._sudo_state is not None
assert cli._app.current_buffer.text == ""
cli._app.current_buffer.text = "secret"
cli._app.current_buffer.cursor_position = len("secret")
cli._sudo_state["response_queue"].put("secret")
thread.join(timeout=2)
assert result["value"] == "secret"
assert cli._app.current_buffer.text == "draft command"
assert cli._app.current_buffer.cursor_position == 5
def test_approval_callback_includes_view_for_long_commands(self):
cli = _make_cli_stub()
command = "sudo dd if=/tmp/githubcli-keyring.gpg of=/usr/share/keyrings/githubcli-archive-keyring.gpg bs=4M status=progress"
+2 -13
View File
@@ -6,17 +6,6 @@ from unittest.mock import patch
from cli import HermesCLI
def _assert_chrome_debug_cmd(cmd, expected_chrome, expected_port):
"""Verify the auto-launch command has all required flags."""
assert cmd[0] == expected_chrome
assert f"--remote-debugging-port={expected_port}" in cmd
assert "--no-first-run" in cmd
assert "--no-default-browser-check" in cmd
user_data_args = [a for a in cmd if a.startswith("--user-data-dir=")]
assert len(user_data_args) == 1, "Expected exactly one --user-data-dir flag"
assert "chrome-debug" in user_data_args[0]
class TestChromeDebugLaunch:
def test_windows_launch_uses_browser_found_on_path(self):
captured = {}
@@ -31,7 +20,7 @@ class TestChromeDebugLaunch:
patch("subprocess.Popen", side_effect=fake_popen):
assert HermesCLI._try_launch_chrome_debug(9333, "Windows") is True
_assert_chrome_debug_cmd(captured["cmd"], r"C:\Chrome\chrome.exe", 9333)
assert captured["cmd"] == [r"C:\Chrome\chrome.exe", "--remote-debugging-port=9333"]
assert captured["kwargs"]["start_new_session"] is True
def test_windows_launch_falls_back_to_common_install_dirs(self, monkeypatch):
@@ -54,4 +43,4 @@ class TestChromeDebugLaunch:
patch("subprocess.Popen", side_effect=fake_popen):
assert HermesCLI._try_launch_chrome_debug(9222, "Windows") is True
_assert_chrome_debug_cmd(captured["cmd"], installed, 9222)
assert captured["cmd"] == [installed, "--remote-debugging-port=9222"]
-1
View File
@@ -41,7 +41,6 @@ def _attach_agent(
session_completion_tokens=completion_tokens,
session_total_tokens=total_tokens,
session_api_calls=api_calls,
get_rate_limit_state=lambda: None,
context_compressor=SimpleNamespace(
last_prompt_tokens=context_tokens,
context_length=context_length,
+10 -4
View File
@@ -619,14 +619,17 @@ class TestReasoningDeltasFiredFlag(unittest.TestCase):
agent = AIAgent.__new__(AIAgent)
agent.reasoning_callback = None
agent.stream_delta_callback = None
agent._reasoning_deltas_fired = False
agent.verbose_logging = False
return agent
def test_fire_reasoning_delta_calls_callback(self):
def test_fire_reasoning_delta_sets_flag(self):
agent = self._make_agent()
captured = []
agent.reasoning_callback = lambda t: captured.append(t)
self.assertFalse(agent._reasoning_deltas_fired)
agent._fire_reasoning_delta("thinking...")
self.assertTrue(agent._reasoning_deltas_fired)
self.assertEqual(captured, ["thinking..."])
def test_build_assistant_message_skips_callback_when_already_streamed(self):
@@ -637,7 +640,8 @@ class TestReasoningDeltasFiredFlag(unittest.TestCase):
agent.reasoning_callback = lambda t: captured.append(t)
agent.stream_delta_callback = lambda t: None # streaming is active
# Simulate streaming having already fired reasoning
# Simulate streaming having fired reasoning
agent._reasoning_deltas_fired = True
msg = SimpleNamespace(
content="I'll merge that.",
@@ -661,8 +665,9 @@ class TestReasoningDeltasFiredFlag(unittest.TestCase):
agent.reasoning_callback = lambda t: captured.append(t)
agent.stream_delta_callback = lambda t: None # streaming active
# Reasoning came through content tags, not reasoning_content deltas.
# Callback should not fire since streaming is active.
# Even though _reasoning_deltas_fired is False (reasoning came through
# content tags, not reasoning_content deltas), callback should not fire
agent._reasoning_deltas_fired = False
msg = SimpleNamespace(
content="I'll merge that.",
@@ -684,6 +689,7 @@ class TestReasoningDeltasFiredFlag(unittest.TestCase):
agent.reasoning_callback = lambda t: captured.append(t)
# No streaming
agent.stream_delta_callback = None
agent._reasoning_deltas_fired = False
msg = SimpleNamespace(
content="I'll merge that.",
-2
View File
@@ -38,8 +38,6 @@ def _isolate_hermes_home(tmp_path, monkeypatch):
monkeypatch.delenv("HERMES_SESSION_CHAT_ID", raising=False)
monkeypatch.delenv("HERMES_SESSION_CHAT_NAME", raising=False)
monkeypatch.delenv("HERMES_GATEWAY_SESSION", raising=False)
# Avoid making real calls during tests if this key is set in the env files
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
@pytest.fixture()
+33 -4
View File
@@ -141,7 +141,7 @@ class TestBlockingGatewayApproval:
def test_resolve_single_pops_oldest_fifo(self):
"""resolve_gateway_approval without resolve_all resolves oldest first."""
from tools.approval import (
resolve_gateway_approval,
resolve_gateway_approval, pending_approval_count,
_ApprovalEntry, _gateway_queues,
)
session_key = "test-fifo"
@@ -154,7 +154,7 @@ class TestBlockingGatewayApproval:
assert e1.event.is_set()
assert e1.result == "once"
assert not e2.event.is_set()
assert len(_gateway_queues[session_key]) == 1
assert pending_approval_count(session_key) == 1
def test_unregister_signals_all_entries(self):
"""unregister_gateway_notify signals all waiting entries to prevent hangs."""
@@ -173,6 +173,35 @@ class TestBlockingGatewayApproval:
assert e1.event.is_set()
assert e2.event.is_set()
def test_clear_session_signals_all_entries(self):
"""clear_session should unblock all waiting approval threads."""
from tools.approval import (
register_gateway_notify, clear_session,
_ApprovalEntry, _gateway_queues,
)
session_key = "test-clear"
register_gateway_notify(session_key, lambda d: None)
e1 = _ApprovalEntry({"command": "cmd1"})
e2 = _ApprovalEntry({"command": "cmd2"})
_gateway_queues[session_key] = [e1, e2]
clear_session(session_key)
assert e1.event.is_set()
assert e2.event.is_set()
def test_pending_approval_count(self):
from tools.approval import (
pending_approval_count, _ApprovalEntry, _gateway_queues,
)
session_key = "test-count"
assert pending_approval_count(session_key) == 0
_gateway_queues[session_key] = [
_ApprovalEntry({"command": "a"}),
_ApprovalEntry({"command": "b"}),
]
assert pending_approval_count(session_key) == 2
# ------------------------------------------------------------------
# /approve command
@@ -477,7 +506,7 @@ class TestBlockingApprovalE2E:
from tools.approval import (
register_gateway_notify, unregister_gateway_notify,
resolve_gateway_approval, check_all_command_guards,
_gateway_queues,
pending_approval_count,
)
session_key = "e2e-parallel"
@@ -516,7 +545,7 @@ class TestBlockingApprovalE2E:
time.sleep(0.05)
assert len(notified) == 3
assert len(_gateway_queues.get(session_key, [])) == 3
assert pending_approval_count(session_key) == 3
# Approve all at once
count = resolve_gateway_approval(session_key, "session", resolve_all=True)
+23 -1
View File
@@ -1,7 +1,7 @@
"""Tests for the delivery routing module."""
from gateway.config import Platform, GatewayConfig, PlatformConfig, HomeChannel
from gateway.delivery import DeliveryRouter, DeliveryTarget
from gateway.delivery import DeliveryRouter, DeliveryTarget, parse_deliver_spec
from gateway.session import SessionSource
@@ -41,6 +41,28 @@ class TestParseTargetPlatformChat:
assert target.platform == Platform.LOCAL
class TestParseDeliverSpec:
def test_none_returns_default(self):
result = parse_deliver_spec(None)
assert result == "origin"
def test_empty_string_returns_default(self):
result = parse_deliver_spec("")
assert result == "origin"
def test_custom_default(self):
result = parse_deliver_spec(None, default="local")
assert result == "local"
def test_passthrough_string(self):
result = parse_deliver_spec("telegram")
assert result == "telegram"
def test_passthrough_list(self):
result = parse_deliver_spec(["local", "telegram"])
assert result == ["local", "telegram"]
class TestTargetToStringRoundtrip:
def test_origin_roundtrip(self):
origin = SessionSource(platform=Platform.TELEGRAM, chat_id="111", thread_id="42")
+3 -3
View File
@@ -56,7 +56,7 @@ class FakeTree:
class FakeBot:
def __init__(self, *, intents, proxy=None):
def __init__(self, *, intents):
self.intents = intents
self.user = SimpleNamespace(id=999, name="Hermes")
self._events = {}
@@ -95,7 +95,7 @@ async def test_connect_only_requests_members_intent_when_needed(monkeypatch, all
created = {}
def fake_bot_factory(*, command_prefix, intents, proxy=None):
def fake_bot_factory(*, command_prefix, intents):
created["bot"] = FakeBot(intents=intents)
return created["bot"]
@@ -124,7 +124,7 @@ async def test_connect_releases_token_lock_on_timeout(monkeypatch):
monkeypatch.setattr(
discord_platform.commands,
"Bot",
lambda **kwargs: FakeBot(intents=kwargs["intents"], proxy=kwargs.get("proxy")),
lambda **kwargs: FakeBot(intents=kwargs["intents"]),
)
async def fake_wait_for(awaitable, timeout):
@@ -209,31 +209,14 @@ class TestIncomingDocumentHandling:
assert "[Content of readme.md]:" in event.text
assert "# Title" in event.text
@pytest.mark.asyncio
async def test_log_content_injected(self, adapter):
""".log file under 100KB should be treated as text/plain and injected."""
file_content = b"BLE trace line 1\nBLE trace line 2"
with _mock_aiohttp_download(file_content):
msg = make_message(
attachments=[make_attachment(filename="btsnoop_hci.log", content_type="text/plain")],
content="please inspect this",
)
await adapter._handle_message(msg)
event = adapter.handle_message.call_args[0][0]
assert "[Content of btsnoop_hci.log]:" in event.text
assert "BLE trace line 1" in event.text
assert "please inspect this" in event.text
@pytest.mark.asyncio
async def test_oversized_document_skipped(self, adapter):
"""A document over 32MB should be skipped — media_urls stays empty."""
"""A document over 20MB should be skipped — media_urls stays empty."""
msg = make_message([
make_attachment(
filename="huge.pdf",
content_type="application/pdf",
size=33 * 1024 * 1024,
size=25 * 1024 * 1024,
)
])
await adapter._handle_message(msg)
@@ -243,24 +226,6 @@ class TestIncomingDocumentHandling:
# handler must still be called
adapter.handle_message.assert_called_once()
@pytest.mark.asyncio
async def test_mid_sized_zip_under_32mb_is_cached(self, adapter):
"""A 25MB .zip should be accepted now that Discord documents allow up to 32MB."""
msg = make_message([
make_attachment(
filename="bugreport.zip",
content_type="application/zip",
size=25 * 1024 * 1024,
)
])
with _mock_aiohttp_download(b"PK\x03\x04test"):
await adapter._handle_message(msg)
event = adapter.handle_message.call_args[0][0]
assert len(event.media_urls) == 1
assert event.media_types == ["application/zip"]
@pytest.mark.asyncio
async def test_zip_document_cached(self, adapter):
"""A .zip file should be cached as a supported document."""
+11 -13
View File
@@ -38,11 +38,10 @@ def _make_timeout_error() -> httpx.TimeoutException:
# cache_image_from_url (base.py)
# ---------------------------------------------------------------------------
@patch("tools.url_safety.is_safe_url", return_value=True)
class TestCacheImageFromUrl:
"""Tests for gateway.platforms.base.cache_image_from_url"""
def test_success_on_first_attempt(self, _mock_safe, tmp_path, monkeypatch):
def test_success_on_first_attempt(self, tmp_path, monkeypatch):
"""A clean 200 response caches the image and returns a path."""
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
@@ -66,7 +65,7 @@ class TestCacheImageFromUrl:
assert path.endswith(".jpg")
mock_client.get.assert_called_once()
def test_retries_on_timeout_then_succeeds(self, _mock_safe, tmp_path, monkeypatch):
def test_retries_on_timeout_then_succeeds(self, tmp_path, monkeypatch):
"""A timeout on the first attempt is retried; second attempt succeeds."""
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
@@ -96,7 +95,7 @@ class TestCacheImageFromUrl:
assert mock_client.get.call_count == 2
mock_sleep.assert_called_once()
def test_retries_on_429_then_succeeds(self, _mock_safe, tmp_path, monkeypatch):
def test_retries_on_429_then_succeeds(self, tmp_path, monkeypatch):
"""A 429 response on the first attempt is retried; second attempt succeeds."""
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
@@ -123,7 +122,7 @@ class TestCacheImageFromUrl:
assert path.endswith(".jpg")
assert mock_client.get.call_count == 2
def test_raises_after_max_retries_exhausted(self, _mock_safe, tmp_path, monkeypatch):
def test_raises_after_max_retries_exhausted(self, tmp_path, monkeypatch):
"""Timeout on every attempt raises after all retries are consumed."""
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
@@ -146,7 +145,7 @@ class TestCacheImageFromUrl:
# 3 total calls: initial + 2 retries
assert mock_client.get.call_count == 3
def test_non_retryable_4xx_raises_immediately(self, _mock_safe, tmp_path, monkeypatch):
def test_non_retryable_4xx_raises_immediately(self, tmp_path, monkeypatch):
"""A 404 (non-retryable) is raised immediately without any retry."""
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
@@ -176,11 +175,10 @@ class TestCacheImageFromUrl:
# cache_audio_from_url (base.py)
# ---------------------------------------------------------------------------
@patch("tools.url_safety.is_safe_url", return_value=True)
class TestCacheAudioFromUrl:
"""Tests for gateway.platforms.base.cache_audio_from_url"""
def test_success_on_first_attempt(self, _mock_safe, tmp_path, monkeypatch):
def test_success_on_first_attempt(self, tmp_path, monkeypatch):
"""A clean 200 response caches the audio and returns a path."""
monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio")
@@ -204,7 +202,7 @@ class TestCacheAudioFromUrl:
assert path.endswith(".ogg")
mock_client.get.assert_called_once()
def test_retries_on_timeout_then_succeeds(self, _mock_safe, tmp_path, monkeypatch):
def test_retries_on_timeout_then_succeeds(self, tmp_path, monkeypatch):
"""A timeout on the first attempt is retried; second attempt succeeds."""
monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio")
@@ -234,7 +232,7 @@ class TestCacheAudioFromUrl:
assert mock_client.get.call_count == 2
mock_sleep.assert_called_once()
def test_retries_on_429_then_succeeds(self, _mock_safe, tmp_path, monkeypatch):
def test_retries_on_429_then_succeeds(self, tmp_path, monkeypatch):
"""A 429 response on the first attempt is retried; second attempt succeeds."""
monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio")
@@ -261,7 +259,7 @@ class TestCacheAudioFromUrl:
assert path.endswith(".ogg")
assert mock_client.get.call_count == 2
def test_retries_on_500_then_succeeds(self, _mock_safe, tmp_path, monkeypatch):
def test_retries_on_500_then_succeeds(self, tmp_path, monkeypatch):
"""A 500 response on the first attempt is retried; second attempt succeeds."""
monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio")
@@ -288,7 +286,7 @@ class TestCacheAudioFromUrl:
assert path.endswith(".ogg")
assert mock_client.get.call_count == 2
def test_raises_after_max_retries_exhausted(self, _mock_safe, tmp_path, monkeypatch):
def test_raises_after_max_retries_exhausted(self, tmp_path, monkeypatch):
"""Timeout on every attempt raises after all retries are consumed."""
monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio")
@@ -311,7 +309,7 @@ class TestCacheAudioFromUrl:
# 3 total calls: initial + 2 retries
assert mock_client.get.call_count == 3
def test_non_retryable_4xx_raises_immediately(self, _mock_safe, tmp_path, monkeypatch):
def test_non_retryable_4xx_raises_immediately(self, tmp_path, monkeypatch):
"""A 404 (non-retryable) is raised immediately without any retry."""
monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio")
+9
View File
@@ -7,6 +7,7 @@ from gateway.session import (
_hash_id,
_hash_sender_id,
_hash_chat_id,
_looks_like_phone,
)
from gateway.config import Platform, HomeChannel
@@ -38,6 +39,14 @@ class TestHashHelpers:
assert len(result) == 12
assert "12345" not in result
def test_looks_like_phone(self):
assert _looks_like_phone("+15551234567")
assert _looks_like_phone("15551234567")
assert _looks_like_phone("+1-555-123-4567")
assert not _looks_like_phone("alice")
assert not _looks_like_phone("user-123")
assert not _looks_like_phone("")
# ---------------------------------------------------------------------------
# Integration: build_session_context_prompt
-373
View File
@@ -619,18 +619,6 @@ class TestFormatMessage:
result = adapter.format_message("[click here](https://example.com)")
assert result == "<https://example.com|click here>"
def test_link_conversion_strips_markdown_angle_brackets(self, adapter):
result = adapter.format_message("[click here](<https://example.com>)")
assert result == "<https://example.com|click here>"
def test_escapes_control_characters(self, adapter):
result = adapter.format_message("AT&T < 5 > 3")
assert result == "AT&amp;T &lt; 5 &gt; 3"
def test_preserves_existing_slack_entities(self, adapter):
text = "Hey <@U123>, see <https://example.com|example> and <!here>"
assert adapter.format_message(text) == text
def test_strikethrough(self, adapter):
assert adapter.format_message("~~deleted~~") == "~deleted~"
@@ -655,325 +643,6 @@ class TestFormatMessage:
def test_none_passthrough(self, adapter):
assert adapter.format_message(None) is None
def test_blockquote_preserved(self, adapter):
"""Single-line blockquote > marker is preserved."""
assert adapter.format_message("> quoted text") == "> quoted text"
def test_multiline_blockquote(self, adapter):
"""Multi-line blockquote preserves > on each line."""
text = "> line one\n> line two"
assert adapter.format_message(text) == "> line one\n> line two"
def test_blockquote_with_formatting(self, adapter):
"""Blockquote containing bold text."""
assert adapter.format_message("> **bold quote**") == "> *bold quote*"
def test_nested_blockquote(self, adapter):
"""Multiple > characters for nested quotes."""
assert adapter.format_message(">> deeply quoted") == ">> deeply quoted"
def test_blockquote_mixed_with_plain(self, adapter):
"""Blockquote lines interleaved with plain text."""
text = "normal\n> quoted\nnormal again"
result = adapter.format_message(text)
assert "> quoted" in result
assert "normal" in result
def test_non_prefix_gt_still_escaped(self, adapter):
"""Greater-than in mid-line is still escaped."""
assert adapter.format_message("5 > 3") == "5 &gt; 3"
def test_blockquote_with_code(self, adapter):
"""Blockquote containing inline code."""
result = adapter.format_message("> use `fmt.Println`")
assert result.startswith(">")
assert "`fmt.Println`" in result
def test_bold_italic_combined(self, adapter):
"""Triple-star ***text*** converts to Slack bold+italic *_text_*."""
assert adapter.format_message("***hello***") == "*_hello_*"
def test_bold_italic_with_surrounding_text(self, adapter):
"""Bold+italic in a sentence."""
result = adapter.format_message("This is ***important*** stuff")
assert "*_important_*" in result
def test_bold_italic_does_not_break_plain_bold(self, adapter):
"""**bold** still works after adding ***bold italic*** support."""
assert adapter.format_message("**bold**") == "*bold*"
def test_bold_italic_does_not_break_plain_italic(self, adapter):
"""*italic* still works after adding ***bold italic*** support."""
assert adapter.format_message("*italic*") == "_italic_"
def test_bold_italic_mixed_with_bold(self, adapter):
"""Both ***bold italic*** and **bold** in the same message."""
result = adapter.format_message("***important*** and **bold**")
assert "*_important_*" in result
assert "*bold*" in result
def test_pre_escaped_ampersand_not_double_escaped(self, adapter):
"""Already-escaped &amp; must not become &amp;amp;."""
assert adapter.format_message("&amp;") == "&amp;"
def test_pre_escaped_lt_not_double_escaped(self, adapter):
"""Already-escaped &lt; must not become &amp;lt;."""
assert adapter.format_message("&lt;") == "&lt;"
def test_pre_escaped_gt_not_double_escaped(self, adapter):
"""Already-escaped &gt; in plain text must not become &amp;gt;."""
assert adapter.format_message("5 &gt; 3") == "5 &gt; 3"
def test_mixed_raw_and_escaped_entities(self, adapter):
"""Raw & and pre-escaped &amp; coexist correctly."""
result = adapter.format_message("AT&T and &amp; entity")
assert result == "AT&amp;T and &amp; entity"
def test_link_with_parentheses_in_url(self, adapter):
"""Wikipedia-style URL with balanced parens is not truncated."""
result = adapter.format_message("[Foo](https://en.wikipedia.org/wiki/Foo_(bar))")
assert result == "<https://en.wikipedia.org/wiki/Foo_(bar)|Foo>"
def test_link_with_multiple_paren_pairs(self, adapter):
"""URL with multiple balanced paren pairs."""
result = adapter.format_message("[text](https://example.com/a_(b)_c_(d))")
assert result == "<https://example.com/a_(b)_c_(d)|text>"
def test_link_without_parens_still_works(self, adapter):
"""Normal URL without parens is unaffected by regex change."""
result = adapter.format_message("[click](https://example.com/path?q=1)")
assert result == "<https://example.com/path?q=1|click>"
def test_link_with_angle_brackets_and_parens(self, adapter):
"""Angle-bracket URL with parens (CommonMark syntax)."""
result = adapter.format_message("[Foo](<https://en.wikipedia.org/wiki/Foo_(bar)>)")
assert result == "<https://en.wikipedia.org/wiki/Foo_(bar)|Foo>"
def test_escaping_is_idempotent(self, adapter):
"""Formatting already-formatted text produces the same result."""
original = "AT&T < 5 > 3"
once = adapter.format_message(original)
twice = adapter.format_message(once)
assert once == twice
# --- Entity preservation (spec-compliance) ---
def test_channel_mention_preserved(self, adapter):
"""<!channel> special mention passes through unchanged."""
assert adapter.format_message("Attention <!channel>") == "Attention <!channel>"
def test_everyone_mention_preserved(self, adapter):
"""<!everyone> special mention passes through unchanged."""
assert adapter.format_message("Hey <!everyone>") == "Hey <!everyone>"
def test_subteam_mention_preserved(self, adapter):
"""<!subteam^ID> user group mention passes through unchanged."""
assert adapter.format_message("Paging <!subteam^S12345>") == "Paging <!subteam^S12345>"
def test_date_formatting_preserved(self, adapter):
"""<!date^...> formatting token passes through unchanged."""
text = "Posted <!date^1392734382^{date_pretty}|Feb 18, 2014>"
assert adapter.format_message(text) == text
def test_channel_link_preserved(self, adapter):
"""<#CHANNEL_ID> channel link passes through unchanged."""
assert adapter.format_message("Join <#C12345>") == "Join <#C12345>"
# --- Additional edge cases ---
def test_message_only_code_block(self, adapter):
"""Entire message is a fenced code block — no conversion."""
code = "```python\nx = 1\n```"
assert adapter.format_message(code) == code
def test_multiline_mixed_formatting(self, adapter):
"""Multi-line message with headers, bold, links, code, and blockquotes."""
text = "## Title\n**bold** and [link](https://x.com)\n> quote\n`code`"
result = adapter.format_message(text)
assert result.startswith("*Title*")
assert "*bold*" in result
assert "<https://x.com|link>" in result
assert "> quote" in result
assert "`code`" in result
def test_markdown_unordered_list_with_asterisk(self, adapter):
"""Asterisk list items must not trigger italic conversion."""
text = "* item one\n* item two"
result = adapter.format_message(text)
assert "item one" in result
assert "item two" in result
def test_nested_bold_in_link(self, adapter):
"""Bold inside link label — label is stashed before bold pass."""
result = adapter.format_message("[**bold**](https://example.com)")
assert "https://example.com" in result
assert "bold" in result
def test_url_with_query_string_and_ampersand(self, adapter):
"""Ampersand in URL query string must not be escaped."""
result = adapter.format_message("[link](https://x.com?a=1&b=2)")
assert result == "<https://x.com?a=1&b=2|link>"
def test_emoji_shortcodes_passthrough(self, adapter):
"""Emoji shortcodes like :smile: pass through unchanged."""
assert adapter.format_message(":smile: hello :wave:") == ":smile: hello :wave:"
# ---------------------------------------------------------------------------
# TestEditMessage
# ---------------------------------------------------------------------------
class TestEditMessage:
"""Verify that edit_message() applies mrkdwn formatting before sending."""
@pytest.mark.asyncio
async def test_edit_message_formats_bold(self, adapter):
"""edit_message converts **bold** to Slack *bold*."""
adapter._app.client.chat_update = AsyncMock(return_value={"ok": True})
await adapter.edit_message("C123", "1234.5678", "**hello world**")
kwargs = adapter._app.client.chat_update.call_args.kwargs
assert kwargs["text"] == "*hello world*"
@pytest.mark.asyncio
async def test_edit_message_formats_links(self, adapter):
"""edit_message converts markdown links to Slack format."""
adapter._app.client.chat_update = AsyncMock(return_value={"ok": True})
await adapter.edit_message("C123", "1234.5678", "[click](https://example.com)")
kwargs = adapter._app.client.chat_update.call_args.kwargs
assert kwargs["text"] == "<https://example.com|click>"
@pytest.mark.asyncio
async def test_edit_message_preserves_blockquotes(self, adapter):
"""edit_message preserves blockquote > markers."""
adapter._app.client.chat_update = AsyncMock(return_value={"ok": True})
await adapter.edit_message("C123", "1234.5678", "> quoted text")
kwargs = adapter._app.client.chat_update.call_args.kwargs
assert kwargs["text"] == "> quoted text"
@pytest.mark.asyncio
async def test_edit_message_escapes_control_chars(self, adapter):
"""edit_message escapes & < > in plain text."""
adapter._app.client.chat_update = AsyncMock(return_value={"ok": True})
await adapter.edit_message("C123", "1234.5678", "AT&T < 5 > 3")
kwargs = adapter._app.client.chat_update.call_args.kwargs
assert kwargs["text"] == "AT&amp;T &lt; 5 &gt; 3"
# ---------------------------------------------------------------------------
# TestEditMessageStreamingPipeline
# ---------------------------------------------------------------------------
class TestEditMessageStreamingPipeline:
"""E2E: verify that sequential streaming edits all go through format_message.
Simulates the GatewayStreamConsumer pattern where edit_message is called
repeatedly with progressively longer accumulated text. Every call must
produce properly formatted mrkdwn in the chat_update payload.
"""
@pytest.mark.asyncio
async def test_edit_message_formats_streaming_updates(self, adapter):
"""Simulates streaming: multiple edits, each should be formatted."""
adapter._app.client.chat_update = AsyncMock(return_value={"ok": True})
# First streaming update — bold
result1 = await adapter.edit_message("C123", "ts1", "**Processing**...")
assert result1.success is True
kwargs1 = adapter._app.client.chat_update.call_args.kwargs
assert kwargs1["text"] == "*Processing*..."
# Second streaming update — bold + link
result2 = await adapter.edit_message(
"C123", "ts1", "**Done!** See [results](https://example.com)"
)
assert result2.success is True
kwargs2 = adapter._app.client.chat_update.call_args.kwargs
assert kwargs2["text"] == "*Done!* See <https://example.com|results>"
@pytest.mark.asyncio
async def test_edit_message_formats_code_and_bold(self, adapter):
"""Streaming update with code block and bold — code must be preserved."""
adapter._app.client.chat_update = AsyncMock(return_value={"ok": True})
content = "**Result:**\n```python\nprint('hello')\n```"
result = await adapter.edit_message("C123", "ts1", content)
assert result.success is True
kwargs = adapter._app.client.chat_update.call_args.kwargs
assert kwargs["text"].startswith("*Result:*")
assert "```python\nprint('hello')\n```" in kwargs["text"]
@pytest.mark.asyncio
async def test_edit_message_formats_blockquote_in_stream(self, adapter):
"""Streaming update with blockquote — '>' marker must survive."""
adapter._app.client.chat_update = AsyncMock(return_value={"ok": True})
content = "> **Important:** do this\nnormal line"
result = await adapter.edit_message("C123", "ts1", content)
assert result.success is True
kwargs = adapter._app.client.chat_update.call_args.kwargs
assert kwargs["text"].startswith("> *Important:*")
assert "normal line" in kwargs["text"]
@pytest.mark.asyncio
async def test_edit_message_formats_progressive_accumulation(self, adapter):
"""Simulate real streaming: text grows with each edit, all formatted."""
adapter._app.client.chat_update = AsyncMock(return_value={"ok": True})
updates = [
("**Step 1**", "*Step 1*"),
("**Step 1**\n**Step 2**", "*Step 1*\n*Step 2*"),
(
"**Step 1**\n**Step 2**\nSee [docs](https://docs.example.com)",
"*Step 1*\n*Step 2*\nSee <https://docs.example.com|docs>",
),
]
for raw, expected in updates:
result = await adapter.edit_message("C123", "ts1", raw)
assert result.success is True
kwargs = adapter._app.client.chat_update.call_args.kwargs
assert kwargs["text"] == expected, f"Failed for input: {raw!r}"
# Total edit count should match number of updates
assert adapter._app.client.chat_update.call_count == len(updates)
@pytest.mark.asyncio
async def test_edit_message_formats_bold_italic(self, adapter):
"""Bold+italic ***text*** is formatted as *_text_* in edited messages."""
adapter._app.client.chat_update = AsyncMock(return_value={"ok": True})
await adapter.edit_message("C123", "ts1", "***important*** update")
kwargs = adapter._app.client.chat_update.call_args.kwargs
assert "*_important_*" in kwargs["text"]
@pytest.mark.asyncio
async def test_edit_message_does_not_double_escape(self, adapter):
"""Pre-escaped entities in edited messages must not get double-escaped."""
adapter._app.client.chat_update = AsyncMock(return_value={"ok": True})
await adapter.edit_message("C123", "ts1", "5 &gt; 3 and &amp; entity")
kwargs = adapter._app.client.chat_update.call_args.kwargs
assert "&amp;gt;" not in kwargs["text"]
assert "&amp;amp;" not in kwargs["text"]
assert "&gt;" in kwargs["text"]
assert "&amp;" in kwargs["text"]
@pytest.mark.asyncio
async def test_edit_message_formats_url_with_parens(self, adapter):
"""Wikipedia-style URL with parens survives edit pipeline."""
adapter._app.client.chat_update = AsyncMock(return_value={"ok": True})
await adapter.edit_message("C123", "ts1", "See [Foo](https://en.wikipedia.org/wiki/Foo_(bar))")
kwargs = adapter._app.client.chat_update.call_args.kwargs
assert "<https://en.wikipedia.org/wiki/Foo_(bar)|Foo>" in kwargs["text"]
@pytest.mark.asyncio
async def test_edit_message_not_connected(self, adapter):
"""edit_message returns failure when adapter is not connected."""
adapter._app = None
result = await adapter.edit_message("C123", "ts1", "**hello**")
assert result.success is False
assert "Not connected" in result.error
# ---------------------------------------------------------------------------
# TestReactions
@@ -1416,48 +1085,6 @@ class TestMessageSplitting:
await adapter.send("C123", "hello world")
assert adapter._app.client.chat_postMessage.call_count == 1
@pytest.mark.asyncio
async def test_send_preserves_blockquote_formatting(self, adapter):
"""Blockquote '>' markers must survive format → chunk → send pipeline."""
adapter._app.client.chat_postMessage = AsyncMock(return_value={"ts": "ts1"})
await adapter.send("C123", "> quoted text\nnormal text")
kwargs = adapter._app.client.chat_postMessage.call_args.kwargs
sent_text = kwargs["text"]
assert sent_text.startswith("> quoted text")
assert "normal text" in sent_text
@pytest.mark.asyncio
async def test_send_formats_bold_italic(self, adapter):
"""Bold+italic ***text*** is formatted as *_text_* in sent messages."""
adapter._app.client.chat_postMessage = AsyncMock(return_value={"ts": "ts1"})
await adapter.send("C123", "***important*** update")
kwargs = adapter._app.client.chat_postMessage.call_args.kwargs
assert "*_important_*" in kwargs["text"]
@pytest.mark.asyncio
async def test_send_explicitly_enables_mrkdwn(self, adapter):
adapter._app.client.chat_postMessage = AsyncMock(return_value={"ts": "ts1"})
await adapter.send("C123", "**hello**")
kwargs = adapter._app.client.chat_postMessage.call_args.kwargs
assert kwargs.get("mrkdwn") is True
@pytest.mark.asyncio
async def test_send_does_not_double_escape_entities(self, adapter):
"""Pre-escaped &amp; in sent messages must not become &amp;amp;."""
adapter._app.client.chat_postMessage = AsyncMock(return_value={"ts": "ts1"})
await adapter.send("C123", "Use &amp; for ampersand")
kwargs = adapter._app.client.chat_postMessage.call_args.kwargs
assert "&amp;amp;" not in kwargs["text"]
assert "&amp;" in kwargs["text"]
@pytest.mark.asyncio
async def test_send_formats_url_with_parens(self, adapter):
"""Wikipedia-style URL with parens survives send pipeline."""
adapter._app.client.chat_postMessage = AsyncMock(return_value={"ts": "ts1"})
await adapter.send("C123", "See [Foo](https://en.wikipedia.org/wiki/Foo_(bar))")
kwargs = adapter._app.client.chat_postMessage.call_args.kwargs
assert "<https://en.wikipedia.org/wiki/Foo_(bar)|Foo>" in kwargs["text"]
# ---------------------------------------------------------------------------
# TestReplyBroadcast
-312
View File
@@ -1,312 +0,0 @@
"""
Tests for Slack mention gating (require_mention / free_response_channels).
Follows the same pattern as test_whatsapp_group_gating.py.
"""
import sys
from unittest.mock import MagicMock
from gateway.config import Platform, PlatformConfig
# ---------------------------------------------------------------------------
# Mock slack-bolt if not installed (same as test_slack.py)
# ---------------------------------------------------------------------------
def _ensure_slack_mock():
if "slack_bolt" in sys.modules and hasattr(sys.modules["slack_bolt"], "__file__"):
return
slack_bolt = MagicMock()
slack_bolt.async_app.AsyncApp = MagicMock
slack_bolt.adapter.socket_mode.async_handler.AsyncSocketModeHandler = MagicMock
slack_sdk = MagicMock()
slack_sdk.web.async_client.AsyncWebClient = MagicMock
for name, mod in [
("slack_bolt", slack_bolt),
("slack_bolt.async_app", slack_bolt.async_app),
("slack_bolt.adapter", slack_bolt.adapter),
("slack_bolt.adapter.socket_mode", slack_bolt.adapter.socket_mode),
("slack_bolt.adapter.socket_mode.async_handler", slack_bolt.adapter.socket_mode.async_handler),
("slack_sdk", slack_sdk),
("slack_sdk.web", slack_sdk.web),
("slack_sdk.web.async_client", slack_sdk.web.async_client),
]:
sys.modules.setdefault(name, mod)
_ensure_slack_mock()
import gateway.platforms.slack as _slack_mod
_slack_mod.SLACK_AVAILABLE = True
from gateway.platforms.slack import SlackAdapter # noqa: E402
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
BOT_USER_ID = "U_BOT_123"
CHANNEL_ID = "C0AQWDLHY9M"
OTHER_CHANNEL_ID = "C9999999999"
def _make_adapter(require_mention=None, free_response_channels=None):
extra = {}
if require_mention is not None:
extra["require_mention"] = require_mention
if free_response_channels is not None:
extra["free_response_channels"] = free_response_channels
adapter = object.__new__(SlackAdapter)
adapter.platform = Platform.SLACK
adapter.config = PlatformConfig(enabled=True, extra=extra)
adapter._bot_user_id = BOT_USER_ID
adapter._team_bot_user_ids = {}
return adapter
# ---------------------------------------------------------------------------
# Tests: _slack_require_mention
# ---------------------------------------------------------------------------
def test_require_mention_defaults_to_true(monkeypatch):
monkeypatch.delenv("SLACK_REQUIRE_MENTION", raising=False)
adapter = _make_adapter()
assert adapter._slack_require_mention() is True
def test_require_mention_false():
adapter = _make_adapter(require_mention=False)
assert adapter._slack_require_mention() is False
def test_require_mention_true():
adapter = _make_adapter(require_mention=True)
assert adapter._slack_require_mention() is True
def test_require_mention_string_true():
adapter = _make_adapter(require_mention="true")
assert adapter._slack_require_mention() is True
def test_require_mention_string_false():
adapter = _make_adapter(require_mention="false")
assert adapter._slack_require_mention() is False
def test_require_mention_string_no():
adapter = _make_adapter(require_mention="no")
assert adapter._slack_require_mention() is False
def test_require_mention_string_yes():
adapter = _make_adapter(require_mention="yes")
assert adapter._slack_require_mention() is True
def test_require_mention_empty_string_stays_true():
"""Empty/malformed strings keep gating ON (explicit-false parser)."""
adapter = _make_adapter(require_mention="")
assert adapter._slack_require_mention() is True
def test_require_mention_malformed_string_stays_true():
"""Unrecognised values keep gating ON (fail-closed)."""
adapter = _make_adapter(require_mention="maybe")
assert adapter._slack_require_mention() is True
def test_require_mention_env_var_fallback(monkeypatch):
monkeypatch.setenv("SLACK_REQUIRE_MENTION", "false")
adapter = _make_adapter() # no config value -> falls back to env
assert adapter._slack_require_mention() is False
def test_require_mention_env_var_default_true(monkeypatch):
monkeypatch.delenv("SLACK_REQUIRE_MENTION", raising=False)
adapter = _make_adapter()
assert adapter._slack_require_mention() is True
# ---------------------------------------------------------------------------
# Tests: _slack_free_response_channels
# ---------------------------------------------------------------------------
def test_free_response_channels_default_empty(monkeypatch):
monkeypatch.delenv("SLACK_FREE_RESPONSE_CHANNELS", raising=False)
adapter = _make_adapter()
assert adapter._slack_free_response_channels() == set()
def test_free_response_channels_list():
adapter = _make_adapter(free_response_channels=[CHANNEL_ID, OTHER_CHANNEL_ID])
result = adapter._slack_free_response_channels()
assert CHANNEL_ID in result
assert OTHER_CHANNEL_ID in result
def test_free_response_channels_csv_string():
adapter = _make_adapter(free_response_channels=f"{CHANNEL_ID}, {OTHER_CHANNEL_ID}")
result = adapter._slack_free_response_channels()
assert CHANNEL_ID in result
assert OTHER_CHANNEL_ID in result
def test_free_response_channels_empty_string():
adapter = _make_adapter(free_response_channels="")
assert adapter._slack_free_response_channels() == set()
def test_free_response_channels_env_var_fallback(monkeypatch):
monkeypatch.setenv("SLACK_FREE_RESPONSE_CHANNELS", f"{CHANNEL_ID},{OTHER_CHANNEL_ID}")
adapter = _make_adapter() # no config value → falls back to env
result = adapter._slack_free_response_channels()
assert CHANNEL_ID in result
assert OTHER_CHANNEL_ID in result
# ---------------------------------------------------------------------------
# Tests: mention gating integration (simulating _handle_slack_message logic)
# ---------------------------------------------------------------------------
def _would_process(adapter, *, is_dm=False, channel_id=CHANNEL_ID,
text="hello", mentioned=False, thread_reply=False,
active_session=False):
"""Simulate the mention gating logic from _handle_slack_message.
Returns True if the message would be processed, False if it would be
skipped (returned early).
"""
bot_uid = adapter._team_bot_user_ids.get("T1", adapter._bot_user_id)
if mentioned:
text = f"<@{bot_uid}> {text}"
is_mentioned = bot_uid and f"<@{bot_uid}>" in text
if not is_dm:
if channel_id in adapter._slack_free_response_channels():
return True
elif not adapter._slack_require_mention():
return True
elif not is_mentioned:
if thread_reply and active_session:
return True
else:
return False
return True
def test_default_require_mention_channel_without_mention_ignored():
adapter = _make_adapter() # default: require_mention=True
assert _would_process(adapter, text="hello everyone") is False
def test_require_mention_false_channel_without_mention_processed():
adapter = _make_adapter(require_mention=False)
assert _would_process(adapter, text="hello everyone") is True
def test_channel_in_free_response_processed_without_mention():
adapter = _make_adapter(
require_mention=True,
free_response_channels=[CHANNEL_ID],
)
assert _would_process(adapter, channel_id=CHANNEL_ID, text="hello") is True
def test_other_channel_not_in_free_response_still_gated():
adapter = _make_adapter(
require_mention=True,
free_response_channels=[CHANNEL_ID],
)
assert _would_process(adapter, channel_id=OTHER_CHANNEL_ID, text="hello") is False
def test_dm_always_processed_regardless_of_setting():
adapter = _make_adapter(require_mention=True)
assert _would_process(adapter, is_dm=True, text="hello") is True
def test_mentioned_message_always_processed():
adapter = _make_adapter(require_mention=True)
assert _would_process(adapter, mentioned=True, text="what's up") is True
def test_thread_reply_with_active_session_processed():
adapter = _make_adapter(require_mention=True)
assert _would_process(
adapter, text="followup",
thread_reply=True, active_session=True,
) is True
def test_thread_reply_without_active_session_ignored():
adapter = _make_adapter(require_mention=True)
assert _would_process(
adapter, text="followup",
thread_reply=True, active_session=False,
) is False
def test_bot_uid_none_processes_channel_message():
"""When bot_uid is None (before auth_test), channel messages pass through.
This preserves the old behavior: the gating block is skipped entirely
when bot_uid is falsy, so messages are not silently dropped during
startup or for new workspaces.
"""
adapter = _make_adapter(require_mention=True)
adapter._bot_user_id = None
adapter._team_bot_user_ids = {}
# With bot_uid=None, the `if not is_dm and bot_uid:` condition is False,
# so the gating block is skipped — message passes through.
bot_uid = adapter._team_bot_user_ids.get("T1", adapter._bot_user_id)
assert bot_uid is None
# Simulate: gating block not entered when bot_uid is falsy
is_dm = False
if not is_dm and bot_uid:
result = False # would enter gating
else:
result = True # gating skipped, message processed
assert result is True
# ---------------------------------------------------------------------------
# Tests: config bridging
# ---------------------------------------------------------------------------
def test_config_bridges_slack_free_response_channels(monkeypatch, tmp_path):
from gateway.config import load_gateway_config
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
(hermes_home / "config.yaml").write_text(
"slack:\n"
" require_mention: false\n"
" free_response_channels:\n"
" - C0AQWDLHY9M\n"
" - C9999999999\n",
encoding="utf-8",
)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.delenv("SLACK_REQUIRE_MENTION", raising=False)
monkeypatch.delenv("SLACK_FREE_RESPONSE_CHANNELS", raising=False)
config = load_gateway_config()
assert config is not None
slack_extra = config.platforms[Platform.SLACK].extra
assert slack_extra.get("require_mention") is False
assert slack_extra.get("free_response_channels") == ["C0AQWDLHY9M", "C9999999999"]
# Verify env vars were set by config bridging
import os as _os
assert _os.environ["SLACK_REQUIRE_MENTION"] == "false"
assert _os.environ["SLACK_FREE_RESPONSE_CHANNELS"] == "C0AQWDLHY9M,C9999999999"
+2 -3
View File
@@ -4,7 +4,7 @@ import base64
import os
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock
import pytest
@@ -355,8 +355,7 @@ class TestMediaUpload:
assert calls[3][1]["chunk_index"] == 2
@pytest.mark.asyncio
@patch("tools.url_safety.is_safe_url", return_value=True)
async def test_download_remote_bytes_rejects_large_content_length(self, _mock_safe):
async def test_download_remote_bytes_rejects_large_content_length(self):
from gateway.platforms.wecom import WeComAdapter
class FakeResponse:
+5 -20
View File
@@ -628,21 +628,14 @@ class TestHasAnyProviderConfigured:
def test_claude_code_creds_ignored_on_fresh_install(self, monkeypatch, tmp_path):
"""Claude Code credentials should NOT skip the wizard when Hermes is unconfigured."""
from hermes_cli import config as config_module
from hermes_cli.auth import PROVIDER_REGISTRY
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
monkeypatch.setattr(config_module, "get_env_path", lambda: hermes_home / ".env")
monkeypatch.setattr(config_module, "get_hermes_home", lambda: hermes_home)
# Clear all provider env vars so earlier checks don't short-circuit
_all_vars = {"OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY",
"ANTHROPIC_TOKEN", "OPENAI_BASE_URL"}
for pconfig in PROVIDER_REGISTRY.values():
if pconfig.auth_type == "api_key":
_all_vars.update(pconfig.api_key_env_vars)
for var in _all_vars:
for var in ("OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY",
"ANTHROPIC_TOKEN", "OPENAI_BASE_URL"):
monkeypatch.delenv(var, raising=False)
# Prevent gh-cli / copilot auth fallback from leaking in
monkeypatch.setattr("hermes_cli.auth.get_auth_status", lambda _pid: {})
# Simulate valid Claude Code credentials
monkeypatch.setattr(
"agent.anthropic_adapter.read_claude_code_credentials",
@@ -717,7 +710,6 @@ class TestHasAnyProviderConfigured:
"""config.yaml model dict with empty default and no creds stays false."""
import yaml
from hermes_cli import config as config_module
from hermes_cli.auth import PROVIDER_REGISTRY
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
config_file = hermes_home / "config.yaml"
@@ -727,15 +719,9 @@ class TestHasAnyProviderConfigured:
monkeypatch.setattr(config_module, "get_env_path", lambda: hermes_home / ".env")
monkeypatch.setattr(config_module, "get_hermes_home", lambda: hermes_home)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
_all_vars = {"OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY",
"ANTHROPIC_TOKEN", "OPENAI_BASE_URL"}
for pconfig in PROVIDER_REGISTRY.values():
if pconfig.auth_type == "api_key":
_all_vars.update(pconfig.api_key_env_vars)
for var in _all_vars:
for var in ("OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY",
"ANTHROPIC_TOKEN", "OPENAI_BASE_URL"):
monkeypatch.delenv(var, raising=False)
# Prevent gh-cli / copilot auth fallback from leaking in
monkeypatch.setattr("hermes_cli.auth.get_auth_status", lambda _pid: {})
from hermes_cli.main import _has_any_provider_configured
assert _has_any_provider_configured() is False
@@ -955,10 +941,9 @@ class TestHuggingFaceModels:
"""Every HF model should have a context length entry."""
from hermes_cli.models import _PROVIDER_MODELS
from agent.model_metadata import DEFAULT_CONTEXT_LENGTHS
lower_keys = {k.lower() for k in DEFAULT_CONTEXT_LENGTHS}
hf_models = _PROVIDER_MODELS["huggingface"]
for model in hf_models:
assert model.lower() in lower_keys, (
assert model in DEFAULT_CONTEXT_LENGTHS, (
f"HF model {model!r} missing from DEFAULT_CONTEXT_LENGTHS"
)
+2 -13
View File
@@ -68,17 +68,6 @@ class TestCommandRegistry:
for cmd in COMMAND_REGISTRY:
assert cmd.category in valid_categories, f"{cmd.name} has invalid category '{cmd.category}'"
def test_reasoning_subcommands_are_in_logical_order(self):
reasoning = next(cmd for cmd in COMMAND_REGISTRY if cmd.name == "reasoning")
assert reasoning.subcommands[:6] == (
"none",
"minimal",
"low",
"medium",
"high",
"xhigh",
)
def test_cli_only_and_gateway_only_are_mutually_exclusive(self):
for cmd in COMMAND_REGISTRY:
assert not (cmd.cli_only and cmd.gateway_only), \
@@ -436,8 +425,8 @@ class TestSlashCommandCompleter:
class TestSubcommands:
def test_explicit_subcommands_extracted(self):
"""Commands with explicit subcommands on CommandDef are extracted."""
assert "/skills" in SUBCOMMANDS
assert "install" in SUBCOMMANDS["/skills"]
assert "/prompt" in SUBCOMMANDS
assert "clear" in SUBCOMMANDS["/prompt"]
def test_reasoning_has_subcommands(self):
assert "/reasoning" in SUBCOMMANDS
+6
View File
@@ -35,6 +35,12 @@ class TestTokenValidation:
valid, msg = validate_copilot_token("")
assert valid is False
def test_is_classic_pat(self):
from hermes_cli.copilot_auth import is_classic_pat
assert is_classic_pat("ghp_abc123") is True
assert is_classic_pat("gho_abc123") is False
assert is_classic_pat("github_pat_abc") is False
assert is_classic_pat("") is False
class TestResolveToken:
@@ -0,0 +1,50 @@
"""Tests for detect_external_credentials() -- Phase 2 credential sync."""
import json
from pathlib import Path
from unittest.mock import patch
import pytest
from hermes_cli.auth import detect_external_credentials
class TestDetectCodexCLI:
def test_detects_valid_codex_auth(self, tmp_path, monkeypatch):
codex_dir = tmp_path / ".codex"
codex_dir.mkdir()
auth = codex_dir / "auth.json"
auth.write_text(json.dumps({
"tokens": {"access_token": "tok-123", "refresh_token": "ref-456"}
}))
monkeypatch.setenv("CODEX_HOME", str(codex_dir))
result = detect_external_credentials()
codex_hits = [c for c in result if c["provider"] == "openai-codex"]
assert len(codex_hits) == 1
assert "Codex CLI" in codex_hits[0]["label"]
def test_skips_codex_without_access_token(self, tmp_path, monkeypatch):
codex_dir = tmp_path / ".codex"
codex_dir.mkdir()
(codex_dir / "auth.json").write_text(json.dumps({"tokens": {}}))
monkeypatch.setenv("CODEX_HOME", str(codex_dir))
result = detect_external_credentials()
assert not any(c["provider"] == "openai-codex" for c in result)
def test_skips_missing_codex_dir(self, tmp_path, monkeypatch):
monkeypatch.setenv("CODEX_HOME", str(tmp_path / "nonexistent"))
result = detect_external_credentials()
assert not any(c["provider"] == "openai-codex" for c in result)
def test_skips_malformed_codex_auth(self, tmp_path, monkeypatch):
codex_dir = tmp_path / ".codex"
codex_dir.mkdir()
(codex_dir / "auth.json").write_text("{bad json")
monkeypatch.setenv("CODEX_HOME", str(codex_dir))
result = detect_external_credentials()
assert not any(c["provider"] == "openai-codex" for c in result)
def test_returns_empty_when_nothing_found(self, tmp_path, monkeypatch):
monkeypatch.setenv("CODEX_HOME", str(tmp_path / "nonexistent"))
result = detect_external_credentials()
assert result == []
+38 -4
View File
@@ -3,13 +3,15 @@
from unittest.mock import patch, MagicMock
from hermes_cli.models import (
OPENROUTER_MODELS, model_ids, detect_provider_for_model,
OPENROUTER_MODELS, menu_labels, model_ids, detect_provider_for_model,
filter_nous_free_models, _NOUS_ALLOWED_FREE_MODELS,
is_nous_free_tier, partition_nous_models_by_tier,
check_nous_free_tier, _FREE_TIER_CACHE_TTL,
check_nous_free_tier, clear_nous_free_tier_cache,
_FREE_TIER_CACHE_TTL,
)
import hermes_cli.models as _models_mod
class TestModelIds:
def test_returns_non_empty_list(self):
ids = model_ids()
@@ -31,6 +33,25 @@ class TestModelIds:
assert len(ids) == len(set(ids)), "Duplicate model IDs found"
class TestMenuLabels:
def test_same_length_as_model_ids(self):
assert len(menu_labels()) == len(model_ids())
def test_first_label_marked_recommended(self):
labels = menu_labels()
assert "recommended" in labels[0].lower()
def test_each_label_contains_its_model_id(self):
for label, mid in zip(menu_labels(), model_ids()):
assert mid in label, f"Label '{label}' doesn't contain model ID '{mid}'"
def test_non_recommended_labels_have_no_tag(self):
"""Only the first model should have (recommended)."""
labels = menu_labels()
for label in labels[1:]:
assert "recommended" not in label.lower(), f"Unexpected 'recommended' in '{label}'"
class TestOpenRouterModels:
def test_structure_is_list_of_tuples(self):
for entry in OPENROUTER_MODELS:
@@ -281,10 +302,12 @@ class TestCheckNousFreeTierCache:
"""Tests for the TTL cache on check_nous_free_tier()."""
def setup_method(self):
_models_mod._free_tier_cache = None
"""Reset cache before each test."""
clear_nous_free_tier_cache()
def teardown_method(self):
_models_mod._free_tier_cache = None
"""Reset cache after each test."""
clear_nous_free_tier_cache()
@patch("hermes_cli.models.fetch_nous_account_tier")
@patch("hermes_cli.models.is_nous_free_tier", return_value=True)
@@ -298,6 +321,7 @@ class TestCheckNousFreeTierCache:
assert result1 is True
assert result2 is True
# fetch_nous_account_tier should only be called once (cached on second call)
assert mock_fetch.call_count == 1
@patch("hermes_cli.models.fetch_nous_account_tier")
@@ -310,6 +334,7 @@ class TestCheckNousFreeTierCache:
result1 = check_nous_free_tier()
assert mock_fetch.call_count == 1
# Simulate TTL expiry by backdating the cache timestamp
cached_result, cached_at = _models_mod._free_tier_cache
_models_mod._free_tier_cache = (cached_result, cached_at - _FREE_TIER_CACHE_TTL - 1)
@@ -319,6 +344,15 @@ class TestCheckNousFreeTierCache:
assert result1 is False
assert result2 is False
def test_clear_cache_forces_refresh(self):
"""clear_nous_free_tier_cache() invalidates the cached result."""
# Manually seed the cache
import time
_models_mod._free_tier_cache = (True, time.monotonic())
clear_nous_free_tier_cache()
assert _models_mod._free_tier_cache is None
def test_cache_ttl_is_short(self):
"""TTL should be short enough to catch upgrades quickly (<=5 min)."""
assert _FREE_TIER_CACHE_TTL <= 300
@@ -1,34 +0,0 @@
import sys
import types
from hermes_cli.main import _prompt_reasoning_effort_selection
class _FakeTerminalMenu:
last_choices = None
def __init__(self, choices, **kwargs):
_FakeTerminalMenu.last_choices = choices
self._cursor_index = kwargs.get("cursor_index")
def show(self):
return self._cursor_index
def test_reasoning_menu_orders_minimal_before_low(monkeypatch):
fake_module = types.SimpleNamespace(TerminalMenu=_FakeTerminalMenu)
monkeypatch.setitem(sys.modules, "simple_term_menu", fake_module)
selected = _prompt_reasoning_effort_selection(
["low", "minimal", "medium", "high"],
current_effort="medium",
)
assert selected == "medium"
assert _FakeTerminalMenu.last_choices[:4] == [
" minimal",
" low",
" medium ← currently in use",
" high",
]
@@ -305,6 +305,7 @@ def test_setup_copilot_acp_skips_same_provider_pool_step(tmp_path, monkeypatch):
monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", fake_prompt_yes_no)
monkeypatch.setattr("hermes_cli.setup.prompt", lambda *args, **kwargs: "")
monkeypatch.setattr("hermes_cli.auth.get_active_provider", lambda: None)
monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: [])
monkeypatch.setattr("agent.auxiliary_client.get_available_vision_backends", lambda: [])
setup_model_provider(config)
@@ -0,0 +1,155 @@
"""Tests for _setup_provider_model_selection and the zai/kimi/minimax branch.
Regression test for the is_coding_plan NameError that crashed setup when
selecting zai, kimi-coding, minimax, or minimax-cn providers.
"""
import pytest
from unittest.mock import patch, MagicMock
@pytest.fixture
def mock_provider_registry():
"""Minimal PROVIDER_REGISTRY entries for tested providers."""
class FakePConfig:
def __init__(self, name, env_vars, base_url_env, inference_url):
self.name = name
self.api_key_env_vars = env_vars
self.base_url_env_var = base_url_env
self.inference_base_url = inference_url
return {
"zai": FakePConfig("ZAI", ["ZAI_API_KEY"], "ZAI_BASE_URL", "https://api.zai.example"),
"kimi-coding": FakePConfig("Kimi Coding", ["KIMI_API_KEY"], "KIMI_BASE_URL", "https://api.kimi.example"),
"minimax": FakePConfig("MiniMax", ["MINIMAX_API_KEY"], "MINIMAX_BASE_URL", "https://api.minimax.example"),
"minimax-cn": FakePConfig("MiniMax CN", ["MINIMAX_API_KEY"], "MINIMAX_CN_BASE_URL", "https://api.minimax-cn.example"),
"opencode-zen": FakePConfig("OpenCode Zen", ["OPENCODE_ZEN_API_KEY"], "OPENCODE_ZEN_BASE_URL", "https://opencode.ai/zen/v1"),
"opencode-go": FakePConfig("OpenCode Go", ["OPENCODE_GO_API_KEY"], "OPENCODE_GO_BASE_URL", "https://opencode.ai/zen/go/v1"),
}
class TestSetupProviderModelSelection:
"""Verify _setup_provider_model_selection works for all providers
that previously hit the is_coding_plan NameError."""
@pytest.mark.parametrize("provider_id,expected_defaults", [
("zai", ["glm-5", "glm-4.7", "glm-4.5", "glm-4.5-flash"]),
("kimi-coding", ["kimi-k2.5", "kimi-k2-thinking", "kimi-k2-turbo-preview"]),
("minimax", ["MiniMax-M1", "MiniMax-M1-40k", "MiniMax-M1-80k", "MiniMax-M1-128k", "MiniMax-M1-256k", "MiniMax-M2.5", "MiniMax-M2.7"]),
("minimax-cn", ["MiniMax-M1", "MiniMax-M1-40k", "MiniMax-M1-80k", "MiniMax-M1-128k", "MiniMax-M1-256k", "MiniMax-M2.5", "MiniMax-M2.7"]),
("opencode-zen", ["gpt-5.4", "gpt-5.3-codex", "claude-sonnet-4-6", "gemini-3-flash"]),
("opencode-go", ["glm-5", "kimi-k2.5", "minimax-m2.5", "minimax-m2.7"]),
])
@patch("hermes_cli.models.fetch_api_models", return_value=[])
@patch("hermes_cli.config.get_env_value", return_value="fake-key")
def test_falls_back_to_default_models_without_crashing(
self, mock_env, mock_fetch, provider_id, expected_defaults, mock_provider_registry
):
"""Previously this code path raised NameError: 'is_coding_plan'.
Now it delegates to _setup_provider_model_selection which uses
_DEFAULT_PROVIDER_MODELS -- no crash, correct model list."""
from hermes_cli.setup import _setup_provider_model_selection
captured_choices = {}
def fake_prompt_choice(label, choices, default):
captured_choices["choices"] = choices
# Select "Keep current" (last item)
return len(choices) - 1
with patch("hermes_cli.auth.PROVIDER_REGISTRY", mock_provider_registry):
_setup_provider_model_selection(
config={"model": {}},
provider_id=provider_id,
current_model="some-model",
prompt_choice=fake_prompt_choice,
prompt_fn=lambda _: None,
)
# The offered model list should start with the default models
offered = captured_choices["choices"]
for model in expected_defaults:
assert model in offered, f"{model} not in choices for {provider_id}"
@patch("hermes_cli.models.fetch_api_models")
@patch("hermes_cli.config.get_env_value", return_value="fake-key")
def test_live_models_used_when_available(
self, mock_env, mock_fetch, mock_provider_registry
):
"""When fetch_api_models returns results, those are used instead of defaults."""
from hermes_cli.setup import _setup_provider_model_selection
live = ["live-model-1", "live-model-2"]
mock_fetch.return_value = live
captured_choices = {}
def fake_prompt_choice(label, choices, default):
captured_choices["choices"] = choices
return len(choices) - 1
with patch("hermes_cli.auth.PROVIDER_REGISTRY", mock_provider_registry):
_setup_provider_model_selection(
config={"model": {}},
provider_id="zai",
current_model="some-model",
prompt_choice=fake_prompt_choice,
prompt_fn=lambda _: None,
)
offered = captured_choices["choices"]
assert "live-model-1" in offered
assert "live-model-2" in offered
@patch("hermes_cli.models.fetch_api_models", return_value=[])
@patch("hermes_cli.config.get_env_value", return_value="fake-key")
def test_custom_model_selection(
self, mock_env, mock_fetch, mock_provider_registry
):
"""Selecting 'Custom model' lets user type a model name."""
from hermes_cli.setup import _setup_provider_model_selection, _DEFAULT_PROVIDER_MODELS
defaults = _DEFAULT_PROVIDER_MODELS["zai"]
custom_model_idx = len(defaults) # "Custom model" is right after defaults
config = {"model": {}}
def fake_prompt_choice(label, choices, default):
return custom_model_idx
with patch("hermes_cli.auth.PROVIDER_REGISTRY", mock_provider_registry):
_setup_provider_model_selection(
config=config,
provider_id="zai",
current_model="some-model",
prompt_choice=fake_prompt_choice,
prompt_fn=lambda _: "my-custom-model",
)
assert config["model"]["default"] == "my-custom-model"
@patch("hermes_cli.models.fetch_api_models", return_value=["opencode-go/kimi-k2.5", "opencode-go/minimax-m2.7"])
@patch("hermes_cli.config.get_env_value", return_value="fake-key")
def test_opencode_live_models_are_normalized_for_selection(
self, mock_env, mock_fetch, mock_provider_registry
):
from hermes_cli.setup import _setup_provider_model_selection
captured_choices = {}
def fake_prompt_choice(label, choices, default):
captured_choices["choices"] = choices
return len(choices) - 1
with patch("hermes_cli.auth.PROVIDER_REGISTRY", mock_provider_registry):
_setup_provider_model_selection(
config={"model": {}},
provider_id="opencode-go",
current_model="opencode-go/kimi-k2.5",
prompt_choice=fake_prompt_choice,
prompt_fn=lambda _: None,
)
offered = captured_choices["choices"]
assert "kimi-k2.5" in offered
assert "minimax-m2.7" in offered
assert all("opencode-go/" not in choice for choice in offered)
@@ -44,7 +44,7 @@ class TestOfferOpenclawMigration:
assert setup_mod._offer_openclaw_migration(tmp_path / ".hermes") is False
def test_runs_migration_when_user_accepts(self, tmp_path):
"""Should run dry-run preview first, then execute after confirmation."""
"""Should dynamically load the script and run the Migrator."""
openclaw_dir = tmp_path / ".openclaw"
openclaw_dir.mkdir()
@@ -60,7 +60,6 @@ class TestOfferOpenclawMigration:
fake_migrator = MagicMock()
fake_migrator.migrate.return_value = {
"summary": {"migrated": 3, "skipped": 1, "conflict": 0, "error": 0},
"items": [{"kind": "config", "status": "migrated", "destination": "/tmp/x"}],
"output_dir": str(hermes_home / "migration"),
}
fake_mod.Migrator = MagicMock(return_value=fake_migrator)
@@ -71,7 +70,6 @@ class TestOfferOpenclawMigration:
with (
patch("hermes_cli.setup.Path.home", return_value=tmp_path),
patch.object(setup_mod, "_OPENCLAW_SCRIPT", script),
# Both prompts answered Yes: preview offer + proceed confirmation
patch.object(setup_mod, "prompt_yes_no", return_value=True),
patch.object(setup_mod, "get_config_path", return_value=config_path),
patch("importlib.util.spec_from_file_location") as mock_spec_fn,
@@ -93,75 +91,13 @@ class TestOfferOpenclawMigration:
fake_mod.resolve_selected_options.assert_called_once_with(
None, None, preset="full"
)
# Migrator called twice: once for dry-run preview, once for execution
assert fake_mod.Migrator.call_count == 2
# First call: dry-run preview (execute=False, overwrite=True to show all)
preview_kwargs = fake_mod.Migrator.call_args_list[0][1]
assert preview_kwargs["execute"] is False
assert preview_kwargs["overwrite"] is True
assert preview_kwargs["migrate_secrets"] is True
assert preview_kwargs["preset_name"] == "full"
# Second call: actual execution (execute=True, overwrite=False to preserve)
exec_kwargs = fake_mod.Migrator.call_args_list[1][1]
assert exec_kwargs["execute"] is True
assert exec_kwargs["overwrite"] is False
assert exec_kwargs["migrate_secrets"] is True
assert exec_kwargs["preset_name"] == "full"
# migrate() called twice (once per Migrator instance)
assert fake_migrator.migrate.call_count == 2
def test_user_declines_after_preview(self, tmp_path):
"""Should return False when user sees preview but declines to proceed."""
openclaw_dir = tmp_path / ".openclaw"
openclaw_dir.mkdir()
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
config_path = hermes_home / "config.yaml"
config_path.write_text("agent:\n max_turns: 90\n")
fake_mod = ModuleType("openclaw_to_hermes")
fake_mod.resolve_selected_options = MagicMock(return_value={"soul", "memory"})
fake_migrator = MagicMock()
fake_migrator.migrate.return_value = {
"summary": {"migrated": 3, "skipped": 0, "conflict": 0, "error": 0},
"items": [{"kind": "config", "status": "migrated", "destination": "/tmp/x"}],
}
fake_mod.Migrator = MagicMock(return_value=fake_migrator)
script = tmp_path / "openclaw_to_hermes.py"
script.write_text("# placeholder")
# First prompt (preview): Yes, Second prompt (proceed): No
prompt_responses = iter([True, False])
with (
patch("hermes_cli.setup.Path.home", return_value=tmp_path),
patch.object(setup_mod, "_OPENCLAW_SCRIPT", script),
patch.object(setup_mod, "prompt_yes_no", side_effect=prompt_responses),
patch.object(setup_mod, "get_config_path", return_value=config_path),
patch("importlib.util.spec_from_file_location") as mock_spec_fn,
):
mock_spec = MagicMock()
mock_spec.loader = MagicMock()
mock_spec_fn.return_value = mock_spec
def exec_module(mod):
mod.resolve_selected_options = fake_mod.resolve_selected_options
mod.Migrator = fake_mod.Migrator
mock_spec.loader.exec_module = exec_module
result = setup_mod._offer_openclaw_migration(hermes_home)
assert result is False
# Only dry-run Migrator was created, not the execute one
assert fake_mod.Migrator.call_count == 1
preview_kwargs = fake_mod.Migrator.call_args[1]
assert preview_kwargs["execute"] is False
fake_mod.Migrator.assert_called_once()
call_kwargs = fake_mod.Migrator.call_args[1]
assert call_kwargs["execute"] is True
assert call_kwargs["overwrite"] is True
assert call_kwargs["migrate_secrets"] is True
assert call_kwargs["preset_name"] == "full"
fake_migrator.migrate.assert_called_once()
def test_handles_migration_error_gracefully(self, tmp_path):
"""Should catch exceptions and return False."""
+25
View File
@@ -196,6 +196,31 @@ class TestDisplayIntegration:
set_active_skin("ares")
assert get_skin_tool_prefix() == ""
def test_get_skin_faces_default(self):
from agent.display import get_skin_faces, KawaiiSpinner
faces = get_skin_faces("waiting_faces", KawaiiSpinner.KAWAII_WAITING)
# Default skin has no custom faces, so should return the default list
assert faces == KawaiiSpinner.KAWAII_WAITING
def test_get_skin_faces_ares(self):
from hermes_cli.skin_engine import set_active_skin
from agent.display import get_skin_faces, KawaiiSpinner
set_active_skin("ares")
faces = get_skin_faces("waiting_faces", KawaiiSpinner.KAWAII_WAITING)
assert "(⚔)" in faces
def test_get_skin_verbs_default(self):
from agent.display import get_skin_verbs, KawaiiSpinner
verbs = get_skin_verbs()
assert verbs == KawaiiSpinner.THINKING_VERBS
def test_get_skin_verbs_ares(self):
from hermes_cli.skin_engine import set_active_skin
from agent.display import get_skin_verbs
set_active_skin("ares")
verbs = get_skin_verbs()
assert "forging" in verbs
def test_tool_message_uses_skin_prefix(self):
from hermes_cli.skin_engine import set_active_skin
from agent.display import get_cute_tool_message
-8
View File
@@ -354,14 +354,6 @@ def test_first_install_nous_auto_configures_managed_defaults(monkeypatch):
lambda *args, **kwargs: {"web", "image_gen", "tts", "browser"},
)
monkeypatch.setattr("hermes_cli.tools_config.save_config", lambda config: None)
# Prevent leaked platform tokens (e.g. DISCORD_BOT_TOKEN from gateway.run
# import) from adding extra platforms. The loop in tools_command runs
# apply_nous_managed_defaults per platform; a second iteration sees values
# set by the first as "explicit" and skips them.
monkeypatch.setattr(
"hermes_cli.tools_config._get_enabled_platforms",
lambda: ["cli"],
)
monkeypatch.setattr(
"hermes_cli.nous_subscription.get_nous_auth_status",
lambda: {"logged_in": True},
@@ -368,9 +368,6 @@ class TestCmdUpdateLaunchdRestart:
monkeypatch.setattr(
gateway_cli, "is_macos", lambda: False,
)
monkeypatch.setattr(
gateway_cli, "is_linux", lambda: True,
)
mock_run.side_effect = _make_run_side_effect(
commit_count="3",
+9 -12
View File
@@ -211,15 +211,14 @@ class TestExchangeAuthCode:
assert setup_module.PENDING_AUTH_PATH.exists()
assert not setup_module.TOKEN_PATH.exists()
def test_accepts_narrower_scopes_with_warning(self, setup_module, capsys):
"""Partial scopes are accepted with a warning (gws migration: v2.0)."""
def test_refuses_to_overwrite_existing_token_with_narrower_scopes(self, setup_module, capsys):
setup_module.PENDING_AUTH_PATH.write_text(
json.dumps({"state": "saved-state", "code_verifier": "saved-verifier"})
)
setup_module.TOKEN_PATH.write_text(json.dumps({"token": "***", "scopes": setup_module.SCOPES}))
setup_module.TOKEN_PATH.write_text(json.dumps({"token": "existing-token", "scopes": setup_module.SCOPES}))
FakeFlow.credentials_payload = {
"token": "***",
"refresh_token": "***",
"token": "narrow-token",
"refresh_token": "refresh-token",
"token_uri": "https://oauth2.googleapis.com/token",
"client_id": "client-id",
"client_secret": "client-secret",
@@ -229,12 +228,10 @@ class TestExchangeAuthCode:
],
}
setup_module.exchange_auth_code("4/test-auth-code")
with pytest.raises(SystemExit):
setup_module.exchange_auth_code("4/test-auth-code")
out = capsys.readouterr().out
assert "warning" in out.lower()
assert "missing" in out.lower()
# Token is saved (partial scopes accepted)
assert setup_module.TOKEN_PATH.exists()
# Pending auth is cleaned up
assert not setup_module.PENDING_AUTH_PATH.exists()
assert "refusing to save incomplete google workspace token" in out.lower()
assert json.loads(setup_module.TOKEN_PATH.read_text())["token"] == "existing-token"
assert setup_module.PENDING_AUTH_PATH.exists()
+82 -140
View File
@@ -1,175 +1,117 @@
"""Tests for Google Workspace gws bridge and CLI wrapper."""
"""Regression tests for Google Workspace API credential validation."""
import importlib.util
import json
import os
import subprocess
import sys
import types
from datetime import datetime, timedelta, timezone
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
BRIDGE_PATH = (
Path(__file__).resolve().parents[2]
/ "skills/productivity/google-workspace/scripts/gws_bridge.py"
)
API_PATH = (
SCRIPT_PATH = (
Path(__file__).resolve().parents[2]
/ "skills/productivity/google-workspace/scripts/google_api.py"
)
@pytest.fixture
def bridge_module(monkeypatch, tmp_path):
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
class FakeAuthorizedCredentials:
def __init__(self, *, valid=True, expired=False, refresh_token="refresh-token"):
self.valid = valid
self.expired = expired
self.refresh_token = refresh_token
self.refresh_calls = 0
spec = importlib.util.spec_from_file_location("gws_bridge_test", BRIDGE_PATH)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(module)
return module
def refresh(self, _request):
self.refresh_calls += 1
self.valid = True
self.expired = False
def to_json(self):
return json.dumps({
"token": "refreshed-token",
"refresh_token": self.refresh_token,
"token_uri": "https://oauth2.googleapis.com/token",
"client_id": "client-id",
"client_secret": "client-secret",
"scopes": [
"https://www.googleapis.com/auth/gmail.readonly",
"https://www.googleapis.com/auth/gmail.send",
"https://www.googleapis.com/auth/gmail.modify",
"https://www.googleapis.com/auth/calendar",
"https://www.googleapis.com/auth/drive.readonly",
"https://www.googleapis.com/auth/contacts.readonly",
"https://www.googleapis.com/auth/spreadsheets",
"https://www.googleapis.com/auth/documents.readonly",
],
})
class FakeCredentialsFactory:
creds = FakeAuthorizedCredentials()
@classmethod
def from_authorized_user_file(cls, _path, _scopes):
return cls.creds
@pytest.fixture
def api_module(monkeypatch, tmp_path):
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
def google_api_module(monkeypatch, tmp_path):
google_module = types.ModuleType("google")
oauth2_module = types.ModuleType("google.oauth2")
credentials_module = types.ModuleType("google.oauth2.credentials")
credentials_module.Credentials = FakeCredentialsFactory
auth_module = types.ModuleType("google.auth")
transport_module = types.ModuleType("google.auth.transport")
requests_module = types.ModuleType("google.auth.transport.requests")
requests_module.Request = object
spec = importlib.util.spec_from_file_location("gws_api_test", API_PATH)
monkeypatch.setitem(sys.modules, "google", google_module)
monkeypatch.setitem(sys.modules, "google.oauth2", oauth2_module)
monkeypatch.setitem(sys.modules, "google.oauth2.credentials", credentials_module)
monkeypatch.setitem(sys.modules, "google.auth", auth_module)
monkeypatch.setitem(sys.modules, "google.auth.transport", transport_module)
monkeypatch.setitem(sys.modules, "google.auth.transport.requests", requests_module)
spec = importlib.util.spec_from_file_location("google_workspace_api_test", SCRIPT_PATH)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(module)
monkeypatch.setattr(module, "TOKEN_PATH", tmp_path / "google_token.json")
return module
def _write_token(path: Path, *, token="ya29.test", expiry=None, **extra):
data = {
"token": token,
"refresh_token": "1//refresh",
"client_id": "123.apps.googleusercontent.com",
"client_secret": "secret",
def _write_token(path: Path, scopes):
path.write_text(json.dumps({
"token": "access-token",
"refresh_token": "refresh-token",
"token_uri": "https://oauth2.googleapis.com/token",
**extra,
}
if expiry is not None:
data["expiry"] = expiry
path.write_text(json.dumps(data))
"client_id": "client-id",
"client_secret": "client-secret",
"scopes": scopes,
}))
def test_bridge_returns_valid_token(bridge_module, tmp_path):
"""Non-expired token is returned without refresh."""
future = (datetime.now(timezone.utc) + timedelta(hours=1)).isoformat()
token_path = bridge_module.get_token_path()
_write_token(token_path, token="ya29.valid", expiry=future)
def test_get_credentials_rejects_missing_scopes(google_api_module, capsys):
FakeCredentialsFactory.creds = FakeAuthorizedCredentials(valid=True)
_write_token(google_api_module.TOKEN_PATH, [
"https://www.googleapis.com/auth/drive.readonly",
"https://www.googleapis.com/auth/spreadsheets",
])
result = bridge_module.get_valid_token()
assert result == "ya29.valid"
def test_bridge_refreshes_expired_token(bridge_module, tmp_path):
"""Expired token triggers a refresh via token_uri."""
past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat()
token_path = bridge_module.get_token_path()
_write_token(token_path, token="ya29.old", expiry=past)
mock_resp = MagicMock()
mock_resp.read.return_value = json.dumps({
"access_token": "ya29.refreshed",
"expires_in": 3600,
}).encode()
mock_resp.__enter__ = lambda s: s
mock_resp.__exit__ = MagicMock(return_value=False)
with patch("urllib.request.urlopen", return_value=mock_resp):
result = bridge_module.get_valid_token()
assert result == "ya29.refreshed"
# Verify persisted
saved = json.loads(token_path.read_text())
assert saved["token"] == "ya29.refreshed"
def test_bridge_exits_on_missing_token(bridge_module):
"""Missing token file causes exit with code 1."""
with pytest.raises(SystemExit):
bridge_module.get_valid_token()
google_api_module.get_credentials()
err = capsys.readouterr().err
assert "missing google workspace scopes" in err.lower()
assert "gmail.send" in err
def test_bridge_main_injects_token_env(bridge_module, tmp_path):
"""main() sets GOOGLE_WORKSPACE_CLI_TOKEN in subprocess env."""
future = (datetime.now(timezone.utc) + timedelta(hours=1)).isoformat()
token_path = bridge_module.get_token_path()
_write_token(token_path, token="ya29.injected", expiry=future)
def test_get_credentials_accepts_full_scope_token(google_api_module):
FakeCredentialsFactory.creds = FakeAuthorizedCredentials(valid=True)
_write_token(google_api_module.TOKEN_PATH, list(google_api_module.SCOPES))
captured = {}
creds = google_api_module.get_credentials()
def capture_run(cmd, **kwargs):
captured["cmd"] = cmd
captured["env"] = kwargs.get("env", {})
return MagicMock(returncode=0)
with patch.object(sys, "argv", ["gws_bridge.py", "gmail", "+triage"]):
with patch.object(subprocess, "run", side_effect=capture_run):
with pytest.raises(SystemExit):
bridge_module.main()
assert captured["env"]["GOOGLE_WORKSPACE_CLI_TOKEN"] == "ya29.injected"
assert captured["cmd"] == ["gws", "gmail", "+triage"]
def test_api_calendar_list_uses_agenda_by_default(api_module):
"""calendar list without dates uses +agenda helper."""
captured = {}
def capture_run(cmd, **kwargs):
captured["cmd"] = cmd
return MagicMock(returncode=0)
args = api_module.argparse.Namespace(
start="", end="", max=25, calendar="primary", func=api_module.calendar_list,
)
with patch.object(subprocess, "run", side_effect=capture_run):
with pytest.raises(SystemExit):
api_module.calendar_list(args)
gws_args = captured["cmd"][2:] # skip python + bridge path
assert "calendar" in gws_args
assert "+agenda" in gws_args
assert "--days" in gws_args
def test_api_calendar_list_respects_date_range(api_module):
"""calendar list with --start/--end uses raw events list API."""
captured = {}
def capture_run(cmd, **kwargs):
captured["cmd"] = cmd
return MagicMock(returncode=0)
args = api_module.argparse.Namespace(
start="2026-04-01T00:00:00Z",
end="2026-04-07T23:59:59Z",
max=25,
calendar="primary",
func=api_module.calendar_list,
)
with patch.object(subprocess, "run", side_effect=capture_run):
with pytest.raises(SystemExit):
api_module.calendar_list(args)
gws_args = captured["cmd"][2:]
assert "events" in gws_args
assert "list" in gws_args
params_idx = gws_args.index("--params")
params = json.loads(gws_args[params_idx + 1])
assert params["timeMin"] == "2026-04-01T00:00:00Z"
assert params["timeMax"] == "2026-04-07T23:59:59Z"
assert creds is FakeCredentialsFactory.creds

Some files were not shown because too many files have changed in this diff Show More