Compare commits
97 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6aba50f5ba | |||
| a562550af3 | |||
| 37c478cf2f | |||
| 27eeea0555 | |||
| fd73937ec8 | |||
| 723b5bec85 | |||
| 14ccd32cee | |||
| 06f862fa1b | |||
| 39cd57083a | |||
| d99e2a29d6 | |||
| cab814af15 | |||
| 5c2ecdec49 | |||
| 6d272ba477 | |||
| 97b0cd51ee | |||
| 6ee0005e8c | |||
| c8aff74632 | |||
| 08f35076c9 | |||
| 289d2745af | |||
| fc417ed049 | |||
| 32519066dc | |||
| 689c515090 | |||
| 758c4ad1ef | |||
| 000a881fcf | |||
| 5f0caf54d6 | |||
| 90352b2adf | |||
| ee39e88b03 | |||
| b53f681993 | |||
| 8c3935ebe8 | |||
| 1e5056ec30 | |||
| d82580b25b | |||
| b80e318168 | |||
| 72b345e068 | |||
| 8160d7a03d | |||
| dfe7386a58 | |||
| ef73babea1 | |||
| f2893fe51a | |||
| 255f59de18 | |||
| 4bede272cf | |||
| 0e6354df50 | |||
| b0892375cd | |||
| 0a922bf218 | |||
| d053845703 | |||
| 0970f1de50 | |||
| 8ce6aaac23 | |||
| ad1e8804a6 | |||
| c22bffc92e | |||
| cc4b1f0007 | |||
| dfc820345d | |||
| 75380de430 | |||
| 885123d44b | |||
| 04c1c5d53f | |||
| cf53e2676b | |||
| f4f4078ad9 | |||
| 59e630a64d | |||
| 2d328d5c70 | |||
| 151654851c | |||
| 5910412002 | |||
| 39da23a129 | |||
| cac6178104 | |||
| dafe443beb | |||
| da9f96bf51 | |||
| 3ec8809b78 | |||
| 4e3e87b677 | |||
| 26bbb422b1 | |||
| 976bad5bde | |||
| d4bb44d4b9 | |||
| 6693e2a497 | |||
| 55fac8a386 | |||
| 50bb4fe010 | |||
| 06e1d9cdd4 | |||
| 69f3aaa1d6 | |||
| c94936839c | |||
| d7607292d9 | |||
| af9caec44f | |||
| f459214010 | |||
| a2f9f04c06 | |||
| 671d5068e7 | |||
| 1a40073a3a | |||
| 3dd76d2718 | |||
| 50ad66aee6 | |||
| 80d82c2f5c | |||
| 7241e6134b | |||
| ae9a713a0a | |||
| eb8071bbc1 | |||
| 086d92a0e0 | |||
| 4e56eacdce | |||
| 1909877e6e | |||
| 307697688e | |||
| 4d1f1dccf9 | |||
| 640441b865 | |||
| 5a55d54ee2 | |||
| 424b62aa16 | |||
| c89719ad9c | |||
| d3c5d65563 | |||
| 4f5e8b22a7 | |||
| eeb8b4b00f | |||
| ffbd80f5fc |
@@ -89,6 +89,15 @@
|
||||
# Optional base URL override:
|
||||
# HERMES_QWEN_BASE_URL=https://portal.qwen.ai/v1
|
||||
|
||||
# =============================================================================
|
||||
# LLM PROVIDER (Xiaomi MiMo)
|
||||
# =============================================================================
|
||||
# Xiaomi MiMo models (mimo-v2-pro, mimo-v2-omni, mimo-v2-flash).
|
||||
# Get your key at: https://platform.xiaomimimo.com
|
||||
# XIAOMI_API_KEY=your_key_here
|
||||
# Optional base URL override:
|
||||
# XIAOMI_BASE_URL=https://api.xiaomimimo.com/v1
|
||||
|
||||
# =============================================================================
|
||||
# TOOL API KEYS
|
||||
# =============================================================================
|
||||
|
||||
@@ -351,8 +351,9 @@ Cache-breaking forces dramatically higher costs. The ONLY time we alter context
|
||||
|
||||
### Background Process Notifications (Gateway)
|
||||
|
||||
When `terminal(background=true, check_interval=...)` is used, the gateway runs a watcher that
|
||||
pushes status updates to the user's chat. Control verbosity with `display.background_process_notifications`
|
||||
When `terminal(background=true, notify_on_complete=true)` is used, the gateway runs a watcher that
|
||||
detects process completion and triggers a new agent turn. Control verbosity of background process
|
||||
messages with `display.background_process_notifications`
|
||||
in config.yaml (or `HERMES_BACKGROUND_NOTIFICATIONS` env var):
|
||||
|
||||
- `all` — running-output updates + final message (default)
|
||||
|
||||
+245
-62
@@ -23,17 +23,13 @@ Resolution order for vision/multimodal tasks (auto mode):
|
||||
6. Custom endpoint (for local vision models: Qwen-VL, LLaVA, Pixtral, etc.)
|
||||
7. None
|
||||
|
||||
Per-task provider overrides (e.g. AUXILIARY_VISION_PROVIDER,
|
||||
CONTEXT_COMPRESSION_PROVIDER) can force a specific provider for each task.
|
||||
Per-task overrides are configured in config.yaml under the ``auxiliary:`` section
|
||||
(e.g. ``auxiliary.vision.provider``, ``auxiliary.compression.model``).
|
||||
Default "auto" follows the chains above.
|
||||
|
||||
Per-task model overrides (e.g. AUXILIARY_VISION_MODEL,
|
||||
AUXILIARY_WEB_EXTRACT_MODEL) let callers use a different model slug
|
||||
than the provider's default.
|
||||
|
||||
Per-task direct endpoint overrides (e.g. AUXILIARY_VISION_BASE_URL,
|
||||
AUXILIARY_VISION_API_KEY) let callers route a specific auxiliary task to a
|
||||
custom OpenAI-compatible endpoint without touching the main model settings.
|
||||
Legacy env var overrides (AUXILIARY_{TASK}_PROVIDER, AUXILIARY_{TASK}_MODEL,
|
||||
AUXILIARY_{TASK}_BASE_URL, etc.) are still read as a backward-compat fallback
|
||||
but config.yaml takes priority. New configuration should always use config.yaml.
|
||||
|
||||
Payment / credit exhaustion fallback:
|
||||
When a resolved provider returns HTTP 402 or a credit-related error,
|
||||
@@ -59,6 +55,9 @@ from hermes_constants import OPENROUTER_BASE_URL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Module-level flag: only warn once per process about stale OPENAI_BASE_URL.
|
||||
_stale_base_url_warned = False
|
||||
|
||||
_PROVIDER_ALIASES = {
|
||||
"google": "gemini",
|
||||
"google-gemini": "gemini",
|
||||
@@ -108,6 +107,14 @@ _API_KEY_PROVIDER_AUX_MODELS: Dict[str, str] = {
|
||||
"kilocode": "google/gemini-3-flash-preview",
|
||||
}
|
||||
|
||||
# Vision-specific model overrides for direct providers.
|
||||
# When the user's main provider has a dedicated vision/multimodal model that
|
||||
# differs from their main chat model, map it here. The vision auto-detect
|
||||
# "exotic provider" branch checks this before falling back to the main model.
|
||||
_PROVIDER_VISION_MODELS: Dict[str, str] = {
|
||||
"xiaomi": "mimo-v2-omni",
|
||||
}
|
||||
|
||||
# OpenRouter app attribution headers
|
||||
_OR_HEADERS = {
|
||||
"HTTP-Referer": "https://hermes-agent.nousresearch.com",
|
||||
@@ -707,7 +714,9 @@ def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
base_url = _to_openai_base_url(
|
||||
_pool_runtime_base_url(entry, pconfig.inference_base_url) or pconfig.inference_base_url
|
||||
)
|
||||
model = _API_KEY_PROVIDER_AUX_MODELS.get(provider_id, "default")
|
||||
model = _API_KEY_PROVIDER_AUX_MODELS.get(provider_id)
|
||||
if model is None:
|
||||
continue # skip provider if we don't know a valid aux model
|
||||
logger.debug("Auxiliary text client: %s (%s) via pool", pconfig.name, model)
|
||||
extra = {}
|
||||
if "api.kimi.com" in base_url.lower():
|
||||
@@ -726,7 +735,9 @@ def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
base_url = _to_openai_base_url(
|
||||
str(creds.get("base_url", "")).strip().rstrip("/") or pconfig.inference_base_url
|
||||
)
|
||||
model = _API_KEY_PROVIDER_AUX_MODELS.get(provider_id, "default")
|
||||
model = _API_KEY_PROVIDER_AUX_MODELS.get(provider_id)
|
||||
if model is None:
|
||||
continue # skip provider if we don't know a valid aux model
|
||||
logger.debug("Auxiliary text client: %s (%s)", pconfig.name, model)
|
||||
extra = {}
|
||||
if "api.kimi.com" in base_url.lower():
|
||||
@@ -1075,11 +1086,12 @@ def _is_connection_error(exc: Exception) -> bool:
|
||||
def _try_payment_fallback(
|
||||
failed_provider: str,
|
||||
task: str = None,
|
||||
reason: str = "payment error",
|
||||
) -> Tuple[Optional[Any], Optional[str], str]:
|
||||
"""Try alternative providers after a payment/credit error.
|
||||
"""Try alternative providers after a payment/credit or connection error.
|
||||
|
||||
Iterates the standard auto-detection chain, skipping the provider that
|
||||
returned a payment error.
|
||||
failed.
|
||||
|
||||
Returns:
|
||||
(client, model, provider_label) or (None, None, "") if no fallback.
|
||||
@@ -1105,15 +1117,15 @@ def _try_payment_fallback(
|
||||
client, model = try_fn()
|
||||
if client is not None:
|
||||
logger.info(
|
||||
"Auxiliary %s: payment error on %s — falling back to %s (%s)",
|
||||
task or "call", failed_provider, label, model or "default",
|
||||
"Auxiliary %s: %s on %s — falling back to %s (%s)",
|
||||
task or "call", reason, failed_provider, label, model or "default",
|
||||
)
|
||||
return client, model, label
|
||||
tried.append(label)
|
||||
|
||||
logger.warning(
|
||||
"Auxiliary %s: payment error on %s and no fallback available (tried: %s)",
|
||||
task or "call", failed_provider, ", ".join(tried),
|
||||
"Auxiliary %s: %s on %s and no fallback available (tried: %s)",
|
||||
task or "call", reason, failed_provider, ", ".join(tried),
|
||||
)
|
||||
return None, None, ""
|
||||
|
||||
@@ -1128,9 +1140,28 @@ def _resolve_auto() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
provider they already have credentials for — no OpenRouter key needed.
|
||||
2. OpenRouter → Nous → custom → Codex → API-key providers (original chain).
|
||||
"""
|
||||
global auxiliary_is_nous
|
||||
global auxiliary_is_nous, _stale_base_url_warned
|
||||
auxiliary_is_nous = False # Reset — _try_nous() will set True if it wins
|
||||
|
||||
# ── Warn once if OPENAI_BASE_URL is set but config.yaml uses a named
|
||||
# provider (not 'custom'). This catches the common "env poisoning"
|
||||
# scenario where a user switches providers via `hermes model` but the
|
||||
# old OPENAI_BASE_URL lingers in ~/.hermes/.env. ──
|
||||
if not _stale_base_url_warned:
|
||||
_env_base = os.getenv("OPENAI_BASE_URL", "").strip()
|
||||
_cfg_provider = _read_main_provider()
|
||||
if (_env_base and _cfg_provider
|
||||
and _cfg_provider != "custom"
|
||||
and not _cfg_provider.startswith("custom:")):
|
||||
logger.warning(
|
||||
"OPENAI_BASE_URL is set (%s) but model.provider is '%s'. "
|
||||
"Auxiliary clients may route to the wrong endpoint. "
|
||||
"Run: hermes model to reconfigure, or remove "
|
||||
"OPENAI_BASE_URL from ~/.hermes/.env",
|
||||
_env_base, _cfg_provider,
|
||||
)
|
||||
_stale_base_url_warned = True
|
||||
|
||||
# ── Step 1: non-aggregator main provider → use main model directly ──
|
||||
main_provider = _read_main_provider()
|
||||
main_model = _read_main_model()
|
||||
@@ -1217,6 +1248,7 @@ def resolve_provider_client(
|
||||
raw_codex: bool = False,
|
||||
explicit_base_url: str = None,
|
||||
explicit_api_key: str = None,
|
||||
api_mode: str = None,
|
||||
) -> Tuple[Optional[Any], Optional[str]]:
|
||||
"""Central router: given a provider name and optional model, return a
|
||||
configured client with the correct auth, base URL, and API format.
|
||||
@@ -1240,6 +1272,10 @@ def resolve_provider_client(
|
||||
the main agent loop).
|
||||
explicit_base_url: Optional direct OpenAI-compatible endpoint.
|
||||
explicit_api_key: Optional API key paired with explicit_base_url.
|
||||
api_mode: API mode override. One of "chat_completions",
|
||||
"codex_responses", or None (auto-detect). When set to
|
||||
"codex_responses", the client is wrapped in
|
||||
CodexAuxiliaryClient to route through the Responses API.
|
||||
|
||||
Returns:
|
||||
(client, resolved_model) or (None, None) if auth is unavailable.
|
||||
@@ -1247,6 +1283,40 @@ def resolve_provider_client(
|
||||
# Normalise aliases
|
||||
provider = _normalize_aux_provider(provider)
|
||||
|
||||
def _needs_codex_wrap(client_obj, base_url_str: str, model_str: str) -> bool:
|
||||
"""Decide if a plain OpenAI client should be wrapped for Responses API.
|
||||
|
||||
Returns True when api_mode is explicitly "codex_responses", or when
|
||||
auto-detection (api.openai.com + codex-family model) suggests it.
|
||||
Already-wrapped clients (CodexAuxiliaryClient) are skipped.
|
||||
"""
|
||||
if isinstance(client_obj, CodexAuxiliaryClient):
|
||||
return False
|
||||
if raw_codex:
|
||||
return False
|
||||
if api_mode == "codex_responses":
|
||||
return True
|
||||
# Auto-detect: api.openai.com + codex model name pattern
|
||||
if api_mode and api_mode != "codex_responses":
|
||||
return False # explicit non-codex mode
|
||||
normalized_base = (base_url_str or "").strip().lower()
|
||||
if "api.openai.com" in normalized_base and "openrouter" not in normalized_base:
|
||||
model_lower = (model_str or "").lower()
|
||||
if "codex" in model_lower:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _wrap_if_needed(client_obj, final_model_str: str, base_url_str: str = ""):
|
||||
"""Wrap a plain OpenAI client in CodexAuxiliaryClient if Responses API is needed."""
|
||||
if _needs_codex_wrap(client_obj, base_url_str, final_model_str):
|
||||
logger.debug(
|
||||
"resolve_provider_client: wrapping client in CodexAuxiliaryClient "
|
||||
"(api_mode=%s, model=%s, base_url=%s)",
|
||||
api_mode or "auto-detected", final_model_str,
|
||||
base_url_str[:60] if base_url_str else "")
|
||||
return CodexAuxiliaryClient(client_obj, final_model_str)
|
||||
return client_obj
|
||||
|
||||
# ── Auto: try all providers in priority order ────────────────────
|
||||
if provider == "auto":
|
||||
client, resolved = _resolve_auto()
|
||||
@@ -1336,6 +1406,7 @@ def resolve_provider_client(
|
||||
from hermes_cli.models import copilot_default_headers
|
||||
extra["default_headers"] = copilot_default_headers()
|
||||
client = OpenAI(api_key=custom_key, base_url=custom_base, **extra)
|
||||
client = _wrap_if_needed(client, final_model, custom_base)
|
||||
return (_to_async_client(client, final_model) if async_mode
|
||||
else (client, final_model))
|
||||
# Try custom first, then codex, then API-key providers
|
||||
@@ -1344,6 +1415,8 @@ def resolve_provider_client(
|
||||
client, default = try_fn()
|
||||
if client is not None:
|
||||
final_model = _normalize_resolved_model(model or default, provider)
|
||||
_cbase = str(getattr(client, "base_url", "") or "")
|
||||
client = _wrap_if_needed(client, final_model, _cbase)
|
||||
return (_to_async_client(client, final_model) if async_mode
|
||||
else (client, final_model))
|
||||
logger.warning("resolve_provider_client: custom/main requested "
|
||||
@@ -1363,6 +1436,7 @@ def resolve_provider_client(
|
||||
provider,
|
||||
)
|
||||
client = OpenAI(api_key=custom_key, base_url=custom_base)
|
||||
client = _wrap_if_needed(client, final_model, custom_base)
|
||||
logger.debug(
|
||||
"resolve_provider_client: named custom provider %r (%s)",
|
||||
provider, final_model)
|
||||
@@ -1442,6 +1516,11 @@ def resolve_provider_client(
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Honor api_mode for any API-key provider (e.g. direct OpenAI with
|
||||
# codex-family models). The copilot-specific wrapping above handles
|
||||
# copilot; this covers the general case (#6800).
|
||||
client = _wrap_if_needed(client, final_model, base_url)
|
||||
|
||||
logger.debug("resolve_provider_client: %s (%s)", provider, final_model)
|
||||
return (_to_async_client(client, final_model) if async_mode
|
||||
else (client, final_model))
|
||||
@@ -1474,12 +1553,13 @@ def get_text_auxiliary_client(task: str = "") -> Tuple[Optional[OpenAI], Optiona
|
||||
Callers may override the returned model with a per-task env var
|
||||
(e.g. CONTEXT_COMPRESSION_MODEL, AUXILIARY_WEB_EXTRACT_MODEL).
|
||||
"""
|
||||
provider, model, base_url, api_key = _resolve_task_provider_model(task or None)
|
||||
provider, model, base_url, api_key, api_mode = _resolve_task_provider_model(task or None)
|
||||
return resolve_provider_client(
|
||||
provider,
|
||||
model=model,
|
||||
explicit_base_url=base_url,
|
||||
explicit_api_key=api_key,
|
||||
api_mode=api_mode,
|
||||
)
|
||||
|
||||
|
||||
@@ -1490,13 +1570,14 @@ def get_async_text_auxiliary_client(task: str = ""):
|
||||
(AsyncCodexAuxiliaryClient, model) which wraps the Responses API.
|
||||
Returns (None, None) when no provider is available.
|
||||
"""
|
||||
provider, model, base_url, api_key = _resolve_task_provider_model(task or None)
|
||||
provider, model, base_url, api_key, api_mode = _resolve_task_provider_model(task or None)
|
||||
return resolve_provider_client(
|
||||
provider,
|
||||
model=model,
|
||||
async_mode=True,
|
||||
explicit_base_url=base_url,
|
||||
explicit_api_key=api_key,
|
||||
api_mode=api_mode,
|
||||
)
|
||||
|
||||
|
||||
@@ -1569,7 +1650,7 @@ def resolve_vision_provider_client(
|
||||
backends, so users can intentionally force experimental providers. Auto mode
|
||||
stays conservative and only tries vision backends known to work today.
|
||||
"""
|
||||
requested, resolved_model, resolved_base_url, resolved_api_key = _resolve_task_provider_model(
|
||||
requested, resolved_model, resolved_base_url, resolved_api_key, resolved_api_mode = _resolve_task_provider_model(
|
||||
"vision", provider, model, base_url, api_key
|
||||
)
|
||||
requested = _normalize_vision_provider(requested)
|
||||
@@ -1610,16 +1691,18 @@ def resolve_vision_provider_client(
|
||||
if sync_client is not None:
|
||||
return _finalize(main_provider, sync_client, default_model)
|
||||
else:
|
||||
# Exotic provider (DeepSeek, Alibaba, named custom, etc.)
|
||||
# Exotic provider (DeepSeek, Alibaba, Xiaomi, named custom, etc.)
|
||||
# Use provider-specific vision model if available, otherwise main model.
|
||||
vision_model = _PROVIDER_VISION_MODELS.get(main_provider, main_model)
|
||||
rpc_client, rpc_model = resolve_provider_client(
|
||||
main_provider, main_model)
|
||||
main_provider, vision_model)
|
||||
if rpc_client is not None:
|
||||
logger.info(
|
||||
"Vision auto-detect: using active provider %s (%s)",
|
||||
main_provider, rpc_model or main_model,
|
||||
main_provider, rpc_model or vision_model,
|
||||
)
|
||||
return _finalize(
|
||||
main_provider, rpc_client, rpc_model or main_model)
|
||||
main_provider, rpc_client, rpc_model or vision_model)
|
||||
|
||||
# Fall back through aggregators.
|
||||
for candidate in _VISION_AUTO_PROVIDER_ORDER:
|
||||
@@ -1785,12 +1868,30 @@ def cleanup_stale_async_clients() -> None:
|
||||
del _client_cache[key]
|
||||
|
||||
|
||||
def _is_openrouter_client(client: Any) -> bool:
|
||||
for obj in (client, getattr(client, "_client", None), getattr(client, "client", None)):
|
||||
if obj and "openrouter" in str(getattr(obj, "base_url", "") or "").lower():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _compat_model(client: Any, model: Optional[str], cached_default: Optional[str]) -> Optional[str]:
|
||||
"""Drop OpenRouter-format model slugs (with '/') for non-OpenRouter clients.
|
||||
|
||||
Mirrors the guard in resolve_provider_client() which is skipped on cache hits.
|
||||
"""
|
||||
if model and "/" in model and not _is_openrouter_client(client):
|
||||
return cached_default
|
||||
return model or cached_default
|
||||
|
||||
|
||||
def _get_cached_client(
|
||||
provider: str,
|
||||
model: str = None,
|
||||
async_mode: bool = False,
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
api_mode: str = None,
|
||||
) -> Tuple[Optional[Any], Optional[str]]:
|
||||
"""Get or create a cached client for the given provider.
|
||||
|
||||
@@ -1814,7 +1915,7 @@ def _get_cached_client(
|
||||
loop_id = id(current_loop)
|
||||
except RuntimeError:
|
||||
pass
|
||||
cache_key = (provider, async_mode, base_url or "", api_key or "", loop_id)
|
||||
cache_key = (provider, async_mode, base_url or "", api_key or "", api_mode or "", loop_id)
|
||||
with _client_cache_lock:
|
||||
if cache_key in _client_cache:
|
||||
cached_client, cached_default, cached_loop = _client_cache[cache_key]
|
||||
@@ -1826,9 +1927,11 @@ def _get_cached_client(
|
||||
_force_close_async_httpx(cached_client)
|
||||
del _client_cache[cache_key]
|
||||
else:
|
||||
return cached_client, model or cached_default
|
||||
effective = _compat_model(cached_client, model, cached_default)
|
||||
return cached_client, effective
|
||||
else:
|
||||
return cached_client, model or cached_default
|
||||
effective = _compat_model(cached_client, model, cached_default)
|
||||
return cached_client, effective
|
||||
# Build outside the lock
|
||||
client, default_model = resolve_provider_client(
|
||||
provider,
|
||||
@@ -1836,6 +1939,7 @@ def _get_cached_client(
|
||||
async_mode,
|
||||
explicit_base_url=base_url,
|
||||
explicit_api_key=api_key,
|
||||
api_mode=api_mode,
|
||||
)
|
||||
if client is not None:
|
||||
# For async clients, remember which loop they were created on so we
|
||||
@@ -1855,24 +1959,26 @@ def _resolve_task_provider_model(
|
||||
model: str = None,
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
) -> Tuple[str, Optional[str], Optional[str], Optional[str]]:
|
||||
) -> Tuple[str, Optional[str], Optional[str], Optional[str], Optional[str]]:
|
||||
"""Determine provider + model for a call.
|
||||
|
||||
Priority:
|
||||
1. Explicit provider/model/base_url/api_key args (always win)
|
||||
2. Env var overrides (AUXILIARY_{TASK}_*, CONTEXT_{TASK}_*)
|
||||
3. Config file (auxiliary.{task}.* or compression.*)
|
||||
2. Config file (auxiliary.{task}.* or compression.*)
|
||||
3. Env var overrides (backward-compat: AUXILIARY_{TASK}_*, CONTEXT_{TASK}_*)
|
||||
4. "auto" (full auto-detection chain)
|
||||
|
||||
Returns (provider, model, base_url, api_key) where model may be None
|
||||
(use provider default). When base_url is set, provider is forced to
|
||||
"custom" and the task uses that direct endpoint.
|
||||
Returns (provider, model, base_url, api_key, api_mode) where model may
|
||||
be None (use provider default). When base_url is set, provider is forced
|
||||
to "custom" and the task uses that direct endpoint. api_mode is one of
|
||||
"chat_completions", "codex_responses", or None (auto-detect).
|
||||
"""
|
||||
config = {}
|
||||
cfg_provider = None
|
||||
cfg_model = None
|
||||
cfg_base_url = None
|
||||
cfg_api_key = None
|
||||
cfg_api_mode = None
|
||||
|
||||
if task:
|
||||
try:
|
||||
@@ -1889,6 +1995,7 @@ def _resolve_task_provider_model(
|
||||
cfg_model = str(task_config.get("model", "")).strip() or None
|
||||
cfg_base_url = str(task_config.get("base_url", "")).strip() or None
|
||||
cfg_api_key = str(task_config.get("api_key", "")).strip() or None
|
||||
cfg_api_mode = str(task_config.get("api_mode", "")).strip() or None
|
||||
|
||||
# Backwards compat: compression section has its own keys.
|
||||
# The auxiliary.compression defaults to provider="auto", so treat
|
||||
@@ -1901,31 +2008,38 @@ def _resolve_task_provider_model(
|
||||
_sbu = comp.get("summary_base_url") or ""
|
||||
cfg_base_url = cfg_base_url or _sbu.strip() or None
|
||||
|
||||
# Env vars are backward-compat fallback only — config.yaml is primary.
|
||||
env_model = _get_auxiliary_env_override(task, "MODEL") if task else None
|
||||
resolved_model = model or env_model or cfg_model
|
||||
env_api_mode = _get_auxiliary_env_override(task, "API_MODE") if task else None
|
||||
resolved_model = model or cfg_model or env_model
|
||||
resolved_api_mode = cfg_api_mode or env_api_mode
|
||||
|
||||
if base_url:
|
||||
return "custom", resolved_model, base_url, api_key
|
||||
return "custom", resolved_model, base_url, api_key, resolved_api_mode
|
||||
if provider:
|
||||
return provider, resolved_model, base_url, api_key
|
||||
return provider, resolved_model, base_url, api_key, resolved_api_mode
|
||||
|
||||
if task:
|
||||
# Config.yaml is the primary source for per-task overrides.
|
||||
if cfg_base_url:
|
||||
return "custom", resolved_model, cfg_base_url, cfg_api_key, resolved_api_mode
|
||||
if cfg_provider and cfg_provider != "auto":
|
||||
return cfg_provider, resolved_model, None, None, resolved_api_mode
|
||||
|
||||
# Env vars are backward-compat fallback for users who haven't
|
||||
# migrated to config.yaml yet.
|
||||
env_base_url = _get_auxiliary_env_override(task, "BASE_URL")
|
||||
env_api_key = _get_auxiliary_env_override(task, "API_KEY")
|
||||
if env_base_url:
|
||||
return "custom", resolved_model, env_base_url, env_api_key or cfg_api_key
|
||||
return "custom", resolved_model, env_base_url, env_api_key, resolved_api_mode
|
||||
|
||||
env_provider = _get_auxiliary_provider(task)
|
||||
if env_provider != "auto":
|
||||
return env_provider, resolved_model, None, None
|
||||
return env_provider, resolved_model, None, None, resolved_api_mode
|
||||
|
||||
if cfg_base_url:
|
||||
return "custom", resolved_model, cfg_base_url, cfg_api_key
|
||||
if cfg_provider and cfg_provider != "auto":
|
||||
return cfg_provider, resolved_model, None, None
|
||||
return "auto", resolved_model, None, None
|
||||
return "auto", resolved_model, None, None, resolved_api_mode
|
||||
|
||||
return "auto", resolved_model, None, None
|
||||
return "auto", resolved_model, None, None, resolved_api_mode
|
||||
|
||||
|
||||
_DEFAULT_AUX_TIMEOUT = 30.0
|
||||
@@ -1997,6 +2111,37 @@ def _build_call_kwargs(
|
||||
return kwargs
|
||||
|
||||
|
||||
def _validate_llm_response(response: Any, task: str = None) -> Any:
|
||||
"""Validate that an LLM response has the expected .choices[0].message shape.
|
||||
|
||||
Fails fast with a clear error instead of letting malformed payloads
|
||||
propagate to downstream consumers where they crash with misleading
|
||||
AttributeError (e.g. "'str' object has no attribute 'choices'").
|
||||
|
||||
See #7264.
|
||||
"""
|
||||
if response is None:
|
||||
raise RuntimeError(
|
||||
f"Auxiliary {task or 'call'}: LLM returned None response"
|
||||
)
|
||||
# Allow SimpleNamespace responses from adapters (CodexAuxiliaryClient,
|
||||
# AnthropicAuxiliaryClient) — they have .choices[0].message.
|
||||
try:
|
||||
choices = response.choices
|
||||
if not choices or not hasattr(choices[0], "message"):
|
||||
raise AttributeError("missing choices[0].message")
|
||||
except (AttributeError, TypeError, IndexError) as exc:
|
||||
response_type = type(response).__name__
|
||||
response_preview = str(response)[:120]
|
||||
raise RuntimeError(
|
||||
f"Auxiliary {task or 'call'}: LLM returned invalid response "
|
||||
f"(type={response_type}): {response_preview!r}. "
|
||||
f"Expected object with .choices[0].message — check provider "
|
||||
f"adapter or custom endpoint compatibility."
|
||||
) from exc
|
||||
return response
|
||||
|
||||
|
||||
def call_llm(
|
||||
task: str = None,
|
||||
*,
|
||||
@@ -2035,7 +2180,7 @@ def call_llm(
|
||||
Raises:
|
||||
RuntimeError: If no provider is configured.
|
||||
"""
|
||||
resolved_provider, resolved_model, resolved_base_url, resolved_api_key = _resolve_task_provider_model(
|
||||
resolved_provider, resolved_model, resolved_base_url, resolved_api_key, resolved_api_mode = _resolve_task_provider_model(
|
||||
task, provider, model, base_url, api_key)
|
||||
|
||||
if task == "vision":
|
||||
@@ -2068,6 +2213,7 @@ def call_llm(
|
||||
resolved_model,
|
||||
base_url=resolved_base_url,
|
||||
api_key=resolved_api_key,
|
||||
api_mode=resolved_api_mode,
|
||||
)
|
||||
if client is None:
|
||||
# When the user explicitly chose a non-OpenRouter provider but no
|
||||
@@ -2111,18 +2257,20 @@ def call_llm(
|
||||
|
||||
# Handle max_tokens vs max_completion_tokens retry, then payment fallback.
|
||||
try:
|
||||
return client.chat.completions.create(**kwargs)
|
||||
return _validate_llm_response(
|
||||
client.chat.completions.create(**kwargs), task)
|
||||
except Exception as first_err:
|
||||
err_str = str(first_err)
|
||||
if "max_tokens" in err_str or "unsupported_parameter" in err_str:
|
||||
kwargs.pop("max_tokens", None)
|
||||
kwargs["max_completion_tokens"] = max_tokens
|
||||
try:
|
||||
return client.chat.completions.create(**kwargs)
|
||||
return _validate_llm_response(
|
||||
client.chat.completions.create(**kwargs), task)
|
||||
except Exception as retry_err:
|
||||
# If the max_tokens retry also hits a payment error,
|
||||
# fall through to the payment fallback below.
|
||||
if not _is_payment_error(retry_err):
|
||||
# If the max_tokens retry also hits a payment or connection
|
||||
# error, fall through to the fallback chain below.
|
||||
if not (_is_payment_error(retry_err) or _is_connection_error(retry_err)):
|
||||
raise
|
||||
first_err = retry_err
|
||||
|
||||
@@ -2139,19 +2287,24 @@ def call_llm(
|
||||
# and providers the user never configured that got picked up by
|
||||
# the auto-detection chain.
|
||||
should_fallback = _is_payment_error(first_err) or _is_connection_error(first_err)
|
||||
if should_fallback:
|
||||
# Only try alternative providers when the user didn't explicitly
|
||||
# configure this task's provider. Explicit provider = hard constraint;
|
||||
# auto (the default) = best-effort fallback chain. (#7559)
|
||||
is_auto = resolved_provider in ("auto", "", None)
|
||||
if should_fallback and is_auto:
|
||||
reason = "payment error" if _is_payment_error(first_err) else "connection error"
|
||||
logger.info("Auxiliary %s: %s on %s (%s), trying fallback",
|
||||
task or "call", reason, resolved_provider, first_err)
|
||||
fb_client, fb_model, fb_label = _try_payment_fallback(
|
||||
resolved_provider, task)
|
||||
resolved_provider, task, reason=reason)
|
||||
if fb_client is not None:
|
||||
fb_kwargs = _build_call_kwargs(
|
||||
fb_label, fb_model, messages,
|
||||
temperature=temperature, max_tokens=max_tokens,
|
||||
tools=tools, timeout=effective_timeout,
|
||||
extra_body=extra_body)
|
||||
return fb_client.chat.completions.create(**fb_kwargs)
|
||||
return _validate_llm_response(
|
||||
fb_client.chat.completions.create(**fb_kwargs), task)
|
||||
raise
|
||||
|
||||
|
||||
@@ -2229,7 +2382,7 @@ async def async_call_llm(
|
||||
|
||||
Same as call_llm() but async. See call_llm() for full documentation.
|
||||
"""
|
||||
resolved_provider, resolved_model, resolved_base_url, resolved_api_key = _resolve_task_provider_model(
|
||||
resolved_provider, resolved_model, resolved_base_url, resolved_api_key, resolved_api_mode = _resolve_task_provider_model(
|
||||
task, provider, model, base_url, api_key)
|
||||
|
||||
if task == "vision":
|
||||
@@ -2263,6 +2416,7 @@ async def async_call_llm(
|
||||
async_mode=True,
|
||||
base_url=resolved_base_url,
|
||||
api_key=resolved_api_key,
|
||||
api_mode=resolved_api_mode,
|
||||
)
|
||||
if client is None:
|
||||
_explicit = (resolved_provider or "").strip().lower()
|
||||
@@ -2273,11 +2427,9 @@ async def async_call_llm(
|
||||
f"variable, or switch to a different provider with `hermes model`."
|
||||
)
|
||||
if not resolved_base_url:
|
||||
logger.warning("Provider %s unavailable, falling back to openrouter",
|
||||
resolved_provider)
|
||||
client, final_model = _get_cached_client(
|
||||
"openrouter", resolved_model or _OPENROUTER_MODEL,
|
||||
async_mode=True)
|
||||
logger.info("Auxiliary %s: provider %s unavailable, trying auto-detection chain",
|
||||
task or "call", resolved_provider)
|
||||
client, final_model = _get_cached_client("auto", async_mode=True)
|
||||
if client is None:
|
||||
raise RuntimeError(
|
||||
f"No LLM provider configured for task={task} provider={resolved_provider}. "
|
||||
@@ -2292,11 +2444,42 @@ async def async_call_llm(
|
||||
base_url=resolved_base_url)
|
||||
|
||||
try:
|
||||
return await client.chat.completions.create(**kwargs)
|
||||
return _validate_llm_response(
|
||||
await client.chat.completions.create(**kwargs), task)
|
||||
except Exception as first_err:
|
||||
err_str = str(first_err)
|
||||
if "max_tokens" in err_str or "unsupported_parameter" in err_str:
|
||||
kwargs.pop("max_tokens", None)
|
||||
kwargs["max_completion_tokens"] = max_tokens
|
||||
return await client.chat.completions.create(**kwargs)
|
||||
try:
|
||||
return _validate_llm_response(
|
||||
await client.chat.completions.create(**kwargs), task)
|
||||
except Exception as retry_err:
|
||||
# If the max_tokens retry also hits a payment or connection
|
||||
# error, fall through to the fallback chain below.
|
||||
if not (_is_payment_error(retry_err) or _is_connection_error(retry_err)):
|
||||
raise
|
||||
first_err = retry_err
|
||||
|
||||
# ── Payment / connection fallback (mirrors sync call_llm) ─────
|
||||
should_fallback = _is_payment_error(first_err) or _is_connection_error(first_err)
|
||||
is_auto = resolved_provider in ("auto", "", None)
|
||||
if should_fallback and is_auto:
|
||||
reason = "payment error" if _is_payment_error(first_err) else "connection error"
|
||||
logger.info("Auxiliary %s (async): %s on %s (%s), trying fallback",
|
||||
task or "call", reason, resolved_provider, first_err)
|
||||
fb_client, fb_model, fb_label = _try_payment_fallback(
|
||||
resolved_provider, task, reason=reason)
|
||||
if fb_client is not None:
|
||||
fb_kwargs = _build_call_kwargs(
|
||||
fb_label, fb_model, messages,
|
||||
temperature=temperature, max_tokens=max_tokens,
|
||||
tools=tools, timeout=effective_timeout,
|
||||
extra_body=extra_body)
|
||||
# Convert sync fallback client to async
|
||||
async_fb, async_fb_model = _to_async_client(fb_client, fb_model or "")
|
||||
if async_fb_model and async_fb_model != fb_kwargs.get("model"):
|
||||
fb_kwargs["model"] = async_fb_model
|
||||
return _validate_llm_response(
|
||||
await async_fb.chat.completions.create(**fb_kwargs), task)
|
||||
raise
|
||||
|
||||
@@ -20,6 +20,7 @@ from typing import Any, Dict, List, Optional
|
||||
from agent.auxiliary_client import call_llm
|
||||
from agent.context_engine import ContextEngine
|
||||
from agent.model_metadata import (
|
||||
MINIMUM_CONTEXT_LENGTH,
|
||||
get_model_context_length,
|
||||
estimate_messages_tokens_rough,
|
||||
)
|
||||
@@ -87,7 +88,10 @@ class ContextCompressor(ContextEngine):
|
||||
self.api_key = api_key
|
||||
self.provider = provider
|
||||
self.context_length = context_length
|
||||
self.threshold_tokens = int(context_length * self.threshold_percent)
|
||||
self.threshold_tokens = max(
|
||||
int(context_length * self.threshold_percent),
|
||||
MINIMUM_CONTEXT_LENGTH,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -118,7 +122,14 @@ class ContextCompressor(ContextEngine):
|
||||
config_context_length=config_context_length,
|
||||
provider=provider,
|
||||
)
|
||||
self.threshold_tokens = int(self.context_length * threshold_percent)
|
||||
# Floor: never compress below MINIMUM_CONTEXT_LENGTH tokens even if
|
||||
# the percentage would suggest a lower value. This prevents premature
|
||||
# compression on large-context models at 50% while keeping the % sane
|
||||
# for models right at the minimum.
|
||||
self.threshold_tokens = max(
|
||||
int(self.context_length * threshold_percent),
|
||||
MINIMUM_CONTEXT_LENGTH,
|
||||
)
|
||||
self.compression_count = 0
|
||||
|
||||
# Derive token budgets: ratio is relative to the threshold, not total context
|
||||
|
||||
+9
-16
@@ -4,7 +4,6 @@ Pure display functions and classes with no AIAgent dependency.
|
||||
Used by AIAgent._execute_tool_calls for CLI feedback.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@@ -14,6 +13,8 @@ from dataclasses import dataclass, field
|
||||
from difflib import unified_diff
|
||||
from pathlib import Path
|
||||
|
||||
from utils import safe_json_loads
|
||||
|
||||
# ANSI escape codes for coloring tool failure indicators
|
||||
_RED = "\033[31m"
|
||||
_RESET = "\033[0m"
|
||||
@@ -372,9 +373,8 @@ def _result_succeeded(result: str | None) -> bool:
|
||||
"""Conservatively detect whether a tool result represents success."""
|
||||
if not result:
|
||||
return False
|
||||
try:
|
||||
data = json.loads(result)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
data = safe_json_loads(result)
|
||||
if data is None:
|
||||
return False
|
||||
if not isinstance(data, dict):
|
||||
return False
|
||||
@@ -423,10 +423,7 @@ def extract_edit_diff(
|
||||
) -> str | None:
|
||||
"""Extract a unified diff from a file-edit tool result."""
|
||||
if tool_name == "patch" and result:
|
||||
try:
|
||||
data = json.loads(result)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
data = None
|
||||
data = safe_json_loads(result)
|
||||
if isinstance(data, dict):
|
||||
diff = data.get("diff")
|
||||
if isinstance(diff, str) and diff.strip():
|
||||
@@ -780,23 +777,19 @@ def _detect_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str]
|
||||
return False, ""
|
||||
|
||||
if tool_name == "terminal":
|
||||
try:
|
||||
data = json.loads(result)
|
||||
data = safe_json_loads(result)
|
||||
if isinstance(data, dict):
|
||||
exit_code = data.get("exit_code")
|
||||
if exit_code is not None and exit_code != 0:
|
||||
return True, f" [exit {exit_code}]"
|
||||
except (json.JSONDecodeError, TypeError, AttributeError):
|
||||
logger.debug("Could not parse terminal result as JSON for exit code check")
|
||||
return False, ""
|
||||
|
||||
# Memory-specific: distinguish "full" from real errors
|
||||
if tool_name == "memory":
|
||||
try:
|
||||
data = json.loads(result)
|
||||
data = safe_json_loads(result)
|
||||
if isinstance(data, dict):
|
||||
if data.get("success") is False and "exceed the limit" in data.get("error", ""):
|
||||
return True, " [full]"
|
||||
except (json.JSONDecodeError, TypeError, AttributeError):
|
||||
logger.debug("Could not parse memory result as JSON for capacity check")
|
||||
|
||||
# Generic heuristic for non-terminal tools
|
||||
lower = result[:500].lower()
|
||||
|
||||
+35
-8
@@ -27,12 +27,14 @@ _PROVIDER_PREFIXES: frozenset[str] = frozenset({
|
||||
"gemini", "zai", "kimi-coding", "minimax", "minimax-cn", "anthropic", "deepseek",
|
||||
"opencode-zen", "opencode-go", "ai-gateway", "kilocode", "alibaba",
|
||||
"qwen-oauth",
|
||||
"xiaomi",
|
||||
"custom", "local",
|
||||
# Common aliases
|
||||
"google", "google-gemini", "google-ai-studio",
|
||||
"glm", "z-ai", "z.ai", "zhipu", "github", "github-copilot",
|
||||
"github-models", "kimi", "moonshot", "claude", "deep-seek",
|
||||
"opencode", "zen", "go", "vercel", "kilo", "dashscope", "aliyun", "qwen",
|
||||
"mimo", "xiaomi-mimo",
|
||||
"qwen-portal",
|
||||
})
|
||||
|
||||
@@ -83,6 +85,11 @@ CONTEXT_PROBE_TIERS = [
|
||||
# Default context length when no detection method succeeds.
|
||||
DEFAULT_FALLBACK_CONTEXT = CONTEXT_PROBE_TIERS[0]
|
||||
|
||||
# Minimum context length required to run Hermes Agent. Models with fewer
|
||||
# tokens cannot maintain enough working memory for tool-calling workflows.
|
||||
# Sessions, model switches, and cron jobs should reject models below this.
|
||||
MINIMUM_CONTEXT_LENGTH = 64_000
|
||||
|
||||
# Thin fallback defaults — only broad model family patterns.
|
||||
# These fire only when provider is unknown AND models.dev/OpenRouter/Anthropic
|
||||
# all miss. Replaced the previous 80+ entry dict.
|
||||
@@ -113,7 +120,10 @@ DEFAULT_CONTEXT_LENGTHS = {
|
||||
"deepseek": 128000,
|
||||
# Meta
|
||||
"llama": 131072,
|
||||
# Qwen
|
||||
# Qwen — specific model families before the catch-all.
|
||||
# Official docs: https://help.aliyun.com/zh/model-studio/developer-reference/
|
||||
"qwen3-coder-plus": 1000000, # 1M context
|
||||
"qwen3-coder": 262144, # 256K context
|
||||
"qwen": 131072,
|
||||
# MiniMax — official docs: 204,800 context for all models
|
||||
# https://platform.minimax.io/docs/api-reference/text-anthropic-api
|
||||
@@ -146,9 +156,10 @@ DEFAULT_CONTEXT_LENGTHS = {
|
||||
"moonshotai/Kimi-K2.5": 262144,
|
||||
"moonshotai/Kimi-K2-Thinking": 262144,
|
||||
"MiniMaxAI/MiniMax-M2.5": 204800,
|
||||
"XiaomiMiMo/MiMo-V2-Flash": 32768,
|
||||
"mimo-v2-pro": 1048576,
|
||||
"mimo-v2-omni": 1048576,
|
||||
"XiaomiMiMo/MiMo-V2-Flash": 256000,
|
||||
"mimo-v2-pro": 1000000,
|
||||
"mimo-v2-omni": 256000,
|
||||
"mimo-v2-flash": 256000,
|
||||
"zai-org/GLM-5": 202752,
|
||||
}
|
||||
|
||||
@@ -173,6 +184,12 @@ _MAX_COMPLETION_KEYS = (
|
||||
|
||||
# Local server hostnames / address patterns
|
||||
_LOCAL_HOSTS = ("localhost", "127.0.0.1", "::1", "0.0.0.0")
|
||||
# Docker / Podman / Lima DNS names that resolve to the host machine
|
||||
_CONTAINER_LOCAL_SUFFIXES = (
|
||||
".docker.internal",
|
||||
".containers.internal",
|
||||
".lima.internal",
|
||||
)
|
||||
|
||||
|
||||
def _normalize_base_url(base_url: str) -> str:
|
||||
@@ -208,6 +225,8 @@ _URL_TO_PROVIDER: Dict[str, str] = {
|
||||
"api.fireworks.ai": "fireworks",
|
||||
"opencode.ai": "opencode-go",
|
||||
"api.x.ai": "xai",
|
||||
"api.xiaomimimo.com": "xiaomi",
|
||||
"xiaomimimo.com": "xiaomi",
|
||||
}
|
||||
|
||||
|
||||
@@ -246,6 +265,9 @@ def is_local_endpoint(base_url: str) -> bool:
|
||||
return False
|
||||
if host in _LOCAL_HOSTS:
|
||||
return True
|
||||
# Docker / Podman / Lima internal DNS names (e.g. host.docker.internal)
|
||||
if any(host.endswith(suffix) for suffix in _CONTAINER_LOCAL_SUFFIXES):
|
||||
return True
|
||||
# RFC-1918 private ranges and link-local
|
||||
import ipaddress
|
||||
try:
|
||||
@@ -1023,16 +1045,21 @@ def get_model_context_length(
|
||||
|
||||
|
||||
def estimate_tokens_rough(text: str) -> int:
|
||||
"""Rough token estimate (~4 chars/token) for pre-flight checks."""
|
||||
"""Rough token estimate (~4 chars/token) for pre-flight checks.
|
||||
|
||||
Uses ceiling division so short texts (1-3 chars) never estimate as
|
||||
0 tokens, which would cause the compressor and pre-flight checks to
|
||||
systematically undercount when many short tool results are present.
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
return len(text) // 4
|
||||
return (len(text) + 3) // 4
|
||||
|
||||
|
||||
def estimate_messages_tokens_rough(messages: List[Dict[str, Any]]) -> int:
|
||||
"""Rough token estimate for a message list (pre-flight only)."""
|
||||
total_chars = sum(len(str(msg)) for msg in messages)
|
||||
return total_chars // 4
|
||||
return (total_chars + 3) // 4
|
||||
|
||||
|
||||
def estimate_request_tokens_rough(
|
||||
@@ -1055,4 +1082,4 @@ def estimate_request_tokens_rough(
|
||||
total_chars += sum(len(str(msg)) for msg in messages)
|
||||
if tools:
|
||||
total_chars += len(str(tools))
|
||||
return total_chars // 4
|
||||
return (total_chars + 3) // 4
|
||||
|
||||
+9
-1
@@ -161,6 +161,7 @@ PROVIDER_TO_MODELS_DEV: Dict[str, str] = {
|
||||
"gemini": "google",
|
||||
"google": "google",
|
||||
"xai": "xai",
|
||||
"xiaomi": "xiaomi",
|
||||
"nvidia": "nvidia",
|
||||
"groq": "groq",
|
||||
"mistral": "mistral",
|
||||
@@ -383,7 +384,14 @@ def get_model_capabilities(provider: str, model: str) -> Optional[ModelCapabilit
|
||||
|
||||
# Extract capability flags (default to False if missing)
|
||||
supports_tools = bool(entry.get("tool_call", False))
|
||||
supports_vision = bool(entry.get("attachment", False))
|
||||
# Vision: check both the `attachment` flag and `modalities.input` for "image".
|
||||
# Some models (e.g. gemma-4) list image in input modalities but not attachment.
|
||||
input_mods = entry.get("modalities", {})
|
||||
if isinstance(input_mods, dict):
|
||||
input_mods = input_mods.get("input", [])
|
||||
else:
|
||||
input_mods = []
|
||||
supports_vision = bool(entry.get("attachment", False)) or "image" in input_mods
|
||||
supports_reasoning = bool(entry.get("reasoning", False))
|
||||
|
||||
# Extract limits
|
||||
|
||||
@@ -12,7 +12,7 @@ import threading
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
from hermes_constants import get_hermes_home, get_skills_dir
|
||||
from typing import Optional
|
||||
|
||||
from agent.skill_utils import (
|
||||
@@ -548,8 +548,7 @@ def build_skills_system_prompt(
|
||||
are read-only — they appear in the index but new skills are always created
|
||||
in the local dir. Local skills take precedence when names collide.
|
||||
"""
|
||||
hermes_home = get_hermes_home()
|
||||
skills_dir = hermes_home / "skills"
|
||||
skills_dir = get_skills_dir()
|
||||
external_dirs = get_all_skills_dirs()[1:] # skip local (index 0)
|
||||
|
||||
if not skills_dir.exists() and not external_dirs:
|
||||
|
||||
@@ -168,7 +168,7 @@ def _build_skill_message(
|
||||
subdir_path = skill_dir / subdir
|
||||
if subdir_path.exists():
|
||||
for f in sorted(subdir_path.rglob("*")):
|
||||
if f.is_file():
|
||||
if f.is_file() and not f.is_symlink():
|
||||
rel = str(f.relative_to(skill_dir))
|
||||
supporting.append(rel)
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Set, Tuple
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
from hermes_constants import get_config_path, get_skills_dir
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -130,7 +130,7 @@ def get_disabled_skill_names(platform: str | None = None) -> Set[str]:
|
||||
Reads the config file directly (no CLI config imports) to stay
|
||||
lightweight.
|
||||
"""
|
||||
config_path = get_hermes_home() / "config.yaml"
|
||||
config_path = get_config_path()
|
||||
if not config_path.exists():
|
||||
return set()
|
||||
try:
|
||||
@@ -178,7 +178,7 @@ def get_external_skills_dirs() -> List[Path]:
|
||||
path. Only directories that actually exist are returned. Duplicates and
|
||||
paths that resolve to the local ``~/.hermes/skills/`` are silently skipped.
|
||||
"""
|
||||
config_path = get_hermes_home() / "config.yaml"
|
||||
config_path = get_config_path()
|
||||
if not config_path.exists():
|
||||
return []
|
||||
try:
|
||||
@@ -200,7 +200,7 @@ def get_external_skills_dirs() -> List[Path]:
|
||||
if not isinstance(raw_dirs, list):
|
||||
return []
|
||||
|
||||
local_skills = (get_hermes_home() / "skills").resolve()
|
||||
local_skills = get_skills_dir().resolve()
|
||||
seen: Set[Path] = set()
|
||||
result: List[Path] = []
|
||||
|
||||
@@ -230,7 +230,7 @@ def get_all_skills_dirs() -> List[Path]:
|
||||
The local dir is always first (and always included even if it doesn't exist
|
||||
yet — callers handle that). External dirs follow in config order.
|
||||
"""
|
||||
dirs = [get_hermes_home() / "skills"]
|
||||
dirs = [get_skills_dir()]
|
||||
dirs.extend(get_external_skills_dirs())
|
||||
return dirs
|
||||
|
||||
@@ -384,7 +384,7 @@ def resolve_skill_config_values(
|
||||
current values (or the declared default if the key isn't set).
|
||||
Path values are expanded via ``os.path.expanduser``.
|
||||
"""
|
||||
config_path = get_hermes_home() / "config.yaml"
|
||||
config_path = get_config_path()
|
||||
config: Dict[str, Any] = {}
|
||||
if config_path.exists():
|
||||
try:
|
||||
|
||||
@@ -24,6 +24,7 @@ model:
|
||||
# "minimax" - MiniMax global (requires: MINIMAX_API_KEY)
|
||||
# "minimax-cn" - MiniMax China (requires: MINIMAX_CN_API_KEY)
|
||||
# "huggingface" - Hugging Face Inference (requires: HF_TOKEN)
|
||||
# "xiaomi" - Xiaomi MiMo (requires: XIAOMI_API_KEY)
|
||||
# "kilocode" - KiloCode gateway (requires: KILOCODE_API_KEY)
|
||||
# "ai-gateway" - Vercel AI Gateway (requires: AI_GATEWAY_API_KEY)
|
||||
#
|
||||
@@ -588,7 +589,7 @@ platform_toolsets:
|
||||
# skills_hub - skill_hub (search/install/manage from online registries — user-driven only)
|
||||
# moa - mixture_of_agents (requires OPENROUTER_API_KEY)
|
||||
# todo - todo (in-memory task planning, no deps)
|
||||
# tts - text_to_speech (Edge TTS free, or ELEVENLABS/OPENAI/MINIMAX key)
|
||||
# tts - text_to_speech (Edge TTS free, or ELEVENLABS/OPENAI/MINIMAX/MISTRAL key)
|
||||
# cronjob - cronjob (create/list/update/pause/resume/run/remove scheduled tasks)
|
||||
# rl - rl_list_environments, rl_start_training, etc. (requires TINKER_API_KEY)
|
||||
#
|
||||
@@ -617,7 +618,7 @@ platform_toolsets:
|
||||
# todo - Task planning and tracking for multi-step work
|
||||
# memory - Persistent memory across sessions (personal notes + user profile)
|
||||
# session_search - Search and recall past conversations (FTS5 + Gemini Flash summarization)
|
||||
# tts - Text-to-speech (Edge TTS free, ElevenLabs, OpenAI, MiniMax)
|
||||
# tts - Text-to-speech (Edge TTS free, ElevenLabs, OpenAI, MiniMax, Mistral)
|
||||
# cronjob - Schedule and manage automated tasks (CLI-only)
|
||||
# rl - RL training tools (Tinker-Atropos)
|
||||
#
|
||||
@@ -773,6 +774,11 @@ display:
|
||||
# Toggle at runtime with /verbose in the CLI
|
||||
tool_progress: all
|
||||
|
||||
# Gateway-only natural mid-turn assistant updates.
|
||||
# When true, completed assistant status messages are sent as separate chat
|
||||
# messages. This is independent of tool_progress and gateway streaming.
|
||||
interim_assistant_messages: true
|
||||
|
||||
# What Enter does when Hermes is already busy in the CLI.
|
||||
# interrupt: Interrupt the current run and redirect Hermes (default)
|
||||
# queue: Queue your message for the next turn
|
||||
@@ -781,7 +787,7 @@ display:
|
||||
|
||||
# Background process notifications (gateway/messaging only).
|
||||
# Controls how chatty the process watcher is when you use
|
||||
# terminal(background=true, check_interval=...) from Telegram/Discord/etc.
|
||||
# terminal(background=true, notify_on_complete=true) from Telegram/Discord/etc.
|
||||
# off: No watcher messages at all
|
||||
# result: Only the final completion message
|
||||
# error: Only the final message when exit code != 0
|
||||
|
||||
@@ -1171,6 +1171,45 @@ def _resolve_attachment_path(raw_path: str) -> Path | None:
|
||||
return resolved
|
||||
|
||||
|
||||
def _format_process_notification(evt: dict) -> "str | None":
|
||||
"""Format a process notification event into a [SYSTEM: ...] message.
|
||||
|
||||
Handles both completion events (notify_on_complete) and watch pattern
|
||||
match events from the unified completion_queue.
|
||||
"""
|
||||
evt_type = evt.get("type", "completion")
|
||||
_sid = evt.get("session_id", "unknown")
|
||||
_cmd = evt.get("command", "unknown")
|
||||
|
||||
if evt_type == "watch_disabled":
|
||||
return f"[SYSTEM: {evt.get('message', '')}]"
|
||||
|
||||
if evt_type == "watch_match":
|
||||
_pat = evt.get("pattern", "?")
|
||||
_out = evt.get("output", "")
|
||||
_sup = evt.get("suppressed", 0)
|
||||
text = (
|
||||
f"[SYSTEM: Background process {_sid} matched "
|
||||
f"watch pattern \"{_pat}\".\n"
|
||||
f"Command: {_cmd}\n"
|
||||
f"Matched output:\n{_out}"
|
||||
)
|
||||
if _sup:
|
||||
text += f"\n({_sup} earlier matches were suppressed by rate limit)"
|
||||
text += "]"
|
||||
return text
|
||||
|
||||
# Default: completion event
|
||||
_exit = evt.get("exit_code", "?")
|
||||
_out = evt.get("output", "")
|
||||
return (
|
||||
f"[SYSTEM: Background process {_sid} completed "
|
||||
f"(exit code {_exit}).\n"
|
||||
f"Command: {_cmd}\n"
|
||||
f"Output:\n{_out}]"
|
||||
)
|
||||
|
||||
|
||||
def _detect_file_drop(user_input: str) -> "dict | None":
|
||||
"""Detect if *user_input* starts with a real local file path.
|
||||
|
||||
@@ -1316,6 +1355,19 @@ class ChatConsole:
|
||||
for line in output.rstrip("\n").split("\n"):
|
||||
_cprint(line)
|
||||
|
||||
@contextmanager
|
||||
def status(self, *_args, **_kwargs):
|
||||
"""Provide a no-op Rich-compatible status context.
|
||||
|
||||
Some slash command helpers use ``console.status(...)`` when running in
|
||||
the standalone CLI. Interactive chat routes those helpers through
|
||||
``ChatConsole()``, which historically only implemented ``print()``.
|
||||
Returning a silent context manager keeps slash commands compatible
|
||||
without duplicating the higher-level busy indicator already shown by
|
||||
``HermesCLI._busy_command()``.
|
||||
"""
|
||||
yield self
|
||||
|
||||
# ASCII Art - HERMES-AGENT logo (full width, single line - requires ~95 char terminal)
|
||||
HERMES_AGENT_LOGO = """[bold #FFD700]██╗ ██╗███████╗██████╗ ███╗ ███╗███████╗███████╗ █████╗ ██████╗ ███████╗███╗ ██╗████████╗[/]
|
||||
[bold #FFD700]██║ ██║██╔════╝██╔══██╗████╗ ████║██╔════╝██╔════╝ ██╔══██╗██╔════╝ ██╔════╝████╗ ██║╚══██╔══╝[/]
|
||||
@@ -1765,6 +1817,7 @@ class HermesCLI:
|
||||
self._approval_state = None
|
||||
self._approval_deadline = 0
|
||||
self._approval_lock = threading.Lock()
|
||||
self._model_picker_state = None
|
||||
self._secret_state = None
|
||||
self._secret_deadline = 0
|
||||
self._spinner_text: str = "" # thinking spinner text for TUI
|
||||
@@ -2007,7 +2060,7 @@ class HermesCLI:
|
||||
return f"⚕ {self.model if getattr(self, 'model', None) else 'Hermes'}"
|
||||
|
||||
def _get_status_bar_fragments(self):
|
||||
if not self._status_bar_visible:
|
||||
if not self._status_bar_visible or getattr(self, '_model_picker_state', None):
|
||||
return []
|
||||
try:
|
||||
snapshot = self._get_status_bar_snapshot()
|
||||
@@ -2671,6 +2724,15 @@ class HermesCLI:
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
|
||||
# When a custom_provider entry carries an explicit `model` field,
|
||||
# use it as the effective model name. Without this, running
|
||||
# `hermes chat --model <provider-name>` sends the provider name
|
||||
# (e.g. "my-provider") as the model string to the API instead of
|
||||
# the configured model (e.g. "qwen3.6-plus"), causing 400 errors.
|
||||
runtime_model = runtime.get("model")
|
||||
if runtime_model and isinstance(runtime_model, str):
|
||||
self.model = runtime_model
|
||||
|
||||
# Normalize model for the resolved provider (e.g. swap non-Codex
|
||||
# models when provider is openai-codex). Fixes #651.
|
||||
model_changed = self._normalize_model_for_provider(resolved_provider)
|
||||
@@ -4230,6 +4292,265 @@ class HermesCLI:
|
||||
remaining = len(self.conversation_history)
|
||||
print(f" {remaining} message(s) remaining in history.")
|
||||
|
||||
def _run_curses_picker(self, title: str, items: list[str], default_index: int = 0) -> int | None:
|
||||
"""Run curses_single_select via run_in_terminal so prompt_toolkit handles terminal ownership cleanly."""
|
||||
import threading
|
||||
from hermes_cli.curses_ui import curses_single_select
|
||||
|
||||
result = [None]
|
||||
|
||||
def _pick():
|
||||
result[0] = curses_single_select(title, items, default_index=default_index)
|
||||
|
||||
# run_in_terminal requires an asyncio event loop — only exists in the
|
||||
# main prompt_toolkit thread. If we're in a background thread (e.g.
|
||||
# process_loop), fall back to direct curses call.
|
||||
in_main_thread = threading.current_thread() is threading.main_thread()
|
||||
|
||||
if self._app and in_main_thread:
|
||||
from prompt_toolkit.application import run_in_terminal
|
||||
was_visible = self._status_bar_visible
|
||||
self._status_bar_visible = False
|
||||
self._app.invalidate()
|
||||
try:
|
||||
run_in_terminal(_pick)
|
||||
finally:
|
||||
self._status_bar_visible = was_visible
|
||||
self._app.invalidate()
|
||||
else:
|
||||
_pick()
|
||||
|
||||
return result[0]
|
||||
|
||||
def _prompt_text_input(self, prompt_text: str) -> str | None:
|
||||
"""Prompt for free-text input safely inside or outside prompt_toolkit."""
|
||||
result = [None]
|
||||
|
||||
def _ask():
|
||||
try:
|
||||
result[0] = input(prompt_text).strip() or None
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
pass
|
||||
|
||||
if self._app:
|
||||
from prompt_toolkit.application import run_in_terminal
|
||||
was_visible = self._status_bar_visible
|
||||
self._status_bar_visible = False
|
||||
self._app.invalidate()
|
||||
try:
|
||||
run_in_terminal(_ask)
|
||||
finally:
|
||||
self._status_bar_visible = was_visible
|
||||
self._app.invalidate()
|
||||
else:
|
||||
_ask()
|
||||
return result[0]
|
||||
|
||||
def _interactive_provider_selection(
|
||||
self, providers: list, current_model: str, current_provider: str
|
||||
) -> str | None:
|
||||
"""Show provider picker, return slug or None on cancel."""
|
||||
choices = []
|
||||
for p in providers:
|
||||
count = p.get("total_models", len(p.get("models", [])))
|
||||
label = f"{p['name']} ({count} model{'s' if count != 1 else ''})"
|
||||
if p.get("is_current"):
|
||||
label += " ← current"
|
||||
choices.append(label)
|
||||
|
||||
default_idx = next(
|
||||
(i for i, p in enumerate(providers) if p.get("is_current")), 0
|
||||
)
|
||||
|
||||
idx = self._run_curses_picker(
|
||||
f"Select a provider (current: {current_model} on {current_provider}):",
|
||||
choices,
|
||||
default_index=default_idx,
|
||||
)
|
||||
if idx is None:
|
||||
return None
|
||||
return providers[idx]["slug"]
|
||||
|
||||
def _interactive_model_selection(
|
||||
self, model_list: list, provider_data: dict
|
||||
) -> str | None:
|
||||
"""Show model picker for a given provider, return model_id or None on cancel."""
|
||||
pname = provider_data.get("name", provider_data.get("slug", ""))
|
||||
total = provider_data.get("total_models", len(model_list))
|
||||
|
||||
if not model_list:
|
||||
_cprint(f"\n No models listed for {pname}.")
|
||||
return self._prompt_text_input(" Enter model name manually (or Enter to cancel): ")
|
||||
|
||||
choices = list(model_list) + ["Enter custom model name"]
|
||||
idx = self._run_curses_picker(
|
||||
f"Select model from {pname} ({len(model_list)} of {total}):",
|
||||
choices,
|
||||
)
|
||||
if idx is None:
|
||||
return None
|
||||
if idx < len(model_list):
|
||||
return model_list[idx]
|
||||
return self._prompt_text_input(" Enter model name: ")
|
||||
|
||||
def _open_model_picker(self, providers: list, current_model: str, current_provider: str, user_provs=None, custom_provs=None) -> None:
|
||||
"""Open prompt_toolkit-native /model picker modal."""
|
||||
self._capture_modal_input_snapshot()
|
||||
default_idx = next((i for i, p in enumerate(providers) if p.get("is_current")), 0)
|
||||
self._model_picker_state = {
|
||||
"stage": "provider",
|
||||
"providers": providers,
|
||||
"selected": default_idx,
|
||||
"current_model": current_model,
|
||||
"current_provider": current_provider,
|
||||
"user_provs": user_provs,
|
||||
"custom_provs": custom_provs,
|
||||
}
|
||||
self._invalidate(min_interval=0.0)
|
||||
|
||||
def _close_model_picker(self) -> None:
|
||||
self._model_picker_state = None
|
||||
self._restore_modal_input_snapshot()
|
||||
self._invalidate(min_interval=0.0)
|
||||
|
||||
def _apply_model_switch_result(self, result, persist_global: bool) -> None:
|
||||
if not result.success:
|
||||
_cprint(f" ✗ {result.error_message}")
|
||||
return
|
||||
|
||||
old_model = self.model
|
||||
self.model = result.new_model
|
||||
self.provider = result.target_provider
|
||||
self.requested_provider = result.target_provider
|
||||
if result.api_key:
|
||||
self.api_key = result.api_key
|
||||
self._explicit_api_key = result.api_key
|
||||
if result.base_url:
|
||||
self.base_url = result.base_url
|
||||
self._explicit_base_url = result.base_url
|
||||
if result.api_mode:
|
||||
self.api_mode = result.api_mode
|
||||
|
||||
if self.agent is not None:
|
||||
try:
|
||||
self.agent.switch_model(
|
||||
new_model=result.new_model,
|
||||
new_provider=result.target_provider,
|
||||
api_key=result.api_key,
|
||||
base_url=result.base_url,
|
||||
api_mode=result.api_mode,
|
||||
)
|
||||
except Exception as exc:
|
||||
_cprint(f" ⚠ Agent swap failed ({exc}); change applied to next session.")
|
||||
|
||||
self._pending_model_switch_note = (
|
||||
f"[Note: model was just switched from {old_model} to {result.new_model} "
|
||||
f"via {result.provider_label or result.target_provider}. "
|
||||
f"Adjust your self-identification accordingly.]"
|
||||
)
|
||||
|
||||
provider_label = result.provider_label or result.target_provider
|
||||
_cprint(f" ✓ Model switched: {result.new_model}")
|
||||
_cprint(f" Provider: {provider_label}")
|
||||
|
||||
mi = result.model_info
|
||||
if mi:
|
||||
if mi.context_window:
|
||||
_cprint(f" Context: {mi.context_window:,} tokens")
|
||||
if mi.max_output:
|
||||
_cprint(f" Max output: {mi.max_output:,} tokens")
|
||||
if mi.has_cost_data():
|
||||
_cprint(f" Cost: {mi.format_cost()}")
|
||||
_cprint(f" Capabilities: {mi.format_capabilities()}")
|
||||
else:
|
||||
try:
|
||||
from agent.model_metadata import get_model_context_length
|
||||
ctx = get_model_context_length(
|
||||
result.new_model,
|
||||
base_url=result.base_url or self.base_url,
|
||||
api_key=result.api_key or self.api_key,
|
||||
provider=result.target_provider,
|
||||
)
|
||||
_cprint(f" Context: {ctx:,} tokens")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
cache_enabled = (
|
||||
("openrouter" in (result.base_url or "").lower() and "claude" in result.new_model.lower())
|
||||
or result.api_mode == "anthropic_messages"
|
||||
)
|
||||
if cache_enabled:
|
||||
_cprint(" Prompt caching: enabled")
|
||||
if result.warning_message:
|
||||
_cprint(f" ⚠ {result.warning_message}")
|
||||
if persist_global:
|
||||
save_config_value("model.default", result.new_model)
|
||||
if result.provider_changed:
|
||||
save_config_value("model.provider", result.target_provider)
|
||||
_cprint(" Saved to config.yaml (--global)")
|
||||
else:
|
||||
_cprint(" (session only — add --global to persist)")
|
||||
|
||||
def _handle_model_picker_selection(self, persist_global: bool = False) -> None:
|
||||
state = self._model_picker_state
|
||||
if not state:
|
||||
return
|
||||
selected = state.get("selected", 0)
|
||||
stage = state.get("stage")
|
||||
if stage == "provider":
|
||||
providers = state.get("providers") or []
|
||||
if selected >= len(providers):
|
||||
self._close_model_picker()
|
||||
return
|
||||
provider_data = providers[selected]
|
||||
model_list = []
|
||||
try:
|
||||
from hermes_cli.models import provider_model_ids
|
||||
live = provider_model_ids(provider_data["slug"])
|
||||
if live:
|
||||
model_list = live
|
||||
except Exception:
|
||||
pass
|
||||
if not model_list:
|
||||
model_list = provider_data.get("models", [])
|
||||
state["stage"] = "model"
|
||||
state["provider_data"] = provider_data
|
||||
state["model_list"] = model_list
|
||||
state["selected"] = 0
|
||||
self._invalidate(min_interval=0.0)
|
||||
return
|
||||
if stage == "model":
|
||||
provider_data = state.get("provider_data") or {}
|
||||
model_list = state.get("model_list") or []
|
||||
back_idx = len(model_list)
|
||||
cancel_idx = len(model_list) + 1
|
||||
if selected == back_idx:
|
||||
state["stage"] = "provider"
|
||||
state["selected"] = next((i for i, p in enumerate(state.get("providers") or []) if p.get("slug") == provider_data.get("slug")), 0)
|
||||
self._invalidate(min_interval=0.0)
|
||||
return
|
||||
if selected >= cancel_idx:
|
||||
self._close_model_picker()
|
||||
return
|
||||
if selected < len(model_list):
|
||||
from hermes_cli.model_switch import switch_model
|
||||
chosen_model = model_list[selected]
|
||||
result = switch_model(
|
||||
raw_input=chosen_model,
|
||||
current_provider=self.provider or "",
|
||||
current_model=self.model or "",
|
||||
current_base_url=self.base_url or "",
|
||||
current_api_key=self.api_key or "",
|
||||
is_global=persist_global,
|
||||
explicit_provider=provider_data.get("slug"),
|
||||
user_providers=state.get("user_provs"),
|
||||
custom_providers=state.get("custom_provs"),
|
||||
)
|
||||
self._close_model_picker()
|
||||
self._apply_model_switch_result(result, persist_global)
|
||||
return
|
||||
self._close_model_picker()
|
||||
|
||||
def _handle_model_switch(self, cmd_original: str):
|
||||
"""Handle /model command — switch model for this session.
|
||||
|
||||
@@ -4252,56 +4573,46 @@ class HermesCLI:
|
||||
|
||||
user_provs = None
|
||||
custom_provs = None
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
cfg = load_config()
|
||||
user_provs = cfg.get("providers")
|
||||
custom_provs = cfg.get("custom_providers")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# No args at all: show available providers + models
|
||||
# No args at all: open prompt_toolkit-native picker modal
|
||||
if not model_input and not explicit_provider:
|
||||
model_display = self.model or "unknown"
|
||||
provider_display = get_label(self.provider) if self.provider else "unknown"
|
||||
_cprint(f" Current: {model_display} on {provider_display}")
|
||||
_cprint("")
|
||||
|
||||
# Show authenticated providers with top models
|
||||
user_provs = None
|
||||
custom_provs = None
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
cfg = load_config()
|
||||
user_provs = cfg.get("providers")
|
||||
custom_provs = cfg.get("custom_providers")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
providers = list_authenticated_providers(
|
||||
current_provider=self.provider or "",
|
||||
user_providers=user_provs,
|
||||
custom_providers=custom_provs,
|
||||
max_models=6,
|
||||
max_models=50,
|
||||
)
|
||||
if providers:
|
||||
for p in providers:
|
||||
tag = " (current)" if p["is_current"] else ""
|
||||
_cprint(f" {p['name']} [--provider {p['slug']}]{tag}:")
|
||||
if p["models"]:
|
||||
model_strs = ", ".join(p["models"])
|
||||
extra = f" (+{p['total_models'] - len(p['models'])} more)" if p["total_models"] > len(p["models"]) else ""
|
||||
_cprint(f" {model_strs}{extra}")
|
||||
elif p.get("api_url"):
|
||||
_cprint(f" {p['api_url']} (use /model <name> --provider {p['slug']})")
|
||||
else:
|
||||
_cprint(f" (no models listed)")
|
||||
_cprint("")
|
||||
else:
|
||||
_cprint(" No authenticated providers found.")
|
||||
_cprint("")
|
||||
except Exception:
|
||||
pass
|
||||
providers = []
|
||||
|
||||
# Aliases
|
||||
from hermes_cli.model_switch import MODEL_ALIASES
|
||||
alias_list = ", ".join(sorted(MODEL_ALIASES.keys()))
|
||||
_cprint(f" Aliases: {alias_list}")
|
||||
_cprint("")
|
||||
_cprint(" /model <name> switch model")
|
||||
_cprint(" /model <name> --provider <slug> switch provider")
|
||||
_cprint(" /model <name> --global persist to config")
|
||||
if not providers:
|
||||
_cprint(" No authenticated providers found.")
|
||||
_cprint("")
|
||||
_cprint(" /model <name> switch model")
|
||||
_cprint(" /model --provider <slug> switch provider")
|
||||
return
|
||||
|
||||
self._open_model_picker(
|
||||
providers,
|
||||
model_display,
|
||||
provider_display,
|
||||
user_provs=user_provs,
|
||||
custom_provs=custom_provs,
|
||||
)
|
||||
return
|
||||
|
||||
# Perform the switch
|
||||
@@ -4409,6 +4720,18 @@ class HermesCLI:
|
||||
else:
|
||||
_cprint(" (session only — add --global to persist)")
|
||||
|
||||
def _should_handle_model_command_inline(self, text: str, has_images: bool = False) -> bool:
|
||||
"""Return True when /model should be handled immediately on the UI thread."""
|
||||
if not text or has_images or not _looks_like_slash_command(text):
|
||||
return False
|
||||
try:
|
||||
from hermes_cli.commands import resolve_command
|
||||
base = text.split(None, 1)[0].lower().lstrip('/')
|
||||
cmd = resolve_command(base)
|
||||
return bool(cmd and cmd.name == "model")
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _show_model_and_providers(self):
|
||||
"""Show current model + provider and list all authenticated providers.
|
||||
|
||||
@@ -7631,7 +7954,8 @@ class HermesCLI:
|
||||
secret_widget,
|
||||
approval_widget,
|
||||
clarify_widget,
|
||||
spinner_widget,
|
||||
model_picker_widget=None,
|
||||
spinner_widget=None,
|
||||
spacer,
|
||||
status_bar,
|
||||
input_rule_top,
|
||||
@@ -7648,21 +7972,24 @@ class HermesCLI:
|
||||
ordering.
|
||||
"""
|
||||
return [
|
||||
Window(height=0),
|
||||
sudo_widget,
|
||||
secret_widget,
|
||||
approval_widget,
|
||||
clarify_widget,
|
||||
spinner_widget,
|
||||
spacer,
|
||||
*self._get_extra_tui_widgets(),
|
||||
status_bar,
|
||||
input_rule_top,
|
||||
image_bar,
|
||||
input_area,
|
||||
input_rule_bot,
|
||||
voice_status_bar,
|
||||
completions_menu,
|
||||
item for item in [
|
||||
Window(height=0),
|
||||
sudo_widget,
|
||||
secret_widget,
|
||||
approval_widget,
|
||||
clarify_widget,
|
||||
model_picker_widget,
|
||||
spinner_widget,
|
||||
spacer,
|
||||
*self._get_extra_tui_widgets(),
|
||||
status_bar,
|
||||
input_rule_top,
|
||||
image_bar,
|
||||
input_area,
|
||||
input_rule_bot,
|
||||
voice_status_bar,
|
||||
completions_menu,
|
||||
] if item is not None
|
||||
]
|
||||
|
||||
def run(self):
|
||||
@@ -7823,6 +8150,12 @@ class HermesCLI:
|
||||
event.app.invalidate()
|
||||
return
|
||||
|
||||
# --- /model picker modal ---
|
||||
if self._model_picker_state:
|
||||
self._handle_model_picker_selection()
|
||||
event.app.invalidate()
|
||||
return
|
||||
|
||||
# --- Clarify freetext mode: user typed their own answer ---
|
||||
if self._clarify_freetext and self._clarify_state:
|
||||
text = event.app.current_buffer.text.strip()
|
||||
@@ -7853,6 +8186,16 @@ class HermesCLI:
|
||||
text = event.app.current_buffer.text.strip()
|
||||
has_images = bool(self._attached_images)
|
||||
if text or has_images:
|
||||
# Handle /model directly on the UI thread so interactive pickers
|
||||
# can safely use prompt_toolkit terminal handoff helpers.
|
||||
if self._should_handle_model_command_inline(text, has_images=has_images):
|
||||
if not self.process_command(text):
|
||||
self._should_exit = True
|
||||
if event.app.is_running:
|
||||
event.app.exit()
|
||||
event.app.current_buffer.reset(append_to_history=True)
|
||||
return
|
||||
|
||||
# Snapshot and clear attached images
|
||||
images = list(self._attached_images)
|
||||
self._attached_images.clear()
|
||||
@@ -7956,12 +8299,31 @@ class HermesCLI:
|
||||
self._approval_state["selected"] = min(max_idx, self._approval_state["selected"] + 1)
|
||||
event.app.invalidate()
|
||||
|
||||
# --- /model picker: arrow-key navigation ---
|
||||
@kb.add('up', filter=Condition(lambda: bool(self._model_picker_state)))
|
||||
def model_picker_up(event):
|
||||
if self._model_picker_state:
|
||||
self._model_picker_state["selected"] = max(0, self._model_picker_state.get("selected", 0) - 1)
|
||||
event.app.invalidate()
|
||||
|
||||
@kb.add('down', filter=Condition(lambda: bool(self._model_picker_state)))
|
||||
def model_picker_down(event):
|
||||
state = self._model_picker_state
|
||||
if not state:
|
||||
return
|
||||
if state.get("stage") == "provider":
|
||||
max_idx = len(state.get("providers") or [])
|
||||
else:
|
||||
max_idx = len(state.get("model_list") or []) + 1
|
||||
state["selected"] = min(max_idx, state.get("selected", 0) + 1)
|
||||
event.app.invalidate()
|
||||
|
||||
# --- History navigation: up/down browse history in normal input mode ---
|
||||
# The TextArea is multiline, so by default up/down only move the cursor.
|
||||
# Buffer.auto_up/auto_down handle both: cursor movement when multi-line,
|
||||
# history browsing when on the first/last line (or single-line input).
|
||||
_normal_input = Condition(
|
||||
lambda: not self._clarify_state and not self._approval_state and not self._sudo_state and not self._secret_state
|
||||
lambda: not self._clarify_state and not self._approval_state and not self._sudo_state and not self._secret_state and not self._model_picker_state
|
||||
)
|
||||
|
||||
@kb.add('up', filter=_normal_input)
|
||||
@@ -8027,6 +8389,13 @@ class HermesCLI:
|
||||
event.app.invalidate()
|
||||
return
|
||||
|
||||
# Cancel /model picker
|
||||
if self._model_picker_state:
|
||||
self._close_model_picker()
|
||||
event.app.current_buffer.reset()
|
||||
event.app.invalidate()
|
||||
return
|
||||
|
||||
# Cancel clarify prompt
|
||||
if self._clarify_state:
|
||||
self._clarify_state["response_queue"].put(
|
||||
@@ -8079,7 +8448,7 @@ class HermesCLI:
|
||||
agent_name = get_active_skin().get_branding("agent_name", "Hermes Agent")
|
||||
msg = f"\n{agent_name} has been suspended. Run `fg` to bring {agent_name} back."
|
||||
def _suspend():
|
||||
os.write(1, msg.encode("utf-8", errors="replace"))
|
||||
os.write(1, msg.encode())
|
||||
os.kill(0, _sig.SIGTSTP)
|
||||
run_in_terminal(_suspend)
|
||||
|
||||
@@ -8644,6 +9013,60 @@ class HermesCLI:
|
||||
filter=Condition(lambda: cli_ref._approval_state is not None),
|
||||
)
|
||||
|
||||
# --- /model picker: display widget ---
|
||||
def _get_model_picker_display():
|
||||
state = cli_ref._model_picker_state
|
||||
if not state:
|
||||
return []
|
||||
stage = state.get("stage", "provider")
|
||||
if stage == "provider":
|
||||
title = "⚙ Model Picker — Select Provider"
|
||||
choices = []
|
||||
for p in state.get("providers") or []:
|
||||
count = p.get("total_models", len(p.get("models", [])))
|
||||
label = f"{p['name']} ({count} model{'s' if count != 1 else ''})"
|
||||
if p.get("is_current"):
|
||||
label += " ← current"
|
||||
choices.append(label)
|
||||
choices.append("Cancel")
|
||||
hint = f"Current: {state.get('current_model', 'unknown')} on {state.get('current_provider', 'unknown')}"
|
||||
else:
|
||||
provider_data = state.get("provider_data") or {}
|
||||
model_list = state.get("model_list") or []
|
||||
title = f"⚙ Model Picker — {provider_data.get('name', provider_data.get('slug', 'Provider'))}"
|
||||
choices = list(model_list) + ["← Back", "Cancel"]
|
||||
if model_list:
|
||||
hint = f"Select a model ({len(model_list)} available)"
|
||||
else:
|
||||
hint = "No models listed for this provider. Use Back or Cancel."
|
||||
|
||||
box_width = _panel_box_width(title, [hint] + choices, min_width=46, max_width=84)
|
||||
inner_text_width = max(8, box_width - 6)
|
||||
lines = []
|
||||
lines.append(('class:clarify-border', '╭─ '))
|
||||
lines.append(('class:clarify-title', title))
|
||||
lines.append(('class:clarify-border', ' ' + ('─' * max(0, box_width - len(title) - 3)) + '╮\n'))
|
||||
_append_blank_panel_line(lines, 'class:clarify-border', box_width)
|
||||
_append_panel_line(lines, 'class:clarify-border', 'class:clarify-hint', hint, box_width)
|
||||
_append_blank_panel_line(lines, 'class:clarify-border', box_width)
|
||||
selected = state.get("selected", 0)
|
||||
for idx, choice in enumerate(choices):
|
||||
style = 'class:clarify-selected' if idx == selected else 'class:clarify-choice'
|
||||
prefix = '❯ ' if idx == selected else ' '
|
||||
for wrapped in _wrap_panel_text(prefix + choice, inner_text_width, subsequent_indent=' '):
|
||||
_append_panel_line(lines, 'class:clarify-border', style, wrapped, box_width)
|
||||
_append_blank_panel_line(lines, 'class:clarify-border', box_width)
|
||||
lines.append(('class:clarify-border', '╰' + ('─' * box_width) + '╯\n'))
|
||||
return lines
|
||||
|
||||
model_picker_widget = ConditionalContainer(
|
||||
Window(
|
||||
FormattedTextControl(_get_model_picker_display),
|
||||
wrap_lines=True,
|
||||
),
|
||||
filter=Condition(lambda: cli_ref._model_picker_state is not None),
|
||||
)
|
||||
|
||||
# Horizontal rules above and below the input.
|
||||
# On narrow/mobile terminals we keep the top separator for structure but
|
||||
# hide the bottom one to recover a full row for conversation content.
|
||||
@@ -8719,6 +9142,7 @@ class HermesCLI:
|
||||
secret_widget=secret_widget,
|
||||
approval_widget=approval_widget,
|
||||
clarify_widget=clarify_widget,
|
||||
model_picker_widget=model_picker_widget,
|
||||
spinner_widget=spinner_widget,
|
||||
spacer=spacer,
|
||||
status_bar=status_bar,
|
||||
@@ -8870,23 +9294,15 @@ class HermesCLI:
|
||||
# Periodic config watcher — auto-reload MCP on mcp_servers change
|
||||
if not self._agent_running:
|
||||
self._check_config_mcp_changes()
|
||||
# Check for background process completion notifications
|
||||
# while the agent is idle (user hasn't typed anything yet).
|
||||
# Check for background process notifications (completions
|
||||
# and watch pattern matches) while agent is idle.
|
||||
try:
|
||||
from tools.process_registry import process_registry
|
||||
if not process_registry.completion_queue.empty():
|
||||
completion = process_registry.completion_queue.get_nowait()
|
||||
_exit = completion.get("exit_code", "?")
|
||||
_cmd = completion.get("command", "unknown")
|
||||
_sid = completion.get("session_id", "unknown")
|
||||
_out = completion.get("output", "")
|
||||
_synth = (
|
||||
f"[SYSTEM: Background process {_sid} completed "
|
||||
f"(exit code {_exit}).\n"
|
||||
f"Command: {_cmd}\n"
|
||||
f"Output:\n{_out}]"
|
||||
)
|
||||
self._pending_input.put(_synth)
|
||||
evt = process_registry.completion_queue.get_nowait()
|
||||
_synth = _format_process_notification(evt)
|
||||
if _synth:
|
||||
self._pending_input.put(_synth)
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
@@ -9004,25 +9420,15 @@ class HermesCLI:
|
||||
_cprint(f"{_DIM}Voice auto-restart failed: {e}{_RST}")
|
||||
threading.Thread(target=_restart_recording, daemon=True).start()
|
||||
|
||||
# Drain process completion notifications — any background
|
||||
# process that finished with notify_on_complete while the
|
||||
# agent was running (or before) gets auto-injected as a
|
||||
# new user message so the agent can react to it.
|
||||
# Drain process notifications (completions + watch matches)
|
||||
# that arrived while the agent was running.
|
||||
try:
|
||||
from tools.process_registry import process_registry
|
||||
while not process_registry.completion_queue.empty():
|
||||
completion = process_registry.completion_queue.get_nowait()
|
||||
_exit = completion.get("exit_code", "?")
|
||||
_cmd = completion.get("command", "unknown")
|
||||
_sid = completion.get("session_id", "unknown")
|
||||
_out = completion.get("output", "")
|
||||
_synth = (
|
||||
f"[SYSTEM: Background process {_sid} completed "
|
||||
f"(exit code {_exit}).\n"
|
||||
f"Command: {_cmd}\n"
|
||||
f"Output:\n{_out}]"
|
||||
)
|
||||
self._pending_input.put(_synth)
|
||||
evt = process_registry.completion_queue.get_nowait()
|
||||
_synth = _format_process_notification(evt)
|
||||
if _synth:
|
||||
self._pending_input.put(_synth)
|
||||
except Exception:
|
||||
pass # Non-fatal — don't break the main loop
|
||||
|
||||
|
||||
+11
-8
@@ -44,7 +44,7 @@ logger = logging.getLogger(__name__)
|
||||
_KNOWN_DELIVERY_PLATFORMS = frozenset({
|
||||
"telegram", "discord", "slack", "whatsapp", "signal",
|
||||
"matrix", "mattermost", "homeassistant", "dingtalk", "feishu",
|
||||
"wecom", "weixin", "sms", "email", "webhook", "bluebubbles",
|
||||
"wecom", "wecom_callback", "weixin", "sms", "email", "webhook", "bluebubbles",
|
||||
})
|
||||
|
||||
from cron.jobs import get_due_jobs, mark_job_run, save_job_output, advance_next_run
|
||||
@@ -234,6 +234,7 @@ def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> Option
|
||||
"dingtalk": Platform.DINGTALK,
|
||||
"feishu": Platform.FEISHU,
|
||||
"wecom": Platform.WECOM,
|
||||
"wecom_callback": Platform.WECOM_CALLBACK,
|
||||
"weixin": Platform.WEIXIN,
|
||||
"email": Platform.EMAIL,
|
||||
"sms": Platform.SMS,
|
||||
@@ -442,6 +443,14 @@ def _run_job_script(script_path: str) -> tuple[bool, str]:
|
||||
stdout = (result.stdout or "").strip()
|
||||
stderr = (result.stderr or "").strip()
|
||||
|
||||
# Redact secrets from both stdout and stderr before any return path.
|
||||
try:
|
||||
from agent.redact import redact_sensitive_text
|
||||
stdout = redact_sensitive_text(stdout)
|
||||
stderr = redact_sensitive_text(stderr)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if result.returncode != 0:
|
||||
parts = [f"Script exited with code {result.returncode}"]
|
||||
if stderr:
|
||||
@@ -450,13 +459,6 @@ def _run_job_script(script_path: str) -> tuple[bool, str]:
|
||||
parts.append(f"stdout:\n{stdout}")
|
||||
return False, "\n".join(parts)
|
||||
|
||||
# Redact any secrets that may appear in script output before
|
||||
# they are injected into the LLM prompt context.
|
||||
try:
|
||||
from agent.redact import redact_sensitive_text
|
||||
stdout = redact_sensitive_text(stdout)
|
||||
except Exception:
|
||||
pass
|
||||
return True, stdout
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
@@ -721,6 +723,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
provider_sort=pr.get("sort"),
|
||||
disabled_toolsets=["cronjob", "messaging", "clarify"],
|
||||
quiet_mode=True,
|
||||
skip_context_files=True, # Don't inject SOUL.md/AGENTS.md from scheduler cwd
|
||||
skip_memory=True, # Cron system prompts would corrupt user representations
|
||||
platform="cron",
|
||||
session_id=_cron_session_id,
|
||||
|
||||
+44
-12
@@ -11,12 +11,14 @@ When you run `hermes setup` for the first time and Hermes detects `~/.openclaw`,
|
||||
### 2. CLI Command (quick, scriptable)
|
||||
|
||||
```bash
|
||||
hermes claw migrate # Full migration with confirmation prompt
|
||||
hermes claw migrate --dry-run # Preview what would happen
|
||||
hermes claw migrate # Preview then migrate (always shows preview first)
|
||||
hermes claw migrate --dry-run # Preview only, no changes
|
||||
hermes claw migrate --preset user-data # Migrate without API keys/secrets
|
||||
hermes claw migrate --yes # Skip confirmation prompt
|
||||
```
|
||||
|
||||
The migration always shows a full preview of what will be imported before making any changes. You review the preview and confirm before anything is written.
|
||||
|
||||
**All options:**
|
||||
|
||||
| Flag | Description |
|
||||
@@ -39,7 +41,7 @@ Ask the agent to run the migration for you:
|
||||
```
|
||||
|
||||
The agent will use the `openclaw-migration` skill to:
|
||||
1. Run a dry-run first to preview changes
|
||||
1. Run a preview first to show what would change
|
||||
2. Ask about conflict resolution (SOUL.md, skills, etc.)
|
||||
3. Let you choose between `user-data` and `full` presets
|
||||
4. Execute the migration with your choices
|
||||
@@ -58,16 +60,31 @@ The agent will use the `openclaw-migration` skill to:
|
||||
| Messaging settings | `~/.openclaw/config.yaml` (TELEGRAM_ALLOWED_USERS, MESSAGING_CWD) | `~/.hermes/.env` |
|
||||
| TTS assets | `~/.openclaw/workspace/tts/` | `~/.hermes/tts/` |
|
||||
|
||||
Workspace files are also checked at `workspace.default/` and `workspace-main/` as fallback paths (OpenClaw renamed `workspace/` to `workspace-main/` in recent versions).
|
||||
|
||||
### `full` preset (adds to `user-data`)
|
||||
| Item | Source | Destination |
|
||||
|------|--------|-------------|
|
||||
| Telegram bot token | `~/.openclaw/config.yaml` | `~/.hermes/.env` |
|
||||
| OpenRouter API key | `~/.openclaw/.env` or config | `~/.hermes/.env` |
|
||||
| OpenAI API key | `~/.openclaw/.env` or config | `~/.hermes/.env` |
|
||||
| Anthropic API key | `~/.openclaw/.env` or config | `~/.hermes/.env` |
|
||||
| ElevenLabs API key | `~/.openclaw/.env` or config | `~/.hermes/.env` |
|
||||
| Telegram bot token | `openclaw.json` channels config | `~/.hermes/.env` |
|
||||
| OpenRouter API key | `.env`, `openclaw.json`, or `openclaw.json["env"]` | `~/.hermes/.env` |
|
||||
| OpenAI API key | `.env`, `openclaw.json`, or `openclaw.json["env"]` | `~/.hermes/.env` |
|
||||
| Anthropic API key | `.env`, `openclaw.json`, or `openclaw.json["env"]` | `~/.hermes/.env` |
|
||||
| ElevenLabs API key | `.env`, `openclaw.json`, or `openclaw.json["env"]` | `~/.hermes/.env` |
|
||||
|
||||
Only these 6 allowlisted secrets are ever imported. Other credentials are skipped and reported.
|
||||
API keys are searched across four sources: inline config values, `~/.openclaw/.env`, the `openclaw.json` `"env"` sub-object, and per-agent auth profiles.
|
||||
|
||||
Only allowlisted secrets are ever imported. Other credentials are skipped and reported.
|
||||
|
||||
## OpenClaw Schema Compatibility
|
||||
|
||||
The migration handles both old and current OpenClaw config layouts:
|
||||
|
||||
- **Channel tokens**: Reads from flat paths (`channels.telegram.botToken`) and the newer `accounts.default` layout (`channels.telegram.accounts.default.botToken`)
|
||||
- **TTS provider**: OpenClaw renamed "edge" to "microsoft" — both are recognized and mapped to Hermes' "edge"
|
||||
- **Provider API types**: Both short (`openai`, `anthropic`) and hyphenated (`openai-completions`, `anthropic-messages`, `google-generative-ai`) values are mapped correctly
|
||||
- **thinkingDefault**: All enum values are handled including newer ones (`minimal`, `xhigh`, `adaptive`)
|
||||
- **Matrix**: Uses `accessToken` field (not `botToken`)
|
||||
- **SecretRef formats**: Plain strings, env templates (`${VAR}`), and `source: "env"` SecretRefs are resolved. `source: "file"` and `source: "exec"` SecretRefs produce a warning — add those keys manually after migration.
|
||||
|
||||
## Conflict Handling
|
||||
|
||||
@@ -84,18 +101,24 @@ For skills, you can also use `--skill-conflict rename` to import conflicting ski
|
||||
|
||||
## Migration Report
|
||||
|
||||
Every migration (including dry runs) produces a report showing:
|
||||
Every migration produces a report showing:
|
||||
- **Migrated items** — what was successfully imported
|
||||
- **Conflicts** — items skipped because they already exist
|
||||
- **Skipped items** — items not found in the source
|
||||
- **Errors** — items that failed to import
|
||||
|
||||
For execute runs, the full report is saved to `~/.hermes/migration/openclaw/<timestamp>/`.
|
||||
For executed migrations, the full report is saved to `~/.hermes/migration/openclaw/<timestamp>/`.
|
||||
|
||||
## Post-Migration Notes
|
||||
|
||||
- **Skills require a new session** — imported skills take effect after restarting your agent or starting a new chat.
|
||||
- **WhatsApp requires re-pairing** — WhatsApp uses QR-code pairing, not token-based auth. Run `hermes whatsapp` to pair.
|
||||
- **Archive cleanup** — after migration, you'll be offered to rename `~/.openclaw/` to `.openclaw.pre-migration/` to prevent state confusion. You can also run `hermes claw cleanup` later.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "OpenClaw directory not found"
|
||||
The migration looks for `~/.openclaw` by default. If your OpenClaw is installed elsewhere, use `--source`:
|
||||
The migration looks for `~/.openclaw` by default, then tries `~/.clawdbot` and `~/.moldbot`. If your OpenClaw is installed elsewhere, use `--source`:
|
||||
```bash
|
||||
hermes claw migrate --source /path/to/.openclaw
|
||||
```
|
||||
@@ -108,3 +131,12 @@ hermes skills install openclaw-migration
|
||||
|
||||
### Memory overflow
|
||||
If your OpenClaw MEMORY.md or USER.md exceeds Hermes' character limits, excess entries are exported to an overflow file in the migration report directory. You can manually review and add the most important ones.
|
||||
|
||||
### API keys not found
|
||||
Keys might be stored in different places depending on your OpenClaw setup:
|
||||
- `~/.openclaw/.env` file
|
||||
- Inline in `openclaw.json` under `models.providers.*.apiKey`
|
||||
- In `openclaw.json` under the `"env"` or `"env.vars"` sub-objects
|
||||
- In `~/.openclaw/agents/main/agent/auth-profiles.json`
|
||||
|
||||
The migration checks all four. If keys use `source: "file"` or `source: "exec"` SecretRefs, they can't be resolved automatically — add them via `hermes config set`.
|
||||
|
||||
@@ -0,0 +1,329 @@
|
||||
# Container-Aware CLI Review Fixes Spec
|
||||
|
||||
**PR:** NousResearch/hermes-agent#7543
|
||||
**Review:** cursor[bot] bugbot review (4094049442) + two prior rounds
|
||||
**Date:** 2026-04-12
|
||||
**Branch:** `feat/container-aware-cli-clean`
|
||||
|
||||
## Review Issues Summary
|
||||
|
||||
Six issues were raised across three bugbot review rounds. Three were fixed in intermediate commits (38277a6a, 726cf90f). This spec addresses remaining design concerns surfaced by those reviews and simplifies the implementation based on interview decisions.
|
||||
|
||||
| # | Issue | Severity | Status |
|
||||
|---|-------|----------|--------|
|
||||
| 1 | `os.execvp` retry loop unreachable | Medium | Fixed in 79e8cd12 (switched to subprocess.run) |
|
||||
| 2 | Redundant `shutil.which("sudo")` | Medium | Fixed in 38277a6a (reuses `sudo` var) |
|
||||
| 3 | Missing `chown -h` on symlink update | Low | Fixed in 38277a6a |
|
||||
| 4 | Container routing after `parse_args()` | High | Fixed in 726cf90f |
|
||||
| 5 | Hardcoded `/home/${user}` | Medium | Fixed in 726cf90f |
|
||||
| 6 | Group membership not gated on `container.enable` | Low | Fixed in 726cf90f |
|
||||
|
||||
The mechanical fixes are in place but the overall design needs revision. The retry loop, error swallowing, and process model have deeper issues than what the bugbot flagged.
|
||||
|
||||
---
|
||||
|
||||
## Spec: Revised `_exec_in_container`
|
||||
|
||||
### Design Principles
|
||||
|
||||
1. **Let it crash.** No silent fallbacks. If `.container-mode` exists but something goes wrong, the error propagates naturally (Python traceback). The only case where container routing is skipped is when `.container-mode` doesn't exist or `HERMES_DEV=1`.
|
||||
2. **No retries.** Probe once for sudo, exec once. If it fails, docker/podman's stderr reaches the user verbatim.
|
||||
3. **Completely transparent.** No error wrapping, no prefixes, no spinners. Docker's output goes straight through.
|
||||
4. **`os.execvp` on the happy path.** Replace the Python process entirely so there's no idle parent during interactive sessions. Note: `execvp` never returns on success (process is replaced) and raises `OSError` on failure (it does not return a value). The container process's exit code becomes the process exit code by definition — no explicit propagation needed.
|
||||
5. **One human-readable exception to "let it crash".** `subprocess.TimeoutExpired` from the sudo probe gets a specific catch with a readable message, since a raw traceback for "your Docker daemon is slow" is confusing. All other exceptions propagate naturally.
|
||||
|
||||
### Execution Flow
|
||||
|
||||
```
|
||||
1. get_container_exec_info()
|
||||
- HERMES_DEV=1 → return None (skip routing)
|
||||
- Inside container → return None (skip routing)
|
||||
- .container-mode doesn't exist → return None (skip routing)
|
||||
- .container-mode exists → parse and return dict
|
||||
- .container-mode exists but malformed/unreadable → LET IT CRASH (no try/except)
|
||||
|
||||
2. _exec_in_container(container_info, sys.argv[1:])
|
||||
a. shutil.which(backend) → if None, print "{backend} not found on PATH" and sys.exit(1)
|
||||
b. Sudo probe: subprocess.run([runtime, "inspect", "--format", "ok", container_name], timeout=15)
|
||||
- If succeeds → needs_sudo = False
|
||||
- If fails → try subprocess.run([sudo, "-n", runtime, "inspect", ...], timeout=15)
|
||||
- If succeeds → needs_sudo = True
|
||||
- If fails → print error with sudoers hint (including why -n is required) and sys.exit(1)
|
||||
- If TimeoutExpired → catch specifically, print human-readable message about slow daemon
|
||||
c. Build exec_cmd: [sudo? + runtime, "exec", tty_flags, "-u", exec_user, env_flags, container, hermes_bin, *cli_args]
|
||||
d. os.execvp(exec_cmd[0], exec_cmd)
|
||||
- On success: process is replaced — Python is gone, container exit code IS the process exit code
|
||||
- On OSError: let it crash (natural traceback)
|
||||
```
|
||||
|
||||
### Changes to `hermes_cli/main.py`
|
||||
|
||||
#### `_exec_in_container` — rewrite
|
||||
|
||||
Remove:
|
||||
- The entire retry loop (`max_retries`, `for attempt in range(...)`)
|
||||
- Spinner logic (`"Waiting for container..."`, dots)
|
||||
- Exit code classification (125/126/127 handling)
|
||||
- `subprocess.run` for the exec call (keep it only for the sudo probe)
|
||||
- Special TTY vs non-TTY retry counts
|
||||
- The `time` import (no longer needed)
|
||||
|
||||
Change:
|
||||
- Use `os.execvp(exec_cmd[0], exec_cmd)` as the final call
|
||||
- Keep the `subprocess` import only for the sudo probe
|
||||
- Keep TTY detection for the `-it` vs `-i` flag
|
||||
- Keep env var forwarding (TERM, COLORTERM, LANG, LC_ALL)
|
||||
- Keep the sudo probe as-is (it's the one "smart" part)
|
||||
- Bump probe `timeout` from 5s to 15s — cold podman on a loaded machine needs headroom
|
||||
- Catch `subprocess.TimeoutExpired` specifically on both probe calls — print a readable message about the daemon being unresponsive instead of a raw traceback
|
||||
- Expand the sudoers hint error message to explain *why* `-n` (non-interactive) is required: a password prompt would hang the CLI or break piped commands
|
||||
|
||||
The function becomes roughly:
|
||||
|
||||
```python
|
||||
def _exec_in_container(container_info: dict, cli_args: list):
|
||||
"""Replace the current process with a command inside the managed container.
|
||||
|
||||
Probes whether sudo is needed (rootful containers), then os.execvp
|
||||
into the container. If exec fails, the OS error propagates naturally.
|
||||
"""
|
||||
import shutil
|
||||
import subprocess
|
||||
|
||||
backend = container_info["backend"]
|
||||
container_name = container_info["container_name"]
|
||||
exec_user = container_info["exec_user"]
|
||||
hermes_bin = container_info["hermes_bin"]
|
||||
|
||||
runtime = shutil.which(backend)
|
||||
if not runtime:
|
||||
print(f"Error: {backend} not found on PATH. Cannot route to container.",
|
||||
file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Probe whether we need sudo to see the rootful container.
|
||||
# Timeout is 15s — cold podman on a loaded machine can take a while.
|
||||
# TimeoutExpired is caught specifically for a human-readable message;
|
||||
# all other exceptions propagate naturally.
|
||||
needs_sudo = False
|
||||
sudo = None
|
||||
try:
|
||||
probe = subprocess.run(
|
||||
[runtime, "inspect", "--format", "ok", container_name],
|
||||
capture_output=True, text=True, timeout=15,
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
print(
|
||||
f"Error: timed out waiting for {backend} to respond.\n"
|
||||
f"The {backend} daemon may be unresponsive or starting up.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
if probe.returncode != 0:
|
||||
sudo = shutil.which("sudo")
|
||||
if sudo:
|
||||
try:
|
||||
probe2 = subprocess.run(
|
||||
[sudo, "-n", runtime, "inspect", "--format", "ok", container_name],
|
||||
capture_output=True, text=True, timeout=15,
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
print(
|
||||
f"Error: timed out waiting for sudo {backend} to respond.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
if probe2.returncode == 0:
|
||||
needs_sudo = True
|
||||
else:
|
||||
print(
|
||||
f"Error: container '{container_name}' not found via {backend}.\n"
|
||||
f"\n"
|
||||
f"The NixOS service runs the container as root. Your user cannot\n"
|
||||
f"see it because {backend} uses per-user namespaces.\n"
|
||||
f"\n"
|
||||
f"Fix: grant passwordless sudo for {backend}. The -n (non-interactive)\n"
|
||||
f"flag is required because the CLI calls sudo non-interactively —\n"
|
||||
f"a password prompt would hang or break piped commands:\n"
|
||||
f"\n"
|
||||
f' security.sudo.extraRules = [{{\n'
|
||||
f' users = [ "{os.getenv("USER", "your-user")}" ];\n'
|
||||
f' commands = [{{ command = "{runtime}"; options = [ "NOPASSWD" ]; }}];\n'
|
||||
f' }}];\n'
|
||||
f"\n"
|
||||
f"Or run: sudo hermes {' '.join(cli_args)}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
else:
|
||||
print(
|
||||
f"Error: container '{container_name}' not found via {backend}.\n"
|
||||
f"The container may be running under root. Try: sudo hermes {' '.join(cli_args)}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
is_tty = sys.stdin.isatty()
|
||||
tty_flags = ["-it"] if is_tty else ["-i"]
|
||||
|
||||
env_flags = []
|
||||
for var in ("TERM", "COLORTERM", "LANG", "LC_ALL"):
|
||||
val = os.environ.get(var)
|
||||
if val:
|
||||
env_flags.extend(["-e", f"{var}={val}"])
|
||||
|
||||
cmd_prefix = [sudo, "-n", runtime] if needs_sudo else [runtime]
|
||||
exec_cmd = (
|
||||
cmd_prefix + ["exec"]
|
||||
+ tty_flags
|
||||
+ ["-u", exec_user]
|
||||
+ env_flags
|
||||
+ [container_name, hermes_bin]
|
||||
+ cli_args
|
||||
)
|
||||
|
||||
# execvp replaces this process entirely — it never returns on success.
|
||||
# On failure it raises OSError, which propagates naturally.
|
||||
os.execvp(exec_cmd[0], exec_cmd)
|
||||
```
|
||||
|
||||
#### Container routing call site in `main()` — remove try/except
|
||||
|
||||
Current:
|
||||
```python
|
||||
try:
|
||||
from hermes_cli.config import get_container_exec_info
|
||||
container_info = get_container_exec_info()
|
||||
if container_info:
|
||||
_exec_in_container(container_info, sys.argv[1:])
|
||||
sys.exit(1) # exec failed if we reach here
|
||||
except SystemExit:
|
||||
raise
|
||||
except Exception:
|
||||
pass # Container routing unavailable, proceed locally
|
||||
```
|
||||
|
||||
Revised:
|
||||
```python
|
||||
from hermes_cli.config import get_container_exec_info
|
||||
container_info = get_container_exec_info()
|
||||
if container_info:
|
||||
_exec_in_container(container_info, sys.argv[1:])
|
||||
# Unreachable: os.execvp never returns on success (process is replaced)
|
||||
# and raises OSError on failure (which propagates as a traceback).
|
||||
# This line exists only as a defensive assertion.
|
||||
sys.exit(1)
|
||||
```
|
||||
|
||||
No try/except. If `.container-mode` doesn't exist, `get_container_exec_info()` returns `None` and we skip routing. If it exists but is broken, the exception propagates with a natural traceback.
|
||||
|
||||
Note: `sys.exit(1)` after `_exec_in_container` is dead code in all paths — `os.execvp` either replaces the process or raises. It's kept as a belt-and-suspenders assertion with a comment marking it unreachable, not as actual error handling.
|
||||
|
||||
### Changes to `hermes_cli/config.py`
|
||||
|
||||
#### `get_container_exec_info` — remove inner try/except
|
||||
|
||||
Current code catches `(OSError, IOError)` and returns `None`. This silently hides permission errors, corrupt files, etc.
|
||||
|
||||
Change: Remove the try/except around file reading. Keep the early returns for `HERMES_DEV=1` and `_is_inside_container()`. The `FileNotFoundError` from `open()` when `.container-mode` doesn't exist should still return `None` (this is the "container mode not enabled" case). All other exceptions propagate.
|
||||
|
||||
```python
|
||||
def get_container_exec_info() -> Optional[dict]:
|
||||
if os.environ.get("HERMES_DEV") == "1":
|
||||
return None
|
||||
if _is_inside_container():
|
||||
return None
|
||||
|
||||
container_mode_file = get_hermes_home() / ".container-mode"
|
||||
|
||||
try:
|
||||
with open(container_mode_file, "r") as f:
|
||||
# ... parse key=value lines ...
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
# All other exceptions (PermissionError, malformed data, etc.) propagate
|
||||
|
||||
return { ... }
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Spec: NixOS Module Changes
|
||||
|
||||
### Symlink creation — simplify to two branches
|
||||
|
||||
Current: 4 branches (symlink exists, directory exists, other file, doesn't exist).
|
||||
|
||||
Revised: 2 branches.
|
||||
|
||||
```bash
|
||||
if [ -d "${symlinkPath}" ] && [ ! -L "${symlinkPath}" ]; then
|
||||
# Real directory — back it up, then create symlink
|
||||
_backup="${symlinkPath}.bak.$(date +%s)"
|
||||
echo "hermes-agent: backing up existing ${symlinkPath} to $_backup"
|
||||
mv "${symlinkPath}" "$_backup"
|
||||
fi
|
||||
# For everything else (symlink, doesn't exist, etc.) — just force-create
|
||||
ln -sfn "${target}" "${symlinkPath}"
|
||||
chown -h ${user}:${cfg.group} "${symlinkPath}"
|
||||
```
|
||||
|
||||
`ln -sfn` handles: existing symlink (replaces), doesn't exist (creates), and after the `mv` above (creates). The only case that needs special handling is a real directory, because `ln -sfn` cannot atomically replace a directory.
|
||||
|
||||
Note: there is a theoretical race between the `[ -d ... ]` check and the `mv` (something could create/remove the directory in between). In practice this is a NixOS activation script running as root during `nixos-rebuild switch` — no other process should be touching `~/.hermes` at that moment. Not worth adding locking for.
|
||||
|
||||
### Sudoers — document, don't auto-configure
|
||||
|
||||
Do NOT add `security.sudo.extraRules` to the module. Document the sudoers requirement in the module's description/comments and in the error message the CLI prints when sudo probe fails.
|
||||
|
||||
### Group membership gating — keep as-is
|
||||
|
||||
The fix in 726cf90f (`cfg.container.enable && cfg.container.hostUsers != []`) is correct. Leftover group membership when container mode is disabled is harmless. No cleanup needed.
|
||||
|
||||
---
|
||||
|
||||
## Spec: Test Rewrite
|
||||
|
||||
The existing test file (`tests/hermes_cli/test_container_aware_cli.py`) has 16 tests. With the simplified exec model, several are obsolete.
|
||||
|
||||
### Tests to keep (update as needed)
|
||||
|
||||
- `test_is_inside_container_dockerenv` — unchanged
|
||||
- `test_is_inside_container_containerenv` — unchanged
|
||||
- `test_is_inside_container_cgroup_docker` — unchanged
|
||||
- `test_is_inside_container_false_on_host` — unchanged
|
||||
- `test_get_container_exec_info_returns_metadata` — unchanged
|
||||
- `test_get_container_exec_info_none_inside_container` — unchanged
|
||||
- `test_get_container_exec_info_none_without_file` — unchanged
|
||||
- `test_get_container_exec_info_skipped_when_hermes_dev` — unchanged
|
||||
- `test_get_container_exec_info_not_skipped_when_hermes_dev_zero` — unchanged
|
||||
- `test_get_container_exec_info_defaults` — unchanged
|
||||
- `test_get_container_exec_info_docker_backend` — unchanged
|
||||
|
||||
### Tests to add
|
||||
|
||||
- `test_get_container_exec_info_crashes_on_permission_error` — verify that `PermissionError` propagates (no silent `None` return)
|
||||
- `test_exec_in_container_calls_execvp` — verify `os.execvp` is called with correct args (runtime, tty flags, user, env, container, binary, cli args)
|
||||
- `test_exec_in_container_sudo_probe_sets_prefix` — verify that when first probe fails and sudo probe succeeds, `os.execvp` is called with `sudo -n` prefix
|
||||
- `test_exec_in_container_no_runtime_hard_fails` — keep existing, verify `sys.exit(1)` when `shutil.which` returns None
|
||||
- `test_exec_in_container_non_tty_uses_i_only` — update to check `os.execvp` args instead of `subprocess.run` args
|
||||
- `test_exec_in_container_probe_timeout_prints_message` — verify that `subprocess.TimeoutExpired` from the probe produces a human-readable error and `sys.exit(1)`, not a raw traceback
|
||||
- `test_exec_in_container_container_not_running_no_sudo` — verify the path where runtime exists (`shutil.which` returns a path) but probe returns non-zero and no sudo is available. Should print the "container may be running under root" error. This is distinct from `no_runtime_hard_fails` which covers `shutil.which` returning None.
|
||||
|
||||
### Tests to delete
|
||||
|
||||
- `test_exec_in_container_tty_retries_on_container_failure` — retry loop removed
|
||||
- `test_exec_in_container_non_tty_retries_silently_exits_126` — retry loop removed
|
||||
- `test_exec_in_container_propagates_hermes_exit_code` — no subprocess.run to check exit codes; execvp replaces the process. Note: exit code propagation still works correctly — when `os.execvp` succeeds, the container's process *becomes* this process, so its exit code is the process exit code by OS semantics. No application code needed, no test needed. A comment in the function docstring documents this intent for future readers.
|
||||
|
||||
---
|
||||
|
||||
## Out of Scope
|
||||
|
||||
- Auto-configuring sudoers rules in the NixOS module
|
||||
- Any changes to `get_container_exec_info` parsing logic beyond the try/except narrowing
|
||||
- Changes to `.container-mode` file format
|
||||
- Changes to the `HERMES_DEV=1` bypass
|
||||
- Changes to container detection logic (`_is_inside_container`)
|
||||
@@ -49,6 +49,8 @@ class HermesToolCallParser(ToolCallParser):
|
||||
continue
|
||||
|
||||
tc_data = json.loads(raw_json)
|
||||
if "name" not in tc_data:
|
||||
continue
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
||||
|
||||
@@ -89,6 +89,8 @@ class MistralToolCallParser(ToolCallParser):
|
||||
parsed = [parsed]
|
||||
|
||||
for tc in parsed:
|
||||
if "name" not in tc:
|
||||
continue
|
||||
args = tc.get("arguments", {})
|
||||
if isinstance(args, dict):
|
||||
args = json.dumps(args, ensure_ascii=False)
|
||||
|
||||
+29
-3
@@ -63,6 +63,7 @@ class Platform(Enum):
|
||||
WEBHOOK = "webhook"
|
||||
FEISHU = "feishu"
|
||||
WECOM = "wecom"
|
||||
WECOM_CALLBACK = "wecom_callback"
|
||||
WEIXIN = "weixin"
|
||||
BLUEBUBBLES = "bluebubbles"
|
||||
|
||||
@@ -190,7 +191,7 @@ class StreamingConfig:
|
||||
"""Configuration for real-time token streaming to messaging platforms."""
|
||||
enabled: bool = False
|
||||
transport: str = "edit" # "edit" (progressive editMessageText) or "off"
|
||||
edit_interval: float = 0.3 # Seconds between message edits
|
||||
edit_interval: float = 1.0 # Seconds between message edits (Telegram rate-limits at ~1/s)
|
||||
buffer_threshold: int = 40 # Chars before forcing an edit
|
||||
cursor: str = " ▉" # Cursor shown during streaming
|
||||
|
||||
@@ -210,7 +211,7 @@ class StreamingConfig:
|
||||
return cls(
|
||||
enabled=data.get("enabled", False),
|
||||
transport=data.get("transport", "edit"),
|
||||
edit_interval=float(data.get("edit_interval", 0.3)),
|
||||
edit_interval=float(data.get("edit_interval", 1.0)),
|
||||
buffer_threshold=int(data.get("buffer_threshold", 40)),
|
||||
cursor=data.get("cursor", " ▉"),
|
||||
)
|
||||
@@ -291,9 +292,14 @@ class GatewayConfig:
|
||||
# Feishu uses extra dict for app credentials
|
||||
elif platform == Platform.FEISHU and config.extra.get("app_id"):
|
||||
connected.append(platform)
|
||||
# WeCom uses extra dict for bot credentials
|
||||
# WeCom bot mode uses extra dict for bot credentials
|
||||
elif platform == Platform.WECOM and config.extra.get("bot_id"):
|
||||
connected.append(platform)
|
||||
# WeCom callback mode uses corp_id or apps list
|
||||
elif platform == Platform.WECOM_CALLBACK and (
|
||||
config.extra.get("corp_id") or config.extra.get("apps")
|
||||
):
|
||||
connected.append(platform)
|
||||
# BlueBubbles uses extra dict for local server config
|
||||
elif platform == Platform.BLUEBUBBLES and config.extra.get("server_url") and config.extra.get("password"):
|
||||
connected.append(platform)
|
||||
@@ -987,6 +993,23 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
name=os.getenv("WECOM_HOME_CHANNEL_NAME", "Home"),
|
||||
)
|
||||
|
||||
# WeCom callback mode (self-built apps)
|
||||
wecom_callback_corp_id = os.getenv("WECOM_CALLBACK_CORP_ID")
|
||||
wecom_callback_corp_secret = os.getenv("WECOM_CALLBACK_CORP_SECRET")
|
||||
if wecom_callback_corp_id and wecom_callback_corp_secret:
|
||||
if Platform.WECOM_CALLBACK not in config.platforms:
|
||||
config.platforms[Platform.WECOM_CALLBACK] = PlatformConfig()
|
||||
config.platforms[Platform.WECOM_CALLBACK].enabled = True
|
||||
config.platforms[Platform.WECOM_CALLBACK].extra.update({
|
||||
"corp_id": wecom_callback_corp_id,
|
||||
"corp_secret": wecom_callback_corp_secret,
|
||||
"agent_id": os.getenv("WECOM_CALLBACK_AGENT_ID", ""),
|
||||
"token": os.getenv("WECOM_CALLBACK_TOKEN", ""),
|
||||
"encoding_aes_key": os.getenv("WECOM_CALLBACK_ENCODING_AES_KEY", ""),
|
||||
"host": os.getenv("WECOM_CALLBACK_HOST", "0.0.0.0"),
|
||||
"port": int(os.getenv("WECOM_CALLBACK_PORT", "8645")),
|
||||
})
|
||||
|
||||
# Weixin (personal WeChat via iLink Bot API)
|
||||
weixin_token = os.getenv("WEIXIN_TOKEN")
|
||||
weixin_account_id = os.getenv("WEIXIN_ACCOUNT_ID")
|
||||
@@ -1017,6 +1040,9 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
weixin_group_allowed_users = os.getenv("WEIXIN_GROUP_ALLOWED_USERS", "").strip()
|
||||
if weixin_group_allowed_users:
|
||||
extra["group_allow_from"] = weixin_group_allowed_users
|
||||
weixin_split_multiline = os.getenv("WEIXIN_SPLIT_MULTILINE_MESSAGES", "").strip()
|
||||
if weixin_split_multiline:
|
||||
extra["split_multiline_messages"] = weixin_split_multiline
|
||||
weixin_home = os.getenv("WEIXIN_HOME_CHANNEL", "").strip()
|
||||
if weixin_home:
|
||||
config.platforms[Platform.WEIXIN].home_channel = HomeChannel(
|
||||
|
||||
@@ -0,0 +1,206 @@
|
||||
"""Per-platform display/verbosity configuration resolver.
|
||||
|
||||
Provides ``resolve_display_setting()`` — the single entry-point for reading
|
||||
display settings with platform-specific overrides and sensible defaults.
|
||||
|
||||
Resolution order (first non-None wins):
|
||||
1. ``display.platforms.<platform>.<key>`` — explicit per-platform user override
|
||||
2. ``display.<key>`` — global user setting
|
||||
3. ``_PLATFORM_DEFAULTS[<platform>][<key>]`` — built-in sensible default
|
||||
4. ``_GLOBAL_DEFAULTS[<key>]`` — built-in global default
|
||||
|
||||
Backward compatibility: ``display.tool_progress_overrides`` is still read as a
|
||||
fallback for ``tool_progress`` when no ``display.platforms`` entry exists. A
|
||||
config migration (version bump) automatically moves the old format into the new
|
||||
``display.platforms`` structure.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Overrideable display settings and their global defaults
|
||||
# ---------------------------------------------------------------------------
|
||||
# These are the settings that can be configured per-platform.
|
||||
# Other display settings (compact, personality, skin, etc.) are CLI-only
|
||||
# and don't participate in per-platform resolution.
|
||||
|
||||
_GLOBAL_DEFAULTS: dict[str, Any] = {
|
||||
"tool_progress": "all",
|
||||
"show_reasoning": False,
|
||||
"tool_preview_length": 0,
|
||||
"streaming": None, # None = follow top-level streaming config
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sensible per-platform defaults — tiered by platform capability
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tier 1 (high): Supports message editing, typically personal/team use
|
||||
# Tier 2 (medium): Supports editing but often workspace/customer-facing
|
||||
# Tier 3 (low): No edit support — each progress msg is permanent
|
||||
# Tier 4 (minimal): Batch/non-interactive delivery
|
||||
|
||||
_TIER_HIGH = {
|
||||
"tool_progress": "all",
|
||||
"show_reasoning": False,
|
||||
"tool_preview_length": 40,
|
||||
"streaming": None, # follow global
|
||||
}
|
||||
|
||||
_TIER_MEDIUM = {
|
||||
"tool_progress": "new",
|
||||
"show_reasoning": False,
|
||||
"tool_preview_length": 40,
|
||||
"streaming": None,
|
||||
}
|
||||
|
||||
_TIER_LOW = {
|
||||
"tool_progress": "off",
|
||||
"show_reasoning": False,
|
||||
"tool_preview_length": 40,
|
||||
"streaming": False,
|
||||
}
|
||||
|
||||
_TIER_MINIMAL = {
|
||||
"tool_progress": "off",
|
||||
"show_reasoning": False,
|
||||
"tool_preview_length": 0,
|
||||
"streaming": False,
|
||||
}
|
||||
|
||||
_PLATFORM_DEFAULTS: dict[str, dict[str, Any]] = {
|
||||
# Tier 1 — full edit support, personal/team use
|
||||
"telegram": _TIER_HIGH,
|
||||
"discord": _TIER_HIGH,
|
||||
|
||||
# Tier 2 — edit support, often customer/workspace channels
|
||||
"slack": _TIER_MEDIUM,
|
||||
"mattermost": _TIER_MEDIUM,
|
||||
"matrix": _TIER_MEDIUM,
|
||||
"feishu": _TIER_MEDIUM,
|
||||
|
||||
# Tier 3 — no edit support, progress messages are permanent
|
||||
"signal": _TIER_LOW,
|
||||
"whatsapp": _TIER_LOW,
|
||||
"bluebubbles": _TIER_LOW,
|
||||
"weixin": _TIER_LOW,
|
||||
"wecom": _TIER_LOW,
|
||||
"wecom_callback": _TIER_LOW,
|
||||
"dingtalk": _TIER_LOW,
|
||||
|
||||
# Tier 4 — batch or non-interactive delivery
|
||||
"email": _TIER_MINIMAL,
|
||||
"sms": _TIER_MINIMAL,
|
||||
"webhook": _TIER_MINIMAL,
|
||||
"homeassistant": _TIER_MINIMAL,
|
||||
"api_server": {**_TIER_HIGH, "tool_preview_length": 0},
|
||||
}
|
||||
|
||||
# Canonical set of per-platform overrideable keys (for validation).
|
||||
OVERRIDEABLE_KEYS = frozenset(_GLOBAL_DEFAULTS.keys())
|
||||
|
||||
|
||||
def resolve_display_setting(
|
||||
user_config: dict,
|
||||
platform_key: str,
|
||||
setting: str,
|
||||
fallback: Any = None,
|
||||
) -> Any:
|
||||
"""Resolve a display setting with per-platform override support.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
user_config : dict
|
||||
The full parsed config.yaml dict.
|
||||
platform_key : str
|
||||
Platform config key (e.g. ``"telegram"``, ``"slack"``). Use
|
||||
``_platform_config_key(source.platform)`` from gateway/run.py.
|
||||
setting : str
|
||||
Display setting name (e.g. ``"tool_progress"``, ``"show_reasoning"``).
|
||||
fallback : Any
|
||||
Fallback value when the setting isn't found anywhere.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The resolved value, or *fallback* if nothing is configured.
|
||||
"""
|
||||
display_cfg = user_config.get("display") or {}
|
||||
|
||||
# 1. Explicit per-platform override (display.platforms.<platform>.<key>)
|
||||
platforms = display_cfg.get("platforms") or {}
|
||||
plat_overrides = platforms.get(platform_key)
|
||||
if isinstance(plat_overrides, dict):
|
||||
val = plat_overrides.get(setting)
|
||||
if val is not None:
|
||||
return _normalise(setting, val)
|
||||
|
||||
# 1b. Backward compat: display.tool_progress_overrides.<platform>
|
||||
if setting == "tool_progress":
|
||||
legacy = display_cfg.get("tool_progress_overrides")
|
||||
if isinstance(legacy, dict):
|
||||
val = legacy.get(platform_key)
|
||||
if val is not None:
|
||||
return _normalise(setting, val)
|
||||
|
||||
# 2. Global user setting (display.<key>)
|
||||
val = display_cfg.get(setting)
|
||||
if val is not None:
|
||||
return _normalise(setting, val)
|
||||
|
||||
# 3. Built-in platform default
|
||||
plat_defaults = _PLATFORM_DEFAULTS.get(platform_key)
|
||||
if plat_defaults:
|
||||
val = plat_defaults.get(setting)
|
||||
if val is not None:
|
||||
return val
|
||||
|
||||
# 4. Built-in global default
|
||||
val = _GLOBAL_DEFAULTS.get(setting)
|
||||
if val is not None:
|
||||
return val
|
||||
|
||||
return fallback
|
||||
|
||||
|
||||
def get_platform_defaults(platform_key: str) -> dict[str, Any]:
|
||||
"""Return the built-in default display settings for a platform.
|
||||
|
||||
Falls back to ``_GLOBAL_DEFAULTS`` for unknown platforms.
|
||||
"""
|
||||
return dict(_PLATFORM_DEFAULTS.get(platform_key, _GLOBAL_DEFAULTS))
|
||||
|
||||
|
||||
def get_effective_display(user_config: dict, platform_key: str) -> dict[str, Any]:
|
||||
"""Return the fully-resolved display settings for a platform.
|
||||
|
||||
Useful for status commands that want to show all effective settings.
|
||||
"""
|
||||
return {
|
||||
key: resolve_display_setting(user_config, platform_key, key)
|
||||
for key in OVERRIDEABLE_KEYS
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _normalise(setting: str, value: Any) -> Any:
|
||||
"""Normalise YAML quirks (bare ``off`` → False in YAML 1.1)."""
|
||||
if setting == "tool_progress":
|
||||
if value is False:
|
||||
return "off"
|
||||
if value is True:
|
||||
return "all"
|
||||
return str(value).lower()
|
||||
if setting in ("show_reasoning", "streaming"):
|
||||
if isinstance(value, str):
|
||||
return value.lower() in ("true", "1", "yes", "on")
|
||||
return bool(value)
|
||||
if setting == "tool_preview_length":
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return 0
|
||||
return value
|
||||
@@ -53,6 +53,7 @@ DEFAULT_HOST = "127.0.0.1"
|
||||
DEFAULT_PORT = 8642
|
||||
MAX_STORED_RESPONSES = 100
|
||||
MAX_REQUEST_BYTES = 1_000_000 # 1 MB default limit for POST bodies
|
||||
CHAT_COMPLETIONS_SSE_KEEPALIVE_SECONDS = 30.0
|
||||
|
||||
|
||||
def check_api_server_requirements() -> bool:
|
||||
@@ -762,7 +763,11 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
import queue as _q
|
||||
|
||||
sse_headers = {"Content-Type": "text/event-stream", "Cache-Control": "no-cache"}
|
||||
sse_headers = {
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no",
|
||||
}
|
||||
# CORS middleware can't inject headers into StreamResponse after
|
||||
# prepare() flushes them, so resolve CORS headers up front.
|
||||
origin = request.headers.get("Origin", "")
|
||||
@@ -775,6 +780,8 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
await response.prepare(request)
|
||||
|
||||
try:
|
||||
last_activity = time.monotonic()
|
||||
|
||||
# Role chunk
|
||||
role_chunk = {
|
||||
"id": completion_id, "object": "chat.completion.chunk",
|
||||
@@ -782,6 +789,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
"choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}],
|
||||
}
|
||||
await response.write(f"data: {json.dumps(role_chunk)}\n\n".encode())
|
||||
last_activity = time.monotonic()
|
||||
|
||||
# Helper — route a queue item to the correct SSE event.
|
||||
async def _emit(item):
|
||||
@@ -805,6 +813,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
"choices": [{"index": 0, "delta": {"content": item}, "finish_reason": None}],
|
||||
}
|
||||
await response.write(f"data: {json.dumps(content_chunk)}\n\n".encode())
|
||||
return time.monotonic()
|
||||
|
||||
# Stream content chunks as they arrive from the agent
|
||||
loop = asyncio.get_event_loop()
|
||||
@@ -819,16 +828,19 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
delta = stream_q.get_nowait()
|
||||
if delta is None:
|
||||
break
|
||||
await _emit(delta)
|
||||
last_activity = await _emit(delta)
|
||||
except _q.Empty:
|
||||
break
|
||||
break
|
||||
if time.monotonic() - last_activity >= CHAT_COMPLETIONS_SSE_KEEPALIVE_SECONDS:
|
||||
await response.write(b": keepalive\n\n")
|
||||
last_activity = time.monotonic()
|
||||
continue
|
||||
|
||||
if delta is None: # End of stream sentinel
|
||||
break
|
||||
|
||||
await _emit(delta)
|
||||
last_activity = await _emit(delta)
|
||||
|
||||
# Get usage from completed agent
|
||||
usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
||||
|
||||
@@ -823,7 +823,36 @@ class BasePlatformAdapter(ABC):
|
||||
result = handler(self)
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
|
||||
|
||||
def _acquire_platform_lock(self, scope: str, identity: str, resource_desc: str) -> bool:
|
||||
"""Acquire a scoped lock for this adapter. Returns True on success."""
|
||||
from gateway.status import acquire_scoped_lock
|
||||
self._platform_lock_scope = scope
|
||||
self._platform_lock_identity = identity
|
||||
acquired, existing = acquire_scoped_lock(
|
||||
scope, identity, metadata={'platform': self.platform.value}
|
||||
)
|
||||
if acquired:
|
||||
return True
|
||||
owner_pid = existing.get('pid') if isinstance(existing, dict) else None
|
||||
message = (
|
||||
f'{resource_desc} already in use'
|
||||
+ (f' (PID {owner_pid})' if owner_pid else '')
|
||||
+ '. Stop the other gateway first.'
|
||||
)
|
||||
logger.error('[%s] %s', self.name, message)
|
||||
self._set_fatal_error(f'{scope}_lock', message, retryable=False)
|
||||
return False
|
||||
|
||||
def _release_platform_lock(self) -> None:
|
||||
"""Release the scoped lock acquired by _acquire_platform_lock."""
|
||||
identity = getattr(self, '_platform_lock_identity', None)
|
||||
if not identity:
|
||||
return
|
||||
from gateway.status import release_scoped_lock
|
||||
release_scoped_lock(self._platform_lock_scope, identity)
|
||||
self._platform_lock_identity = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Human-readable name for this adapter."""
|
||||
|
||||
@@ -30,6 +30,7 @@ from gateway.platforms.base import (
|
||||
cache_audio_from_bytes,
|
||||
cache_document_from_bytes,
|
||||
)
|
||||
from gateway.platforms.helpers import strip_markdown
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -89,18 +90,7 @@ def _normalize_server_url(raw: str) -> str:
|
||||
return value.rstrip("/")
|
||||
|
||||
|
||||
def _strip_markdown(text: str) -> str:
|
||||
"""Strip common markdown formatting for iMessage plain-text delivery."""
|
||||
text = re.sub(r"\*\*(.+?)\*\*", r"\1", text, flags=re.DOTALL)
|
||||
text = re.sub(r"\*(.+?)\*", r"\1", text, flags=re.DOTALL)
|
||||
text = re.sub(r"__(.+?)__", r"\1", text, flags=re.DOTALL)
|
||||
text = re.sub(r"_(.+?)_", r"\1", text, flags=re.DOTALL)
|
||||
text = re.sub(r"```[a-zA-Z0-9_+-]*\n?", "", text)
|
||||
text = re.sub(r"`(.+?)`", r"\1", text)
|
||||
text = re.sub(r"^#{1,6}\s+", "", text, flags=re.MULTILINE)
|
||||
text = re.sub(r"\[([^\]]+)\]\(([^\)]+)\)", r"\1", text)
|
||||
text = re.sub(r"\n{3,}", "\n\n", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -393,7 +383,7 @@ class BlueBubblesAdapter(BasePlatformAdapter):
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
text = _strip_markdown(content or "")
|
||||
text = strip_markdown(content or "")
|
||||
if not text:
|
||||
return SendResult(success=False, error="BlueBubbles send requires text")
|
||||
chunks = self.truncate_message(text, max_length=self.MAX_MESSAGE_LENGTH)
|
||||
@@ -679,7 +669,7 @@ class BlueBubblesAdapter(BasePlatformAdapter):
|
||||
return info
|
||||
|
||||
def format_message(self, content: str) -> str:
|
||||
return _strip_markdown(content)
|
||||
return strip_markdown(content)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Inbound attachment downloading (from #4588)
|
||||
|
||||
@@ -42,6 +42,7 @@ except ImportError:
|
||||
httpx = None # type: ignore[assignment]
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.helpers import MessageDeduplicator
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
@@ -52,8 +53,6 @@ from gateway.platforms.base import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_MESSAGE_LENGTH = 20000
|
||||
DEDUP_WINDOW_SECONDS = 300
|
||||
DEDUP_MAX_SIZE = 1000
|
||||
RECONNECT_BACKOFF = [2, 5, 10, 30, 60]
|
||||
_SESSION_WEBHOOKS_MAX = 500
|
||||
_DINGTALK_WEBHOOK_RE = re.compile(r'^https://api\.dingtalk\.com/')
|
||||
@@ -89,8 +88,8 @@ class DingTalkAdapter(BasePlatformAdapter):
|
||||
self._stream_task: Optional[asyncio.Task] = None
|
||||
self._http_client: Optional["httpx.AsyncClient"] = None
|
||||
|
||||
# Message deduplication: msg_id -> timestamp
|
||||
self._seen_messages: Dict[str, float] = {}
|
||||
# Message deduplication
|
||||
self._dedup = MessageDeduplicator(max_size=1000)
|
||||
# Map chat_id -> session_webhook for reply routing
|
||||
self._session_webhooks: Dict[str, str] = {}
|
||||
|
||||
@@ -170,7 +169,7 @@ class DingTalkAdapter(BasePlatformAdapter):
|
||||
|
||||
self._stream_client = None
|
||||
self._session_webhooks.clear()
|
||||
self._seen_messages.clear()
|
||||
self._dedup.clear()
|
||||
logger.info("[%s] Disconnected", self.name)
|
||||
|
||||
# -- Inbound message processing -----------------------------------------
|
||||
@@ -178,7 +177,7 @@ class DingTalkAdapter(BasePlatformAdapter):
|
||||
async def _on_message(self, message: "ChatbotMessage") -> None:
|
||||
"""Process an incoming DingTalk chatbot message."""
|
||||
msg_id = getattr(message, "message_id", None) or uuid.uuid4().hex
|
||||
if self._is_duplicate(msg_id):
|
||||
if self._dedup.is_duplicate(msg_id):
|
||||
logger.debug("[%s] Duplicate message %s, skipping", self.name, msg_id)
|
||||
return
|
||||
|
||||
@@ -256,20 +255,6 @@ class DingTalkAdapter(BasePlatformAdapter):
|
||||
content = " ".join(parts).strip()
|
||||
return content
|
||||
|
||||
# -- Deduplication ------------------------------------------------------
|
||||
|
||||
def _is_duplicate(self, msg_id: str) -> bool:
|
||||
"""Check and record a message ID. Returns True if already seen."""
|
||||
now = time.time()
|
||||
if len(self._seen_messages) > DEDUP_MAX_SIZE:
|
||||
cutoff = now - DEDUP_WINDOW_SECONDS
|
||||
self._seen_messages = {k: v for k, v in self._seen_messages.items() if v > cutoff}
|
||||
|
||||
if msg_id in self._seen_messages:
|
||||
return True
|
||||
self._seen_messages[msg_id] = now
|
||||
return False
|
||||
|
||||
# -- Outbound messaging -------------------------------------------------
|
||||
|
||||
async def send(
|
||||
|
||||
@@ -45,6 +45,7 @@ sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
import re
|
||||
|
||||
from gateway.platforms.helpers import MessageDeduplicator, ThreadParticipationTracker
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
@@ -450,18 +451,14 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
# Track threads where the bot has participated so follow-up messages
|
||||
# in those threads don't require @mention. Persisted to disk so the
|
||||
# set survives gateway restarts.
|
||||
self._bot_participated_threads: set = self._load_participated_threads()
|
||||
self._threads = ThreadParticipationTracker("discord")
|
||||
# Persistent typing indicator loops per channel (DMs don't reliably
|
||||
# show the standard typing gateway event for bots)
|
||||
self._typing_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._bot_task: Optional[asyncio.Task] = None
|
||||
# Cap to prevent unbounded growth (Discord threads get archived).
|
||||
self._MAX_TRACKED_THREADS = 500
|
||||
# Dedup cache: message_id → timestamp. Prevents duplicate bot
|
||||
# responses when Discord RESUME replays events after reconnects.
|
||||
self._seen_messages: Dict[str, float] = {}
|
||||
self._SEEN_TTL = 300 # 5 minutes
|
||||
self._SEEN_MAX = 2000 # prune threshold
|
||||
# Dedup cache: prevents duplicate bot responses when Discord
|
||||
# RESUME replays events after reconnects.
|
||||
self._dedup = MessageDeduplicator()
|
||||
# Reply threading mode: "off" (no replies), "first" (reply on first
|
||||
# chunk only, default), "all" (reply-reference on every chunk).
|
||||
self._reply_to_mode: str = getattr(config, 'reply_to_mode', 'first') or 'first'
|
||||
@@ -502,18 +499,9 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
return False
|
||||
|
||||
try:
|
||||
# Acquire scoped lock to prevent duplicate bot token usage
|
||||
from gateway.status import acquire_scoped_lock
|
||||
self._token_lock_identity = self.config.token
|
||||
acquired, existing = acquire_scoped_lock('discord-bot-token', self._token_lock_identity, metadata={'platform': 'discord'})
|
||||
if not acquired:
|
||||
owner_pid = existing.get('pid') if isinstance(existing, dict) else None
|
||||
message = f'Discord bot token already in use' + (f' (PID {owner_pid})' if owner_pid else '') + '. Stop the other gateway first.'
|
||||
logger.error('[%s] %s', self.name, message)
|
||||
self._set_fatal_error('discord_token_lock', message, retryable=False)
|
||||
if not self._acquire_platform_lock('discord-bot-token', self.config.token, 'Discord bot token'):
|
||||
return False
|
||||
|
||||
|
||||
# Parse allowed user entries (may contain usernames or IDs)
|
||||
allowed_env = os.getenv("DISCORD_ALLOWED_USERS", "")
|
||||
if allowed_env:
|
||||
@@ -569,17 +557,8 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
@self._client.event
|
||||
async def on_message(message: DiscordMessage):
|
||||
# Dedup: Discord RESUME replays events after reconnects (#4777)
|
||||
msg_id = str(message.id)
|
||||
now = time.time()
|
||||
if msg_id in adapter_self._seen_messages:
|
||||
if adapter_self._dedup.is_duplicate(str(message.id)):
|
||||
return
|
||||
adapter_self._seen_messages[msg_id] = now
|
||||
if len(adapter_self._seen_messages) > adapter_self._SEEN_MAX:
|
||||
cutoff = now - adapter_self._SEEN_TTL
|
||||
adapter_self._seen_messages = {
|
||||
k: v for k, v in adapter_self._seen_messages.items()
|
||||
if v > cutoff
|
||||
}
|
||||
|
||||
# Always ignore our own messages
|
||||
if message.author == self._client.user:
|
||||
@@ -685,23 +664,11 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("[%s] Timeout waiting for connection to Discord", self.name, exc_info=True)
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
if getattr(self, '_token_lock_identity', None):
|
||||
release_scoped_lock('discord-bot-token', self._token_lock_identity)
|
||||
self._token_lock_identity = None
|
||||
except Exception:
|
||||
pass
|
||||
self._release_platform_lock()
|
||||
return False
|
||||
except Exception as e: # pragma: no cover - defensive logging
|
||||
logger.error("[%s] Failed to connect to Discord: %s", self.name, e, exc_info=True)
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
if getattr(self, '_token_lock_identity', None):
|
||||
release_scoped_lock('discord-bot-token', self._token_lock_identity)
|
||||
self._token_lock_identity = None
|
||||
except Exception:
|
||||
pass
|
||||
self._release_platform_lock()
|
||||
return False
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
@@ -723,14 +690,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
self._client = None
|
||||
self._ready_event.clear()
|
||||
|
||||
# Release the token lock
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
if getattr(self, '_token_lock_identity', None):
|
||||
release_scoped_lock('discord-bot-token', self._token_lock_identity)
|
||||
self._token_lock_identity = None
|
||||
except Exception:
|
||||
pass
|
||||
self._release_platform_lock()
|
||||
|
||||
logger.info("[%s] Disconnected", self.name)
|
||||
|
||||
@@ -1870,7 +1830,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
|
||||
# Track thread participation so follow-ups don't require @mention
|
||||
if thread_id:
|
||||
self._track_thread(thread_id)
|
||||
self._threads.mark(thread_id)
|
||||
|
||||
# If a message was provided, kick off a new Hermes session in the thread
|
||||
starter = (message or "").strip()
|
||||
@@ -2241,49 +2201,6 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
return f"{parent_name} / {thread_name}"
|
||||
return thread_name
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Thread participation persistence
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _thread_state_path() -> Path:
|
||||
"""Path to the persisted thread participation set."""
|
||||
from hermes_cli.config import get_hermes_home
|
||||
return get_hermes_home() / "discord_threads.json"
|
||||
|
||||
@classmethod
|
||||
def _load_participated_threads(cls) -> set:
|
||||
"""Load persisted thread IDs from disk."""
|
||||
path = cls._thread_state_path()
|
||||
try:
|
||||
if path.exists():
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
if isinstance(data, list):
|
||||
return set(data)
|
||||
except Exception as e:
|
||||
logger.debug("Could not load discord thread state: %s", e)
|
||||
return set()
|
||||
|
||||
def _save_participated_threads(self) -> None:
|
||||
"""Persist the current thread set to disk (best-effort)."""
|
||||
path = self._thread_state_path()
|
||||
try:
|
||||
# Trim to most recent entries if over cap
|
||||
thread_list = list(self._bot_participated_threads)
|
||||
if len(thread_list) > self._MAX_TRACKED_THREADS:
|
||||
thread_list = thread_list[-self._MAX_TRACKED_THREADS:]
|
||||
self._bot_participated_threads = set(thread_list)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(json.dumps(thread_list), encoding="utf-8")
|
||||
except Exception as e:
|
||||
logger.debug("Could not save discord thread state: %s", e)
|
||||
|
||||
def _track_thread(self, thread_id: str) -> None:
|
||||
"""Add a thread to the participation set and persist."""
|
||||
if thread_id not in self._bot_participated_threads:
|
||||
self._bot_participated_threads.add(thread_id)
|
||||
self._save_participated_threads()
|
||||
|
||||
async def _handle_message(self, message: DiscordMessage) -> None:
|
||||
"""Handle incoming Discord messages."""
|
||||
# In server channels (not DMs), require the bot to be @mentioned
|
||||
@@ -2335,7 +2252,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
|
||||
# Skip the mention check if the message is in a thread where
|
||||
# the bot has previously participated (auto-created or replied in).
|
||||
in_bot_thread = is_thread and thread_id in self._bot_participated_threads
|
||||
in_bot_thread = is_thread and thread_id in self._threads
|
||||
|
||||
if require_mention and not is_free_channel and not in_bot_thread:
|
||||
if self._client.user not in message.mentions:
|
||||
@@ -2361,7 +2278,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
is_thread = True
|
||||
thread_id = str(thread.id)
|
||||
auto_threaded_channel = thread
|
||||
self._track_thread(thread_id)
|
||||
self._threads.mark(thread_id)
|
||||
|
||||
# Determine message type
|
||||
msg_type = MessageType.TEXT
|
||||
@@ -2545,7 +2462,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
# Track thread participation so the bot won't require @mention for
|
||||
# follow-up messages in threads it has already engaged in.
|
||||
if thread_id:
|
||||
self._track_thread(thread_id)
|
||||
self._threads.mark(thread_id)
|
||||
|
||||
# Only batch plain text messages — commands, media, etc. dispatch
|
||||
# immediately since they won't be split by the Discord client.
|
||||
|
||||
@@ -360,19 +360,21 @@ def _render_code_block_element(element: Dict[str, Any]) -> str:
|
||||
|
||||
|
||||
def _strip_markdown_to_plain_text(text: str) -> str:
|
||||
"""Strip markdown formatting to plain text for Feishu text fallbacks.
|
||||
|
||||
Delegates common markdown stripping to the shared helper and adds
|
||||
Feishu-specific patterns (blockquotes, strikethrough, underline tags,
|
||||
horizontal rules, \\r\\n normalisation).
|
||||
"""
|
||||
from gateway.platforms.helpers import strip_markdown
|
||||
plain = text.replace("\r\n", "\n")
|
||||
plain = _MARKDOWN_LINK_RE.sub(lambda m: f"{m.group(1)} ({m.group(2).strip()})", plain)
|
||||
plain = re.sub(r"^#{1,6}\s+", "", plain, flags=re.MULTILINE)
|
||||
plain = re.sub(r"^>\s?", "", plain, flags=re.MULTILINE)
|
||||
plain = re.sub(r"^\s*---+\s*$", "---", plain, flags=re.MULTILINE)
|
||||
plain = re.sub(r"```(?:[^\n]*\n)?([\s\S]*?)```", lambda m: m.group(1).strip("\n"), plain)
|
||||
plain = re.sub(r"`([^`\n]+)`", r"\1", plain)
|
||||
plain = re.sub(r"\*\*([^*\n]+)\*\*", r"\1", plain)
|
||||
plain = re.sub(r"\*([^*\n]+)\*", r"\1", plain)
|
||||
plain = re.sub(r"~~([^~\n]+)~~", r"\1", plain)
|
||||
plain = re.sub(r"<u>([\s\S]*?)</u>", r"\1", plain)
|
||||
plain = re.sub(r"\n{3,}", "\n\n", plain)
|
||||
return plain.strip()
|
||||
plain = strip_markdown(plain)
|
||||
return plain
|
||||
|
||||
|
||||
def _coerce_int(value: Any, default: Optional[int] = None, min_value: int = 0) -> Optional[int]:
|
||||
|
||||
@@ -0,0 +1,261 @@
|
||||
"""Shared helper classes for gateway platform adapters.
|
||||
|
||||
Extracts common patterns that were duplicated across 5-7 adapters:
|
||||
message deduplication, text batch aggregation, markdown stripping,
|
||||
and thread participation tracking.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Dict, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ─── Message Deduplication ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class MessageDeduplicator:
|
||||
"""TTL-based message deduplication cache.
|
||||
|
||||
Replaces the identical ``_seen_messages`` / ``_is_duplicate()`` pattern
|
||||
previously duplicated in discord, slack, dingtalk, wecom, weixin,
|
||||
mattermost, and feishu adapters.
|
||||
|
||||
Usage::
|
||||
|
||||
self._dedup = MessageDeduplicator()
|
||||
|
||||
# In message handler:
|
||||
if self._dedup.is_duplicate(msg_id):
|
||||
return
|
||||
"""
|
||||
|
||||
def __init__(self, max_size: int = 2000, ttl_seconds: float = 300):
|
||||
self._seen: Dict[str, float] = {}
|
||||
self._max_size = max_size
|
||||
self._ttl = ttl_seconds
|
||||
|
||||
def is_duplicate(self, msg_id: str) -> bool:
|
||||
"""Return True if *msg_id* was already seen within the TTL window."""
|
||||
if not msg_id:
|
||||
return False
|
||||
now = time.time()
|
||||
if msg_id in self._seen:
|
||||
return True
|
||||
self._seen[msg_id] = now
|
||||
if len(self._seen) > self._max_size:
|
||||
cutoff = now - self._ttl
|
||||
self._seen = {k: v for k, v in self._seen.items() if v > cutoff}
|
||||
return False
|
||||
|
||||
def clear(self):
|
||||
"""Clear all tracked messages."""
|
||||
self._seen.clear()
|
||||
|
||||
|
||||
# ─── Text Batch Aggregation ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TextBatchAggregator:
|
||||
"""Aggregates rapid-fire text events into single messages.
|
||||
|
||||
Replaces the ``_enqueue_text_event`` / ``_flush_text_batch`` pattern
|
||||
previously duplicated in telegram, discord, matrix, wecom, and feishu.
|
||||
|
||||
Usage::
|
||||
|
||||
self._text_batcher = TextBatchAggregator(
|
||||
handler=self._message_handler,
|
||||
batch_delay=0.6,
|
||||
split_threshold=1900,
|
||||
)
|
||||
|
||||
# In message dispatch:
|
||||
if msg_type == MessageType.TEXT and self._text_batcher.is_enabled():
|
||||
self._text_batcher.enqueue(event, session_key)
|
||||
return
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handler,
|
||||
*,
|
||||
batch_delay: float = 0.6,
|
||||
split_delay: float = 2.0,
|
||||
split_threshold: int = 4000,
|
||||
):
|
||||
self._handler = handler
|
||||
self._batch_delay = batch_delay
|
||||
self._split_delay = split_delay
|
||||
self._split_threshold = split_threshold
|
||||
self._pending: Dict[str, "MessageEvent"] = {}
|
||||
self._pending_tasks: Dict[str, asyncio.Task] = {}
|
||||
|
||||
def is_enabled(self) -> bool:
|
||||
"""Return True if batching is active (delay > 0)."""
|
||||
return self._batch_delay > 0
|
||||
|
||||
def enqueue(self, event: "MessageEvent", key: str) -> None:
|
||||
"""Add *event* to the pending batch for *key*."""
|
||||
chunk_len = len(event.text or "")
|
||||
existing = self._pending.get(key)
|
||||
if not existing:
|
||||
event._last_chunk_len = chunk_len # type: ignore[attr-defined]
|
||||
self._pending[key] = event
|
||||
else:
|
||||
existing.text = f"{existing.text}\n{event.text}"
|
||||
existing._last_chunk_len = chunk_len # type: ignore[attr-defined]
|
||||
|
||||
# Cancel prior flush timer, start a new one
|
||||
prior = self._pending_tasks.get(key)
|
||||
if prior and not prior.done():
|
||||
prior.cancel()
|
||||
self._pending_tasks[key] = asyncio.create_task(self._flush(key))
|
||||
|
||||
async def _flush(self, key: str) -> None:
|
||||
"""Wait then dispatch the batched event for *key*."""
|
||||
current_task = self._pending_tasks.get(key)
|
||||
pending = self._pending.get(key)
|
||||
last_len = getattr(pending, "_last_chunk_len", 0) if pending else 0
|
||||
|
||||
# Use longer delay when the last chunk looks like a split message
|
||||
delay = self._split_delay if last_len >= self._split_threshold else self._batch_delay
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
event = self._pending.pop(key, None)
|
||||
if event:
|
||||
try:
|
||||
await self._handler(event)
|
||||
except Exception:
|
||||
logger.exception("[TextBatchAggregator] Error dispatching batched event for %s", key)
|
||||
|
||||
if self._pending_tasks.get(key) is current_task:
|
||||
self._pending_tasks.pop(key, None)
|
||||
|
||||
def cancel_all(self) -> None:
|
||||
"""Cancel all pending flush tasks."""
|
||||
for task in self._pending_tasks.values():
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
self._pending_tasks.clear()
|
||||
self._pending.clear()
|
||||
|
||||
|
||||
# ─── Markdown Stripping ──────────────────────────────────────────────────────
|
||||
|
||||
# Pre-compiled regexes for performance
|
||||
_RE_BOLD = re.compile(r"\*\*(.+?)\*\*", re.DOTALL)
|
||||
_RE_ITALIC_STAR = re.compile(r"\*(.+?)\*", re.DOTALL)
|
||||
_RE_BOLD_UNDER = re.compile(r"__(.+?)__", re.DOTALL)
|
||||
_RE_ITALIC_UNDER = re.compile(r"_(.+?)_", re.DOTALL)
|
||||
_RE_CODE_BLOCK = re.compile(r"```[a-zA-Z0-9_+-]*\n?")
|
||||
_RE_INLINE_CODE = re.compile(r"`(.+?)`")
|
||||
_RE_HEADING = re.compile(r"^#{1,6}\s+", re.MULTILINE)
|
||||
_RE_LINK = re.compile(r"\[([^\]]+)\]\([^\)]+\)")
|
||||
_RE_MULTI_NEWLINE = re.compile(r"\n{3,}")
|
||||
|
||||
|
||||
def strip_markdown(text: str) -> str:
|
||||
"""Strip markdown formatting for plain-text platforms (SMS, iMessage, etc.).
|
||||
|
||||
Replaces the identical ``_strip_markdown()`` functions previously
|
||||
duplicated in sms.py, bluebubbles.py, and feishu.py.
|
||||
"""
|
||||
text = _RE_BOLD.sub(r"\1", text)
|
||||
text = _RE_ITALIC_STAR.sub(r"\1", text)
|
||||
text = _RE_BOLD_UNDER.sub(r"\1", text)
|
||||
text = _RE_ITALIC_UNDER.sub(r"\1", text)
|
||||
text = _RE_CODE_BLOCK.sub("", text)
|
||||
text = _RE_INLINE_CODE.sub(r"\1", text)
|
||||
text = _RE_HEADING.sub("", text)
|
||||
text = _RE_LINK.sub(r"\1", text)
|
||||
text = _RE_MULTI_NEWLINE.sub("\n\n", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
# ─── Thread Participation Tracking ───────────────────────────────────────────
|
||||
|
||||
|
||||
class ThreadParticipationTracker:
|
||||
"""Persistent tracking of threads the bot has participated in.
|
||||
|
||||
Replaces the identical ``_load/_save_participated_threads`` +
|
||||
``_mark_thread_participated`` pattern previously duplicated in
|
||||
discord.py and matrix.py.
|
||||
|
||||
Usage::
|
||||
|
||||
self._threads = ThreadParticipationTracker("discord")
|
||||
|
||||
# Check membership:
|
||||
if thread_id in self._threads:
|
||||
...
|
||||
|
||||
# Mark participation:
|
||||
self._threads.mark(thread_id)
|
||||
"""
|
||||
|
||||
_MAX_TRACKED = 500
|
||||
|
||||
def __init__(self, platform_name: str, max_tracked: int = 500):
|
||||
self._platform = platform_name
|
||||
self._max_tracked = max_tracked
|
||||
self._threads: set = self._load()
|
||||
|
||||
def _state_path(self) -> Path:
|
||||
from hermes_constants import get_hermes_home
|
||||
return get_hermes_home() / f"{self._platform}_threads.json"
|
||||
|
||||
def _load(self) -> set:
|
||||
path = self._state_path()
|
||||
if path.exists():
|
||||
try:
|
||||
return set(json.loads(path.read_text(encoding="utf-8")))
|
||||
except Exception:
|
||||
pass
|
||||
return set()
|
||||
|
||||
def _save(self) -> None:
|
||||
path = self._state_path()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
thread_list = list(self._threads)
|
||||
if len(thread_list) > self._max_tracked:
|
||||
thread_list = thread_list[-self._max_tracked:]
|
||||
self._threads = set(thread_list)
|
||||
path.write_text(json.dumps(thread_list), encoding="utf-8")
|
||||
|
||||
def mark(self, thread_id: str) -> None:
|
||||
"""Mark *thread_id* as participated and persist."""
|
||||
if thread_id not in self._threads:
|
||||
self._threads.add(thread_id)
|
||||
self._save()
|
||||
|
||||
def __contains__(self, thread_id: str) -> bool:
|
||||
return thread_id in self._threads
|
||||
|
||||
def clear(self) -> None:
|
||||
self._threads.clear()
|
||||
|
||||
|
||||
# ─── Phone Number Redaction ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def redact_phone(phone: str) -> str:
|
||||
"""Redact a phone number for logging, preserving country code and last 4.
|
||||
|
||||
Replaces the identical ``_redact_phone()`` functions in signal.py,
|
||||
sms.py, and bluebubbles.py.
|
||||
"""
|
||||
if not phone:
|
||||
return "<none>"
|
||||
if len(phone) <= 8:
|
||||
return phone[:2] + "****" + phone[-2:] if len(phone) > 4 else "****"
|
||||
return phone[:4] + "****" + phone[-4:]
|
||||
+46
-53
@@ -92,6 +92,7 @@ from gateway.platforms.base import (
|
||||
ProcessingOutcome,
|
||||
SendResult,
|
||||
)
|
||||
from gateway.platforms.helpers import ThreadParticipationTracker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -216,8 +217,7 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
self._pending_megolm: list = []
|
||||
|
||||
# Thread participation tracking (for require_mention bypass)
|
||||
self._bot_participated_threads: set = self._load_participated_threads()
|
||||
self._MAX_TRACKED_THREADS = 500
|
||||
self._threads = ThreadParticipationTracker("matrix")
|
||||
|
||||
# Mention/thread gating — parsed once from env vars.
|
||||
self._require_mention: bool = os.getenv("MATRIX_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no")
|
||||
@@ -352,7 +352,16 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
from mautrix.crypto import OlmMachine
|
||||
from mautrix.crypto.store import MemoryCryptoStore
|
||||
|
||||
crypto_store = MemoryCryptoStore()
|
||||
# account_id and pickle_key are required by mautrix ≥0.21.
|
||||
# Use the Matrix user ID as account_id for stable identity.
|
||||
# pickle_key secures in-memory serialisation; derive from
|
||||
# the same user_id:device_id pair used for the on-disk HMAC.
|
||||
_acct_id = self._user_id or "hermes"
|
||||
_pickle_key = f"{_acct_id}:{self._device_id}"
|
||||
crypto_store = MemoryCryptoStore(
|
||||
account_id=_acct_id,
|
||||
pickle_key=_pickle_key,
|
||||
)
|
||||
|
||||
# Restore persisted crypto state from a previous run.
|
||||
# Uses HMAC to verify integrity before unpickling.
|
||||
@@ -418,6 +427,11 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
if isinstance(sync_data, dict):
|
||||
rooms_join = sync_data.get("rooms", {}).get("join", {})
|
||||
self._joined_rooms = set(rooms_join.keys())
|
||||
# Store the next_batch token so incremental syncs start
|
||||
# from where the initial sync left off.
|
||||
nb = sync_data.get("next_batch")
|
||||
if nb:
|
||||
await client.sync_store.put_next_batch(nb)
|
||||
logger.info(
|
||||
"Matrix: initial sync complete, joined %d rooms",
|
||||
len(self._joined_rooms),
|
||||
@@ -809,19 +823,40 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
|
||||
async def _sync_loop(self) -> None:
|
||||
"""Continuously sync with the homeserver."""
|
||||
client = self._client
|
||||
# Resume from the token stored during the initial sync.
|
||||
next_batch = await client.sync_store.get_next_batch()
|
||||
while not self._closing:
|
||||
try:
|
||||
sync_data = await self._client.sync(timeout=30000)
|
||||
sync_data = await client.sync(
|
||||
since=next_batch, timeout=30000,
|
||||
)
|
||||
if isinstance(sync_data, dict):
|
||||
# Update joined rooms from sync response.
|
||||
rooms_join = sync_data.get("rooms", {}).get("join", {})
|
||||
if rooms_join:
|
||||
self._joined_rooms.update(rooms_join.keys())
|
||||
|
||||
# Share keys periodically if E2EE is enabled.
|
||||
if self._encryption and getattr(self._client, "crypto", None):
|
||||
# Advance the sync token so the next request is
|
||||
# incremental instead of a full initial sync.
|
||||
nb = sync_data.get("next_batch")
|
||||
if nb:
|
||||
next_batch = nb
|
||||
await client.sync_store.put_next_batch(nb)
|
||||
|
||||
# Dispatch events to registered handlers so that
|
||||
# _on_room_message / _on_reaction / _on_invite fire.
|
||||
try:
|
||||
await self._client.crypto.share_keys()
|
||||
tasks = client.handle_sync(sync_data)
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks)
|
||||
except Exception as exc:
|
||||
logger.warning("Matrix: sync event dispatch error: %s", exc)
|
||||
|
||||
# Share keys periodically if E2EE is enabled.
|
||||
if self._encryption and getattr(client, "crypto", None):
|
||||
try:
|
||||
await client.crypto.share_keys()
|
||||
except Exception as exc:
|
||||
logger.warning("Matrix: E2EE key share failed: %s", exc)
|
||||
|
||||
@@ -984,7 +1019,7 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
# Require-mention gating.
|
||||
if not is_dm:
|
||||
is_free_room = room_id in self._free_rooms
|
||||
in_bot_thread = bool(thread_id and thread_id in self._bot_participated_threads)
|
||||
in_bot_thread = bool(thread_id and thread_id in self._threads)
|
||||
if self._require_mention and not is_free_room and not in_bot_thread:
|
||||
if not is_mentioned:
|
||||
return None
|
||||
@@ -992,7 +1027,7 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
# DM mention-thread.
|
||||
if is_dm and not thread_id and self._dm_mention_threads and is_mentioned:
|
||||
thread_id = event_id
|
||||
self._track_thread(thread_id)
|
||||
self._threads.mark(thread_id)
|
||||
|
||||
# Strip mention from body.
|
||||
if is_mentioned:
|
||||
@@ -1001,7 +1036,7 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
# Auto-thread.
|
||||
if not is_dm and not thread_id and self._auto_thread:
|
||||
thread_id = event_id
|
||||
self._track_thread(thread_id)
|
||||
self._threads.mark(thread_id)
|
||||
|
||||
display_name = await self._get_display_name(room_id, sender)
|
||||
source = self.build_source(
|
||||
@@ -1013,7 +1048,7 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
)
|
||||
|
||||
if thread_id:
|
||||
self._track_thread(thread_id)
|
||||
self._threads.mark(thread_id)
|
||||
|
||||
self._background_read_receipt(room_id, event_id)
|
||||
|
||||
@@ -1662,48 +1697,6 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
for rid in self._joined_rooms
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Thread participation tracking
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _thread_state_path() -> Path:
|
||||
"""Path to the persisted thread participation set."""
|
||||
from hermes_cli.config import get_hermes_home
|
||||
return get_hermes_home() / "matrix_threads.json"
|
||||
|
||||
@classmethod
|
||||
def _load_participated_threads(cls) -> set:
|
||||
"""Load persisted thread IDs from disk."""
|
||||
path = cls._thread_state_path()
|
||||
try:
|
||||
if path.exists():
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
if isinstance(data, list):
|
||||
return set(data)
|
||||
except Exception as e:
|
||||
logger.debug("Could not load matrix thread state: %s", e)
|
||||
return set()
|
||||
|
||||
def _save_participated_threads(self) -> None:
|
||||
"""Persist the current thread set to disk (best-effort)."""
|
||||
path = self._thread_state_path()
|
||||
try:
|
||||
thread_list = list(self._bot_participated_threads)
|
||||
if len(thread_list) > self._MAX_TRACKED_THREADS:
|
||||
thread_list = thread_list[-self._MAX_TRACKED_THREADS:]
|
||||
self._bot_participated_threads = set(thread_list)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(json.dumps(thread_list), encoding="utf-8")
|
||||
except Exception as e:
|
||||
logger.debug("Could not save matrix thread state: %s", e)
|
||||
|
||||
def _track_thread(self, thread_id: str) -> None:
|
||||
"""Add a thread to the participation set and persist."""
|
||||
if thread_id not in self._bot_participated_threads:
|
||||
self._bot_participated_threads.add(thread_id)
|
||||
self._save_participated_threads()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Mention detection helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -18,11 +18,11 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.helpers import MessageDeduplicator
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
@@ -96,10 +96,8 @@ class MattermostAdapter(BasePlatformAdapter):
|
||||
or os.getenv("MATTERMOST_REPLY_MODE", "off")
|
||||
).lower()
|
||||
|
||||
# Dedup cache: post_id → timestamp (prevent reprocessing)
|
||||
self._seen_posts: Dict[str, float] = {}
|
||||
self._SEEN_MAX = 2000
|
||||
self._SEEN_TTL = 300 # 5 minutes
|
||||
# Dedup cache (prevent reprocessing)
|
||||
self._dedup = MessageDeduplicator()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# HTTP helpers
|
||||
@@ -604,10 +602,8 @@ class MattermostAdapter(BasePlatformAdapter):
|
||||
post_id = post.get("id", "")
|
||||
|
||||
# Dedup.
|
||||
self._prune_seen()
|
||||
if post_id in self._seen_posts:
|
||||
if self._dedup.is_duplicate(post_id):
|
||||
return
|
||||
self._seen_posts[post_id] = time.time()
|
||||
|
||||
# Build message event.
|
||||
channel_id = post.get("channel_id", "")
|
||||
@@ -734,13 +730,4 @@ class MattermostAdapter(BasePlatformAdapter):
|
||||
|
||||
await self.handle_message(msg_event)
|
||||
|
||||
def _prune_seen(self) -> None:
|
||||
"""Remove expired entries from the dedup cache."""
|
||||
if len(self._seen_posts) < self._SEEN_MAX:
|
||||
return
|
||||
now = time.time()
|
||||
self._seen_posts = {
|
||||
pid: ts
|
||||
for pid, ts in self._seen_posts.items()
|
||||
if now - ts < self._SEEN_TTL
|
||||
}
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ from gateway.platforms.base import (
|
||||
cache_document_from_bytes,
|
||||
cache_image_from_url,
|
||||
)
|
||||
from gateway.platforms.helpers import redact_phone
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -51,22 +52,10 @@ SSE_RETRY_DELAY_MAX = 60.0
|
||||
HEALTH_CHECK_INTERVAL = 30.0 # seconds between health checks
|
||||
HEALTH_CHECK_STALE_THRESHOLD = 120.0 # seconds without SSE activity before concern
|
||||
|
||||
# E.164 phone number pattern for redaction
|
||||
_PHONE_RE = re.compile(r"\+[1-9]\d{6,14}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _redact_phone(phone: str) -> str:
|
||||
"""Redact a phone number for logging: +15551234567 -> +155****4567."""
|
||||
if not phone:
|
||||
return "<none>"
|
||||
if len(phone) <= 8:
|
||||
return phone[:2] + "****" + phone[-2:] if len(phone) > 4 else "****"
|
||||
return phone[:4] + "****" + phone[-4:]
|
||||
|
||||
|
||||
def _parse_comma_list(value: str) -> List[str]:
|
||||
"""Split a comma-separated string into a list, stripping whitespace."""
|
||||
@@ -184,10 +173,8 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
self._recent_sent_timestamps: set = set()
|
||||
self._max_recent_timestamps = 50
|
||||
|
||||
self._phone_lock_identity: Optional[str] = None
|
||||
|
||||
logger.info("Signal adapter initialized: url=%s account=%s groups=%s",
|
||||
self.http_url, _redact_phone(self.account),
|
||||
self.http_url, redact_phone(self.account),
|
||||
"enabled" if self.group_allow_from else "disabled")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@@ -202,23 +189,7 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
|
||||
# Acquire scoped lock to prevent duplicate Signal listeners for the same phone
|
||||
try:
|
||||
from gateway.status import acquire_scoped_lock
|
||||
|
||||
self._phone_lock_identity = self.account
|
||||
acquired, existing = acquire_scoped_lock(
|
||||
"signal-phone",
|
||||
self._phone_lock_identity,
|
||||
metadata={"platform": self.platform.value},
|
||||
)
|
||||
if not acquired:
|
||||
owner_pid = existing.get("pid") if isinstance(existing, dict) else None
|
||||
message = (
|
||||
"Another local Hermes gateway is already using this Signal account"
|
||||
+ (f" (PID {owner_pid})." if owner_pid else ".")
|
||||
+ " Stop the other gateway before starting a second Signal listener."
|
||||
)
|
||||
logger.error("Signal: %s", message)
|
||||
self._set_fatal_error("signal_phone_lock", message, retryable=False)
|
||||
if not self._acquire_platform_lock('signal-phone', self.account, 'Signal account'):
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning("Signal: Could not acquire phone lock (non-fatal): %s", e)
|
||||
@@ -270,13 +241,7 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
await self.client.aclose()
|
||||
self.client = None
|
||||
|
||||
if self._phone_lock_identity:
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
release_scoped_lock("signal-phone", self._phone_lock_identity)
|
||||
except Exception as e:
|
||||
logger.warning("Signal: Error releasing phone lock: %s", e, exc_info=True)
|
||||
self._phone_lock_identity = None
|
||||
self._release_platform_lock()
|
||||
|
||||
logger.info("Signal: disconnected")
|
||||
|
||||
@@ -542,7 +507,7 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
)
|
||||
|
||||
logger.debug("Signal: message from %s in %s: %s",
|
||||
_redact_phone(sender), chat_id[:20], (text or "")[:50])
|
||||
redact_phone(sender), chat_id[:20], (text or "")[:50])
|
||||
|
||||
await self.handle_message(event)
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ from pathlib import Path as _Path
|
||||
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.helpers import MessageDeduplicator
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
@@ -89,11 +90,9 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
self._team_clients: Dict[str, AsyncWebClient] = {} # team_id → WebClient
|
||||
self._team_bot_user_ids: Dict[str, str] = {} # team_id → bot_user_id
|
||||
self._channel_team: Dict[str, str] = {} # channel_id → team_id
|
||||
# Dedup cache: event_ts → timestamp. Prevents duplicate bot
|
||||
# responses when Socket Mode reconnects redeliver events.
|
||||
self._seen_messages: Dict[str, float] = {}
|
||||
self._SEEN_TTL = 300 # 5 minutes
|
||||
self._SEEN_MAX = 2000 # prune threshold
|
||||
# Dedup cache: prevents duplicate bot responses when Socket Mode
|
||||
# reconnects redeliver events.
|
||||
self._dedup = MessageDeduplicator()
|
||||
# Track pending approval message_ts → resolved flag to prevent
|
||||
# double-clicks on approval buttons.
|
||||
self._approval_resolved: Dict[str, bool] = {}
|
||||
@@ -152,15 +151,7 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
logger.warning("[Slack] Failed to read %s: %s", tokens_file, e)
|
||||
|
||||
try:
|
||||
# Acquire scoped lock to prevent duplicate app token usage
|
||||
from gateway.status import acquire_scoped_lock
|
||||
self._token_lock_identity = app_token
|
||||
acquired, existing = acquire_scoped_lock('slack-app-token', app_token, metadata={'platform': 'slack'})
|
||||
if not acquired:
|
||||
owner_pid = existing.get('pid') if isinstance(existing, dict) else None
|
||||
message = f'Slack app token already in use' + (f' (PID {owner_pid})' if owner_pid else '') + '. Stop the other gateway first.'
|
||||
logger.error('[%s] %s', self.name, message)
|
||||
self._set_fatal_error('slack_token_lock', message, retryable=False)
|
||||
if not self._acquire_platform_lock('slack-app-token', app_token, 'Slack app token'):
|
||||
return False
|
||||
|
||||
# First token is the primary — used for AsyncApp / Socket Mode
|
||||
@@ -247,14 +238,7 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
logger.warning("[Slack] Error while closing Socket Mode handler: %s", e, exc_info=True)
|
||||
self._running = False
|
||||
|
||||
# Release the token lock (use stored identity, not re-read env)
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
if getattr(self, '_token_lock_identity', None):
|
||||
release_scoped_lock('slack-app-token', self._token_lock_identity)
|
||||
self._token_lock_identity = None
|
||||
except Exception:
|
||||
pass
|
||||
self._release_platform_lock()
|
||||
|
||||
logger.info("[Slack] Disconnected")
|
||||
|
||||
@@ -953,17 +937,8 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
"""Handle an incoming Slack message event."""
|
||||
# Dedup: Slack Socket Mode can redeliver events after reconnects (#4777)
|
||||
event_ts = event.get("ts", "")
|
||||
if event_ts:
|
||||
now = time.time()
|
||||
if event_ts in self._seen_messages:
|
||||
return
|
||||
self._seen_messages[event_ts] = now
|
||||
if len(self._seen_messages) > self._SEEN_MAX:
|
||||
cutoff = now - self._SEEN_TTL
|
||||
self._seen_messages = {
|
||||
k: v for k, v in self._seen_messages.items()
|
||||
if v > cutoff
|
||||
}
|
||||
if event_ts and self._dedup.is_duplicate(event_ts):
|
||||
return
|
||||
|
||||
# Bot message filtering (SLACK_ALLOW_BOTS / config allow_bots):
|
||||
# "none" — ignore all bot messages (default, backward-compatible)
|
||||
|
||||
+129
-32
@@ -10,6 +10,9 @@ Shares credentials with the optional telephony skill — same env vars:
|
||||
|
||||
Gateway-specific env vars:
|
||||
- SMS_WEBHOOK_PORT (default 8080)
|
||||
- SMS_WEBHOOK_HOST (default 0.0.0.0)
|
||||
- SMS_WEBHOOK_URL (public URL for Twilio signature validation — required)
|
||||
- SMS_INSECURE_NO_SIGNATURE (true to disable signature validation — dev only)
|
||||
- SMS_ALLOWED_USERS (comma-separated E.164 phone numbers)
|
||||
- SMS_ALLOW_ALL_USERS (true/false)
|
||||
- SMS_HOME_CHANNEL (phone number for cron delivery)
|
||||
@@ -17,9 +20,10 @@ Gateway-specific env vars:
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import urllib.parse
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
@@ -30,24 +34,14 @@ from gateway.platforms.base import (
|
||||
MessageType,
|
||||
SendResult,
|
||||
)
|
||||
from gateway.platforms.helpers import redact_phone, strip_markdown
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TWILIO_API_BASE = "https://api.twilio.com/2010-04-01/Accounts"
|
||||
MAX_SMS_LENGTH = 1600 # ~10 SMS segments
|
||||
DEFAULT_WEBHOOK_PORT = 8080
|
||||
|
||||
# E.164 phone number pattern for redaction
|
||||
_PHONE_RE = re.compile(r"\+[1-9]\d{6,14}")
|
||||
|
||||
|
||||
def _redact_phone(phone: str) -> str:
|
||||
"""Redact a phone number for logging: +15551234567 -> +1555***4567."""
|
||||
if not phone:
|
||||
return "<none>"
|
||||
if len(phone) <= 8:
|
||||
return phone[:2] + "***" + phone[-2:] if len(phone) > 4 else "****"
|
||||
return phone[:5] + "***" + phone[-4:]
|
||||
DEFAULT_WEBHOOK_HOST = "0.0.0.0"
|
||||
|
||||
|
||||
def check_sms_requirements() -> bool:
|
||||
@@ -77,6 +71,8 @@ class SmsAdapter(BasePlatformAdapter):
|
||||
self._webhook_port: int = int(
|
||||
os.getenv("SMS_WEBHOOK_PORT", str(DEFAULT_WEBHOOK_PORT))
|
||||
)
|
||||
self._webhook_host: str = os.getenv("SMS_WEBHOOK_HOST", DEFAULT_WEBHOOK_HOST)
|
||||
self._webhook_url: str = os.getenv("SMS_WEBHOOK_URL", "").strip()
|
||||
self._runner = None
|
||||
self._http_session: Optional["aiohttp.ClientSession"] = None
|
||||
|
||||
@@ -98,13 +94,33 @@ class SmsAdapter(BasePlatformAdapter):
|
||||
logger.error("[sms] TWILIO_PHONE_NUMBER not set — cannot send replies")
|
||||
return False
|
||||
|
||||
insecure_no_sig = os.getenv("SMS_INSECURE_NO_SIGNATURE", "").lower() == "true"
|
||||
|
||||
if not self._webhook_url and not insecure_no_sig:
|
||||
logger.error(
|
||||
"[sms] Refusing to start: SMS_WEBHOOK_URL is required for Twilio "
|
||||
"signature validation. Set it to the public URL configured in your "
|
||||
"Twilio console (e.g. https://example.com/webhooks/twilio). "
|
||||
"For local development without validation, set "
|
||||
"SMS_INSECURE_NO_SIGNATURE=true (NOT recommended for production).",
|
||||
)
|
||||
return False
|
||||
|
||||
if insecure_no_sig and not self._webhook_url:
|
||||
logger.warning(
|
||||
"[sms] SMS_INSECURE_NO_SIGNATURE=true — Twilio signature validation "
|
||||
"is DISABLED. Any client that can reach port %d can inject messages. "
|
||||
"Do NOT use this in production.",
|
||||
self._webhook_port,
|
||||
)
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_post("/webhooks/twilio", self._handle_webhook)
|
||||
app.router.add_get("/health", lambda _: web.Response(text="ok"))
|
||||
|
||||
self._runner = web.AppRunner(app)
|
||||
await self._runner.setup()
|
||||
site = web.TCPSite(self._runner, "0.0.0.0", self._webhook_port)
|
||||
site = web.TCPSite(self._runner, self._webhook_host, self._webhook_port)
|
||||
await site.start()
|
||||
self._http_session = aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
@@ -112,9 +128,10 @@ class SmsAdapter(BasePlatformAdapter):
|
||||
self._running = True
|
||||
|
||||
logger.info(
|
||||
"[sms] Twilio webhook server listening on port %d, from: %s",
|
||||
"[sms] Twilio webhook server listening on %s:%d, from: %s",
|
||||
self._webhook_host,
|
||||
self._webhook_port,
|
||||
_redact_phone(self._from_number),
|
||||
redact_phone(self._from_number),
|
||||
)
|
||||
return True
|
||||
|
||||
@@ -163,7 +180,7 @@ class SmsAdapter(BasePlatformAdapter):
|
||||
error_msg = body.get("message", str(body))
|
||||
logger.error(
|
||||
"[sms] send failed to %s: %s %s",
|
||||
_redact_phone(chat_id),
|
||||
redact_phone(chat_id),
|
||||
resp.status,
|
||||
error_msg,
|
||||
)
|
||||
@@ -174,7 +191,7 @@ class SmsAdapter(BasePlatformAdapter):
|
||||
msg_sid = body.get("sid", "")
|
||||
last_result = SendResult(success=True, message_id=msg_sid)
|
||||
except Exception as e:
|
||||
logger.error("[sms] send error to %s: %s", _redact_phone(chat_id), e)
|
||||
logger.error("[sms] send error to %s: %s", redact_phone(chat_id), e)
|
||||
return SendResult(success=False, error=str(e))
|
||||
finally:
|
||||
# Close session only if we created a fallback (no persistent session)
|
||||
@@ -192,16 +209,75 @@ class SmsAdapter(BasePlatformAdapter):
|
||||
|
||||
def format_message(self, content: str) -> str:
|
||||
"""Strip markdown — SMS renders it as literal characters."""
|
||||
content = re.sub(r"\*\*(.+?)\*\*", r"\1", content, flags=re.DOTALL)
|
||||
content = re.sub(r"\*(.+?)\*", r"\1", content, flags=re.DOTALL)
|
||||
content = re.sub(r"__(.+?)__", r"\1", content, flags=re.DOTALL)
|
||||
content = re.sub(r"_(.+?)_", r"\1", content, flags=re.DOTALL)
|
||||
content = re.sub(r"```[a-z]*\n?", "", content)
|
||||
content = re.sub(r"`(.+?)`", r"\1", content)
|
||||
content = re.sub(r"^#{1,6}\s+", "", content, flags=re.MULTILINE)
|
||||
content = re.sub(r"\[([^\]]+)\]\([^\)]+\)", r"\1", content)
|
||||
content = re.sub(r"\n{3,}", "\n\n", content)
|
||||
return content.strip()
|
||||
return strip_markdown(content)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Twilio signature validation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _validate_twilio_signature(
|
||||
self, url: str, post_params: dict, signature: str,
|
||||
) -> bool:
|
||||
"""Validate ``X-Twilio-Signature`` header (HMAC-SHA1, base64).
|
||||
|
||||
Tries both with and without the default port for the URL scheme,
|
||||
since Twilio may sign with either variant.
|
||||
|
||||
Algorithm: https://www.twilio.com/docs/usage/security#validating-requests
|
||||
"""
|
||||
if self._check_signature(url, post_params, signature):
|
||||
return True
|
||||
|
||||
variant = self._port_variant_url(url)
|
||||
if variant and self._check_signature(variant, post_params, signature):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _check_signature(
|
||||
self, url: str, post_params: dict, signature: str,
|
||||
) -> bool:
|
||||
"""Compute and compare a single Twilio signature."""
|
||||
data_to_sign = url
|
||||
for key in sorted(post_params.keys()):
|
||||
data_to_sign += key + post_params[key]
|
||||
mac = hmac.new(
|
||||
self._auth_token.encode("utf-8"),
|
||||
data_to_sign.encode("utf-8"),
|
||||
hashlib.sha1,
|
||||
)
|
||||
computed = base64.b64encode(mac.digest()).decode("utf-8")
|
||||
return hmac.compare_digest(computed, signature)
|
||||
|
||||
@staticmethod
|
||||
def _port_variant_url(url: str) -> str | None:
|
||||
"""Return the URL with the default port toggled, or None.
|
||||
|
||||
Only toggles default ports (443 for https, 80 for http).
|
||||
Non-standard ports are never modified.
|
||||
"""
|
||||
parsed = urllib.parse.urlparse(url)
|
||||
default_ports = {"https": 443, "http": 80}
|
||||
default_port = default_ports.get(parsed.scheme)
|
||||
if default_port is None:
|
||||
return None
|
||||
|
||||
if parsed.port == default_port:
|
||||
# Has explicit default port → strip it
|
||||
return urllib.parse.urlunparse(
|
||||
(parsed.scheme, parsed.hostname, parsed.path,
|
||||
parsed.params, parsed.query, parsed.fragment)
|
||||
)
|
||||
elif parsed.port is None:
|
||||
# No port → add default
|
||||
netloc = f"{parsed.hostname}:{default_port}"
|
||||
return urllib.parse.urlunparse(
|
||||
(parsed.scheme, netloc, parsed.path,
|
||||
parsed.params, parsed.query, parsed.fragment)
|
||||
)
|
||||
|
||||
# Non-standard port — no variant
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Twilio webhook handler
|
||||
@@ -213,7 +289,7 @@ class SmsAdapter(BasePlatformAdapter):
|
||||
try:
|
||||
raw = await request.read()
|
||||
# Twilio sends form-encoded data, not JSON
|
||||
form = urllib.parse.parse_qs(raw.decode("utf-8"))
|
||||
form = urllib.parse.parse_qs(raw.decode("utf-8"), keep_blank_values=True)
|
||||
except Exception as e:
|
||||
logger.error("[sms] webhook parse error: %s", e)
|
||||
return web.Response(
|
||||
@@ -222,6 +298,27 @@ class SmsAdapter(BasePlatformAdapter):
|
||||
status=400,
|
||||
)
|
||||
|
||||
# Validate Twilio request signature when SMS_WEBHOOK_URL is configured
|
||||
if self._webhook_url:
|
||||
twilio_sig = request.headers.get("X-Twilio-Signature", "")
|
||||
if not twilio_sig:
|
||||
logger.warning("[sms] Rejected: missing X-Twilio-Signature header")
|
||||
return web.Response(
|
||||
text='<?xml version="1.0" encoding="UTF-8"?><Response></Response>',
|
||||
content_type="application/xml",
|
||||
status=403,
|
||||
)
|
||||
flat_params = {k: v[0] for k, v in form.items() if v}
|
||||
if not self._validate_twilio_signature(
|
||||
self._webhook_url, flat_params, twilio_sig
|
||||
):
|
||||
logger.warning("[sms] Rejected: invalid Twilio signature")
|
||||
return web.Response(
|
||||
text='<?xml version="1.0" encoding="UTF-8"?><Response></Response>',
|
||||
content_type="application/xml",
|
||||
status=403,
|
||||
)
|
||||
|
||||
# Extract fields (parse_qs returns lists)
|
||||
from_number = (form.get("From", [""]))[0].strip()
|
||||
to_number = (form.get("To", [""]))[0].strip()
|
||||
@@ -236,7 +333,7 @@ class SmsAdapter(BasePlatformAdapter):
|
||||
|
||||
# Ignore messages from our own number (echo prevention)
|
||||
if from_number == self._from_number:
|
||||
logger.debug("[sms] ignoring echo from own number %s", _redact_phone(from_number))
|
||||
logger.debug("[sms] ignoring echo from own number %s", redact_phone(from_number))
|
||||
return web.Response(
|
||||
text='<?xml version="1.0" encoding="UTF-8"?><Response></Response>',
|
||||
content_type="application/xml",
|
||||
@@ -244,8 +341,8 @@ class SmsAdapter(BasePlatformAdapter):
|
||||
|
||||
logger.info(
|
||||
"[sms] inbound from %s -> %s: %s",
|
||||
_redact_phone(from_number),
|
||||
_redact_phone(to_number),
|
||||
redact_phone(from_number),
|
||||
redact_phone(to_number),
|
||||
text[:80],
|
||||
)
|
||||
|
||||
|
||||
@@ -147,7 +147,6 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
self._text_batch_split_delay_seconds = float(os.getenv("HERMES_TELEGRAM_TEXT_BATCH_SPLIT_DELAY_SECONDS", "2.0"))
|
||||
self._pending_text_batches: Dict[str, MessageEvent] = {}
|
||||
self._pending_text_batch_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._token_lock_identity: Optional[str] = None
|
||||
self._polling_error_task: Optional[asyncio.Task] = None
|
||||
self._polling_conflict_count: int = 0
|
||||
self._polling_network_error_count: int = 0
|
||||
@@ -300,9 +299,11 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
|
||||
# Exhausted retries — fatal
|
||||
message = (
|
||||
"Another Telegram bot poller is already using this token. "
|
||||
"Another process is already polling this Telegram bot token "
|
||||
"(possibly OpenClaw or another Hermes instance). "
|
||||
"Hermes stopped Telegram polling after %d retries. "
|
||||
"Make sure only one gateway instance is running for this bot token."
|
||||
"Only one poller can run per token — stop the other process "
|
||||
"and restart with 'hermes start'."
|
||||
% MAX_CONFLICT_RETRIES
|
||||
)
|
||||
logger.error("[%s] %s Original error: %s", self.name, message, error)
|
||||
@@ -497,23 +498,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
return False
|
||||
|
||||
try:
|
||||
from gateway.status import acquire_scoped_lock
|
||||
|
||||
self._token_lock_identity = self.config.token
|
||||
acquired, existing = acquire_scoped_lock(
|
||||
"telegram-bot-token",
|
||||
self._token_lock_identity,
|
||||
metadata={"platform": self.platform.value},
|
||||
)
|
||||
if not acquired:
|
||||
owner_pid = existing.get("pid") if isinstance(existing, dict) else None
|
||||
message = (
|
||||
"Another local Hermes gateway is already using this Telegram bot token"
|
||||
+ (f" (PID {owner_pid})." if owner_pid else ".")
|
||||
+ " Stop the other gateway before starting a second Telegram poller."
|
||||
)
|
||||
logger.error("[%s] %s", self.name, message)
|
||||
self._set_fatal_error("telegram_token_lock", message, retryable=False)
|
||||
if not self._acquire_platform_lock('telegram-bot-token', self.config.token, 'Telegram bot token'):
|
||||
return False
|
||||
|
||||
# Build the application
|
||||
@@ -737,12 +722,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
if self._token_lock_identity:
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
release_scoped_lock("telegram-bot-token", self._token_lock_identity)
|
||||
except Exception:
|
||||
pass
|
||||
self._release_platform_lock()
|
||||
message = f"Telegram startup failed: {e}"
|
||||
self._set_fatal_error("telegram_connect_error", message, retryable=True)
|
||||
logger.error("[%s] Failed to connect to Telegram: %s", self.name, e, exc_info=True)
|
||||
@@ -768,12 +748,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
await self._app.shutdown()
|
||||
except Exception as e:
|
||||
logger.warning("[%s] Error during Telegram disconnect: %s", self.name, e, exc_info=True)
|
||||
if self._token_lock_identity:
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
release_scoped_lock("telegram-bot-token", self._token_lock_identity)
|
||||
except Exception as e:
|
||||
logger.warning("[%s] Error releasing Telegram token lock: %s", self.name, e, exc_info=True)
|
||||
self._release_platform_lock()
|
||||
|
||||
for task in self._pending_photo_batch_tasks.values():
|
||||
if task and not task.done():
|
||||
@@ -784,7 +759,6 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
self._mark_disconnected()
|
||||
self._app = None
|
||||
self._bot = None
|
||||
self._token_lock_identity = None
|
||||
logger.info("[%s] Disconnected from Telegram", self.name)
|
||||
|
||||
def _should_thread_reply(self, reply_to: Optional[str], chunk_index: int) -> bool:
|
||||
|
||||
@@ -201,6 +201,7 @@ class WebhookAdapter(BasePlatformAdapter):
|
||||
"dingtalk",
|
||||
"feishu",
|
||||
"wecom",
|
||||
"wecom_callback",
|
||||
"weixin",
|
||||
"bluebubbles",
|
||||
):
|
||||
|
||||
+18
-22
@@ -59,6 +59,7 @@ except ImportError:
|
||||
httpx = None # type: ignore[assignment]
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.helpers import MessageDeduplicator
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
@@ -92,7 +93,6 @@ REQUEST_TIMEOUT_SECONDS = 15.0
|
||||
HEARTBEAT_INTERVAL_SECONDS = 30.0
|
||||
RECONNECT_BACKOFF = [2, 5, 10, 30, 60]
|
||||
|
||||
DEDUP_WINDOW_SECONDS = 300
|
||||
DEDUP_MAX_SIZE = 1000
|
||||
|
||||
IMAGE_MAX_BYTES = 10 * 1024 * 1024
|
||||
@@ -172,7 +172,7 @@ class WeComAdapter(BasePlatformAdapter):
|
||||
self._listen_task: Optional[asyncio.Task] = None
|
||||
self._heartbeat_task: Optional[asyncio.Task] = None
|
||||
self._pending_responses: Dict[str, asyncio.Future] = {}
|
||||
self._seen_messages: Dict[str, float] = {}
|
||||
self._dedup = MessageDeduplicator(max_size=DEDUP_MAX_SIZE)
|
||||
self._reply_req_ids: Dict[str, str] = {}
|
||||
|
||||
# Text batching: merge rapid successive messages (Telegram-style).
|
||||
@@ -250,7 +250,7 @@ class WeComAdapter(BasePlatformAdapter):
|
||||
await self._http_client.aclose()
|
||||
self._http_client = None
|
||||
|
||||
self._seen_messages.clear()
|
||||
self._dedup.clear()
|
||||
logger.info("[%s] Disconnected", self.name)
|
||||
|
||||
async def _cleanup_ws(self) -> None:
|
||||
@@ -476,7 +476,7 @@ class WeComAdapter(BasePlatformAdapter):
|
||||
return
|
||||
|
||||
msg_id = str(body.get("msgid") or self._payload_req_id(payload) or uuid.uuid4().hex)
|
||||
if self._is_duplicate(msg_id):
|
||||
if self._dedup.is_duplicate(msg_id):
|
||||
logger.debug("[%s] Duplicate message %s ignored", self.name, msg_id)
|
||||
return
|
||||
self._remember_reply_req_id(msg_id, self._payload_req_id(payload))
|
||||
@@ -636,6 +636,13 @@ class WeComAdapter(BasePlatformAdapter):
|
||||
if voice_text:
|
||||
text_parts.append(voice_text)
|
||||
|
||||
# Extract appmsg title (filename) for WeCom AI Bot attachments
|
||||
if msgtype == "appmsg":
|
||||
appmsg = body.get("appmsg") if isinstance(body.get("appmsg"), dict) else {}
|
||||
title = str(appmsg.get("title") or "").strip()
|
||||
if title:
|
||||
text_parts.append(title)
|
||||
|
||||
quote = body.get("quote") if isinstance(body.get("quote"), dict) else {}
|
||||
quote_type = str(quote.get("msgtype") or "").lower()
|
||||
if quote_type == "text":
|
||||
@@ -668,6 +675,13 @@ class WeComAdapter(BasePlatformAdapter):
|
||||
refs.append(("image", body["image"]))
|
||||
if msgtype == "file" and isinstance(body.get("file"), dict):
|
||||
refs.append(("file", body["file"]))
|
||||
# Handle appmsg (WeCom AI Bot attachments with PDF/Word/Excel)
|
||||
if msgtype == "appmsg" and isinstance(body.get("appmsg"), dict):
|
||||
appmsg = body["appmsg"]
|
||||
if isinstance(appmsg.get("file"), dict):
|
||||
refs.append(("file", appmsg["file"]))
|
||||
elif isinstance(appmsg.get("image"), dict):
|
||||
refs.append(("image", appmsg["image"]))
|
||||
|
||||
quote = body.get("quote") if isinstance(body.get("quote"), dict) else {}
|
||||
quote_type = str(quote.get("msgtype") or "").lower()
|
||||
@@ -825,24 +839,6 @@ class WeComAdapter(BasePlatformAdapter):
|
||||
wildcard = self._groups.get("*")
|
||||
return wildcard if isinstance(wildcard, dict) else {}
|
||||
|
||||
def _is_duplicate(self, msg_id: str) -> bool:
|
||||
now = time.time()
|
||||
if len(self._seen_messages) > DEDUP_MAX_SIZE:
|
||||
cutoff = now - DEDUP_WINDOW_SECONDS
|
||||
self._seen_messages = {
|
||||
key: ts for key, ts in self._seen_messages.items() if ts > cutoff
|
||||
}
|
||||
if self._reply_req_ids:
|
||||
self._reply_req_ids = {
|
||||
key: value for key, value in self._reply_req_ids.items() if key in self._seen_messages
|
||||
}
|
||||
|
||||
if msg_id in self._seen_messages:
|
||||
return True
|
||||
|
||||
self._seen_messages[msg_id] = now
|
||||
return False
|
||||
|
||||
def _remember_reply_req_id(self, message_id: str, req_id: str) -> None:
|
||||
normalized_message_id = str(message_id or "").strip()
|
||||
normalized_req_id = str(req_id or "").strip()
|
||||
|
||||
@@ -0,0 +1,387 @@
|
||||
"""WeCom callback-mode adapter for self-built enterprise applications.
|
||||
|
||||
Unlike the bot/websocket adapter in ``wecom.py``, this handles the standard
|
||||
WeCom callback flow: WeCom POSTs encrypted XML to an HTTP endpoint, the
|
||||
adapter decrypts it, queues the message for the agent, and immediately
|
||||
acknowledges. The agent's reply is delivered later via the proactive
|
||||
``message/send`` API using an access-token.
|
||||
|
||||
Supports multiple self-built apps under one gateway instance, scoped by
|
||||
``corp_id:user_id`` to avoid cross-corp collisions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import socket as _socket
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
from xml.etree import ElementTree as ET
|
||||
|
||||
try:
|
||||
from aiohttp import web
|
||||
|
||||
AIOHTTP_AVAILABLE = True
|
||||
except ImportError:
|
||||
web = None # type: ignore[assignment]
|
||||
AIOHTTP_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import httpx
|
||||
|
||||
HTTPX_AVAILABLE = True
|
||||
except ImportError:
|
||||
httpx = None # type: ignore[assignment]
|
||||
HTTPX_AVAILABLE = False
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType, SendResult
|
||||
from gateway.platforms.wecom_crypto import WXBizMsgCrypt, WeComCryptoError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_HOST = "0.0.0.0"
|
||||
DEFAULT_PORT = 8645
|
||||
DEFAULT_PATH = "/wecom/callback"
|
||||
ACCESS_TOKEN_TTL_SECONDS = 7200
|
||||
MESSAGE_DEDUP_TTL_SECONDS = 300
|
||||
|
||||
|
||||
def check_wecom_callback_requirements() -> bool:
|
||||
return AIOHTTP_AVAILABLE and HTTPX_AVAILABLE
|
||||
|
||||
|
||||
class WecomCallbackAdapter(BasePlatformAdapter):
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.WECOM_CALLBACK)
|
||||
extra = config.extra or {}
|
||||
self._host = str(extra.get("host") or DEFAULT_HOST)
|
||||
self._port = int(extra.get("port") or DEFAULT_PORT)
|
||||
self._path = str(extra.get("path") or DEFAULT_PATH)
|
||||
self._apps: List[Dict[str, Any]] = self._normalize_apps(extra)
|
||||
self._runner: Optional[web.AppRunner] = None
|
||||
self._site: Optional[web.TCPSite] = None
|
||||
self._app: Optional[web.Application] = None
|
||||
self._http_client: Optional[httpx.AsyncClient] = None
|
||||
self._message_queue: asyncio.Queue[MessageEvent] = asyncio.Queue()
|
||||
self._poll_task: Optional[asyncio.Task] = None
|
||||
self._seen_messages: Dict[str, float] = {}
|
||||
self._user_app_map: Dict[str, str] = {}
|
||||
self._access_tokens: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# App normalisation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _user_app_key(corp_id: str, user_id: str) -> str:
|
||||
return f"{corp_id}:{user_id}" if corp_id else user_id
|
||||
|
||||
@staticmethod
|
||||
def _normalize_apps(extra: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
apps = extra.get("apps")
|
||||
if isinstance(apps, list) and apps:
|
||||
return [dict(app) for app in apps if isinstance(app, dict)]
|
||||
if extra.get("corp_id"):
|
||||
return [
|
||||
{
|
||||
"name": extra.get("name") or "default",
|
||||
"corp_id": extra.get("corp_id", ""),
|
||||
"corp_secret": extra.get("corp_secret", ""),
|
||||
"agent_id": str(extra.get("agent_id", "")),
|
||||
"token": extra.get("token", ""),
|
||||
"encoding_aes_key": extra.get("encoding_aes_key", ""),
|
||||
}
|
||||
]
|
||||
return []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def connect(self) -> bool:
|
||||
if not self._apps:
|
||||
logger.warning("[WecomCallback] No callback apps configured")
|
||||
return False
|
||||
if not check_wecom_callback_requirements():
|
||||
logger.warning("[WecomCallback] aiohttp/httpx not installed")
|
||||
return False
|
||||
|
||||
# Quick port-in-use check.
|
||||
try:
|
||||
with _socket.socket(_socket.AF_INET, _socket.SOCK_STREAM) as sock:
|
||||
sock.settimeout(1)
|
||||
sock.connect(("127.0.0.1", self._port))
|
||||
logger.error("[WecomCallback] Port %d already in use", self._port)
|
||||
return False
|
||||
except (ConnectionRefusedError, OSError):
|
||||
pass
|
||||
|
||||
try:
|
||||
self._http_client = httpx.AsyncClient(timeout=20.0)
|
||||
self._app = web.Application()
|
||||
self._app.router.add_get("/health", self._handle_health)
|
||||
self._app.router.add_get(self._path, self._handle_verify)
|
||||
self._app.router.add_post(self._path, self._handle_callback)
|
||||
self._runner = web.AppRunner(self._app)
|
||||
await self._runner.setup()
|
||||
self._site = web.TCPSite(self._runner, self._host, self._port)
|
||||
await self._site.start()
|
||||
self._poll_task = asyncio.create_task(self._poll_loop())
|
||||
self._mark_connected()
|
||||
logger.info(
|
||||
"[WecomCallback] HTTP server listening on %s:%s%s",
|
||||
self._host, self._port, self._path,
|
||||
)
|
||||
for app in self._apps:
|
||||
try:
|
||||
await self._refresh_access_token(app)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[WecomCallback] Initial token refresh failed for app '%s': %s",
|
||||
app.get("name", "default"), exc,
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
await self._cleanup()
|
||||
logger.exception("[WecomCallback] Failed to start")
|
||||
return False
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self._running = False
|
||||
if self._poll_task:
|
||||
self._poll_task.cancel()
|
||||
try:
|
||||
await self._poll_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._poll_task = None
|
||||
await self._cleanup()
|
||||
self._mark_disconnected()
|
||||
logger.info("[WecomCallback] Disconnected")
|
||||
|
||||
async def _cleanup(self) -> None:
|
||||
self._site = None
|
||||
if self._runner:
|
||||
await self._runner.cleanup()
|
||||
self._runner = None
|
||||
self._app = None
|
||||
if self._http_client:
|
||||
await self._http_client.aclose()
|
||||
self._http_client = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Outbound: proactive send via access-token API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def send(
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
app = self._resolve_app_for_chat(chat_id)
|
||||
touser = chat_id.split(":", 1)[1] if ":" in chat_id else chat_id
|
||||
try:
|
||||
token = await self._get_access_token(app)
|
||||
payload = {
|
||||
"touser": touser,
|
||||
"msgtype": "text",
|
||||
"agentid": int(str(app.get("agent_id") or 0)),
|
||||
"text": {"content": content[:2048]},
|
||||
"safe": 0,
|
||||
}
|
||||
resp = await self._http_client.post(
|
||||
f"https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={token}",
|
||||
json=payload,
|
||||
)
|
||||
data = resp.json()
|
||||
if data.get("errcode") != 0:
|
||||
return SendResult(success=False, error=str(data))
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=str(data.get("msgid", "")),
|
||||
raw_response=data,
|
||||
)
|
||||
except Exception as exc:
|
||||
return SendResult(success=False, error=str(exc))
|
||||
|
||||
def _resolve_app_for_chat(self, chat_id: str) -> Dict[str, Any]:
|
||||
"""Pick the app associated with *chat_id*, falling back sensibly."""
|
||||
app_name = self._user_app_map.get(chat_id)
|
||||
if not app_name and ":" not in chat_id:
|
||||
# Legacy bare user_id — try to find a unique match.
|
||||
matching = [k for k in self._user_app_map if k.endswith(f":{chat_id}")]
|
||||
if len(matching) == 1:
|
||||
app_name = self._user_app_map.get(matching[0])
|
||||
app = self._get_app_by_name(app_name) if app_name else None
|
||||
return app or self._apps[0]
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
return {"name": chat_id, "type": "dm"}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Inbound: HTTP callback handlers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _handle_health(self, request: web.Request) -> web.Response:
|
||||
return web.json_response({"status": "ok", "platform": "wecom_callback"})
|
||||
|
||||
async def _handle_verify(self, request: web.Request) -> web.Response:
|
||||
"""GET endpoint — WeCom URL verification handshake."""
|
||||
msg_signature = request.query.get("msg_signature", "")
|
||||
timestamp = request.query.get("timestamp", "")
|
||||
nonce = request.query.get("nonce", "")
|
||||
echostr = request.query.get("echostr", "")
|
||||
for app in self._apps:
|
||||
try:
|
||||
crypt = self._crypt_for_app(app)
|
||||
plain = crypt.verify_url(msg_signature, timestamp, nonce, echostr)
|
||||
return web.Response(text=plain, content_type="text/plain")
|
||||
except Exception:
|
||||
continue
|
||||
return web.Response(status=403, text="signature verification failed")
|
||||
|
||||
async def _handle_callback(self, request: web.Request) -> web.Response:
|
||||
"""POST endpoint — receive an encrypted message callback."""
|
||||
msg_signature = request.query.get("msg_signature", "")
|
||||
timestamp = request.query.get("timestamp", "")
|
||||
nonce = request.query.get("nonce", "")
|
||||
body = await request.text()
|
||||
|
||||
for app in self._apps:
|
||||
try:
|
||||
decrypted = self._decrypt_request(
|
||||
app, body, msg_signature, timestamp, nonce,
|
||||
)
|
||||
event = self._build_event(app, decrypted)
|
||||
if event is not None:
|
||||
# Record which app this user belongs to.
|
||||
if event.source and event.source.user_id:
|
||||
map_key = self._user_app_key(
|
||||
str(app.get("corp_id") or ""), event.source.user_id,
|
||||
)
|
||||
self._user_app_map[map_key] = app["name"]
|
||||
await self._message_queue.put(event)
|
||||
# Immediately acknowledge — the agent's reply will arrive
|
||||
# later via the proactive message/send API.
|
||||
return web.Response(text="success", content_type="text/plain")
|
||||
except WeComCryptoError:
|
||||
continue
|
||||
except Exception:
|
||||
logger.exception("[WecomCallback] Error handling message")
|
||||
break
|
||||
return web.Response(status=400, text="invalid callback payload")
|
||||
|
||||
async def _poll_loop(self) -> None:
|
||||
"""Drain the message queue and dispatch to the gateway runner."""
|
||||
while True:
|
||||
event = await self._message_queue.get()
|
||||
try:
|
||||
task = asyncio.create_task(self.handle_message(event))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
except Exception:
|
||||
logger.exception("[WecomCallback] Failed to enqueue event")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# XML / crypto helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _decrypt_request(
|
||||
self, app: Dict[str, Any], body: str,
|
||||
msg_signature: str, timestamp: str, nonce: str,
|
||||
) -> str:
|
||||
root = ET.fromstring(body)
|
||||
encrypt = root.findtext("Encrypt", default="")
|
||||
crypt = self._crypt_for_app(app)
|
||||
return crypt.decrypt(msg_signature, timestamp, nonce, encrypt).decode("utf-8")
|
||||
|
||||
def _build_event(self, app: Dict[str, Any], xml_text: str) -> Optional[MessageEvent]:
|
||||
root = ET.fromstring(xml_text)
|
||||
msg_type = (root.findtext("MsgType") or "").lower()
|
||||
# Silently acknowledge lifecycle events.
|
||||
if msg_type == "event":
|
||||
event_name = (root.findtext("Event") or "").lower()
|
||||
if event_name in {"enter_agent", "subscribe"}:
|
||||
return None
|
||||
if msg_type not in {"text", "event"}:
|
||||
return None
|
||||
|
||||
user_id = root.findtext("FromUserName", default="")
|
||||
corp_id = root.findtext("ToUserName", default=app.get("corp_id", ""))
|
||||
scoped_chat_id = self._user_app_key(corp_id, user_id)
|
||||
content = root.findtext("Content", default="").strip()
|
||||
if not content and msg_type == "event":
|
||||
content = "/start"
|
||||
msg_id = (
|
||||
root.findtext("MsgId")
|
||||
or f"{user_id}:{root.findtext('CreateTime', default='0')}"
|
||||
)
|
||||
source = self.build_source(
|
||||
chat_id=scoped_chat_id,
|
||||
chat_name=user_id,
|
||||
chat_type="dm",
|
||||
user_id=user_id,
|
||||
user_name=user_id,
|
||||
)
|
||||
return MessageEvent(
|
||||
text=content,
|
||||
message_type=MessageType.TEXT,
|
||||
source=source,
|
||||
raw_message=xml_text,
|
||||
message_id=msg_id,
|
||||
)
|
||||
|
||||
def _crypt_for_app(self, app: Dict[str, Any]) -> WXBizMsgCrypt:
|
||||
return WXBizMsgCrypt(
|
||||
token=str(app.get("token") or ""),
|
||||
encoding_aes_key=str(app.get("encoding_aes_key") or ""),
|
||||
receive_id=str(app.get("corp_id") or ""),
|
||||
)
|
||||
|
||||
def _get_app_by_name(self, name: Optional[str]) -> Optional[Dict[str, Any]]:
|
||||
if not name:
|
||||
return None
|
||||
for app in self._apps:
|
||||
if app.get("name") == name:
|
||||
return app
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Access-token management
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _get_access_token(self, app: Dict[str, Any]) -> str:
|
||||
cached = self._access_tokens.get(app["name"])
|
||||
now = time.time()
|
||||
if cached and cached.get("expires_at", 0) > now + 60:
|
||||
return cached["token"]
|
||||
return await self._refresh_access_token(app)
|
||||
|
||||
async def _refresh_access_token(self, app: Dict[str, Any]) -> str:
|
||||
resp = await self._http_client.get(
|
||||
"https://qyapi.weixin.qq.com/cgi-bin/gettoken",
|
||||
params={
|
||||
"corpid": app.get("corp_id"),
|
||||
"corpsecret": app.get("corp_secret"),
|
||||
},
|
||||
)
|
||||
data = resp.json()
|
||||
if data.get("errcode") != 0:
|
||||
raise RuntimeError(f"WeCom token refresh failed: {data}")
|
||||
token = data["access_token"]
|
||||
expires_in = int(data.get("expires_in", ACCESS_TOKEN_TTL_SECONDS))
|
||||
self._access_tokens[app["name"]] = {
|
||||
"token": token,
|
||||
"expires_at": time.time() + expires_in,
|
||||
}
|
||||
logger.info(
|
||||
"[WecomCallback] Token refreshed for app '%s' (corp=%s), expires in %ss",
|
||||
app.get("name", "default"),
|
||||
app.get("corp_id", ""),
|
||||
expires_in,
|
||||
)
|
||||
return token
|
||||
@@ -0,0 +1,142 @@
|
||||
"""WeCom BizMsgCrypt-compatible AES-CBC encryption for callback mode.
|
||||
|
||||
Implements the same wire format as Tencent's official ``WXBizMsgCrypt``
|
||||
SDK so that WeCom can verify, encrypt, and decrypt callback payloads.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
import secrets
|
||||
import socket
|
||||
import struct
|
||||
from typing import Optional
|
||||
from xml.etree import ElementTree as ET
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
|
||||
|
||||
class WeComCryptoError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class SignatureError(WeComCryptoError):
|
||||
pass
|
||||
|
||||
|
||||
class DecryptError(WeComCryptoError):
|
||||
pass
|
||||
|
||||
|
||||
class EncryptError(WeComCryptoError):
|
||||
pass
|
||||
|
||||
|
||||
class PKCS7Encoder:
|
||||
block_size = 32
|
||||
|
||||
@classmethod
|
||||
def encode(cls, text: bytes) -> bytes:
|
||||
amount_to_pad = cls.block_size - (len(text) % cls.block_size)
|
||||
if amount_to_pad == 0:
|
||||
amount_to_pad = cls.block_size
|
||||
pad = bytes([amount_to_pad]) * amount_to_pad
|
||||
return text + pad
|
||||
|
||||
@classmethod
|
||||
def decode(cls, decrypted: bytes) -> bytes:
|
||||
if not decrypted:
|
||||
raise DecryptError("empty decrypted payload")
|
||||
pad = decrypted[-1]
|
||||
if pad < 1 or pad > cls.block_size:
|
||||
raise DecryptError("invalid PKCS7 padding")
|
||||
if decrypted[-pad:] != bytes([pad]) * pad:
|
||||
raise DecryptError("malformed PKCS7 padding")
|
||||
return decrypted[:-pad]
|
||||
|
||||
|
||||
def _sha1_signature(token: str, timestamp: str, nonce: str, encrypt: str) -> str:
|
||||
parts = sorted([token, timestamp, nonce, encrypt])
|
||||
return hashlib.sha1("".join(parts).encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
class WXBizMsgCrypt:
|
||||
"""Minimal WeCom callback crypto helper compatible with BizMsgCrypt semantics."""
|
||||
|
||||
def __init__(self, token: str, encoding_aes_key: str, receive_id: str):
|
||||
if not token:
|
||||
raise ValueError("token is required")
|
||||
if not encoding_aes_key:
|
||||
raise ValueError("encoding_aes_key is required")
|
||||
if len(encoding_aes_key) != 43:
|
||||
raise ValueError("encoding_aes_key must be 43 chars")
|
||||
if not receive_id:
|
||||
raise ValueError("receive_id is required")
|
||||
|
||||
self.token = token
|
||||
self.receive_id = receive_id
|
||||
self.key = base64.b64decode(encoding_aes_key + "=")
|
||||
self.iv = self.key[:16]
|
||||
|
||||
def verify_url(self, msg_signature: str, timestamp: str, nonce: str, echostr: str) -> str:
|
||||
plain = self.decrypt(msg_signature, timestamp, nonce, echostr)
|
||||
return plain.decode("utf-8")
|
||||
|
||||
def decrypt(self, msg_signature: str, timestamp: str, nonce: str, encrypt: str) -> bytes:
|
||||
expected = _sha1_signature(self.token, timestamp, nonce, encrypt)
|
||||
if expected != msg_signature:
|
||||
raise SignatureError("signature mismatch")
|
||||
try:
|
||||
cipher_text = base64.b64decode(encrypt)
|
||||
except Exception as exc:
|
||||
raise DecryptError(f"invalid base64 payload: {exc}") from exc
|
||||
try:
|
||||
cipher = Cipher(algorithms.AES(self.key), modes.CBC(self.iv), backend=default_backend())
|
||||
decryptor = cipher.decryptor()
|
||||
padded = decryptor.update(cipher_text) + decryptor.finalize()
|
||||
plain = PKCS7Encoder.decode(padded)
|
||||
content = plain[16:] # skip 16-byte random prefix
|
||||
xml_length = socket.ntohl(struct.unpack("I", content[:4])[0])
|
||||
xml_content = content[4:4 + xml_length]
|
||||
receive_id = content[4 + xml_length:].decode("utf-8")
|
||||
except WeComCryptoError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise DecryptError(f"decrypt failed: {exc}") from exc
|
||||
|
||||
if receive_id != self.receive_id:
|
||||
raise DecryptError("receive_id mismatch")
|
||||
return xml_content
|
||||
|
||||
def encrypt(self, plaintext: str, nonce: Optional[str] = None, timestamp: Optional[str] = None) -> str:
|
||||
nonce = nonce or self._random_nonce()
|
||||
timestamp = timestamp or str(int(__import__("time").time()))
|
||||
encrypt = self._encrypt_bytes(plaintext.encode("utf-8"))
|
||||
signature = _sha1_signature(self.token, timestamp, nonce, encrypt)
|
||||
root = ET.Element("xml")
|
||||
ET.SubElement(root, "Encrypt").text = encrypt
|
||||
ET.SubElement(root, "MsgSignature").text = signature
|
||||
ET.SubElement(root, "TimeStamp").text = timestamp
|
||||
ET.SubElement(root, "Nonce").text = nonce
|
||||
return ET.tostring(root, encoding="unicode")
|
||||
|
||||
def _encrypt_bytes(self, raw: bytes) -> str:
|
||||
try:
|
||||
random_prefix = os.urandom(16)
|
||||
msg_len = struct.pack("I", socket.htonl(len(raw)))
|
||||
payload = random_prefix + msg_len + raw + self.receive_id.encode("utf-8")
|
||||
padded = PKCS7Encoder.encode(payload)
|
||||
cipher = Cipher(algorithms.AES(self.key), modes.CBC(self.iv), backend=default_backend())
|
||||
encryptor = cipher.encryptor()
|
||||
encrypted = encryptor.update(padded) + encryptor.finalize()
|
||||
return base64.b64encode(encrypted).decode("utf-8")
|
||||
except Exception as exc:
|
||||
raise EncryptError(f"encrypt failed: {exc}") from exc
|
||||
|
||||
@staticmethod
|
||||
def _random_nonce(length: int = 10) -> str:
|
||||
alphabet = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
return "".join(secrets.choice(alphabet) for _ in range(length))
|
||||
+125
-60
@@ -53,6 +53,7 @@ except ImportError: # pragma: no cover - dependency gate
|
||||
CRYPTO_AVAILABLE = False
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.helpers import MessageDeduplicator
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
@@ -63,6 +64,7 @@ from gateway.platforms.base import (
|
||||
cache_image_from_bytes,
|
||||
)
|
||||
from hermes_constants import get_hermes_home
|
||||
from utils import atomic_json_write
|
||||
|
||||
ILINK_BASE_URL = "https://ilinkai.weixin.qq.com"
|
||||
WEIXIN_CDN_BASE_URL = "https://novac2c.cdn.weixin.qq.com/c2c"
|
||||
@@ -206,7 +208,7 @@ def save_weixin_account(
|
||||
"saved_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
||||
}
|
||||
path = _account_file(hermes_home, account_id)
|
||||
path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
||||
atomic_json_write(path, payload)
|
||||
try:
|
||||
path.chmod(0o600)
|
||||
except OSError:
|
||||
@@ -269,7 +271,7 @@ class ContextTokenStore:
|
||||
if key.startswith(prefix)
|
||||
}
|
||||
try:
|
||||
self._path(account_id).write_text(json.dumps(payload), encoding="utf-8")
|
||||
atomic_json_write(self._path(account_id), payload)
|
||||
except Exception as exc:
|
||||
logger.warning("weixin: failed to persist context tokens for %s: %s", _safe_id(account_id), exc)
|
||||
|
||||
@@ -755,23 +757,58 @@ def _pack_markdown_blocks_for_weixin(content: str, max_length: int) -> List[str]
|
||||
return packed
|
||||
|
||||
|
||||
def _split_text_for_weixin_delivery(content: str, max_length: int) -> List[str]:
|
||||
def _split_text_for_weixin_delivery(
|
||||
content: str, max_length: int, split_per_line: bool = False,
|
||||
) -> List[str]:
|
||||
"""Split content into sequential Weixin messages.
|
||||
|
||||
Prefer one message per top-level line/markdown unit when the author used
|
||||
explicit line breaks. Oversized units fall back to block-aware packing so
|
||||
long code fences still split safely.
|
||||
"""
|
||||
if len(content) <= max_length and "\n" not in content:
|
||||
return [content]
|
||||
*compact* (default): Keep everything in a single message whenever it fits
|
||||
within the platform limit, even when the author used explicit line breaks.
|
||||
Only fall back to block-aware packing when the payload exceeds
|
||||
``max_length``.
|
||||
|
||||
chunks: List[str] = []
|
||||
for unit in _split_delivery_units_for_weixin(content):
|
||||
if len(unit) <= max_length:
|
||||
chunks.append(unit)
|
||||
continue
|
||||
chunks.extend(_pack_markdown_blocks_for_weixin(unit, max_length))
|
||||
return chunks or [content]
|
||||
*per_line* (``split_per_line=True``): Legacy behavior — top-level line
|
||||
breaks become separate chat messages; oversized units still use
|
||||
block-aware packing.
|
||||
|
||||
The active mode is controlled via ``config.yaml`` ->
|
||||
``platforms.weixin.extra.split_multiline_messages`` (``true`` / ``false``)
|
||||
or the env var ``WEIXIN_SPLIT_MULTILINE_MESSAGES``.
|
||||
"""
|
||||
if split_per_line:
|
||||
# Legacy: one message per top-level delivery unit.
|
||||
if len(content) <= max_length and "\n" not in content:
|
||||
return [content]
|
||||
chunks: List[str] = []
|
||||
for unit in _split_delivery_units_for_weixin(content):
|
||||
if len(unit) <= max_length:
|
||||
chunks.append(unit)
|
||||
continue
|
||||
chunks.extend(_pack_markdown_blocks_for_weixin(unit, max_length))
|
||||
return chunks or [content]
|
||||
|
||||
# Compact (default): single message when under the limit.
|
||||
if len(content) <= max_length:
|
||||
return [content]
|
||||
return _pack_markdown_blocks_for_weixin(content, max_length) or [content]
|
||||
|
||||
|
||||
def _coerce_bool(value: Any, default: bool = True) -> bool:
|
||||
"""Coerce a config value to bool, tolerating strings like ``"true"``."""
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, (int, float)):
|
||||
return bool(value)
|
||||
text = str(value).strip().lower()
|
||||
if not text:
|
||||
return default
|
||||
if text in {"1", "true", "yes", "on"}:
|
||||
return True
|
||||
if text in {"0", "false", "no", "off"}:
|
||||
return False
|
||||
return default
|
||||
|
||||
|
||||
def _extract_text(item_list: List[Dict[str, Any]]) -> str:
|
||||
@@ -833,7 +870,7 @@ def _load_sync_buf(hermes_home: str, account_id: str) -> str:
|
||||
|
||||
def _save_sync_buf(hermes_home: str, account_id: str, sync_buf: str) -> None:
|
||||
path = _sync_buf_path(hermes_home, account_id)
|
||||
path.write_text(json.dumps({"get_updates_buf": sync_buf}), encoding="utf-8")
|
||||
atomic_json_write(path, {"get_updates_buf": sync_buf})
|
||||
|
||||
|
||||
async def qr_login(
|
||||
@@ -972,8 +1009,7 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
self._typing_cache = TypingTicketCache()
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._poll_task: Optional[asyncio.Task] = None
|
||||
self._seen_messages: Dict[str, float] = {}
|
||||
self._token_lock_identity: Optional[str] = None
|
||||
self._dedup = MessageDeduplicator(ttl_seconds=MESSAGE_DEDUP_TTL_SECONDS)
|
||||
|
||||
self._account_id = str(extra.get("account_id") or os.getenv("WEIXIN_ACCOUNT_ID", "")).strip()
|
||||
self._token = str(config.token or extra.get("token") or os.getenv("WEIXIN_TOKEN", "")).strip()
|
||||
@@ -981,6 +1017,16 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
self._cdn_base_url = str(
|
||||
extra.get("cdn_base_url") or os.getenv("WEIXIN_CDN_BASE_URL", WEIXIN_CDN_BASE_URL)
|
||||
).strip().rstrip("/")
|
||||
self._send_chunk_delay_seconds = float(
|
||||
extra.get("send_chunk_delay_seconds") or os.getenv("WEIXIN_SEND_CHUNK_DELAY_SECONDS", "0.35")
|
||||
)
|
||||
self._send_chunk_retries = int(
|
||||
extra.get("send_chunk_retries") or os.getenv("WEIXIN_SEND_CHUNK_RETRIES", "2")
|
||||
)
|
||||
self._send_chunk_retry_delay_seconds = float(
|
||||
extra.get("send_chunk_retry_delay_seconds")
|
||||
or os.getenv("WEIXIN_SEND_CHUNK_RETRY_DELAY_SECONDS", "1.0")
|
||||
)
|
||||
self._dm_policy = str(extra.get("dm_policy") or os.getenv("WEIXIN_DM_POLICY", "open")).strip().lower()
|
||||
self._group_policy = str(extra.get("group_policy") or os.getenv("WEIXIN_GROUP_POLICY", "disabled")).strip().lower()
|
||||
allow_from = extra.get("allow_from")
|
||||
@@ -991,6 +1037,11 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
group_allow_from = os.getenv("WEIXIN_GROUP_ALLOWED_USERS", "")
|
||||
self._allow_from = self._coerce_list(allow_from)
|
||||
self._group_allow_from = self._coerce_list(group_allow_from)
|
||||
self._split_multiline_messages = _coerce_bool(
|
||||
extra.get("split_multiline_messages")
|
||||
or os.getenv("WEIXIN_SPLIT_MULTILINE_MESSAGES"),
|
||||
default=False,
|
||||
)
|
||||
|
||||
if self._account_id and not self._token:
|
||||
persisted = load_weixin_account(hermes_home, self._account_id)
|
||||
@@ -1026,23 +1077,7 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
return False
|
||||
|
||||
try:
|
||||
from gateway.status import acquire_scoped_lock
|
||||
|
||||
self._token_lock_identity = self._token
|
||||
acquired, existing = acquire_scoped_lock(
|
||||
"weixin-bot-token",
|
||||
self._token_lock_identity,
|
||||
metadata={"platform": self.platform.value},
|
||||
)
|
||||
if not acquired:
|
||||
owner_pid = existing.get("pid") if isinstance(existing, dict) else None
|
||||
message = (
|
||||
"Another local Hermes gateway is already using this Weixin token"
|
||||
+ (f" (PID {owner_pid})." if owner_pid else ".")
|
||||
+ " Stop the other gateway before starting a second Weixin poller."
|
||||
)
|
||||
logger.error("[%s] %s", self.name, message)
|
||||
self._set_fatal_error("weixin_token_lock", message, retryable=False)
|
||||
if not self._acquire_platform_lock('weixin-bot-token', self._token, 'Weixin bot token'):
|
||||
return False
|
||||
except Exception as exc:
|
||||
logger.debug("[%s] Token lock unavailable (non-fatal): %s", self.name, exc)
|
||||
@@ -1066,12 +1101,7 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
if self._token_lock_identity:
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
release_scoped_lock("weixin-bot-token", self._token_lock_identity)
|
||||
except Exception as exc:
|
||||
logger.warning("[%s] Error releasing Weixin token lock: %s", self.name, exc, exc_info=True)
|
||||
self._release_platform_lock()
|
||||
self._mark_disconnected()
|
||||
logger.info("[%s] Disconnected", self.name)
|
||||
|
||||
@@ -1149,16 +1179,8 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
return
|
||||
|
||||
message_id = str(message.get("message_id") or "").strip()
|
||||
if message_id:
|
||||
now = time.time()
|
||||
self._seen_messages = {
|
||||
key: value
|
||||
for key, value in self._seen_messages.items()
|
||||
if now - value < MESSAGE_DEDUP_TTL_SECONDS
|
||||
}
|
||||
if message_id in self._seen_messages:
|
||||
return
|
||||
self._seen_messages[message_id] = now
|
||||
if message_id and self._dedup.is_duplicate(message_id):
|
||||
return
|
||||
|
||||
chat_type, effective_chat_id = _guess_chat_type(message, self._account_id)
|
||||
if chat_type == "group":
|
||||
@@ -1330,7 +1352,50 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
logger.debug("[%s] getConfig failed for %s: %s", self.name, _safe_id(user_id), exc)
|
||||
|
||||
def _split_text(self, content: str) -> List[str]:
|
||||
return _split_text_for_weixin_delivery(content, self.MAX_MESSAGE_LENGTH)
|
||||
return _split_text_for_weixin_delivery(
|
||||
content, self.MAX_MESSAGE_LENGTH, self._split_multiline_messages,
|
||||
)
|
||||
|
||||
async def _send_text_chunk(
|
||||
self,
|
||||
*,
|
||||
chat_id: str,
|
||||
chunk: str,
|
||||
context_token: Optional[str],
|
||||
client_id: str,
|
||||
) -> None:
|
||||
"""Send a single text chunk with per-chunk retry and backoff."""
|
||||
last_error: Optional[Exception] = None
|
||||
for attempt in range(self._send_chunk_retries + 1):
|
||||
try:
|
||||
await _send_message(
|
||||
self._session,
|
||||
base_url=self._base_url,
|
||||
token=self._token,
|
||||
to=chat_id,
|
||||
text=chunk,
|
||||
context_token=context_token,
|
||||
client_id=client_id,
|
||||
)
|
||||
return
|
||||
except Exception as exc:
|
||||
last_error = exc
|
||||
if attempt >= self._send_chunk_retries:
|
||||
break
|
||||
wait = self._send_chunk_retry_delay_seconds * (attempt + 1)
|
||||
logger.warning(
|
||||
"[%s] send chunk failed to=%s attempt=%d/%d, retrying in %.2fs: %s",
|
||||
self.name,
|
||||
_safe_id(chat_id),
|
||||
attempt + 1,
|
||||
self._send_chunk_retries + 1,
|
||||
wait,
|
||||
exc,
|
||||
)
|
||||
if wait > 0:
|
||||
await asyncio.sleep(wait)
|
||||
assert last_error is not None
|
||||
raise last_error
|
||||
|
||||
async def send(
|
||||
self,
|
||||
@@ -1344,18 +1409,18 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
context_token = self._token_store.get(self._account_id, chat_id)
|
||||
last_message_id: Optional[str] = None
|
||||
try:
|
||||
for chunk in self._split_text(self.format_message(content)):
|
||||
chunks = self._split_text(self.format_message(content))
|
||||
for idx, chunk in enumerate(chunks):
|
||||
client_id = f"hermes-weixin-{uuid.uuid4().hex}"
|
||||
await _send_message(
|
||||
self._session,
|
||||
base_url=self._base_url,
|
||||
token=self._token,
|
||||
to=chat_id,
|
||||
text=chunk,
|
||||
await self._send_text_chunk(
|
||||
chat_id=chat_id,
|
||||
chunk=chunk,
|
||||
context_token=context_token,
|
||||
client_id=client_id,
|
||||
)
|
||||
last_message_id = client_id
|
||||
if idx < len(chunks) - 1 and self._send_chunk_delay_seconds > 0:
|
||||
await asyncio.sleep(self._send_chunk_delay_seconds)
|
||||
return SendResult(success=True, message_id=last_message_id)
|
||||
except Exception as exc:
|
||||
logger.error("[%s] send failed to=%s: %s", self.name, _safe_id(chat_id), exc)
|
||||
|
||||
@@ -145,7 +145,6 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
self._bridge_log: Optional[Path] = None
|
||||
self._poll_task: Optional[asyncio.Task] = None
|
||||
self._http_session: Optional["aiohttp.ClientSession"] = None
|
||||
self._session_lock_identity: Optional[str] = None
|
||||
|
||||
def _whatsapp_require_mention(self) -> bool:
|
||||
configured = self.config.extra.get("require_mention")
|
||||
@@ -290,23 +289,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
|
||||
# Acquire scoped lock to prevent duplicate sessions
|
||||
try:
|
||||
from gateway.status import acquire_scoped_lock
|
||||
|
||||
self._session_lock_identity = str(self._session_path)
|
||||
acquired, existing = acquire_scoped_lock(
|
||||
"whatsapp-session",
|
||||
self._session_lock_identity,
|
||||
metadata={"platform": self.platform.value},
|
||||
)
|
||||
if not acquired:
|
||||
owner_pid = existing.get("pid") if isinstance(existing, dict) else None
|
||||
message = (
|
||||
"Another local Hermes gateway is already using this WhatsApp session"
|
||||
+ (f" (PID {owner_pid})." if owner_pid else ".")
|
||||
+ " Stop the other gateway before starting a second WhatsApp bridge."
|
||||
)
|
||||
logger.error("[%s] %s", self.name, message)
|
||||
self._set_fatal_error("whatsapp_session_lock", message, retryable=False)
|
||||
if not self._acquire_platform_lock('whatsapp-session', str(self._session_path), 'WhatsApp session'):
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning("[%s] Could not acquire session lock (non-fatal): %s", self.name, e)
|
||||
@@ -468,12 +451,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
if self._session_lock_identity:
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
release_scoped_lock("whatsapp-session", self._session_lock_identity)
|
||||
except Exception:
|
||||
pass
|
||||
self._release_platform_lock()
|
||||
logger.error("[%s] Failed to start bridge: %s", self.name, e, exc_info=True)
|
||||
self._close_bridge_log()
|
||||
return False
|
||||
@@ -546,17 +524,11 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
await self._http_session.close()
|
||||
self._http_session = None
|
||||
|
||||
if self._session_lock_identity:
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
release_scoped_lock("whatsapp-session", self._session_lock_identity)
|
||||
except Exception as e:
|
||||
logger.warning("[%s] Error releasing WhatsApp session lock: %s", self.name, e, exc_info=True)
|
||||
self._release_platform_lock()
|
||||
|
||||
self._mark_disconnected()
|
||||
self._bridge_process = None
|
||||
self._close_bridge_log()
|
||||
self._session_lock_identity = None
|
||||
print(f"[{self.name}] Disconnected")
|
||||
|
||||
async def send(
|
||||
|
||||
+591
-307
File diff suppressed because it is too large
Load Diff
+51
-1
@@ -368,6 +368,11 @@ class SessionEntry:
|
||||
# survives gateway restarts (the old in-memory _pre_flushed_sessions
|
||||
# set was lost on restart, causing redundant re-flushes).
|
||||
memory_flushed: bool = False
|
||||
|
||||
# When True the next call to get_or_create_session() will auto-reset
|
||||
# this session (create a new session_id) so the user starts fresh.
|
||||
# Set by /stop to break stuck-resume loops (#7536).
|
||||
suspended: bool = False
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
result = {
|
||||
@@ -387,6 +392,7 @@ class SessionEntry:
|
||||
"estimated_cost_usd": self.estimated_cost_usd,
|
||||
"cost_status": self.cost_status,
|
||||
"memory_flushed": self.memory_flushed,
|
||||
"suspended": self.suspended,
|
||||
}
|
||||
if self.origin:
|
||||
result["origin"] = self.origin.to_dict()
|
||||
@@ -423,6 +429,7 @@ class SessionEntry:
|
||||
estimated_cost_usd=data.get("estimated_cost_usd", 0.0),
|
||||
cost_status=data.get("cost_status", "unknown"),
|
||||
memory_flushed=data.get("memory_flushed", False),
|
||||
suspended=data.get("suspended", False),
|
||||
)
|
||||
|
||||
|
||||
@@ -698,7 +705,12 @@ class SessionStore:
|
||||
if session_key in self._entries and not force_new:
|
||||
entry = self._entries[session_key]
|
||||
|
||||
reset_reason = self._should_reset(entry, source)
|
||||
# Auto-reset sessions marked as suspended (e.g. after /stop
|
||||
# broke a stuck loop — #7536).
|
||||
if entry.suspended:
|
||||
reset_reason = "suspended"
|
||||
else:
|
||||
reset_reason = self._should_reset(entry, source)
|
||||
if not reset_reason:
|
||||
entry.updated_at = now
|
||||
self._save()
|
||||
@@ -771,6 +783,44 @@ class SessionStore:
|
||||
entry.last_prompt_tokens = last_prompt_tokens
|
||||
self._save()
|
||||
|
||||
def suspend_session(self, session_key: str) -> bool:
|
||||
"""Mark a session as suspended so it auto-resets on next access.
|
||||
|
||||
Used by ``/stop`` to prevent stuck sessions from being resumed
|
||||
after a gateway restart (#7536). Returns True if the session
|
||||
existed and was marked.
|
||||
"""
|
||||
with self._lock:
|
||||
self._ensure_loaded_locked()
|
||||
if session_key in self._entries:
|
||||
self._entries[session_key].suspended = True
|
||||
self._save()
|
||||
return True
|
||||
return False
|
||||
|
||||
def suspend_recently_active(self, max_age_seconds: int = 120) -> int:
|
||||
"""Mark recently-active sessions as suspended.
|
||||
|
||||
Called on gateway startup to prevent sessions that were likely
|
||||
in-flight when the gateway last exited from being blindly resumed
|
||||
(#7536). Only suspends sessions updated within *max_age_seconds*
|
||||
to avoid resetting long-idle sessions that are harmless to resume.
|
||||
Returns the number of sessions that were suspended.
|
||||
"""
|
||||
import time as _time
|
||||
|
||||
cutoff = _time.time() - max_age_seconds
|
||||
count = 0
|
||||
with self._lock:
|
||||
self._ensure_loaded_locked()
|
||||
for entry in self._entries.values():
|
||||
if not entry.suspended and entry.updated_at >= cutoff:
|
||||
entry.suspended = True
|
||||
count += 1
|
||||
if count:
|
||||
self._save()
|
||||
return count
|
||||
|
||||
def reset_session(self, session_key: str) -> Optional[SessionEntry]:
|
||||
"""Force reset a session, creating a new session ID."""
|
||||
db_end_session_id = None
|
||||
|
||||
@@ -46,12 +46,18 @@ _SESSION_PLATFORM: ContextVar[str] = ContextVar("HERMES_SESSION_PLATFORM", defau
|
||||
_SESSION_CHAT_ID: ContextVar[str] = ContextVar("HERMES_SESSION_CHAT_ID", default="")
|
||||
_SESSION_CHAT_NAME: ContextVar[str] = ContextVar("HERMES_SESSION_CHAT_NAME", default="")
|
||||
_SESSION_THREAD_ID: ContextVar[str] = ContextVar("HERMES_SESSION_THREAD_ID", default="")
|
||||
_SESSION_USER_ID: ContextVar[str] = ContextVar("HERMES_SESSION_USER_ID", default="")
|
||||
_SESSION_USER_NAME: ContextVar[str] = ContextVar("HERMES_SESSION_USER_NAME", default="")
|
||||
_SESSION_KEY: ContextVar[str] = ContextVar("HERMES_SESSION_KEY", default="")
|
||||
|
||||
_VAR_MAP = {
|
||||
"HERMES_SESSION_PLATFORM": _SESSION_PLATFORM,
|
||||
"HERMES_SESSION_CHAT_ID": _SESSION_CHAT_ID,
|
||||
"HERMES_SESSION_CHAT_NAME": _SESSION_CHAT_NAME,
|
||||
"HERMES_SESSION_THREAD_ID": _SESSION_THREAD_ID,
|
||||
"HERMES_SESSION_USER_ID": _SESSION_USER_ID,
|
||||
"HERMES_SESSION_USER_NAME": _SESSION_USER_NAME,
|
||||
"HERMES_SESSION_KEY": _SESSION_KEY,
|
||||
}
|
||||
|
||||
|
||||
@@ -60,6 +66,9 @@ def set_session_vars(
|
||||
chat_id: str = "",
|
||||
chat_name: str = "",
|
||||
thread_id: str = "",
|
||||
user_id: str = "",
|
||||
user_name: str = "",
|
||||
session_key: str = "",
|
||||
) -> list:
|
||||
"""Set all session context variables and return reset tokens.
|
||||
|
||||
@@ -74,6 +83,9 @@ def set_session_vars(
|
||||
_SESSION_CHAT_ID.set(chat_id),
|
||||
_SESSION_CHAT_NAME.set(chat_name),
|
||||
_SESSION_THREAD_ID.set(thread_id),
|
||||
_SESSION_USER_ID.set(user_id),
|
||||
_SESSION_USER_NAME.set(user_name),
|
||||
_SESSION_KEY.set(session_key),
|
||||
]
|
||||
return tokens
|
||||
|
||||
@@ -87,6 +99,9 @@ def clear_session_vars(tokens: list) -> None:
|
||||
_SESSION_CHAT_ID,
|
||||
_SESSION_CHAT_NAME,
|
||||
_SESSION_THREAD_ID,
|
||||
_SESSION_USER_ID,
|
||||
_SESSION_USER_NAME,
|
||||
_SESSION_KEY,
|
||||
]
|
||||
for var, token in zip(vars_in_order, tokens):
|
||||
var.reset(token)
|
||||
|
||||
+202
-52
@@ -32,11 +32,15 @@ _DONE = object()
|
||||
# new one so that subsequent text appears below tool progress messages.
|
||||
_NEW_SEGMENT = object()
|
||||
|
||||
# Queue marker for a completed assistant commentary message emitted between
|
||||
# API/tool iterations (for example: "I'll inspect the repo first.").
|
||||
_COMMENTARY = object()
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamConsumerConfig:
|
||||
"""Runtime config for a single stream consumer instance."""
|
||||
edit_interval: float = 0.3
|
||||
edit_interval: float = 1.0
|
||||
buffer_threshold: int = 40
|
||||
cursor: str = " ▉"
|
||||
|
||||
@@ -56,6 +60,10 @@ class GatewayStreamConsumer:
|
||||
await task # wait for final edit
|
||||
"""
|
||||
|
||||
# After this many consecutive flood-control failures, permanently disable
|
||||
# progressive edits for the remainder of the stream.
|
||||
_MAX_FLOOD_STRIKES = 3
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
adapter: Any,
|
||||
@@ -71,18 +79,43 @@ class GatewayStreamConsumer:
|
||||
self._accumulated = ""
|
||||
self._message_id: Optional[str] = None
|
||||
self._already_sent = False
|
||||
self._edit_supported = True # Disabled on first edit failure (Signal/Email/HA)
|
||||
self._edit_supported = True # Disabled when progressive edits are no longer usable
|
||||
self._last_edit_time = 0.0
|
||||
self._last_sent_text = "" # Track last-sent text to skip redundant edits
|
||||
self._fallback_final_send = False
|
||||
self._fallback_prefix = ""
|
||||
self._flood_strikes = 0 # Consecutive flood-control edit failures
|
||||
self._current_edit_interval = self.cfg.edit_interval # Adaptive backoff
|
||||
self._final_response_sent = False
|
||||
|
||||
@property
|
||||
def already_sent(self) -> bool:
|
||||
"""True if at least one message was sent/edited — signals the base
|
||||
adapter to skip re-sending the final response."""
|
||||
"""True if at least one message was sent or edited during the run."""
|
||||
return self._already_sent
|
||||
|
||||
@property
|
||||
def final_response_sent(self) -> bool:
|
||||
"""True when the stream consumer delivered the final assistant reply."""
|
||||
return self._final_response_sent
|
||||
|
||||
def on_segment_break(self) -> None:
|
||||
"""Finalize the current stream segment and start a fresh message."""
|
||||
self._queue.put(_NEW_SEGMENT)
|
||||
|
||||
def on_commentary(self, text: str) -> None:
|
||||
"""Queue a completed interim assistant commentary message."""
|
||||
if text:
|
||||
self._queue.put((_COMMENTARY, text))
|
||||
|
||||
def _reset_segment_state(self, *, preserve_no_edit: bool = False) -> None:
|
||||
if preserve_no_edit and self._message_id == "__no_edit__":
|
||||
return
|
||||
self._message_id = None
|
||||
self._accumulated = ""
|
||||
self._last_sent_text = ""
|
||||
self._fallback_final_send = False
|
||||
self._fallback_prefix = ""
|
||||
|
||||
def on_delta(self, text: str) -> None:
|
||||
"""Thread-safe callback — called from the agent's worker thread.
|
||||
|
||||
@@ -93,7 +126,7 @@ class GatewayStreamConsumer:
|
||||
if text:
|
||||
self._queue.put(text)
|
||||
elif text is None:
|
||||
self._queue.put(_NEW_SEGMENT)
|
||||
self.on_segment_break()
|
||||
|
||||
def finish(self) -> None:
|
||||
"""Signal that the stream is complete."""
|
||||
@@ -110,6 +143,7 @@ class GatewayStreamConsumer:
|
||||
# Drain all available items from the queue
|
||||
got_done = False
|
||||
got_segment_break = False
|
||||
commentary_text = None
|
||||
while True:
|
||||
try:
|
||||
item = self._queue.get_nowait()
|
||||
@@ -119,6 +153,9 @@ class GatewayStreamConsumer:
|
||||
if item is _NEW_SEGMENT:
|
||||
got_segment_break = True
|
||||
break
|
||||
if isinstance(item, tuple) and len(item) == 2 and item[0] is _COMMENTARY:
|
||||
commentary_text = item[1]
|
||||
break
|
||||
self._accumulated += item
|
||||
except queue.Empty:
|
||||
break
|
||||
@@ -129,11 +166,13 @@ class GatewayStreamConsumer:
|
||||
should_edit = (
|
||||
got_done
|
||||
or got_segment_break
|
||||
or (elapsed >= self.cfg.edit_interval
|
||||
or commentary_text is not None
|
||||
or (elapsed >= self._current_edit_interval
|
||||
and self._accumulated)
|
||||
or len(self._accumulated) >= self.cfg.buffer_threshold
|
||||
)
|
||||
|
||||
current_update_visible = False
|
||||
if should_edit and self._accumulated:
|
||||
# Split overflow: if accumulated text exceeds the platform
|
||||
# limit, split into properly sized chunks.
|
||||
@@ -155,6 +194,7 @@ class GatewayStreamConsumer:
|
||||
self._last_sent_text = ""
|
||||
self._last_edit_time = time.monotonic()
|
||||
if got_done:
|
||||
self._final_response_sent = self._already_sent
|
||||
return
|
||||
if got_segment_break:
|
||||
self._message_id = None
|
||||
@@ -173,22 +213,23 @@ class GatewayStreamConsumer:
|
||||
if split_at < _safe_limit // 2:
|
||||
split_at = _safe_limit
|
||||
chunk = self._accumulated[:split_at]
|
||||
await self._send_or_edit(chunk)
|
||||
if self._fallback_final_send:
|
||||
# Edit failed while attempting to split an oversized
|
||||
# message. Keep the full accumulated text intact so
|
||||
# the fallback final-send path can deliver the
|
||||
# remaining continuation without dropping content.
|
||||
ok = await self._send_or_edit(chunk)
|
||||
if self._fallback_final_send or not ok:
|
||||
# Edit failed (or backed off due to flood control)
|
||||
# while attempting to split an oversized message.
|
||||
# Keep the full accumulated text intact so the
|
||||
# fallback final-send path can deliver the remaining
|
||||
# continuation without dropping content.
|
||||
break
|
||||
self._accumulated = self._accumulated[split_at:].lstrip("\n")
|
||||
self._message_id = None
|
||||
self._last_sent_text = ""
|
||||
|
||||
display_text = self._accumulated
|
||||
if not got_done and not got_segment_break:
|
||||
if not got_done and not got_segment_break and commentary_text is None:
|
||||
display_text += self.cfg.cursor
|
||||
|
||||
await self._send_or_edit(display_text)
|
||||
current_update_visible = await self._send_or_edit(display_text)
|
||||
self._last_edit_time = time.monotonic()
|
||||
|
||||
if got_done:
|
||||
@@ -199,12 +240,20 @@ class GatewayStreamConsumer:
|
||||
if self._accumulated:
|
||||
if self._fallback_final_send:
|
||||
await self._send_fallback_final(self._accumulated)
|
||||
elif current_update_visible:
|
||||
self._final_response_sent = True
|
||||
elif self._message_id:
|
||||
await self._send_or_edit(self._accumulated)
|
||||
self._final_response_sent = await self._send_or_edit(self._accumulated)
|
||||
elif not self._already_sent:
|
||||
await self._send_or_edit(self._accumulated)
|
||||
self._final_response_sent = await self._send_or_edit(self._accumulated)
|
||||
return
|
||||
|
||||
if commentary_text is not None:
|
||||
self._reset_segment_state()
|
||||
await self._send_commentary(commentary_text)
|
||||
self._last_edit_time = time.monotonic()
|
||||
self._reset_segment_state()
|
||||
|
||||
# Tool boundary: reset message state so the next text chunk
|
||||
# creates a fresh message below any tool-progress messages.
|
||||
#
|
||||
@@ -213,17 +262,14 @@ class GatewayStreamConsumer:
|
||||
# github_comment delivery). Resetting to None would re-enter
|
||||
# the "first send" path on every tool boundary and post one
|
||||
# platform message per tool call — that is what caused 155
|
||||
# comments under a single PR. Instead, keep all state so the
|
||||
# full continuation is delivered once via _send_fallback_final.
|
||||
# comments under a single PR. Instead, preserve the sentinel
|
||||
# so the full continuation is delivered once via
|
||||
# _send_fallback_final.
|
||||
# (When editing fails mid-stream due to flood control the id is
|
||||
# a real string like "msg_1", not "__no_edit__", so that case
|
||||
# still resets and creates a fresh segment as intended.)
|
||||
if got_segment_break and self._message_id != "__no_edit__":
|
||||
self._message_id = None
|
||||
self._accumulated = ""
|
||||
self._last_sent_text = ""
|
||||
self._fallback_final_send = False
|
||||
self._fallback_prefix = ""
|
||||
if got_segment_break:
|
||||
self._reset_segment_state(preserve_no_edit=True)
|
||||
|
||||
await asyncio.sleep(0.05) # Small yield to not busy-loop
|
||||
|
||||
@@ -322,13 +368,17 @@ class GatewayStreamConsumer:
|
||||
return chunks
|
||||
|
||||
async def _send_fallback_final(self, text: str) -> None:
|
||||
"""Send the final continuation after streaming edits stop working."""
|
||||
"""Send the final continuation after streaming edits stop working.
|
||||
|
||||
Retries each chunk once on flood-control failures with a short delay.
|
||||
"""
|
||||
final_text = self._clean_for_display(text)
|
||||
continuation = self._continuation_text(final_text)
|
||||
self._fallback_final_send = False
|
||||
if not continuation.strip():
|
||||
# Nothing new to send — the visible partial already matches final text.
|
||||
self._already_sent = True
|
||||
self._final_response_sent = True
|
||||
return
|
||||
|
||||
raw_limit = getattr(self.adapter, "MAX_MESSAGE_LENGTH", 4096)
|
||||
@@ -339,17 +389,31 @@ class GatewayStreamConsumer:
|
||||
last_successful_chunk = ""
|
||||
sent_any_chunk = False
|
||||
for chunk in chunks:
|
||||
result = await self.adapter.send(
|
||||
chat_id=self.chat_id,
|
||||
content=chunk,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
if not result.success:
|
||||
# Try sending with one retry on flood-control errors.
|
||||
result = None
|
||||
for attempt in range(2):
|
||||
result = await self.adapter.send(
|
||||
chat_id=self.chat_id,
|
||||
content=chunk,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
if result.success:
|
||||
break
|
||||
if attempt == 0 and self._is_flood_error(result):
|
||||
logger.debug(
|
||||
"Flood control on fallback send, retrying in 3s"
|
||||
)
|
||||
await asyncio.sleep(3.0)
|
||||
else:
|
||||
break # non-flood error or second attempt failed
|
||||
|
||||
if not result or not result.success:
|
||||
if sent_any_chunk:
|
||||
# Some continuation text already reached the user. Suppress
|
||||
# the base gateway final-send path so we don't resend the
|
||||
# full response and create another duplicate.
|
||||
self._already_sent = True
|
||||
self._final_response_sent = True
|
||||
self._message_id = last_message_id
|
||||
self._last_sent_text = last_successful_chunk
|
||||
self._fallback_prefix = ""
|
||||
@@ -367,23 +431,74 @@ class GatewayStreamConsumer:
|
||||
|
||||
self._message_id = last_message_id
|
||||
self._already_sent = True
|
||||
self._final_response_sent = True
|
||||
self._last_sent_text = chunks[-1]
|
||||
self._fallback_prefix = ""
|
||||
|
||||
async def _send_or_edit(self, text: str) -> None:
|
||||
"""Send or edit the streaming message."""
|
||||
def _is_flood_error(self, result) -> bool:
|
||||
"""Check if a SendResult failure is due to flood control / rate limiting."""
|
||||
err = getattr(result, "error", "") or ""
|
||||
err_lower = err.lower()
|
||||
return "flood" in err_lower or "retry after" in err_lower or "rate" in err_lower
|
||||
|
||||
async def _try_strip_cursor(self) -> None:
|
||||
"""Best-effort edit to remove the cursor from the last visible message.
|
||||
|
||||
Called when entering fallback mode so the user doesn't see a stuck
|
||||
cursor (▉) in the partial message.
|
||||
"""
|
||||
if not self._message_id or self._message_id == "__no_edit__":
|
||||
return
|
||||
prefix = self._visible_prefix()
|
||||
if not prefix or not prefix.strip():
|
||||
return
|
||||
try:
|
||||
await self.adapter.edit_message(
|
||||
chat_id=self.chat_id,
|
||||
message_id=self._message_id,
|
||||
content=prefix,
|
||||
)
|
||||
self._last_sent_text = prefix
|
||||
except Exception:
|
||||
pass # best-effort — don't let this block the fallback path
|
||||
|
||||
async def _send_commentary(self, text: str) -> bool:
|
||||
"""Send a completed interim assistant commentary message."""
|
||||
text = self._clean_for_display(text)
|
||||
if not text.strip():
|
||||
return False
|
||||
try:
|
||||
result = await self.adapter.send(
|
||||
chat_id=self.chat_id,
|
||||
content=text,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
if result.success:
|
||||
self._already_sent = True
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Commentary send error: %s", e)
|
||||
return False
|
||||
|
||||
async def _send_or_edit(self, text: str) -> bool:
|
||||
"""Send or edit the streaming message.
|
||||
|
||||
Returns True if the text was successfully delivered (sent or edited),
|
||||
False otherwise. Callers like the overflow split loop use this to
|
||||
decide whether to advance past the delivered chunk.
|
||||
"""
|
||||
# Strip MEDIA: directives so they don't appear as visible text.
|
||||
# Media files are delivered as native attachments after the stream
|
||||
# finishes (via _deliver_media_from_response in gateway/run.py).
|
||||
text = self._clean_for_display(text)
|
||||
if not text.strip():
|
||||
return
|
||||
return True # nothing to send is "success"
|
||||
try:
|
||||
if self._message_id is not None:
|
||||
if self._edit_supported:
|
||||
# Skip if text is identical to what we last sent
|
||||
if text == self._last_sent_text:
|
||||
return
|
||||
return True
|
||||
# Edit existing message
|
||||
result = await self.adapter.edit_message(
|
||||
chat_id=self.chat_id,
|
||||
@@ -393,19 +508,52 @@ class GatewayStreamConsumer:
|
||||
if result.success:
|
||||
self._already_sent = True
|
||||
self._last_sent_text = text
|
||||
# Successful edit — reset flood strike counter
|
||||
self._flood_strikes = 0
|
||||
return True
|
||||
else:
|
||||
# If an edit fails mid-stream (especially Telegram flood control),
|
||||
# stop progressive edits and send only the missing tail once the
|
||||
# Edit failed. If this looks like flood control / rate
|
||||
# limiting, use adaptive backoff: double the edit interval
|
||||
# and retry on the next cycle. Only permanently disable
|
||||
# edits after _MAX_FLOOD_STRIKES consecutive failures.
|
||||
if self._is_flood_error(result):
|
||||
self._flood_strikes += 1
|
||||
self._current_edit_interval = min(
|
||||
self._current_edit_interval * 2, 10.0,
|
||||
)
|
||||
logger.debug(
|
||||
"Flood control on edit (strike %d/%d), "
|
||||
"backoff interval → %.1fs",
|
||||
self._flood_strikes,
|
||||
self._MAX_FLOOD_STRIKES,
|
||||
self._current_edit_interval,
|
||||
)
|
||||
if self._flood_strikes < self._MAX_FLOOD_STRIKES:
|
||||
# Don't disable edits yet — just slow down.
|
||||
# Update _last_edit_time so the next edit
|
||||
# respects the new interval.
|
||||
self._last_edit_time = time.monotonic()
|
||||
return False
|
||||
|
||||
# Non-flood error OR flood strikes exhausted: enter
|
||||
# fallback mode — send only the missing tail once the
|
||||
# final response is available.
|
||||
logger.debug("Edit failed, disabling streaming for this adapter")
|
||||
logger.debug(
|
||||
"Edit failed (strikes=%d), entering fallback mode",
|
||||
self._flood_strikes,
|
||||
)
|
||||
self._fallback_prefix = self._visible_prefix()
|
||||
self._fallback_final_send = True
|
||||
self._edit_supported = False
|
||||
self._already_sent = True
|
||||
# Best-effort: strip the cursor from the last visible
|
||||
# message so the user doesn't see a stuck ▉.
|
||||
await self._try_strip_cursor()
|
||||
return False
|
||||
else:
|
||||
# Editing not supported — skip intermediate updates.
|
||||
# The final response will be sent by the fallback path.
|
||||
pass
|
||||
return False
|
||||
else:
|
||||
# First message — send new
|
||||
result = await self.adapter.send(
|
||||
@@ -413,23 +561,25 @@ class GatewayStreamConsumer:
|
||||
content=text,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
if result.success and result.message_id:
|
||||
self._message_id = result.message_id
|
||||
if result.success:
|
||||
if result.message_id:
|
||||
self._message_id = result.message_id
|
||||
else:
|
||||
self._edit_supported = False
|
||||
self._already_sent = True
|
||||
self._last_sent_text = text
|
||||
elif result.success:
|
||||
# Platform accepted the message but returned no message_id
|
||||
# (e.g. Signal). Can't edit without an ID — switch to
|
||||
# fallback mode: suppress intermediate deltas, send only
|
||||
# the missing tail once the final response is ready.
|
||||
self._already_sent = True
|
||||
self._edit_supported = False
|
||||
self._fallback_prefix = self._clean_for_display(text)
|
||||
self._fallback_final_send = True
|
||||
# Sentinel prevents re-entering this branch on every delta
|
||||
self._message_id = "__no_edit__"
|
||||
if not result.message_id:
|
||||
self._fallback_prefix = self._visible_prefix()
|
||||
self._fallback_final_send = True
|
||||
# Sentinel prevents re-entering the first-send path on
|
||||
# every delta/tool boundary when platforms accept a
|
||||
# message but do not return an editable message id.
|
||||
self._message_id = "__no_edit__"
|
||||
return True
|
||||
else:
|
||||
# Initial send failed — disable streaming for this session
|
||||
self._edit_supported = False
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error("Stream send/edit error: %s", e)
|
||||
return False
|
||||
|
||||
@@ -250,9 +250,39 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
|
||||
api_key_env_vars=("HF_TOKEN",),
|
||||
base_url_env_var="HF_BASE_URL",
|
||||
),
|
||||
"xiaomi": ProviderConfig(
|
||||
id="xiaomi",
|
||||
name="Xiaomi MiMo",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://api.xiaomimimo.com/v1",
|
||||
api_key_env_vars=("XIAOMI_API_KEY",),
|
||||
base_url_env_var="XIAOMI_BASE_URL",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Anthropic Key Helper
|
||||
# =============================================================================
|
||||
|
||||
def get_anthropic_key() -> str:
|
||||
"""Return the first usable Anthropic credential, or ``""``.
|
||||
|
||||
Checks both the ``.env`` file (via ``get_env_value``) and the process
|
||||
environment (``os.getenv``). The fallback order mirrors the
|
||||
``PROVIDER_REGISTRY["anthropic"].api_key_env_vars`` tuple:
|
||||
|
||||
ANTHROPIC_API_KEY -> ANTHROPIC_TOKEN -> CLAUDE_CODE_OAUTH_TOKEN
|
||||
"""
|
||||
from hermes_cli.config import get_env_value
|
||||
|
||||
for var in PROVIDER_REGISTRY["anthropic"].api_key_env_vars:
|
||||
value = get_env_value(var) or os.getenv(var, "")
|
||||
if value:
|
||||
return value
|
||||
return ""
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Kimi Code Endpoint Detection
|
||||
# =============================================================================
|
||||
@@ -908,6 +938,7 @@ def resolve_provider(
|
||||
"opencode": "opencode-zen", "zen": "opencode-zen",
|
||||
"qwen-portal": "qwen-oauth", "qwen-cli": "qwen-oauth", "qwen-oauth": "qwen-oauth",
|
||||
"hf": "huggingface", "hugging-face": "huggingface", "huggingface-hub": "huggingface",
|
||||
"mimo": "xiaomi", "xiaomi-mimo": "xiaomi",
|
||||
"go": "opencode-go", "opencode-go-sub": "opencode-go",
|
||||
"kilo": "kilocode", "kilo-code": "kilocode", "kilo-gateway": "kilocode",
|
||||
# Local server aliases — route through the generic custom provider
|
||||
|
||||
+114
-16
@@ -1,8 +1,9 @@
|
||||
"""hermes claw — OpenClaw migration commands.
|
||||
|
||||
Usage:
|
||||
hermes claw migrate # Interactive migration from ~/.openclaw
|
||||
hermes claw migrate --dry-run # Preview what would be migrated
|
||||
hermes claw migrate # Preview then migrate (always shows preview first)
|
||||
hermes claw migrate --dry-run # Preview only, no changes
|
||||
hermes claw migrate --yes # Skip confirmation prompt
|
||||
hermes claw migrate --preset full --overwrite # Full migration, overwrite conflicts
|
||||
hermes claw cleanup # Archive leftover OpenClaw directories
|
||||
hermes claw cleanup --dry-run # Preview what would be archived
|
||||
@@ -51,6 +52,41 @@ _OPENCLAW_SCRIPT_INSTALLED = (
|
||||
# Known OpenClaw directory names (current + legacy)
|
||||
_OPENCLAW_DIR_NAMES = (".openclaw", ".clawdbot", ".moldbot")
|
||||
|
||||
def _warn_if_gateway_running(auto_yes: bool) -> None:
|
||||
"""Check if a Hermes gateway is running with connected platforms.
|
||||
|
||||
Migrating bot tokens while the gateway is polling will cause conflicts
|
||||
(e.g. Telegram 409 "terminated by other getUpdates request"). Warn the
|
||||
user and let them decide whether to continue.
|
||||
"""
|
||||
from gateway.status import get_running_pid, read_runtime_status
|
||||
|
||||
if not get_running_pid():
|
||||
return
|
||||
|
||||
data = read_runtime_status() or {}
|
||||
platforms = data.get("platforms") or {}
|
||||
connected = [name for name, info in platforms.items()
|
||||
if isinstance(info, dict) and info.get("state") == "connected"]
|
||||
if not connected:
|
||||
return
|
||||
|
||||
print()
|
||||
print_error(
|
||||
"Hermes gateway is running with active connections: "
|
||||
+ ", ".join(connected)
|
||||
)
|
||||
print_info(
|
||||
"Migrating bot tokens while the gateway is active will cause "
|
||||
"conflicts (Telegram, Discord, and Slack only allow one active "
|
||||
"session per token)."
|
||||
)
|
||||
print_info("Recommendation: stop the gateway first with 'hermes stop'.")
|
||||
print()
|
||||
if not auto_yes and not prompt_yes_no("Continue anyway?", default=False):
|
||||
print_info("Migration cancelled. Stop the gateway and try again.")
|
||||
sys.exit(0)
|
||||
|
||||
# State files commonly found in OpenClaw workspace directories that cause
|
||||
# confusion after migration (the agent discovers them and writes to them)
|
||||
_WORKSPACE_STATE_GLOBS = (
|
||||
@@ -237,12 +273,12 @@ def _cmd_migrate(args):
|
||||
|
||||
# Show what we're doing
|
||||
hermes_home = get_hermes_home()
|
||||
auto_yes = getattr(args, "yes", False)
|
||||
print()
|
||||
print_header("Migration Settings")
|
||||
print_info(f"Source: {source_dir}")
|
||||
print_info(f"Target: {hermes_home}")
|
||||
print_info(f"Preset: {preset}")
|
||||
print_info(f"Mode: {'dry run (preview only)' if dry_run else 'execute'}")
|
||||
print_info(f"Overwrite: {'yes' if overwrite else 'no (skip conflicts)'}")
|
||||
print_info(f"Secrets: {'yes (allowlisted only)' if migrate_secrets else 'no'}")
|
||||
if skill_conflict != "skip":
|
||||
@@ -251,31 +287,85 @@ def _cmd_migrate(args):
|
||||
print_info(f"Workspace: {workspace_target}")
|
||||
print()
|
||||
|
||||
# For execute mode (non-dry-run), confirm unless --yes was passed
|
||||
if not dry_run and not getattr(args, "yes", False):
|
||||
if not prompt_yes_no("Proceed with migration?", default=True):
|
||||
print_info("Migration cancelled.")
|
||||
return
|
||||
# Check if a gateway is running with connected platforms — migrating tokens
|
||||
# while the gateway is active will cause conflicts (e.g. Telegram 409).
|
||||
_warn_if_gateway_running(auto_yes)
|
||||
|
||||
# Ensure config.yaml exists before migration tries to read it
|
||||
config_path = get_config_path()
|
||||
if not config_path.exists():
|
||||
save_config(load_config())
|
||||
|
||||
# Load and run the migration
|
||||
# Load the migration module
|
||||
try:
|
||||
mod = _load_migration_module(script_path)
|
||||
if mod is None:
|
||||
print_error("Could not load migration script.")
|
||||
return
|
||||
except Exception as e:
|
||||
print()
|
||||
print_error(f"Could not load migration script: {e}")
|
||||
logger.debug("OpenClaw migration error", exc_info=True)
|
||||
return
|
||||
|
||||
selected = mod.resolve_selected_options(None, None, preset=preset)
|
||||
ws_target = Path(workspace_target).resolve() if workspace_target else None
|
||||
selected = mod.resolve_selected_options(None, None, preset=preset)
|
||||
ws_target = Path(workspace_target).resolve() if workspace_target else None
|
||||
|
||||
# ── Phase 1: Always preview first ──────────────────────────
|
||||
try:
|
||||
preview = mod.Migrator(
|
||||
source_root=source_dir.resolve(),
|
||||
target_root=hermes_home.resolve(),
|
||||
execute=False,
|
||||
workspace_target=ws_target,
|
||||
overwrite=overwrite,
|
||||
migrate_secrets=migrate_secrets,
|
||||
output_dir=None,
|
||||
selected_options=selected,
|
||||
preset_name=preset,
|
||||
skill_conflict_mode=skill_conflict,
|
||||
)
|
||||
preview_report = preview.migrate()
|
||||
except Exception as e:
|
||||
print()
|
||||
print_error(f"Migration preview failed: {e}")
|
||||
logger.debug("OpenClaw migration preview error", exc_info=True)
|
||||
return
|
||||
|
||||
preview_summary = preview_report.get("summary", {})
|
||||
preview_count = preview_summary.get("migrated", 0)
|
||||
|
||||
if preview_count == 0:
|
||||
print()
|
||||
print_info("Nothing to migrate from OpenClaw.")
|
||||
_print_migration_report(preview_report, dry_run=True)
|
||||
return
|
||||
|
||||
print()
|
||||
print_header(f"Migration Preview — {preview_count} item(s) would be imported")
|
||||
print_info("No changes have been made yet. Review the list below:")
|
||||
_print_migration_report(preview_report, dry_run=True)
|
||||
|
||||
# If --dry-run, stop here
|
||||
if dry_run:
|
||||
return
|
||||
|
||||
# ── Phase 2: Confirm and execute ───────────────────────────
|
||||
print()
|
||||
if not auto_yes:
|
||||
if not sys.stdin.isatty():
|
||||
print_info("Non-interactive session — preview only.")
|
||||
print_info("To execute, re-run with: hermes claw migrate --yes")
|
||||
return
|
||||
if not prompt_yes_no("Proceed with migration?", default=True):
|
||||
print_info("Migration cancelled.")
|
||||
return
|
||||
|
||||
try:
|
||||
migrator = mod.Migrator(
|
||||
source_root=source_dir.resolve(),
|
||||
target_root=hermes_home.resolve(),
|
||||
execute=not dry_run,
|
||||
execute=True,
|
||||
workspace_target=ws_target,
|
||||
overwrite=overwrite,
|
||||
migrate_secrets=migrate_secrets,
|
||||
@@ -292,11 +382,11 @@ def _cmd_migrate(args):
|
||||
return
|
||||
|
||||
# Print results
|
||||
_print_migration_report(report, dry_run)
|
||||
_print_migration_report(report, dry_run=False)
|
||||
|
||||
# After successful non-dry-run migration, offer to archive the source directory
|
||||
if not dry_run and report.get("summary", {}).get("migrated", 0) > 0:
|
||||
_offer_source_archival(source_dir, getattr(args, "yes", False))
|
||||
# After successful migration, offer to archive the source directory
|
||||
if report.get("summary", {}).get("migrated", 0) > 0:
|
||||
_offer_source_archival(source_dir, auto_yes)
|
||||
|
||||
|
||||
def _offer_source_archival(source_dir: Path, auto_yes: bool = False):
|
||||
@@ -330,6 +420,11 @@ def _offer_source_archival(source_dir: Path, auto_yes: bool = False):
|
||||
print_info("You can always rename it back if needed.")
|
||||
print()
|
||||
|
||||
if not auto_yes and not sys.stdin.isatty():
|
||||
print_info("Non-interactive session — skipping archival.")
|
||||
print_info("Run later with: hermes claw cleanup")
|
||||
return
|
||||
|
||||
if auto_yes or prompt_yes_no(f"Archive {source_dir} now?", default=True):
|
||||
try:
|
||||
archive_path = _archive_directory(source_dir)
|
||||
@@ -433,6 +528,9 @@ def _cmd_cleanup(args):
|
||||
if dry_run:
|
||||
archive_path = _archive_directory(source_dir, dry_run=True)
|
||||
print_info(f"Would archive: {source_dir} → {archive_path}")
|
||||
elif not auto_yes and not sys.stdin.isatty():
|
||||
print_info(f"Non-interactive session — would archive: {source_dir}")
|
||||
print_info("To execute, re-run with: hermes claw cleanup --yes")
|
||||
else:
|
||||
if auto_yes or prompt_yes_no(f"Archive {source_dir}?", default=True):
|
||||
try:
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
"""Shared CLI output helpers for Hermes CLI modules.
|
||||
|
||||
Extracts the identical ``print_info/success/warning/error`` and ``prompt()``
|
||||
functions previously duplicated across setup.py, tools_config.py,
|
||||
mcp_config.py, and memory_setup.py.
|
||||
"""
|
||||
|
||||
import getpass
|
||||
import sys
|
||||
|
||||
from hermes_cli.colors import Colors, color
|
||||
|
||||
|
||||
# ─── Print Helpers ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def print_info(text: str) -> None:
|
||||
"""Print a dim informational message."""
|
||||
print(color(f" {text}", Colors.DIM))
|
||||
|
||||
|
||||
def print_success(text: str) -> None:
|
||||
"""Print a green success message with ✓ prefix."""
|
||||
print(color(f"✓ {text}", Colors.GREEN))
|
||||
|
||||
|
||||
def print_warning(text: str) -> None:
|
||||
"""Print a yellow warning message with ⚠ prefix."""
|
||||
print(color(f"⚠ {text}", Colors.YELLOW))
|
||||
|
||||
|
||||
def print_error(text: str) -> None:
|
||||
"""Print a red error message with ✗ prefix."""
|
||||
print(color(f"✗ {text}", Colors.RED))
|
||||
|
||||
|
||||
def print_header(text: str) -> None:
|
||||
"""Print a bold yellow header."""
|
||||
print(color(f"\n {text}", Colors.YELLOW))
|
||||
|
||||
|
||||
# ─── Input Prompts ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def prompt(
|
||||
question: str,
|
||||
default: str | None = None,
|
||||
password: bool = False,
|
||||
) -> str:
|
||||
"""Prompt the user for input with optional default and password masking.
|
||||
|
||||
Replaces the four independent ``_prompt()`` / ``prompt()`` implementations
|
||||
in setup.py, tools_config.py, mcp_config.py, and memory_setup.py.
|
||||
|
||||
Returns the user's input (stripped), or *default* if the user presses Enter.
|
||||
Returns empty string on Ctrl-C or EOF.
|
||||
"""
|
||||
suffix = f" [{default}]" if default else ""
|
||||
display = color(f" {question}{suffix}: ", Colors.YELLOW)
|
||||
|
||||
try:
|
||||
if password:
|
||||
value = getpass.getpass(display)
|
||||
else:
|
||||
value = input(display)
|
||||
value = value.strip()
|
||||
return value if value else (default or "")
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return ""
|
||||
|
||||
|
||||
def prompt_yes_no(question: str, default: bool = True) -> bool:
|
||||
"""Prompt for a yes/no answer. Returns bool."""
|
||||
hint = "Y/n" if default else "y/N"
|
||||
answer = prompt(f"{question} ({hint})")
|
||||
if not answer:
|
||||
return default
|
||||
return answer.lower().startswith("y")
|
||||
+145
-9
@@ -32,13 +32,15 @@ _ENV_VAR_NAME_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
||||
_EXTRA_ENV_KEYS = frozenset({
|
||||
"OPENAI_API_KEY", "OPENAI_BASE_URL",
|
||||
"ANTHROPIC_API_KEY", "ANTHROPIC_TOKEN",
|
||||
"AUXILIARY_VISION_MODEL",
|
||||
"DISCORD_HOME_CHANNEL", "TELEGRAM_HOME_CHANNEL",
|
||||
"SIGNAL_ACCOUNT", "SIGNAL_HTTP_URL",
|
||||
"SIGNAL_ALLOWED_USERS", "SIGNAL_GROUP_ALLOWED_USERS",
|
||||
"DINGTALK_CLIENT_ID", "DINGTALK_CLIENT_SECRET",
|
||||
"FEISHU_APP_ID", "FEISHU_APP_SECRET", "FEISHU_ENCRYPT_KEY", "FEISHU_VERIFICATION_TOKEN",
|
||||
"WECOM_BOT_ID", "WECOM_SECRET",
|
||||
"WECOM_CALLBACK_CORP_ID", "WECOM_CALLBACK_CORP_SECRET", "WECOM_CALLBACK_AGENT_ID",
|
||||
"WECOM_CALLBACK_TOKEN", "WECOM_CALLBACK_ENCODING_AES_KEY",
|
||||
"WECOM_CALLBACK_HOST", "WECOM_CALLBACK_PORT",
|
||||
"WEIXIN_ACCOUNT_ID", "WEIXIN_TOKEN", "WEIXIN_BASE_URL", "WEIXIN_CDN_BASE_URL",
|
||||
"WEIXIN_HOME_CHANNEL", "WEIXIN_HOME_CHANNEL_NAME", "WEIXIN_DM_POLICY", "WEIXIN_GROUP_POLICY",
|
||||
"WEIXIN_ALLOWED_USERS", "WEIXIN_GROUP_ALLOWED_USERS", "WEIXIN_ALLOW_ALL_USERS",
|
||||
@@ -141,6 +143,73 @@ def managed_error(action: str = "modify configuration"):
|
||||
print(format_managed_message(action), file=sys.stderr)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Container-aware CLI (NixOS container mode)
|
||||
# =============================================================================
|
||||
|
||||
def _is_inside_container() -> bool:
|
||||
"""Detect if we're already running inside a Docker/Podman container."""
|
||||
# Standard Docker/Podman indicators
|
||||
if os.path.exists("/.dockerenv"):
|
||||
return True
|
||||
# Podman uses /run/.containerenv
|
||||
if os.path.exists("/run/.containerenv"):
|
||||
return True
|
||||
# Check cgroup for container runtime evidence (works for both Docker & Podman)
|
||||
try:
|
||||
with open("/proc/1/cgroup", "r") as f:
|
||||
cgroup = f.read()
|
||||
if "docker" in cgroup or "podman" in cgroup or "/lxc/" in cgroup:
|
||||
return True
|
||||
except OSError:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def get_container_exec_info() -> Optional[dict]:
|
||||
"""Read container mode metadata from HERMES_HOME/.container-mode.
|
||||
|
||||
Returns a dict with keys: backend, container_name, exec_user, hermes_bin
|
||||
or None if container mode is not active, we're already inside the
|
||||
container, or HERMES_DEV=1 is set.
|
||||
|
||||
The .container-mode file is written by the NixOS activation script when
|
||||
container.enable = true. It tells the host CLI to exec into the container
|
||||
instead of running locally.
|
||||
"""
|
||||
if os.environ.get("HERMES_DEV") == "1":
|
||||
return None
|
||||
|
||||
if _is_inside_container():
|
||||
return None
|
||||
|
||||
container_mode_file = get_hermes_home() / ".container-mode"
|
||||
|
||||
try:
|
||||
info = {}
|
||||
with open(container_mode_file, "r") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if "=" in line and not line.startswith("#"):
|
||||
key, _, value = line.partition("=")
|
||||
info[key.strip()] = value.strip()
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
# All other exceptions (PermissionError, malformed data, etc.) propagate
|
||||
|
||||
backend = info.get("backend", "docker")
|
||||
container_name = info.get("container_name", "hermes-agent")
|
||||
exec_user = info.get("exec_user", "hermes")
|
||||
hermes_bin = info.get("hermes_bin", "/data/current-package/bin/hermes")
|
||||
|
||||
return {
|
||||
"backend": backend,
|
||||
"container_name": container_name,
|
||||
"exec_user": exec_user,
|
||||
"hermes_bin": hermes_bin,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Config paths
|
||||
# =============================================================================
|
||||
@@ -381,7 +450,7 @@ DEFAULT_CONFIG = {
|
||||
"model": "", # e.g. "google/gemini-2.5-flash", "gpt-4o"
|
||||
"base_url": "", # direct OpenAI-compatible endpoint (takes precedence over provider)
|
||||
"api_key": "", # API key for base_url (falls back to OPENAI_API_KEY)
|
||||
"timeout": 30, # seconds — LLM API call timeout; increase for slow local vision models
|
||||
"timeout": 120, # seconds — LLM API call timeout; vision payloads need generous timeout
|
||||
"download_timeout": 30, # seconds — image HTTP download timeout; increase for slow connections
|
||||
},
|
||||
"web_extract": {
|
||||
@@ -446,9 +515,11 @@ DEFAULT_CONFIG = {
|
||||
"inline_diffs": True, # Show inline diff previews for write actions (write_file, patch, skill_manage)
|
||||
"show_cost": False, # Show $ cost in the status bar (off by default)
|
||||
"skin": "default",
|
||||
"interim_assistant_messages": True, # Gateway: show natural mid-turn assistant status messages
|
||||
"tool_progress_command": False, # Enable /verbose command in messaging gateway
|
||||
"tool_progress_overrides": {}, # Per-platform overrides: {"signal": "off", "telegram": "all"}
|
||||
"tool_progress_overrides": {}, # DEPRECATED — use display.platforms instead
|
||||
"tool_preview_length": 0, # Max chars for tool call previews (0 = no limit, show full paths/commands)
|
||||
"platforms": {}, # Per-platform display overrides: {"telegram": {"tool_progress": "all"}, "slack": {"tool_progress": "off"}}
|
||||
},
|
||||
|
||||
# Privacy settings
|
||||
@@ -458,7 +529,7 @@ DEFAULT_CONFIG = {
|
||||
|
||||
# Text-to-speech configuration
|
||||
"tts": {
|
||||
"provider": "edge", # "edge" (free) | "elevenlabs" (premium) | "openai" | "neutts" (local)
|
||||
"provider": "edge", # "edge" (free) | "elevenlabs" (premium) | "openai" | "minimax" | "mistral" | "neutts" (local)
|
||||
"edge": {
|
||||
"voice": "en-US-AriaNeural",
|
||||
# Popular: AriaNeural, JennyNeural, AndrewNeural, BrianNeural, SoniaNeural
|
||||
@@ -472,6 +543,10 @@ DEFAULT_CONFIG = {
|
||||
"voice": "alloy",
|
||||
# Voices: alloy, echo, fable, onyx, nova, shimmer
|
||||
},
|
||||
"mistral": {
|
||||
"model": "voxtral-mini-tts-2603",
|
||||
"voice_id": "c69964a6-ab8b-4f8a-9465-ec0925096ec8", # Paul - Neutral
|
||||
},
|
||||
"neutts": {
|
||||
"ref_audio": "", # Path to reference voice audio (empty = bundled default)
|
||||
"ref_text": "", # Path to reference voice transcript (empty = bundled default)
|
||||
@@ -632,7 +707,7 @@ DEFAULT_CONFIG = {
|
||||
},
|
||||
|
||||
# Config schema version - bump this when adding new required fields
|
||||
"_config_version": 14,
|
||||
"_config_version": 16,
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
@@ -864,6 +939,21 @@ OPTIONAL_ENV_VARS = {
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"XIAOMI_API_KEY": {
|
||||
"description": "Xiaomi MiMo API key for MiMo models (mimo-v2-pro, mimo-v2-omni, mimo-v2-flash)",
|
||||
"prompt": "Xiaomi MiMo API Key",
|
||||
"url": "https://platform.xiaomimimo.com",
|
||||
"password": True,
|
||||
"category": "provider",
|
||||
},
|
||||
"XIAOMI_BASE_URL": {
|
||||
"description": "Xiaomi MiMo base URL override (default: https://api.xiaomimimo.com/v1)",
|
||||
"prompt": "Xiaomi base URL (leave empty for default)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
|
||||
# ── Tool API keys ──
|
||||
"EXA_API_KEY": {
|
||||
@@ -1016,6 +1106,13 @@ OPTIONAL_ENV_VARS = {
|
||||
"password": True,
|
||||
"category": "tool",
|
||||
},
|
||||
"MISTRAL_API_KEY": {
|
||||
"description": "Mistral API key for Voxtral TTS and transcription (STT)",
|
||||
"prompt": "Mistral API key",
|
||||
"url": "https://console.mistral.ai/",
|
||||
"password": True,
|
||||
"category": "tool",
|
||||
},
|
||||
"GITHUB_TOKEN": {
|
||||
"description": "GitHub token for Skills Hub (higher API rate limits, skill publish)",
|
||||
"prompt": "GitHub Token",
|
||||
@@ -1472,7 +1569,7 @@ _KNOWN_ROOT_KEYS = {
|
||||
|
||||
# Valid fields inside a custom_providers list entry
|
||||
_VALID_CUSTOM_PROVIDER_FIELDS = {
|
||||
"name", "base_url", "api_key", "api_mode", "models",
|
||||
"name", "base_url", "api_key", "api_mode", "model", "models",
|
||||
"context_length", "rate_limit_delay",
|
||||
}
|
||||
|
||||
@@ -1837,6 +1934,44 @@ def migrate_config(interactive: bool = True, quiet: bool = False) -> Dict[str, A
|
||||
if not quiet:
|
||||
print(f" ✓ Migrated legacy stt.model to provider-specific config")
|
||||
|
||||
# ── Version 14 → 15: add explicit gateway interim-message gate ──
|
||||
if current_ver < 15:
|
||||
config = read_raw_config()
|
||||
display = config.get("display", {})
|
||||
if not isinstance(display, dict):
|
||||
display = {}
|
||||
if "interim_assistant_messages" not in display:
|
||||
display["interim_assistant_messages"] = True
|
||||
config["display"] = display
|
||||
results["config_added"].append("display.interim_assistant_messages=true (default)")
|
||||
save_config(config)
|
||||
if not quiet:
|
||||
print(" ✓ Added display.interim_assistant_messages=true")
|
||||
|
||||
# ── Version 15 → 16: migrate tool_progress_overrides into display.platforms ──
|
||||
if current_ver < 16:
|
||||
config = read_raw_config()
|
||||
display = config.get("display", {})
|
||||
if not isinstance(display, dict):
|
||||
display = {}
|
||||
old_overrides = display.get("tool_progress_overrides")
|
||||
if isinstance(old_overrides, dict) and old_overrides:
|
||||
platforms = display.get("platforms", {})
|
||||
if not isinstance(platforms, dict):
|
||||
platforms = {}
|
||||
for plat, mode in old_overrides.items():
|
||||
if plat not in platforms:
|
||||
platforms[plat] = {}
|
||||
if "tool_progress" not in platforms[plat]:
|
||||
platforms[plat]["tool_progress"] = mode
|
||||
display["platforms"] = platforms
|
||||
config["display"] = display
|
||||
save_config(config)
|
||||
if not quiet:
|
||||
migrated = ", ".join(f"{p}={m}" for p, m in old_overrides.items())
|
||||
print(f" ✓ Migrated tool_progress_overrides → display.platforms: {migrated}")
|
||||
results["config_added"].append("display.platforms (migrated from tool_progress_overrides)")
|
||||
|
||||
if current_ver < latest_ver and not quiet:
|
||||
print(f"Config version: {current_ver} → {latest_ver}")
|
||||
|
||||
@@ -2557,7 +2692,8 @@ def show_config():
|
||||
for env_key, name in keys:
|
||||
value = get_env_value(env_key)
|
||||
print(f" {name:<14} {redact_key(value)}")
|
||||
anthropic_value = get_env_value("ANTHROPIC_TOKEN") or get_env_value("ANTHROPIC_API_KEY")
|
||||
from hermes_cli.auth import get_anthropic_key
|
||||
anthropic_value = get_anthropic_key()
|
||||
print(f" {'Anthropic':<14} {redact_key(anthropic_value)}")
|
||||
|
||||
# Model settings
|
||||
@@ -2773,8 +2909,8 @@ def set_config_value(key: str, value: str):
|
||||
|
||||
# Write only user config back (not the full merged defaults)
|
||||
ensure_hermes_home()
|
||||
with open(config_path, 'w', encoding="utf-8") as f:
|
||||
yaml.dump(user_config, f, default_flow_style=False, sort_keys=False)
|
||||
from utils import atomic_yaml_write
|
||||
atomic_yaml_write(config_path, user_config, sort_keys=False)
|
||||
|
||||
# Keep .env in sync for keys that terminal_tool reads directly from env vars.
|
||||
# config.yaml is authoritative, but terminal_tool only reads TERMINAL_ENV etc.
|
||||
|
||||
@@ -287,6 +287,129 @@ def _radio_numbered_fallback(
|
||||
return cancel_returns
|
||||
|
||||
|
||||
def curses_single_select(
|
||||
title: str,
|
||||
items: List[str],
|
||||
default_index: int = 0,
|
||||
*,
|
||||
cancel_label: str = "Cancel",
|
||||
) -> int | None:
|
||||
"""Curses single-select menu. Returns selected index or None on cancel.
|
||||
|
||||
Works inside prompt_toolkit because curses.wrapper() restores the terminal
|
||||
safely, unlike simple_term_menu which conflicts with /dev/tty.
|
||||
"""
|
||||
if not sys.stdin.isatty():
|
||||
return None
|
||||
|
||||
try:
|
||||
import curses
|
||||
result_holder: list = [None]
|
||||
|
||||
all_items = list(items) + [cancel_label]
|
||||
cancel_idx = len(items)
|
||||
|
||||
def _draw(stdscr):
|
||||
curses.curs_set(0)
|
||||
if curses.has_colors():
|
||||
curses.start_color()
|
||||
curses.use_default_colors()
|
||||
curses.init_pair(1, curses.COLOR_GREEN, -1)
|
||||
curses.init_pair(2, curses.COLOR_YELLOW, -1)
|
||||
cursor = min(default_index, len(all_items) - 1)
|
||||
scroll_offset = 0
|
||||
|
||||
while True:
|
||||
stdscr.clear()
|
||||
max_y, max_x = stdscr.getmaxyx()
|
||||
|
||||
try:
|
||||
hattr = curses.A_BOLD
|
||||
if curses.has_colors():
|
||||
hattr |= curses.color_pair(2)
|
||||
stdscr.addnstr(0, 0, title, max_x - 1, hattr)
|
||||
stdscr.addnstr(
|
||||
1, 0,
|
||||
" ↑↓ navigate ENTER confirm ESC/q cancel",
|
||||
max_x - 1, curses.A_DIM,
|
||||
)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
visible_rows = max_y - 3
|
||||
if cursor < scroll_offset:
|
||||
scroll_offset = cursor
|
||||
elif cursor >= scroll_offset + visible_rows:
|
||||
scroll_offset = cursor - visible_rows + 1
|
||||
|
||||
for draw_i, i in enumerate(
|
||||
range(scroll_offset, min(len(all_items), scroll_offset + visible_rows))
|
||||
):
|
||||
y = draw_i + 3
|
||||
if y >= max_y - 1:
|
||||
break
|
||||
arrow = "→" if i == cursor else " "
|
||||
line = f" {arrow} {all_items[i]}"
|
||||
attr = curses.A_NORMAL
|
||||
if i == cursor:
|
||||
attr = curses.A_BOLD
|
||||
if curses.has_colors():
|
||||
attr |= curses.color_pair(1)
|
||||
try:
|
||||
stdscr.addnstr(y, 0, line, max_x - 1, attr)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
stdscr.refresh()
|
||||
key = stdscr.getch()
|
||||
|
||||
if key in (curses.KEY_UP, ord("k")):
|
||||
cursor = (cursor - 1) % len(all_items)
|
||||
elif key in (curses.KEY_DOWN, ord("j")):
|
||||
cursor = (cursor + 1) % len(all_items)
|
||||
elif key in (curses.KEY_ENTER, 10, 13):
|
||||
result_holder[0] = cursor
|
||||
return
|
||||
elif key in (27, ord("q")):
|
||||
result_holder[0] = None
|
||||
return
|
||||
|
||||
curses.wrapper(_draw)
|
||||
flush_stdin()
|
||||
if result_holder[0] is not None and result_holder[0] >= cancel_idx:
|
||||
return None
|
||||
return result_holder[0]
|
||||
|
||||
except Exception:
|
||||
all_items = list(items) + [cancel_label]
|
||||
cancel_idx = len(items)
|
||||
return _numbered_single_fallback(title, all_items, cancel_idx)
|
||||
|
||||
|
||||
def _numbered_single_fallback(
|
||||
title: str,
|
||||
items: List[str],
|
||||
cancel_idx: int,
|
||||
) -> int | None:
|
||||
"""Text-based numbered fallback for single-select."""
|
||||
print(f"\n {title}\n")
|
||||
for i, label in enumerate(items, 1):
|
||||
print(f" {i}. {label}")
|
||||
print()
|
||||
try:
|
||||
val = input(f" Choice [1-{len(items)}]: ").strip()
|
||||
if not val:
|
||||
return None
|
||||
idx = int(val) - 1
|
||||
if 0 <= idx < len(items) and idx < cancel_idx:
|
||||
return idx
|
||||
if idx == cancel_idx:
|
||||
return None
|
||||
except (ValueError, KeyboardInterrupt, EOFError):
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _numbered_fallback(
|
||||
title: str,
|
||||
items: List[str],
|
||||
|
||||
@@ -51,6 +51,7 @@ _PROVIDER_ENV_HINTS = (
|
||||
"AI_GATEWAY_API_KEY",
|
||||
"OPENCODE_ZEN_API_KEY",
|
||||
"OPENCODE_GO_API_KEY",
|
||||
"XIAOMI_API_KEY",
|
||||
)
|
||||
|
||||
|
||||
@@ -335,8 +336,8 @@ def run_doctor(args):
|
||||
model_section[k] = raw_config.pop(k)
|
||||
else:
|
||||
raw_config.pop(k)
|
||||
with open(config_path, "w") as f:
|
||||
yaml.dump(raw_config, f, default_flow_style=False)
|
||||
from utils import atomic_yaml_write
|
||||
atomic_yaml_write(config_path, raw_config)
|
||||
check_ok("Migrated stale root-level keys into model section")
|
||||
fixed_count += 1
|
||||
else:
|
||||
@@ -685,7 +686,8 @@ def run_doctor(args):
|
||||
else:
|
||||
check_warn("OpenRouter API", "(not configured)")
|
||||
|
||||
anthropic_key = os.getenv("ANTHROPIC_TOKEN") or os.getenv("ANTHROPIC_API_KEY")
|
||||
from hermes_cli.auth import get_anthropic_key
|
||||
anthropic_key = get_anthropic_key()
|
||||
if anthropic_key:
|
||||
print(" Checking Anthropic API...", end="", flush=True)
|
||||
try:
|
||||
|
||||
@@ -119,6 +119,7 @@ def _configured_platforms() -> list[str]:
|
||||
"dingtalk": "DINGTALK_CLIENT_ID",
|
||||
"feishu": "FEISHU_APP_ID",
|
||||
"wecom": "WECOM_BOT_ID",
|
||||
"wecom_callback": "WECOM_CALLBACK_CORP_ID",
|
||||
"weixin": "WEIXIN_ACCOUNT_ID",
|
||||
}
|
||||
return [name for name, env in checks.items() if os.getenv(env)]
|
||||
|
||||
+156
-25
@@ -157,30 +157,54 @@ def _request_gateway_self_restart(pid: int) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def find_gateway_pids(exclude_pids: set | None = None) -> list:
|
||||
def find_gateway_pids(exclude_pids: set | None = None, all_profiles: bool = False) -> list:
|
||||
"""Find PIDs of running gateway processes.
|
||||
|
||||
Args:
|
||||
exclude_pids: PIDs to exclude from the result (e.g. service-managed
|
||||
PIDs that should not be killed during a stale-process sweep).
|
||||
all_profiles: When ``True``, return gateway PIDs across **all**
|
||||
profiles (the pre-7923 global behaviour). ``hermes update``
|
||||
needs this because a code update affects every profile.
|
||||
When ``False`` (default), only PIDs belonging to the current
|
||||
Hermes profile are returned.
|
||||
"""
|
||||
pids = []
|
||||
_exclude = exclude_pids or set()
|
||||
pids = [pid for pid in _get_service_pids() if pid not in _exclude]
|
||||
patterns = [
|
||||
"hermes_cli.main gateway",
|
||||
"hermes_cli.main --profile",
|
||||
"hermes_cli.main -p",
|
||||
"hermes_cli/main.py gateway",
|
||||
"hermes_cli/main.py --profile",
|
||||
"hermes_cli/main.py -p",
|
||||
"hermes gateway",
|
||||
"gateway/run.py",
|
||||
]
|
||||
current_home = str(get_hermes_home().resolve())
|
||||
current_profile_arg = _profile_arg(current_home)
|
||||
current_profile_name = current_profile_arg.split()[-1] if current_profile_arg else ""
|
||||
|
||||
def _matches_current_profile(command: str) -> bool:
|
||||
if current_profile_name:
|
||||
return (
|
||||
f"--profile {current_profile_name}" in command
|
||||
or f"-p {current_profile_name}" in command
|
||||
or f"HERMES_HOME={current_home}" in command
|
||||
)
|
||||
|
||||
if "--profile " in command or " -p " in command:
|
||||
return False
|
||||
if "HERMES_HOME=" in command and f"HERMES_HOME={current_home}" not in command:
|
||||
return False
|
||||
return True
|
||||
|
||||
try:
|
||||
if is_windows():
|
||||
# Windows: use wmic to search command lines
|
||||
result = subprocess.run(
|
||||
["wmic", "process", "get", "ProcessId,CommandLine", "/FORMAT:LIST"],
|
||||
capture_output=True, text=True, timeout=10
|
||||
)
|
||||
# Parse WMIC LIST output: blocks of "CommandLine=...\nProcessId=...\n"
|
||||
current_cmd = ""
|
||||
for line in result.stdout.split('\n'):
|
||||
line = line.strip()
|
||||
@@ -188,7 +212,7 @@ def find_gateway_pids(exclude_pids: set | None = None) -> list:
|
||||
current_cmd = line[len("CommandLine="):]
|
||||
elif line.startswith("ProcessId="):
|
||||
pid_str = line[len("ProcessId="):]
|
||||
if any(p in current_cmd for p in patterns):
|
||||
if any(p in current_cmd for p in patterns) and (all_profiles or _matches_current_profile(current_cmd)):
|
||||
try:
|
||||
pid = int(pid_str)
|
||||
if pid != os.getpid() and pid not in pids and pid not in _exclude:
|
||||
@@ -198,41 +222,57 @@ def find_gateway_pids(exclude_pids: set | None = None) -> list:
|
||||
current_cmd = ""
|
||||
else:
|
||||
result = subprocess.run(
|
||||
["ps", "aux"],
|
||||
["ps", "eww", "-ax", "-o", "pid=,command="],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
for line in result.stdout.split('\n'):
|
||||
# Skip grep and current process
|
||||
if 'grep' in line or str(os.getpid()) in line:
|
||||
stripped = line.strip()
|
||||
if not stripped or 'grep' in stripped:
|
||||
continue
|
||||
for pattern in patterns:
|
||||
if pattern in line:
|
||||
parts = line.split()
|
||||
if len(parts) > 1:
|
||||
try:
|
||||
pid = int(parts[1])
|
||||
if pid not in pids and pid not in _exclude:
|
||||
pids.append(pid)
|
||||
except ValueError:
|
||||
continue
|
||||
break
|
||||
except Exception:
|
||||
|
||||
pid = None
|
||||
command = ""
|
||||
|
||||
parts = stripped.split(None, 1)
|
||||
if len(parts) == 2:
|
||||
try:
|
||||
pid = int(parts[0])
|
||||
command = parts[1]
|
||||
except ValueError:
|
||||
pid = None
|
||||
|
||||
if pid is None:
|
||||
aux_parts = stripped.split()
|
||||
if len(aux_parts) > 10 and aux_parts[1].isdigit():
|
||||
pid = int(aux_parts[1])
|
||||
command = " ".join(aux_parts[10:])
|
||||
|
||||
if pid is None:
|
||||
continue
|
||||
if pid == os.getpid() or pid in pids or pid in _exclude:
|
||||
continue
|
||||
if any(pattern in command for pattern in patterns) and (all_profiles or _matches_current_profile(command)):
|
||||
pids.append(pid)
|
||||
except (OSError, subprocess.TimeoutExpired):
|
||||
pass
|
||||
|
||||
return pids
|
||||
|
||||
|
||||
def kill_gateway_processes(force: bool = False, exclude_pids: set | None = None) -> int:
|
||||
def kill_gateway_processes(force: bool = False, exclude_pids: set | None = None,
|
||||
all_profiles: bool = False) -> int:
|
||||
"""Kill any running gateway processes. Returns count killed.
|
||||
|
||||
Args:
|
||||
force: Use the platform's force-kill mechanism instead of graceful terminate.
|
||||
exclude_pids: PIDs to skip (e.g. service-managed PIDs that were just
|
||||
restarted and should not be killed).
|
||||
all_profiles: When ``True``, kill across all profiles. Passed
|
||||
through to :func:`find_gateway_pids`.
|
||||
"""
|
||||
pids = find_gateway_pids(exclude_pids=exclude_pids)
|
||||
pids = find_gateway_pids(exclude_pids=exclude_pids, all_profiles=all_profiles)
|
||||
killed = 0
|
||||
|
||||
for pid in pids:
|
||||
@@ -633,6 +673,17 @@ def print_systemd_linger_guidance() -> None:
|
||||
print(" If you want the gateway user service to survive logout, run:")
|
||||
print(" sudo loginctl enable-linger $USER")
|
||||
|
||||
def _launchd_user_home() -> Path:
|
||||
"""Return the real macOS user home for launchd artifacts.
|
||||
|
||||
Profile-mode Hermes often sets ``HOME`` to a profile-scoped directory, but
|
||||
launchd user agents still live under the actual account home.
|
||||
"""
|
||||
import pwd
|
||||
|
||||
return Path(pwd.getpwuid(os.getuid()).pw_dir)
|
||||
|
||||
|
||||
def get_launchd_plist_path() -> Path:
|
||||
"""Return the launchd plist path, scoped per profile.
|
||||
|
||||
@@ -641,7 +692,7 @@ def get_launchd_plist_path() -> Path:
|
||||
"""
|
||||
suffix = _profile_suffix()
|
||||
name = f"ai.hermes.gateway-{suffix}" if suffix else "ai.hermes.gateway"
|
||||
return Path.home() / "Library" / "LaunchAgents" / f"{name}.plist"
|
||||
return _launchd_user_home() / "Library" / "LaunchAgents" / f"{name}.plist"
|
||||
|
||||
def _detect_venv_dir() -> Path | None:
|
||||
"""Detect the active virtualenv directory.
|
||||
@@ -839,6 +890,25 @@ def _normalize_service_definition(text: str) -> str:
|
||||
return "\n".join(line.rstrip() for line in text.strip().splitlines())
|
||||
|
||||
|
||||
def _normalize_launchd_plist_for_comparison(text: str) -> str:
|
||||
"""Normalize launchd plist text for staleness checks.
|
||||
|
||||
The generated plist intentionally captures a broad PATH assembled from the
|
||||
invoking shell so user-installed tools remain reachable under launchd.
|
||||
That makes raw text comparison unstable across shells, so ignore the PATH
|
||||
payload when deciding whether the installed plist is stale.
|
||||
"""
|
||||
import re
|
||||
|
||||
normalized = _normalize_service_definition(text)
|
||||
return re.sub(
|
||||
r'(<key>PATH</key>\s*<string>)(.*?)(</string>)',
|
||||
r'\1__HERMES_PATH__\3',
|
||||
normalized,
|
||||
flags=re.S,
|
||||
)
|
||||
|
||||
|
||||
def systemd_unit_is_current(system: bool = False) -> bool:
|
||||
unit_path = get_systemd_unit_path(system=system)
|
||||
if not unit_path.exists():
|
||||
@@ -1220,7 +1290,7 @@ def launchd_plist_is_current() -> bool:
|
||||
|
||||
installed = plist_path.read_text(encoding="utf-8")
|
||||
expected = generate_launchd_plist()
|
||||
return _normalize_service_definition(installed) == _normalize_service_definition(expected)
|
||||
return _normalize_launchd_plist_for_comparison(installed) == _normalize_launchd_plist_for_comparison(expected)
|
||||
|
||||
|
||||
def refresh_launchd_plist_if_needed() -> bool:
|
||||
@@ -1751,6 +1821,37 @@ _PLATFORMS = [
|
||||
"help": "Chat ID for scheduled results and notifications."},
|
||||
],
|
||||
},
|
||||
{
|
||||
"key": "wecom_callback",
|
||||
"label": "WeCom Callback (Self-Built App)",
|
||||
"emoji": "💬",
|
||||
"token_var": "WECOM_CALLBACK_CORP_ID",
|
||||
"setup_instructions": [
|
||||
"1. Go to WeCom Admin Console → Applications → Create Self-Built App",
|
||||
"2. Note the Corp ID (top of admin console) and create a Corp Secret",
|
||||
"3. Under Receive Messages, configure the callback URL to point to your server",
|
||||
"4. Copy the Token and EncodingAESKey from the callback configuration",
|
||||
"5. The adapter runs an HTTP server — ensure the port is reachable from WeCom",
|
||||
"6. Restrict access with WECOM_CALLBACK_ALLOWED_USERS for production use",
|
||||
],
|
||||
"vars": [
|
||||
{"name": "WECOM_CALLBACK_CORP_ID", "prompt": "Corp ID", "password": False,
|
||||
"help": "Your WeCom enterprise Corp ID."},
|
||||
{"name": "WECOM_CALLBACK_CORP_SECRET", "prompt": "Corp Secret", "password": True,
|
||||
"help": "The secret for your self-built application."},
|
||||
{"name": "WECOM_CALLBACK_AGENT_ID", "prompt": "Agent ID", "password": False,
|
||||
"help": "The Agent ID of your self-built application."},
|
||||
{"name": "WECOM_CALLBACK_TOKEN", "prompt": "Callback Token", "password": True,
|
||||
"help": "The Token from your WeCom callback configuration."},
|
||||
{"name": "WECOM_CALLBACK_ENCODING_AES_KEY", "prompt": "Encoding AES Key", "password": True,
|
||||
"help": "The EncodingAESKey from your WeCom callback configuration."},
|
||||
{"name": "WECOM_CALLBACK_PORT", "prompt": "Callback server port (default: 8645)", "password": False,
|
||||
"help": "Port for the HTTP callback server."},
|
||||
{"name": "WECOM_CALLBACK_ALLOWED_USERS", "prompt": "Allowed user IDs (comma-separated, or empty)", "password": False,
|
||||
"is_allowlist": True,
|
||||
"help": "Restrict which WeCom users can interact with the app."},
|
||||
],
|
||||
},
|
||||
{
|
||||
"key": "weixin",
|
||||
"label": "Weixin / WeChat",
|
||||
@@ -1981,6 +2082,36 @@ def _setup_whatsapp():
|
||||
cmd_whatsapp(argparse.Namespace())
|
||||
|
||||
|
||||
def _setup_email():
|
||||
"""Configure Email via the standard platform setup."""
|
||||
email_platform = next(p for p in _PLATFORMS if p["key"] == "email")
|
||||
_setup_standard_platform(email_platform)
|
||||
|
||||
|
||||
def _setup_sms():
|
||||
"""Configure SMS (Twilio) via the standard platform setup."""
|
||||
sms_platform = next(p for p in _PLATFORMS if p["key"] == "sms")
|
||||
_setup_standard_platform(sms_platform)
|
||||
|
||||
|
||||
def _setup_dingtalk():
|
||||
"""Configure DingTalk via the standard platform setup."""
|
||||
dingtalk_platform = next(p for p in _PLATFORMS if p["key"] == "dingtalk")
|
||||
_setup_standard_platform(dingtalk_platform)
|
||||
|
||||
|
||||
def _setup_feishu():
|
||||
"""Configure Feishu / Lark via the standard platform setup."""
|
||||
feishu_platform = next(p for p in _PLATFORMS if p["key"] == "feishu")
|
||||
_setup_standard_platform(feishu_platform)
|
||||
|
||||
|
||||
def _setup_wecom():
|
||||
"""Configure WeCom (Enterprise WeChat) via the standard platform setup."""
|
||||
wecom_platform = next(p for p in _PLATFORMS if p["key"] == "wecom")
|
||||
_setup_standard_platform(wecom_platform)
|
||||
|
||||
|
||||
def _is_service_installed() -> bool:
|
||||
"""Check if the gateway is installed as a system service."""
|
||||
if supports_systemd_services():
|
||||
@@ -2540,7 +2671,7 @@ def gateway_command(args):
|
||||
service_available = True
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
killed = kill_gateway_processes()
|
||||
killed = kill_gateway_processes(all_profiles=True)
|
||||
total = killed + (1 if service_available else 0)
|
||||
if total:
|
||||
print(f"✓ Stopped {total} gateway process(es) across all profiles")
|
||||
|
||||
+64
-9
@@ -1,16 +1,18 @@
|
||||
"""``hermes logs`` — view and filter Hermes log files.
|
||||
|
||||
Supports tailing, following, session filtering, level filtering, and
|
||||
relative time ranges. All log files live under ``~/.hermes/logs/``.
|
||||
Supports tailing, following, session filtering, level filtering,
|
||||
component filtering, and relative time ranges. All log files live
|
||||
under ``~/.hermes/logs/``.
|
||||
|
||||
Usage examples::
|
||||
|
||||
hermes logs # last 50 lines of agent.log
|
||||
hermes logs -f # follow agent.log in real time
|
||||
hermes logs errors # last 50 lines of errors.log
|
||||
hermes logs gateway -n 100 # last 100 lines of gateway.log
|
||||
hermes logs gateway -n 100 # last 100 lines of gateway.log
|
||||
hermes logs --level WARNING # only WARNING+ lines
|
||||
hermes logs --session abc123 # filter by session ID substring
|
||||
hermes logs --component tools # only tool-related lines
|
||||
hermes logs --since 1h # lines from the last hour
|
||||
hermes logs --since 30m -f # follow, starting 30 min ago
|
||||
"""
|
||||
@@ -20,7 +22,7 @@ import sys
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from hermes_constants import get_hermes_home, display_hermes_home
|
||||
|
||||
@@ -38,6 +40,15 @@ _TS_RE = re.compile(r"^(\d{4}-\d{2}-\d{2}\s+\d{2}:\d{2}:\d{2})")
|
||||
# Level extraction — matches " INFO ", " WARNING ", " ERROR ", " DEBUG ", " CRITICAL "
|
||||
_LEVEL_RE = re.compile(r"\s(DEBUG|INFO|WARNING|ERROR|CRITICAL)\s")
|
||||
|
||||
# Logger name extraction — after level and optional session tag, the next
|
||||
# non-space token before ":" is the logger name.
|
||||
# Matches: "INFO gateway.run:" or "INFO [sess_abc] tools.terminal_tool:"
|
||||
_LOGGER_NAME_RE = re.compile(
|
||||
r"\s(?:DEBUG|INFO|WARNING|ERROR|CRITICAL)" # level
|
||||
r"(?:\s+\[.*?\])?" # optional session tag
|
||||
r"\s+(\S+):" # logger name
|
||||
)
|
||||
|
||||
# Level ordering for >= filtering
|
||||
_LEVEL_ORDER = {"DEBUG": 0, "INFO": 1, "WARNING": 2, "ERROR": 3, "CRITICAL": 4}
|
||||
|
||||
@@ -79,12 +90,27 @@ def _extract_level(line: str) -> Optional[str]:
|
||||
return m.group(1) if m else None
|
||||
|
||||
|
||||
def _extract_logger_name(line: str) -> Optional[str]:
|
||||
"""Extract the logger name from a log line."""
|
||||
m = _LOGGER_NAME_RE.search(line)
|
||||
return m.group(1) if m else None
|
||||
|
||||
|
||||
def _line_matches_component(line: str, prefixes: Sequence[str]) -> bool:
|
||||
"""Check if a log line's logger name starts with any of *prefixes*."""
|
||||
name = _extract_logger_name(line)
|
||||
if name is None:
|
||||
return False
|
||||
return name.startswith(tuple(prefixes))
|
||||
|
||||
|
||||
def _matches_filters(
|
||||
line: str,
|
||||
*,
|
||||
min_level: Optional[str] = None,
|
||||
session_filter: Optional[str] = None,
|
||||
since: Optional[datetime] = None,
|
||||
component_prefixes: Optional[Sequence[str]] = None,
|
||||
) -> bool:
|
||||
"""Check if a log line passes all active filters."""
|
||||
if since is not None:
|
||||
@@ -102,6 +128,10 @@ def _matches_filters(
|
||||
if session_filter not in line:
|
||||
return False
|
||||
|
||||
if component_prefixes is not None:
|
||||
if not _line_matches_component(line, component_prefixes):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@@ -113,6 +143,7 @@ def tail_log(
|
||||
level: Optional[str] = None,
|
||||
session: Optional[str] = None,
|
||||
since: Optional[str] = None,
|
||||
component: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Read and display log lines, optionally following in real time.
|
||||
|
||||
@@ -130,6 +161,8 @@ def tail_log(
|
||||
Session ID substring to filter on.
|
||||
since
|
||||
Relative time string (e.g. ``"1h"``, ``"30m"``).
|
||||
component
|
||||
Component name to filter by (e.g. ``"gateway"``, ``"tools"``).
|
||||
"""
|
||||
filename = LOG_FILES.get(log_name)
|
||||
if filename is None:
|
||||
@@ -155,13 +188,29 @@ def tail_log(
|
||||
print(f"Invalid --level: {level!r}. Use DEBUG, INFO, WARNING, ERROR, or CRITICAL.")
|
||||
sys.exit(1)
|
||||
|
||||
has_filters = min_level is not None or session is not None or since_dt is not None
|
||||
# Resolve component to logger name prefixes
|
||||
component_prefixes = None
|
||||
if component:
|
||||
from hermes_logging import COMPONENT_PREFIXES
|
||||
component_lower = component.lower()
|
||||
if component_lower not in COMPONENT_PREFIXES:
|
||||
available = ", ".join(sorted(COMPONENT_PREFIXES))
|
||||
print(f"Unknown component: {component!r}. Available: {available}")
|
||||
sys.exit(1)
|
||||
component_prefixes = COMPONENT_PREFIXES[component_lower]
|
||||
|
||||
has_filters = (
|
||||
min_level is not None
|
||||
or session is not None
|
||||
or since_dt is not None
|
||||
or component_prefixes is not None
|
||||
)
|
||||
|
||||
# Read and display the tail
|
||||
try:
|
||||
lines = _read_tail(log_path, num_lines, has_filters=has_filters,
|
||||
min_level=min_level, session_filter=session,
|
||||
since=since_dt)
|
||||
since=since_dt, component_prefixes=component_prefixes)
|
||||
except PermissionError:
|
||||
print(f"Permission denied: {log_path}")
|
||||
sys.exit(1)
|
||||
@@ -172,6 +221,8 @@ def tail_log(
|
||||
filter_parts.append(f"level>={min_level}")
|
||||
if session:
|
||||
filter_parts.append(f"session={session}")
|
||||
if component:
|
||||
filter_parts.append(f"component={component}")
|
||||
if since:
|
||||
filter_parts.append(f"since={since}")
|
||||
filter_desc = f" [{', '.join(filter_parts)}]" if filter_parts else ""
|
||||
@@ -190,7 +241,7 @@ def tail_log(
|
||||
# Follow mode — poll for new content
|
||||
try:
|
||||
_follow_log(log_path, min_level=min_level, session_filter=session,
|
||||
since=since_dt)
|
||||
since=since_dt, component_prefixes=component_prefixes)
|
||||
except KeyboardInterrupt:
|
||||
print("\n--- stopped ---")
|
||||
|
||||
@@ -203,6 +254,7 @@ def _read_tail(
|
||||
min_level: Optional[str] = None,
|
||||
session_filter: Optional[str] = None,
|
||||
since: Optional[datetime] = None,
|
||||
component_prefixes: Optional[Sequence[str]] = None,
|
||||
) -> list:
|
||||
"""Read the last *num_lines* matching lines from a log file.
|
||||
|
||||
@@ -215,7 +267,8 @@ def _read_tail(
|
||||
filtered = [
|
||||
l for l in raw_lines
|
||||
if _matches_filters(l, min_level=min_level,
|
||||
session_filter=session_filter, since=since)
|
||||
session_filter=session_filter, since=since,
|
||||
component_prefixes=component_prefixes)
|
||||
]
|
||||
return filtered[-num_lines:]
|
||||
else:
|
||||
@@ -284,6 +337,7 @@ def _follow_log(
|
||||
min_level: Optional[str] = None,
|
||||
session_filter: Optional[str] = None,
|
||||
since: Optional[datetime] = None,
|
||||
component_prefixes: Optional[Sequence[str]] = None,
|
||||
) -> None:
|
||||
"""Poll a log file for new content and print matching lines."""
|
||||
with open(path, "r", encoding="utf-8", errors="replace") as f:
|
||||
@@ -293,7 +347,8 @@ def _follow_log(
|
||||
line = f.readline()
|
||||
if line:
|
||||
if _matches_filters(line, min_level=min_level,
|
||||
session_filter=session_filter, since=since):
|
||||
session_filter=session_filter, since=since,
|
||||
component_prefixes=component_prefixes):
|
||||
print(line, end="")
|
||||
sys.stdout.flush()
|
||||
else:
|
||||
|
||||
+175
-13
@@ -528,6 +528,113 @@ def _resolve_last_cli_session() -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def _probe_container(cmd: list, backend: str, via_sudo: bool = False):
|
||||
"""Run a container inspect probe, returning the CompletedProcess.
|
||||
|
||||
Catches TimeoutExpired specifically for a human-readable message;
|
||||
all other exceptions propagate naturally.
|
||||
"""
|
||||
try:
|
||||
return subprocess.run(cmd, capture_output=True, text=True, timeout=15)
|
||||
except subprocess.TimeoutExpired:
|
||||
label = f"sudo {backend}" if via_sudo else backend
|
||||
print(
|
||||
f"Error: timed out waiting for {label} to respond.\n"
|
||||
f"The {backend} daemon may be unresponsive or starting up.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def _exec_in_container(container_info: dict, cli_args: list):
|
||||
"""Replace the current process with a command inside the managed container.
|
||||
|
||||
Probes whether sudo is needed (rootful containers), then os.execvp
|
||||
into the container. On success the Python process is replaced entirely
|
||||
and the container's exit code becomes the process exit code (OS semantics).
|
||||
On failure, OSError propagates naturally.
|
||||
|
||||
Args:
|
||||
container_info: dict with backend, container_name, exec_user, hermes_bin
|
||||
cli_args: the original CLI arguments (everything after 'hermes')
|
||||
"""
|
||||
import shutil
|
||||
|
||||
backend = container_info["backend"]
|
||||
container_name = container_info["container_name"]
|
||||
exec_user = container_info["exec_user"]
|
||||
hermes_bin = container_info["hermes_bin"]
|
||||
|
||||
runtime = shutil.which(backend)
|
||||
if not runtime:
|
||||
print(f"Error: {backend} not found on PATH. Cannot route to container.",
|
||||
file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Rootful containers (NixOS systemd service) are invisible to unprivileged
|
||||
# users — Podman uses per-user namespaces, Docker needs group access.
|
||||
# Probe whether the runtime can see the container; if not, try via sudo.
|
||||
sudo_path = None
|
||||
probe = _probe_container(
|
||||
[runtime, "inspect", "--format", "ok", container_name], backend,
|
||||
)
|
||||
if probe.returncode != 0:
|
||||
sudo_path = shutil.which("sudo")
|
||||
if sudo_path:
|
||||
probe2 = _probe_container(
|
||||
[sudo_path, "-n", runtime, "inspect", "--format", "ok", container_name],
|
||||
backend, via_sudo=True,
|
||||
)
|
||||
if probe2.returncode != 0:
|
||||
print(
|
||||
f"Error: container '{container_name}' not found via {backend}.\n"
|
||||
f"\n"
|
||||
f"The container is likely running as root. Your user cannot see it\n"
|
||||
f"because {backend} uses per-user namespaces. Grant passwordless\n"
|
||||
f"sudo for {backend} — the -n (non-interactive) flag is required\n"
|
||||
f"because a password prompt would hang or break piped commands.\n"
|
||||
f"\n"
|
||||
f"On NixOS:\n"
|
||||
f"\n"
|
||||
f' security.sudo.extraRules = [{{\n'
|
||||
f' users = [ "{os.getenv("USER", "your-user")}" ];\n'
|
||||
f' commands = [{{ command = "{runtime}"; options = [ "NOPASSWD" ]; }}];\n'
|
||||
f' }}];\n'
|
||||
f"\n"
|
||||
f"Or run: sudo hermes {' '.join(cli_args)}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
else:
|
||||
print(
|
||||
f"Error: container '{container_name}' not found via {backend}.\n"
|
||||
f"The container may be running under root. Try: sudo hermes {' '.join(cli_args)}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
is_tty = sys.stdin.isatty()
|
||||
tty_flags = ["-it"] if is_tty else ["-i"]
|
||||
|
||||
env_flags = []
|
||||
for var in ("TERM", "COLORTERM", "LANG", "LC_ALL"):
|
||||
val = os.environ.get(var)
|
||||
if val:
|
||||
env_flags.extend(["-e", f"{var}={val}"])
|
||||
|
||||
cmd_prefix = [sudo_path, "-n", runtime] if sudo_path else [runtime]
|
||||
exec_cmd = (
|
||||
cmd_prefix + ["exec"]
|
||||
+ tty_flags
|
||||
+ ["-u", exec_user]
|
||||
+ env_flags
|
||||
+ [container_name, hermes_bin]
|
||||
+ cli_args
|
||||
)
|
||||
|
||||
os.execvp(exec_cmd[0], exec_cmd)
|
||||
|
||||
|
||||
def _resolve_session_by_name_or_id(name_or_id: str) -> Optional[str]:
|
||||
"""Resolve a session name (title) or ID to a session ID.
|
||||
|
||||
@@ -934,6 +1041,7 @@ def select_provider_and_model(args=None):
|
||||
"kilocode": "Kilo Code",
|
||||
"alibaba": "Alibaba Cloud (DashScope)",
|
||||
"huggingface": "Hugging Face",
|
||||
"xiaomi": "Xiaomi MiMo",
|
||||
"custom": "Custom endpoint",
|
||||
}
|
||||
active_label = provider_labels.get(active, active) if active else "none"
|
||||
@@ -966,6 +1074,7 @@ def select_provider_and_model(args=None):
|
||||
("opencode-go", "OpenCode Go (open models, $10/month subscription)"),
|
||||
("ai-gateway", "AI Gateway (Vercel — 200+ models, pay-per-use)"),
|
||||
("alibaba", "Alibaba Cloud / DashScope Coding (Qwen + multi-provider)"),
|
||||
("xiaomi", "Xiaomi MiMo (MiMo-V2 models — pro, omni, flash)"),
|
||||
]
|
||||
|
||||
def _named_custom_provider_map(cfg) -> dict[str, dict[str, str]]:
|
||||
@@ -1077,9 +1186,45 @@ def select_provider_and_model(args=None):
|
||||
_model_flow_anthropic(config, current_model)
|
||||
elif selected_provider == "kimi-coding":
|
||||
_model_flow_kimi(config, current_model)
|
||||
elif selected_provider in ("gemini", "zai", "minimax", "minimax-cn", "kilocode", "opencode-zen", "opencode-go", "ai-gateway", "alibaba", "huggingface"):
|
||||
elif selected_provider in ("gemini", "zai", "minimax", "minimax-cn", "kilocode", "opencode-zen", "opencode-go", "ai-gateway", "alibaba", "huggingface", "xiaomi"):
|
||||
_model_flow_api_key_provider(config, selected_provider, current_model)
|
||||
|
||||
# ── Post-switch cleanup: clear stale OPENAI_BASE_URL ──────────────
|
||||
# When the user switches to a named provider (anything except "custom"),
|
||||
# a leftover OPENAI_BASE_URL in ~/.hermes/.env can poison auxiliary
|
||||
# clients that use provider:auto. Clear it proactively. (#5161)
|
||||
if selected_provider not in ("custom", "cancel", "remove-custom") \
|
||||
and not selected_provider.startswith("custom:"):
|
||||
_clear_stale_openai_base_url()
|
||||
|
||||
|
||||
def _clear_stale_openai_base_url():
|
||||
"""Remove OPENAI_BASE_URL from ~/.hermes/.env if the active provider is not 'custom'.
|
||||
|
||||
After a provider switch, a leftover OPENAI_BASE_URL causes auxiliary
|
||||
clients (compression, vision, delegation) with provider:auto to route
|
||||
requests to the old custom endpoint instead of the newly selected
|
||||
provider. See issue #5161.
|
||||
"""
|
||||
from hermes_cli.config import get_env_value, save_env_value, load_config
|
||||
|
||||
cfg = load_config()
|
||||
model_cfg = cfg.get("model", {})
|
||||
if isinstance(model_cfg, dict):
|
||||
provider = (model_cfg.get("provider") or "").strip().lower()
|
||||
else:
|
||||
provider = ""
|
||||
|
||||
if provider == "custom" or not provider:
|
||||
return # custom provider legitimately uses OPENAI_BASE_URL
|
||||
|
||||
stale_url = get_env_value("OPENAI_BASE_URL")
|
||||
if stale_url:
|
||||
save_env_value("OPENAI_BASE_URL", "")
|
||||
print(f"Cleared stale OPENAI_BASE_URL from .env (was: {stale_url[:40]}...)"
|
||||
if len(stale_url) > 40
|
||||
else f"Cleared stale OPENAI_BASE_URL from .env (was: {stale_url})")
|
||||
|
||||
|
||||
def _prompt_provider_choice(choices, *, default=0):
|
||||
"""Show provider selection menu with curses arrow-key navigation.
|
||||
@@ -2511,13 +2656,8 @@ def _model_flow_anthropic(config, current_model=""):
|
||||
from hermes_cli.models import _PROVIDER_MODELS
|
||||
|
||||
# Check ALL credential sources
|
||||
existing_key = (
|
||||
get_env_value("ANTHROPIC_TOKEN")
|
||||
or os.getenv("ANTHROPIC_TOKEN", "")
|
||||
or get_env_value("ANTHROPIC_API_KEY")
|
||||
or os.getenv("ANTHROPIC_API_KEY", "")
|
||||
or os.getenv("CLAUDE_CODE_OAUTH_TOKEN", "")
|
||||
)
|
||||
from hermes_cli.auth import get_anthropic_key
|
||||
existing_key = get_anthropic_key()
|
||||
cc_available = False
|
||||
try:
|
||||
from agent.anthropic_adapter import read_claude_code_credentials, is_claude_code_token_valid
|
||||
@@ -3843,7 +3983,7 @@ def cmd_update(args):
|
||||
# Exclude PIDs that belong to just-restarted services so we don't
|
||||
# immediately kill the process that systemd/launchd just spawned.
|
||||
service_pids = _get_service_pids()
|
||||
manual_pids = find_gateway_pids(exclude_pids=service_pids)
|
||||
manual_pids = find_gateway_pids(exclude_pids=service_pids, all_profiles=True)
|
||||
for pid in manual_pids:
|
||||
try:
|
||||
os.kill(pid, _signal.SIGTERM)
|
||||
@@ -4198,6 +4338,7 @@ def cmd_logs(args):
|
||||
level=getattr(args, "level", None),
|
||||
session=getattr(args, "session", None),
|
||||
since=getattr(args, "since", None),
|
||||
component=getattr(args, "component", None),
|
||||
)
|
||||
|
||||
|
||||
@@ -4321,7 +4462,7 @@ For more help on a command:
|
||||
)
|
||||
chat_parser.add_argument(
|
||||
"--provider",
|
||||
choices=["auto", "openrouter", "nous", "openai-codex", "copilot-acp", "copilot", "anthropic", "gemini", "huggingface", "zai", "kimi-coding", "minimax", "minimax-cn", "kilocode"],
|
||||
choices=["auto", "openrouter", "nous", "openai-codex", "copilot-acp", "copilot", "anthropic", "gemini", "huggingface", "zai", "kimi-coding", "minimax", "minimax-cn", "kilocode", "xiaomi"],
|
||||
default=None,
|
||||
help="Inference provider (default: auto)"
|
||||
)
|
||||
@@ -5113,6 +5254,8 @@ For more help on a command:
|
||||
mcp_add_p.add_argument("--command", help="Stdio command (e.g. npx)")
|
||||
mcp_add_p.add_argument("--args", nargs="*", default=[], help="Arguments for stdio command")
|
||||
mcp_add_p.add_argument("--auth", choices=["oauth", "header"], help="Auth method")
|
||||
mcp_add_p.add_argument("--preset", help="Known MCP preset name")
|
||||
mcp_add_p.add_argument("--env", nargs="*", default=[], help="Environment variables for stdio servers (KEY=VALUE)")
|
||||
|
||||
mcp_rm_p = mcp_sub.add_parser("remove", aliases=["rm"], help="Remove an MCP server")
|
||||
mcp_rm_p.add_argument("name", help="Server name to remove")
|
||||
@@ -5375,7 +5518,8 @@ For more help on a command:
|
||||
claw_migrate = claw_subparsers.add_parser(
|
||||
"migrate",
|
||||
help="Migrate from OpenClaw to Hermes",
|
||||
description="Import settings, memories, skills, and API keys from an OpenClaw installation"
|
||||
description="Import settings, memories, skills, and API keys from an OpenClaw installation. "
|
||||
"Always shows a preview before making changes."
|
||||
)
|
||||
claw_migrate.add_argument(
|
||||
"--source",
|
||||
@@ -5384,7 +5528,7 @@ For more help on a command:
|
||||
claw_migrate.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Preview what would be migrated without making changes"
|
||||
help="Preview only — stop after showing what would be migrated"
|
||||
)
|
||||
claw_migrate.add_argument(
|
||||
"--preset",
|
||||
@@ -5594,6 +5738,7 @@ Examples:
|
||||
hermes logs gateway -n 100 Show last 100 lines of gateway.log
|
||||
hermes logs --level WARNING Only show WARNING and above
|
||||
hermes logs --session abc123 Filter by session ID
|
||||
hermes logs --component tools Only show tool-related lines
|
||||
hermes logs --since 1h Lines from the last hour
|
||||
hermes logs --since 30m -f Follow, starting from 30 min ago
|
||||
hermes logs list List available log files with sizes
|
||||
@@ -5623,6 +5768,10 @@ Examples:
|
||||
"--since", metavar="TIME",
|
||||
help="Show lines since TIME ago (e.g. 1h, 30m, 2d)",
|
||||
)
|
||||
logs_parser.add_argument(
|
||||
"--component", metavar="NAME",
|
||||
help="Filter by component: gateway, agent, tools, cli, cron",
|
||||
)
|
||||
logs_parser.set_defaults(func=cmd_logs)
|
||||
|
||||
# =========================================================================
|
||||
@@ -5631,9 +5780,22 @@ Examples:
|
||||
# Pre-process argv so unquoted multi-word session names after -c / -r
|
||||
# are merged into a single token before argparse sees them.
|
||||
# e.g. ``hermes -c Pokemon Agent Dev`` → ``hermes -c 'Pokemon Agent Dev'``
|
||||
# ── Container-aware routing ────────────────────────────────────────
|
||||
# When NixOS container mode is active, route ALL subcommands into
|
||||
# the managed container. This MUST run before parse_args() so that
|
||||
# --help, unrecognised flags, and every subcommand are forwarded
|
||||
# transparently instead of being intercepted by argparse on the host.
|
||||
from hermes_cli.config import get_container_exec_info
|
||||
container_info = get_container_exec_info()
|
||||
if container_info:
|
||||
_exec_in_container(container_info, sys.argv[1:])
|
||||
# Unreachable: os.execvp never returns on success (process is replaced)
|
||||
# and raises OSError on failure (which propagates as a traceback).
|
||||
sys.exit(1)
|
||||
|
||||
_processed_argv = _coalesce_session_name_args(sys.argv[1:])
|
||||
args = parser.parse_args(_processed_argv)
|
||||
|
||||
|
||||
# Handle --version flag
|
||||
if args.version:
|
||||
cmd_version(args)
|
||||
|
||||
+87
-16
@@ -9,7 +9,6 @@ configuration in ~/.hermes/config.yaml under the ``mcp_servers`` key.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import getpass
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
@@ -28,6 +27,11 @@ from hermes_constants import display_hermes_home
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ENV_VAR_NAME_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
||||
|
||||
|
||||
_MCP_PRESETS: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
|
||||
# ─── UI Helpers ───────────────────────────────────────────────────────────────
|
||||
|
||||
@@ -57,19 +61,8 @@ def _confirm(question: str, default: bool = True) -> bool:
|
||||
|
||||
|
||||
def _prompt(question: str, *, password: bool = False, default: str = "") -> str:
|
||||
display = f" {question}"
|
||||
if default:
|
||||
display += f" [{default}]"
|
||||
display += ": "
|
||||
try:
|
||||
if password:
|
||||
value = getpass.getpass(color(display, Colors.YELLOW))
|
||||
else:
|
||||
value = input(color(display, Colors.YELLOW))
|
||||
return value.strip() or default
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return default
|
||||
from hermes_cli.cli_output import prompt as _shared_prompt
|
||||
return _shared_prompt(question, default=default, password=password)
|
||||
|
||||
|
||||
# ─── Config Helpers ───────────────────────────────────────────────────────────
|
||||
@@ -109,6 +102,59 @@ def _env_key_for_server(name: str) -> str:
|
||||
return f"MCP_{name.upper().replace('-', '_')}_API_KEY"
|
||||
|
||||
|
||||
def _parse_env_assignments(raw_env: Optional[List[str]]) -> Dict[str, str]:
|
||||
"""Parse ``KEY=VALUE`` strings from CLI args into an env dict."""
|
||||
parsed: Dict[str, str] = {}
|
||||
for item in raw_env or []:
|
||||
text = str(item or "").strip()
|
||||
if not text:
|
||||
continue
|
||||
if "=" not in text:
|
||||
raise ValueError(f"Invalid --env value '{text}' (expected KEY=VALUE)")
|
||||
key, value = text.split("=", 1)
|
||||
key = key.strip()
|
||||
if not key:
|
||||
raise ValueError(f"Invalid --env value '{text}' (missing variable name)")
|
||||
if not _ENV_VAR_NAME_RE.match(key):
|
||||
raise ValueError(f"Invalid --env variable name '{key}'")
|
||||
parsed[key] = value
|
||||
return parsed
|
||||
|
||||
|
||||
def _apply_mcp_preset(
|
||||
name: str,
|
||||
*,
|
||||
preset_name: Optional[str],
|
||||
url: Optional[str],
|
||||
command: Optional[str],
|
||||
cmd_args: List[str],
|
||||
server_config: Dict[str, Any],
|
||||
) -> tuple[Optional[str], Optional[str], List[str], bool]:
|
||||
"""Apply a known MCP preset when transport details were omitted."""
|
||||
if not preset_name:
|
||||
return url, command, cmd_args, False
|
||||
|
||||
preset = _MCP_PRESETS.get(preset_name)
|
||||
if not preset:
|
||||
raise ValueError(f"Unknown MCP preset: {preset_name}")
|
||||
|
||||
if url or command:
|
||||
return url, command, cmd_args, False
|
||||
|
||||
url = preset.get("url")
|
||||
command = preset.get("command")
|
||||
cmd_args = list(preset.get("args") or [])
|
||||
|
||||
if url:
|
||||
server_config["url"] = url
|
||||
if command:
|
||||
server_config["command"] = command
|
||||
if cmd_args:
|
||||
server_config["args"] = cmd_args
|
||||
|
||||
return url, command, cmd_args, True
|
||||
|
||||
|
||||
# ─── Discovery (temporary connect) ───────────────────────────────────────────
|
||||
|
||||
def _probe_single_server(
|
||||
@@ -177,13 +223,35 @@ def cmd_mcp_add(args):
|
||||
command = getattr(args, "command", None)
|
||||
cmd_args = getattr(args, "args", None) or []
|
||||
auth_type = getattr(args, "auth", None)
|
||||
preset_name = getattr(args, "preset", None)
|
||||
raw_env = getattr(args, "env", None)
|
||||
|
||||
server_config: Dict[str, Any] = {}
|
||||
try:
|
||||
explicit_env = _parse_env_assignments(raw_env)
|
||||
url, command, cmd_args, _preset_applied = _apply_mcp_preset(
|
||||
name,
|
||||
preset_name=preset_name,
|
||||
url=url,
|
||||
command=command,
|
||||
cmd_args=list(cmd_args),
|
||||
server_config=server_config,
|
||||
)
|
||||
except ValueError as exc:
|
||||
_error(str(exc))
|
||||
return
|
||||
|
||||
if url and explicit_env:
|
||||
_error("--env is only supported for stdio MCP servers (--command or stdio presets)")
|
||||
return
|
||||
|
||||
# Validate transport
|
||||
if not url and not command:
|
||||
_error("Must specify --url <endpoint> or --command <cmd>")
|
||||
_error("Must specify --url <endpoint>, --command <cmd>, or --preset <name>")
|
||||
_info("Examples:")
|
||||
_info(' hermes mcp add ink --url "https://mcp.ml.ink/mcp"')
|
||||
_info(' hermes mcp add github --command npx --args @modelcontextprotocol/server-github')
|
||||
_info(' hermes mcp add myserver --preset mypreset')
|
||||
return
|
||||
|
||||
# Check if server already exists
|
||||
@@ -194,13 +262,15 @@ def cmd_mcp_add(args):
|
||||
return
|
||||
|
||||
# Build initial config
|
||||
server_config: Dict[str, Any] = {}
|
||||
if url:
|
||||
server_config["url"] = url
|
||||
else:
|
||||
server_config["command"] = command
|
||||
if cmd_args:
|
||||
server_config["args"] = cmd_args
|
||||
if explicit_env:
|
||||
server_config["env"] = explicit_env
|
||||
|
||||
|
||||
# ── Authentication ────────────────────────────────────────────────
|
||||
|
||||
@@ -638,6 +708,7 @@ def mcp_command(args):
|
||||
_info("hermes mcp serve Run as MCP server")
|
||||
_info("hermes mcp add <name> --url <endpoint> Add an MCP server")
|
||||
_info("hermes mcp add <name> --command <cmd> Add a stdio server")
|
||||
_info("hermes mcp add <name> --preset <preset> Add from a known preset")
|
||||
_info("hermes mcp remove <name> Remove a server")
|
||||
_info("hermes mcp list List servers")
|
||||
_info("hermes mcp test <name> Test connection")
|
||||
|
||||
@@ -25,85 +25,13 @@ def _curses_select(title: str, items: list[tuple[str, str]], default: int = 0) -
|
||||
items: list of (label, description) tuples.
|
||||
Returns selected index, or default on escape/quit.
|
||||
"""
|
||||
try:
|
||||
import curses
|
||||
result = [default]
|
||||
|
||||
def _menu(stdscr):
|
||||
curses.curs_set(0)
|
||||
if curses.has_colors():
|
||||
curses.start_color()
|
||||
curses.use_default_colors()
|
||||
curses.init_pair(1, curses.COLOR_GREEN, -1)
|
||||
curses.init_pair(2, curses.COLOR_YELLOW, -1)
|
||||
curses.init_pair(3, curses.COLOR_CYAN, -1)
|
||||
cursor = default
|
||||
|
||||
while True:
|
||||
stdscr.clear()
|
||||
max_y, max_x = stdscr.getmaxyx()
|
||||
|
||||
# Title
|
||||
try:
|
||||
stdscr.addnstr(0, 0, title, max_x - 1,
|
||||
curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0))
|
||||
stdscr.addnstr(1, 0, " ↑↓ navigate ⏎ select q quit", max_x - 1,
|
||||
curses.color_pair(3) if curses.has_colors() else curses.A_DIM)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
for i, (label, desc) in enumerate(items):
|
||||
y = i + 3
|
||||
if y >= max_y - 1:
|
||||
break
|
||||
arrow = "→" if i == cursor else " "
|
||||
line = f" {arrow} {label}"
|
||||
if desc:
|
||||
line += f" {desc}"
|
||||
|
||||
attr = curses.A_NORMAL
|
||||
if i == cursor:
|
||||
attr = curses.A_BOLD
|
||||
if curses.has_colors():
|
||||
attr |= curses.color_pair(1)
|
||||
try:
|
||||
stdscr.addnstr(y, 0, line[:max_x - 1], max_x - 1, attr)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
stdscr.refresh()
|
||||
key = stdscr.getch()
|
||||
|
||||
if key in (curses.KEY_UP, ord('k')):
|
||||
cursor = (cursor - 1) % len(items)
|
||||
elif key in (curses.KEY_DOWN, ord('j')):
|
||||
cursor = (cursor + 1) % len(items)
|
||||
elif key in (curses.KEY_ENTER, 10, 13):
|
||||
result[0] = cursor
|
||||
return
|
||||
elif key in (27, ord('q')):
|
||||
return
|
||||
|
||||
curses.wrapper(_menu)
|
||||
return result[0]
|
||||
|
||||
except Exception:
|
||||
# Fallback: numbered input
|
||||
print(f"\n {title}\n")
|
||||
for i, (label, desc) in enumerate(items):
|
||||
marker = "→" if i == default else " "
|
||||
d = f" {desc}" if desc else ""
|
||||
print(f" {marker} {i + 1}. {label}{d}")
|
||||
while True:
|
||||
try:
|
||||
val = input(f"\n Select [1-{len(items)}] ({default + 1}): ")
|
||||
if not val:
|
||||
return default
|
||||
idx = int(val) - 1
|
||||
if 0 <= idx < len(items):
|
||||
return idx
|
||||
except (ValueError, EOFError):
|
||||
return default
|
||||
from hermes_cli.curses_ui import curses_radiolist
|
||||
# Format (label, desc) tuples into display strings
|
||||
display_items = [
|
||||
f"{label} {desc}" if desc else label
|
||||
for label, desc in items
|
||||
]
|
||||
return curses_radiolist(title, display_items, selected=default, cancel_returns=default)
|
||||
|
||||
|
||||
def _prompt(label: str, default: str | None = None, secret: bool = False) -> str:
|
||||
|
||||
@@ -74,13 +74,13 @@ _DOT_TO_HYPHEN_PROVIDERS: frozenset[str] = frozenset({
|
||||
_STRIP_VENDOR_ONLY_PROVIDERS: frozenset[str] = frozenset({
|
||||
"copilot",
|
||||
"copilot-acp",
|
||||
"openai-codex",
|
||||
})
|
||||
|
||||
# Providers whose native naming is authoritative -- pass through unchanged.
|
||||
_AUTHORITATIVE_NATIVE_PROVIDERS: frozenset[str] = frozenset({
|
||||
"gemini",
|
||||
"huggingface",
|
||||
"openai-codex",
|
||||
})
|
||||
|
||||
# Direct providers that accept bare native names but should repair a matching
|
||||
@@ -92,6 +92,7 @@ _MATCHING_PREFIX_STRIP_PROVIDERS: frozenset[str] = frozenset({
|
||||
"minimax-cn",
|
||||
"alibaba",
|
||||
"qwen-oauth",
|
||||
"xiaomi",
|
||||
"custom",
|
||||
})
|
||||
|
||||
@@ -359,7 +360,11 @@ def normalize_model_for_provider(model_input: str, target_provider: str) -> str:
|
||||
|
||||
# --- Copilot: strip matching provider prefix, keep dots ---
|
||||
if provider in _STRIP_VENDOR_ONLY_PROVIDERS:
|
||||
return _strip_matching_provider_prefix(name, provider)
|
||||
stripped = _strip_matching_provider_prefix(name, provider)
|
||||
if stripped == name and name.startswith("openai/"):
|
||||
# openai-codex maps openai/gpt-5.4 -> gpt-5.4
|
||||
return name.split("/", 1)[1]
|
||||
return stripped
|
||||
|
||||
# --- DeepSeek: map to one of two canonical names ---
|
||||
if provider == "deepseek":
|
||||
|
||||
+51
-7
@@ -56,6 +56,18 @@ OPENROUTER_MODELS: list[tuple[str, str]] = [
|
||||
|
||||
_openrouter_catalog_cache: list[tuple[str, str]] | None = None
|
||||
|
||||
|
||||
def _codex_curated_models() -> list[str]:
|
||||
"""Derive the openai-codex curated list from codex_models.py.
|
||||
|
||||
Single source of truth: DEFAULT_CODEX_MODELS + forward-compat synthesis.
|
||||
This keeps the gateway /model picker in sync with the CLI `hermes model`
|
||||
flow without maintaining a separate static list.
|
||||
"""
|
||||
from hermes_cli.codex_models import DEFAULT_CODEX_MODELS, _add_forward_compat_models
|
||||
return _add_forward_compat_models(list(DEFAULT_CODEX_MODELS))
|
||||
|
||||
|
||||
_PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"nous": [
|
||||
"anthropic/claude-opus-4.6",
|
||||
@@ -86,12 +98,7 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"openai/gpt-5.4-pro",
|
||||
"openai/gpt-5.4-nano",
|
||||
],
|
||||
"openai-codex": [
|
||||
"gpt-5.3-codex",
|
||||
"gpt-5.2-codex",
|
||||
"gpt-5.1-codex-mini",
|
||||
"gpt-5.1-codex-max",
|
||||
],
|
||||
"openai-codex": _codex_curated_models(),
|
||||
"copilot-acp": [
|
||||
"copilot-acp",
|
||||
],
|
||||
@@ -181,6 +188,11 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"deepseek-chat",
|
||||
"deepseek-reasoner",
|
||||
],
|
||||
"xiaomi": [
|
||||
"mimo-v2-pro",
|
||||
"mimo-v2-omni",
|
||||
"mimo-v2-flash",
|
||||
],
|
||||
"opencode-zen": [
|
||||
"gpt-5.4-pro",
|
||||
"gpt-5.4",
|
||||
@@ -486,6 +498,7 @@ _PROVIDER_LABELS = {
|
||||
"alibaba": "Alibaba Cloud (DashScope)",
|
||||
"qwen-oauth": "Qwen OAuth (Portal)",
|
||||
"huggingface": "Hugging Face",
|
||||
"xiaomi": "Xiaomi MiMo",
|
||||
"custom": "Custom endpoint",
|
||||
}
|
||||
|
||||
@@ -528,6 +541,8 @@ _PROVIDER_ALIASES = {
|
||||
"hf": "huggingface",
|
||||
"hugging-face": "huggingface",
|
||||
"huggingface-hub": "huggingface",
|
||||
"mimo": "xiaomi",
|
||||
"xiaomi-mimo": "xiaomi",
|
||||
}
|
||||
|
||||
|
||||
@@ -812,7 +827,7 @@ def list_available_providers() -> list[dict[str, str]]:
|
||||
"openrouter", "nous", "openai-codex", "copilot", "copilot-acp",
|
||||
"gemini", "huggingface",
|
||||
"zai", "kimi-coding", "minimax", "minimax-cn", "kilocode", "anthropic", "alibaba",
|
||||
"qwen-oauth",
|
||||
"qwen-oauth", "xiaomi",
|
||||
"opencode-zen", "opencode-go",
|
||||
"ai-gateway", "deepseek", "custom",
|
||||
]
|
||||
@@ -1794,6 +1809,35 @@ def validate_requested_model(
|
||||
"message": message,
|
||||
}
|
||||
|
||||
# OpenAI Codex has its own catalog path; /v1/models probing is not the right validation path.
|
||||
if normalized == "openai-codex":
|
||||
try:
|
||||
codex_models = provider_model_ids("openai-codex")
|
||||
except Exception:
|
||||
codex_models = []
|
||||
if codex_models:
|
||||
if requested_for_lookup in set(codex_models):
|
||||
return {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"recognized": True,
|
||||
"message": None,
|
||||
}
|
||||
suggestions = get_close_matches(requested_for_lookup, codex_models, n=3, cutoff=0.5)
|
||||
suggestion_text = ""
|
||||
if suggestions:
|
||||
suggestion_text = "\n Similar models: " + ", ".join(f"`{s}`" for s in suggestions)
|
||||
return {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"recognized": False,
|
||||
"message": (
|
||||
f"Note: `{requested}` was not found in the OpenAI Codex model listing. "
|
||||
f"It may still work if your account has access to it."
|
||||
f"{suggestion_text}"
|
||||
),
|
||||
}
|
||||
|
||||
# Probe the live API to check if the model actually exists
|
||||
api_models = fetch_api_models(api_key, base_url)
|
||||
|
||||
|
||||
@@ -143,6 +143,7 @@ def _tts_label(current_provider: str) -> str:
|
||||
"openai": "OpenAI TTS",
|
||||
"elevenlabs": "ElevenLabs",
|
||||
"edge": "Edge TTS",
|
||||
"mistral": "Mistral Voxtral TTS",
|
||||
"neutts": "NeuTTS",
|
||||
}
|
||||
return mapping.get(current_provider or "edge", current_provider or "Edge TTS")
|
||||
@@ -309,6 +310,7 @@ def get_nous_subscription_features(
|
||||
tts_current_provider in {"edge", "neutts"}
|
||||
or (tts_current_provider == "openai" and (managed_tts_available or direct_openai_tts))
|
||||
or (tts_current_provider == "elevenlabs" and direct_elevenlabs)
|
||||
or (tts_current_provider == "mistral" and bool(get_env_value("MISTRAL_API_KEY")))
|
||||
)
|
||||
tts_active = bool(tts_tool_enabled and tts_available)
|
||||
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
Shared platform registry for Hermes Agent.
|
||||
|
||||
Single source of truth for platform metadata consumed by both
|
||||
skills_config (label display) and tools_config (default toolset
|
||||
resolution). Import ``PLATFORMS`` from here instead of maintaining
|
||||
duplicate dicts in each module.
|
||||
"""
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import NamedTuple
|
||||
|
||||
|
||||
class PlatformInfo(NamedTuple):
|
||||
"""Metadata for a single platform entry."""
|
||||
label: str
|
||||
default_toolset: str
|
||||
|
||||
|
||||
# Ordered so that TUI menus are deterministic.
|
||||
PLATFORMS: OrderedDict[str, PlatformInfo] = OrderedDict([
|
||||
("cli", PlatformInfo(label="🖥️ CLI", default_toolset="hermes-cli")),
|
||||
("telegram", PlatformInfo(label="📱 Telegram", default_toolset="hermes-telegram")),
|
||||
("discord", PlatformInfo(label="💬 Discord", default_toolset="hermes-discord")),
|
||||
("slack", PlatformInfo(label="💼 Slack", default_toolset="hermes-slack")),
|
||||
("whatsapp", PlatformInfo(label="📱 WhatsApp", default_toolset="hermes-whatsapp")),
|
||||
("signal", PlatformInfo(label="📡 Signal", default_toolset="hermes-signal")),
|
||||
("bluebubbles", PlatformInfo(label="💙 BlueBubbles", default_toolset="hermes-bluebubbles")),
|
||||
("email", PlatformInfo(label="📧 Email", default_toolset="hermes-email")),
|
||||
("homeassistant", PlatformInfo(label="🏠 Home Assistant", default_toolset="hermes-homeassistant")),
|
||||
("mattermost", PlatformInfo(label="💬 Mattermost", default_toolset="hermes-mattermost")),
|
||||
("matrix", PlatformInfo(label="💬 Matrix", default_toolset="hermes-matrix")),
|
||||
("dingtalk", PlatformInfo(label="💬 DingTalk", default_toolset="hermes-dingtalk")),
|
||||
("feishu", PlatformInfo(label="🪽 Feishu", default_toolset="hermes-feishu")),
|
||||
("wecom", PlatformInfo(label="💬 WeCom", default_toolset="hermes-wecom")),
|
||||
("wecom_callback", PlatformInfo(label="💬 WeCom Callback", default_toolset="hermes-wecom-callback")),
|
||||
("weixin", PlatformInfo(label="💬 Weixin", default_toolset="hermes-weixin")),
|
||||
("webhook", PlatformInfo(label="🔗 Webhook", default_toolset="hermes-webhook")),
|
||||
("api_server", PlatformInfo(label="🌐 API Server", default_toolset="hermes-api-server")),
|
||||
])
|
||||
|
||||
|
||||
def platform_label(key: str, default: str = "") -> str:
|
||||
"""Return the display label for a platform key, or *default*."""
|
||||
info = PLATFORMS.get(key)
|
||||
return info.label if info is not None else default
|
||||
@@ -132,6 +132,10 @@ HERMES_OVERLAYS: Dict[str, HermesOverlay] = {
|
||||
base_url_override="https://api.x.ai/v1",
|
||||
base_url_env_var="XAI_BASE_URL",
|
||||
),
|
||||
"xiaomi": HermesOverlay(
|
||||
transport="openai_chat",
|
||||
base_url_env_var="XIAOMI_BASE_URL",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -222,6 +226,10 @@ ALIASES: Dict[str, str] = {
|
||||
"hugging-face": "huggingface",
|
||||
"huggingface-hub": "huggingface",
|
||||
|
||||
# xiaomi
|
||||
"mimo": "xiaomi",
|
||||
"xiaomi-mimo": "xiaomi",
|
||||
|
||||
# Local server aliases → virtual "local" concept (resolved via user config)
|
||||
"lmstudio": "lmstudio",
|
||||
"lm-studio": "lmstudio",
|
||||
@@ -242,6 +250,7 @@ _LABEL_OVERRIDES: Dict[str, str] = {
|
||||
"nous": "Nous Portal",
|
||||
"openai-codex": "OpenAI Codex",
|
||||
"copilot-acp": "GitHub Copilot ACP",
|
||||
"xiaomi": "Xiaomi MiMo",
|
||||
"local": "Local endpoint",
|
||||
}
|
||||
|
||||
|
||||
@@ -304,6 +304,9 @@ def _get_named_custom_provider(requested_provider: str) -> Optional[Dict[str, An
|
||||
api_mode = _parse_api_mode(entry.get("api_mode"))
|
||||
if api_mode:
|
||||
result["api_mode"] = api_mode
|
||||
model_name = str(entry.get("model", "") or "").strip()
|
||||
if model_name:
|
||||
result["model"] = model_name
|
||||
return result
|
||||
|
||||
return None
|
||||
@@ -329,6 +332,11 @@ def _resolve_named_custom_runtime(
|
||||
# Check if a credential pool exists for this custom endpoint
|
||||
pool_result = _try_resolve_from_custom_pool(base_url, "custom", custom_provider.get("api_mode"))
|
||||
if pool_result:
|
||||
# Propagate the model name even when using pooled credentials —
|
||||
# the pool doesn't know about the custom_providers model field.
|
||||
model_name = custom_provider.get("model")
|
||||
if model_name:
|
||||
pool_result["model"] = model_name
|
||||
return pool_result
|
||||
|
||||
api_key_candidates = [
|
||||
@@ -339,7 +347,7 @@ def _resolve_named_custom_runtime(
|
||||
]
|
||||
api_key = next((candidate for candidate in api_key_candidates if has_usable_secret(candidate)), "")
|
||||
|
||||
return {
|
||||
result = {
|
||||
"provider": "custom",
|
||||
"api_mode": custom_provider.get("api_mode")
|
||||
or _detect_api_mode_for_url(base_url)
|
||||
@@ -348,6 +356,11 @@ def _resolve_named_custom_runtime(
|
||||
"api_key": api_key or "no-key-required",
|
||||
"source": f"custom_provider:{custom_provider.get('name', requested_provider)}",
|
||||
}
|
||||
# Propagate the model name so callers can override self.model when the
|
||||
# provider name differs from the actual model string the API expects.
|
||||
if custom_provider.get("model"):
|
||||
result["model"] = custom_provider["model"]
|
||||
return result
|
||||
|
||||
|
||||
def _resolve_openrouter_runtime(
|
||||
|
||||
+102
-95
@@ -197,24 +197,12 @@ def print_header(title: str):
|
||||
print(color(f"◆ {title}", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
|
||||
def print_info(text: str):
|
||||
"""Print info text."""
|
||||
print(color(f" {text}", Colors.DIM))
|
||||
|
||||
|
||||
def print_success(text: str):
|
||||
"""Print success message."""
|
||||
print(color(f"✓ {text}", Colors.GREEN))
|
||||
|
||||
|
||||
def print_warning(text: str):
|
||||
"""Print warning message."""
|
||||
print(color(f"⚠ {text}", Colors.YELLOW))
|
||||
|
||||
|
||||
def print_error(text: str):
|
||||
"""Print error message."""
|
||||
print(color(f"✗ {text}", Colors.RED))
|
||||
from hermes_cli.cli_output import ( # noqa: E402
|
||||
print_error,
|
||||
print_info,
|
||||
print_success,
|
||||
print_warning,
|
||||
)
|
||||
|
||||
|
||||
def is_interactive_stdin() -> bool:
|
||||
@@ -269,80 +257,9 @@ def prompt(question: str, default: str = None, password: bool = False) -> str:
|
||||
|
||||
|
||||
def _curses_prompt_choice(question: str, choices: list, default: int = 0) -> int:
|
||||
"""Single-select menu using curses to avoid simple_term_menu rendering bugs."""
|
||||
try:
|
||||
import curses
|
||||
result_holder = [default]
|
||||
|
||||
def _curses_menu(stdscr):
|
||||
curses.curs_set(0)
|
||||
if curses.has_colors():
|
||||
curses.start_color()
|
||||
curses.use_default_colors()
|
||||
curses.init_pair(1, curses.COLOR_GREEN, -1)
|
||||
curses.init_pair(2, curses.COLOR_YELLOW, -1)
|
||||
cursor = default
|
||||
scroll_offset = 0
|
||||
|
||||
while True:
|
||||
stdscr.clear()
|
||||
max_y, max_x = stdscr.getmaxyx()
|
||||
|
||||
# Rows available for list items: rows 2..(max_y-2) inclusive.
|
||||
visible = max(1, max_y - 3)
|
||||
|
||||
# Scroll the viewport so the cursor is always visible.
|
||||
if cursor < scroll_offset:
|
||||
scroll_offset = cursor
|
||||
elif cursor >= scroll_offset + visible:
|
||||
scroll_offset = cursor - visible + 1
|
||||
scroll_offset = max(0, min(scroll_offset, max(0, len(choices) - visible)))
|
||||
|
||||
try:
|
||||
stdscr.addnstr(
|
||||
0,
|
||||
0,
|
||||
question,
|
||||
max_x - 1,
|
||||
curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0),
|
||||
)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
for row, i in enumerate(range(scroll_offset, min(scroll_offset + visible, len(choices)))):
|
||||
y = row + 2
|
||||
if y >= max_y - 1:
|
||||
break
|
||||
arrow = "→" if i == cursor else " "
|
||||
line = f" {arrow} {choices[i]}"
|
||||
attr = curses.A_NORMAL
|
||||
if i == cursor:
|
||||
attr = curses.A_BOLD
|
||||
if curses.has_colors():
|
||||
attr |= curses.color_pair(1)
|
||||
try:
|
||||
stdscr.addnstr(y, 0, line, max_x - 1, attr)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
stdscr.refresh()
|
||||
key = stdscr.getch()
|
||||
if key in (curses.KEY_UP, ord("k")):
|
||||
cursor = (cursor - 1) % len(choices)
|
||||
elif key in (curses.KEY_DOWN, ord("j")):
|
||||
cursor = (cursor + 1) % len(choices)
|
||||
elif key in (curses.KEY_ENTER, 10, 13):
|
||||
result_holder[0] = cursor
|
||||
return
|
||||
elif key in (27, ord("q")):
|
||||
return
|
||||
|
||||
curses.wrapper(_curses_menu)
|
||||
from hermes_cli.curses_ui import flush_stdin
|
||||
flush_stdin()
|
||||
return result_holder[0]
|
||||
except Exception:
|
||||
return -1
|
||||
"""Single-select menu using curses. Delegates to curses_radiolist."""
|
||||
from hermes_cli.curses_ui import curses_radiolist
|
||||
return curses_radiolist(question, choices, selected=default, cancel_returns=-1)
|
||||
|
||||
|
||||
|
||||
@@ -557,6 +474,8 @@ def _print_setup_summary(config: dict, hermes_home):
|
||||
tool_status.append(("Text-to-Speech (OpenAI)", True, None))
|
||||
elif tts_provider == "minimax" and get_env_value("MINIMAX_API_KEY"):
|
||||
tool_status.append(("Text-to-Speech (MiniMax)", True, None))
|
||||
elif tts_provider == "mistral" and get_env_value("MISTRAL_API_KEY"):
|
||||
tool_status.append(("Text-to-Speech (Mistral Voxtral)", True, None))
|
||||
elif tts_provider == "neutts":
|
||||
try:
|
||||
import importlib.util
|
||||
@@ -1044,6 +963,7 @@ def _setup_tts_provider(config: dict):
|
||||
"elevenlabs": "ElevenLabs",
|
||||
"openai": "OpenAI TTS",
|
||||
"minimax": "MiniMax TTS",
|
||||
"mistral": "Mistral Voxtral TTS",
|
||||
"neutts": "NeuTTS",
|
||||
}
|
||||
current_label = provider_labels.get(current_provider, current_provider)
|
||||
@@ -1064,10 +984,11 @@ def _setup_tts_provider(config: dict):
|
||||
"ElevenLabs (premium quality, needs API key)",
|
||||
"OpenAI TTS (good quality, needs API key)",
|
||||
"MiniMax TTS (high quality with voice cloning, needs API key)",
|
||||
"Mistral Voxtral TTS (multilingual, native Opus, needs API key)",
|
||||
"NeuTTS (local on-device, free, ~300MB model download)",
|
||||
]
|
||||
)
|
||||
providers.extend(["edge", "elevenlabs", "openai", "minimax", "neutts"])
|
||||
providers.extend(["edge", "elevenlabs", "openai", "minimax", "mistral", "neutts"])
|
||||
choices.append(f"Keep current ({current_label})")
|
||||
keep_current_idx = len(choices) - 1
|
||||
idx = prompt_choice("Select TTS provider:", choices, keep_current_idx)
|
||||
@@ -1145,6 +1066,18 @@ def _setup_tts_provider(config: dict):
|
||||
print_warning("No API key provided. Falling back to Edge TTS.")
|
||||
selected = "edge"
|
||||
|
||||
elif selected == "mistral":
|
||||
existing = get_env_value("MISTRAL_API_KEY")
|
||||
if not existing:
|
||||
print()
|
||||
api_key = prompt("Mistral API key for TTS", password=True)
|
||||
if api_key:
|
||||
save_env_value("MISTRAL_API_KEY", api_key)
|
||||
print_success("Mistral TTS API key saved")
|
||||
else:
|
||||
print_warning("No API key provided. Falling back to Edge TTS.")
|
||||
selected = "edge"
|
||||
|
||||
# Save the selection
|
||||
if "tts" not in config:
|
||||
config["tts"] = {}
|
||||
@@ -2036,6 +1969,48 @@ def _setup_weixin():
|
||||
_gateway_setup_weixin()
|
||||
|
||||
|
||||
def _setup_signal():
|
||||
"""Configure Signal via gateway setup."""
|
||||
from hermes_cli.gateway import _setup_signal as _gateway_setup_signal
|
||||
_gateway_setup_signal()
|
||||
|
||||
|
||||
def _setup_email():
|
||||
"""Configure Email via gateway setup."""
|
||||
from hermes_cli.gateway import _setup_email as _gateway_setup_email
|
||||
_gateway_setup_email()
|
||||
|
||||
|
||||
def _setup_sms():
|
||||
"""Configure SMS (Twilio) via gateway setup."""
|
||||
from hermes_cli.gateway import _setup_sms as _gateway_setup_sms
|
||||
_gateway_setup_sms()
|
||||
|
||||
|
||||
def _setup_dingtalk():
|
||||
"""Configure DingTalk via gateway setup."""
|
||||
from hermes_cli.gateway import _setup_dingtalk as _gateway_setup_dingtalk
|
||||
_gateway_setup_dingtalk()
|
||||
|
||||
|
||||
def _setup_feishu():
|
||||
"""Configure Feishu / Lark via gateway setup."""
|
||||
from hermes_cli.gateway import _setup_feishu as _gateway_setup_feishu
|
||||
_gateway_setup_feishu()
|
||||
|
||||
|
||||
def _setup_wecom():
|
||||
"""Configure WeCom (Enterprise WeChat) via gateway setup."""
|
||||
from hermes_cli.gateway import _setup_wecom as _gateway_setup_wecom
|
||||
_gateway_setup_wecom()
|
||||
|
||||
|
||||
def _setup_wecom_callback():
|
||||
"""Configure WeCom Callback (self-built app) via gateway setup."""
|
||||
from hermes_cli.gateway import _setup_wecom_callback as _gw_setup
|
||||
_gw_setup()
|
||||
|
||||
|
||||
def _setup_bluebubbles():
|
||||
"""Configure BlueBubbles iMessage gateway."""
|
||||
print_header("BlueBubbles (iMessage)")
|
||||
@@ -2152,9 +2127,16 @@ _GATEWAY_PLATFORMS = [
|
||||
("Telegram", "TELEGRAM_BOT_TOKEN", _setup_telegram),
|
||||
("Discord", "DISCORD_BOT_TOKEN", _setup_discord),
|
||||
("Slack", "SLACK_BOT_TOKEN", _setup_slack),
|
||||
("Signal", "SIGNAL_HTTP_URL", _setup_signal),
|
||||
("Email", "EMAIL_ADDRESS", _setup_email),
|
||||
("SMS (Twilio)", "TWILIO_ACCOUNT_SID", _setup_sms),
|
||||
("Matrix", "MATRIX_ACCESS_TOKEN", _setup_matrix),
|
||||
("Mattermost", "MATTERMOST_TOKEN", _setup_mattermost),
|
||||
("WhatsApp", "WHATSAPP_ENABLED", _setup_whatsapp),
|
||||
("DingTalk", "DINGTALK_CLIENT_ID", _setup_dingtalk),
|
||||
("Feishu / Lark", "FEISHU_APP_ID", _setup_feishu),
|
||||
("WeCom (Enterprise WeChat)", "WECOM_BOT_ID", _setup_wecom),
|
||||
("WeCom Callback (Self-Built App)", "WECOM_CALLBACK_CORP_ID", _setup_wecom_callback),
|
||||
("Weixin (WeChat)", "WEIXIN_ACCOUNT_ID", _setup_weixin),
|
||||
("BlueBubbles (iMessage)", "BLUEBUBBLES_SERVER_URL", _setup_bluebubbles),
|
||||
("Webhooks (GitHub, GitLab, etc.)", "WEBHOOK_ENABLED", _setup_webhooks),
|
||||
@@ -2196,10 +2178,17 @@ def setup_gateway(config: dict):
|
||||
get_env_value("TELEGRAM_BOT_TOKEN")
|
||||
or get_env_value("DISCORD_BOT_TOKEN")
|
||||
or get_env_value("SLACK_BOT_TOKEN")
|
||||
or get_env_value("SIGNAL_HTTP_URL")
|
||||
or get_env_value("EMAIL_ADDRESS")
|
||||
or get_env_value("TWILIO_ACCOUNT_SID")
|
||||
or get_env_value("MATTERMOST_TOKEN")
|
||||
or get_env_value("MATRIX_ACCESS_TOKEN")
|
||||
or get_env_value("MATRIX_PASSWORD")
|
||||
or get_env_value("WHATSAPP_ENABLED")
|
||||
or get_env_value("DINGTALK_CLIENT_ID")
|
||||
or get_env_value("FEISHU_APP_ID")
|
||||
or get_env_value("WECOM_BOT_ID")
|
||||
or get_env_value("WEIXIN_ACCOUNT_ID")
|
||||
or get_env_value("BLUEBUBBLES_SERVER_URL")
|
||||
or get_env_value("WEBHOOK_ENABLED")
|
||||
)
|
||||
@@ -2388,12 +2377,30 @@ def _get_section_config_summary(config: dict, section_key: str) -> Optional[str]
|
||||
platforms.append("Discord")
|
||||
if get_env_value("SLACK_BOT_TOKEN"):
|
||||
platforms.append("Slack")
|
||||
if get_env_value("WHATSAPP_PHONE_NUMBER_ID"):
|
||||
platforms.append("WhatsApp")
|
||||
if get_env_value("SIGNAL_ACCOUNT"):
|
||||
platforms.append("Signal")
|
||||
if get_env_value("EMAIL_ADDRESS"):
|
||||
platforms.append("Email")
|
||||
if get_env_value("TWILIO_ACCOUNT_SID"):
|
||||
platforms.append("SMS")
|
||||
if get_env_value("MATRIX_ACCESS_TOKEN") or get_env_value("MATRIX_PASSWORD"):
|
||||
platforms.append("Matrix")
|
||||
if get_env_value("MATTERMOST_TOKEN"):
|
||||
platforms.append("Mattermost")
|
||||
if get_env_value("WHATSAPP_PHONE_NUMBER_ID"):
|
||||
platforms.append("WhatsApp")
|
||||
if get_env_value("DINGTALK_CLIENT_ID"):
|
||||
platforms.append("DingTalk")
|
||||
if get_env_value("FEISHU_APP_ID"):
|
||||
platforms.append("Feishu")
|
||||
if get_env_value("WECOM_BOT_ID"):
|
||||
platforms.append("WeCom")
|
||||
if get_env_value("WEIXIN_ACCOUNT_ID"):
|
||||
platforms.append("Weixin")
|
||||
if get_env_value("BLUEBUBBLES_SERVER_URL"):
|
||||
platforms.append("BlueBubbles")
|
||||
if get_env_value("WEBHOOK_ENABLED"):
|
||||
platforms.append("Webhooks")
|
||||
if platforms:
|
||||
return ", ".join(platforms)
|
||||
return None # No platforms configured — section must run
|
||||
|
||||
@@ -15,25 +15,12 @@ from typing import List, Optional, Set
|
||||
|
||||
from hermes_cli.config import load_config, save_config
|
||||
from hermes_cli.colors import Colors, color
|
||||
from hermes_cli.platforms import PLATFORMS as _PLATFORMS, platform_label
|
||||
|
||||
PLATFORMS = {
|
||||
"cli": "🖥️ CLI",
|
||||
"telegram": "📱 Telegram",
|
||||
"discord": "💬 Discord",
|
||||
"slack": "💼 Slack",
|
||||
"whatsapp": "📱 WhatsApp",
|
||||
"signal": "📡 Signal",
|
||||
"bluebubbles": "💬 BlueBubbles",
|
||||
"email": "📧 Email",
|
||||
"homeassistant": "🏠 Home Assistant",
|
||||
"mattermost": "💬 Mattermost",
|
||||
"matrix": "💬 Matrix",
|
||||
"dingtalk": "💬 DingTalk",
|
||||
"feishu": "🪽 Feishu",
|
||||
"wecom": "💬 WeCom",
|
||||
"weixin": "💬 Weixin",
|
||||
"webhook": "🔗 Webhook",
|
||||
}
|
||||
# Backward-compatible view: {key: label_string} so existing code that
|
||||
# iterates ``PLATFORMS.items()`` or calls ``PLATFORMS.get(key)`` keeps
|
||||
# working without changes to every call site.
|
||||
PLATFORMS = {k: info.label for k, info in _PLATFORMS.items() if k != "api_server"}
|
||||
|
||||
# ─── Config Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -141,11 +141,8 @@ def show_status(args):
|
||||
display = redact_key(value) if not show_all else value
|
||||
print(f" {name:<12} {check_mark(has_key)} {display}")
|
||||
|
||||
anthropic_value = (
|
||||
get_env_value("ANTHROPIC_TOKEN")
|
||||
or get_env_value("ANTHROPIC_API_KEY")
|
||||
or ""
|
||||
)
|
||||
from hermes_cli.auth import get_anthropic_key
|
||||
anthropic_value = get_anthropic_key()
|
||||
anthropic_display = redact_key(anthropic_value) if not show_all else anthropic_value
|
||||
print(f" {'Anthropic':<12} {check_mark(bool(anthropic_value))} {anthropic_display}")
|
||||
|
||||
@@ -305,6 +302,7 @@ def show_status(args):
|
||||
"DingTalk": ("DINGTALK_CLIENT_ID", None),
|
||||
"Feishu": ("FEISHU_APP_ID", "FEISHU_HOME_CHANNEL"),
|
||||
"WeCom": ("WECOM_BOT_ID", "WECOM_HOME_CHANNEL"),
|
||||
"WeCom Callback": ("WECOM_CALLBACK_CORP_ID", None),
|
||||
"Weixin": ("WEIXIN_ACCOUNT_ID", "WEIXIN_HOME_CHANNEL"),
|
||||
"BlueBubbles": ("BLUEBUBBLES_SERVER_URL", "BLUEBUBBLES_HOME_CHANNEL"),
|
||||
}
|
||||
|
||||
+30
-126
@@ -33,33 +33,13 @@ PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
||||
|
||||
# ─── UI Helpers (shared with setup.py) ────────────────────────────────────────
|
||||
|
||||
def _print_info(text: str):
|
||||
print(color(f" {text}", Colors.DIM))
|
||||
|
||||
def _print_success(text: str):
|
||||
print(color(f"✓ {text}", Colors.GREEN))
|
||||
|
||||
def _print_warning(text: str):
|
||||
print(color(f"⚠ {text}", Colors.YELLOW))
|
||||
|
||||
def _print_error(text: str):
|
||||
print(color(f"✗ {text}", Colors.RED))
|
||||
|
||||
def _prompt(question: str, default: str = None, password: bool = False) -> str:
|
||||
if default:
|
||||
display = f"{question} [{default}]: "
|
||||
else:
|
||||
display = f"{question}: "
|
||||
try:
|
||||
if password:
|
||||
import getpass
|
||||
value = getpass.getpass(color(display, Colors.YELLOW))
|
||||
else:
|
||||
value = input(color(display, Colors.YELLOW))
|
||||
return value.strip() or default or ""
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return default or ""
|
||||
from hermes_cli.cli_output import ( # noqa: E402 — late import block
|
||||
print_error as _print_error,
|
||||
print_info as _print_info,
|
||||
print_success as _print_success,
|
||||
print_warning as _print_warning,
|
||||
prompt as _prompt,
|
||||
)
|
||||
|
||||
# ─── Toolset Registry ─────────────────────────────────────────────────────────
|
||||
|
||||
@@ -118,25 +98,14 @@ def _get_plugin_toolset_keys() -> set:
|
||||
except Exception:
|
||||
return set()
|
||||
|
||||
# Platform display config
|
||||
# Platform display config — derived from the canonical registry so every
|
||||
# module shares the same data. Kept as dict-of-dicts for backward
|
||||
# compatibility with existing ``PLATFORMS[key]["label"]`` access patterns.
|
||||
from hermes_cli.platforms import PLATFORMS as _PLATFORMS_REGISTRY
|
||||
|
||||
PLATFORMS = {
|
||||
"cli": {"label": "🖥️ CLI", "default_toolset": "hermes-cli"},
|
||||
"telegram": {"label": "📱 Telegram", "default_toolset": "hermes-telegram"},
|
||||
"discord": {"label": "💬 Discord", "default_toolset": "hermes-discord"},
|
||||
"slack": {"label": "💼 Slack", "default_toolset": "hermes-slack"},
|
||||
"whatsapp": {"label": "📱 WhatsApp", "default_toolset": "hermes-whatsapp"},
|
||||
"signal": {"label": "📡 Signal", "default_toolset": "hermes-signal"},
|
||||
"bluebubbles": {"label": "💙 BlueBubbles", "default_toolset": "hermes-bluebubbles"},
|
||||
"homeassistant": {"label": "🏠 Home Assistant", "default_toolset": "hermes-homeassistant"},
|
||||
"email": {"label": "📧 Email", "default_toolset": "hermes-email"},
|
||||
"matrix": {"label": "💬 Matrix", "default_toolset": "hermes-matrix"},
|
||||
"dingtalk": {"label": "💬 DingTalk", "default_toolset": "hermes-dingtalk"},
|
||||
"feishu": {"label": "🪽 Feishu", "default_toolset": "hermes-feishu"},
|
||||
"wecom": {"label": "💬 WeCom", "default_toolset": "hermes-wecom"},
|
||||
"weixin": {"label": "💬 Weixin", "default_toolset": "hermes-weixin"},
|
||||
"api_server": {"label": "🌐 API Server", "default_toolset": "hermes-api-server"},
|
||||
"mattermost": {"label": "💬 Mattermost", "default_toolset": "hermes-mattermost"},
|
||||
"webhook": {"label": "🔗 Webhook", "default_toolset": "hermes-webhook"},
|
||||
k: {"label": info.label, "default_toolset": info.default_toolset}
|
||||
for k, info in _PLATFORMS_REGISTRY.items()
|
||||
}
|
||||
|
||||
|
||||
@@ -181,6 +150,14 @@ TOOL_CATEGORIES = {
|
||||
],
|
||||
"tts_provider": "elevenlabs",
|
||||
},
|
||||
{
|
||||
"name": "Mistral (Voxtral TTS)",
|
||||
"tag": "Multilingual, native Opus, needs MISTRAL_API_KEY",
|
||||
"env_vars": [
|
||||
{"key": "MISTRAL_API_KEY", "prompt": "Mistral API key", "url": "https://console.mistral.ai/"},
|
||||
],
|
||||
"tts_provider": "mistral",
|
||||
},
|
||||
],
|
||||
},
|
||||
"web": {
|
||||
@@ -501,6 +478,10 @@ def _get_platform_tools(
|
||||
default_ts = PLATFORMS[platform]["default_toolset"]
|
||||
toolset_names = [default_ts]
|
||||
|
||||
# YAML may parse bare numeric names (e.g. ``12306:``) as int.
|
||||
# Normalise to str so downstream sorted() never mixes types.
|
||||
toolset_names = [str(ts) for ts in toolset_names]
|
||||
|
||||
configurable_keys = {ts_key for ts_key, _, _ in CONFIGURABLE_TOOLSETS}
|
||||
|
||||
# If the saved list contains any configurable keys directly, the user
|
||||
@@ -559,7 +540,7 @@ def _get_platform_tools(
|
||||
# Special sentinel: "no_mcp" in the toolset list disables all MCP servers.
|
||||
mcp_servers = config.get("mcp_servers") or {}
|
||||
enabled_mcp_servers = {
|
||||
name
|
||||
str(name)
|
||||
for name, server_cfg in mcp_servers.items()
|
||||
if isinstance(server_cfg, dict)
|
||||
and _parse_enabled_flag(server_cfg.get("enabled", True), default=True)
|
||||
@@ -665,86 +646,9 @@ def _toolset_has_keys(ts_key: str, config: dict = None) -> bool:
|
||||
# ─── Menu Helpers ─────────────────────────────────────────────────────────────
|
||||
|
||||
def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
|
||||
"""Single-select menu (arrow keys). Uses curses to avoid simple_term_menu
|
||||
rendering bugs in tmux, iTerm, and other non-standard terminals."""
|
||||
|
||||
# Curses-based single-select — works in tmux, iTerm, and standard terminals
|
||||
try:
|
||||
import curses
|
||||
result_holder = [default]
|
||||
|
||||
def _curses_menu(stdscr):
|
||||
curses.curs_set(0)
|
||||
if curses.has_colors():
|
||||
curses.start_color()
|
||||
curses.use_default_colors()
|
||||
curses.init_pair(1, curses.COLOR_GREEN, -1)
|
||||
curses.init_pair(2, curses.COLOR_YELLOW, -1)
|
||||
cursor = default
|
||||
|
||||
while True:
|
||||
stdscr.clear()
|
||||
max_y, max_x = stdscr.getmaxyx()
|
||||
try:
|
||||
stdscr.addnstr(0, 0, question, max_x - 1,
|
||||
curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0))
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
for i, c in enumerate(choices):
|
||||
y = i + 2
|
||||
if y >= max_y - 1:
|
||||
break
|
||||
arrow = "→" if i == cursor else " "
|
||||
line = f" {arrow} {c}"
|
||||
attr = curses.A_NORMAL
|
||||
if i == cursor:
|
||||
attr = curses.A_BOLD
|
||||
if curses.has_colors():
|
||||
attr |= curses.color_pair(1)
|
||||
try:
|
||||
stdscr.addnstr(y, 0, line, max_x - 1, attr)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
stdscr.refresh()
|
||||
key = stdscr.getch()
|
||||
|
||||
if key in (curses.KEY_UP, ord('k')):
|
||||
cursor = (cursor - 1) % len(choices)
|
||||
elif key in (curses.KEY_DOWN, ord('j')):
|
||||
cursor = (cursor + 1) % len(choices)
|
||||
elif key in (curses.KEY_ENTER, 10, 13):
|
||||
result_holder[0] = cursor
|
||||
return
|
||||
elif key in (27, ord('q')):
|
||||
return
|
||||
|
||||
curses.wrapper(_curses_menu)
|
||||
from hermes_cli.curses_ui import flush_stdin
|
||||
flush_stdin()
|
||||
return result_holder[0]
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback: numbered input (Windows without curses, etc.)
|
||||
print(color(question, Colors.YELLOW))
|
||||
for i, c in enumerate(choices):
|
||||
marker = "●" if i == default else "○"
|
||||
style = Colors.GREEN if i == default else ""
|
||||
print(color(f" {marker} {i+1}. {c}", style) if style else f" {marker} {i+1}. {c}")
|
||||
while True:
|
||||
try:
|
||||
val = input(color(f" Select [1-{len(choices)}] ({default + 1}): ", Colors.DIM))
|
||||
if not val:
|
||||
return default
|
||||
idx = int(val) - 1
|
||||
if 0 <= idx < len(choices):
|
||||
return idx
|
||||
except (ValueError, KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return default
|
||||
"""Single-select menu (arrow keys). Delegates to curses_radiolist."""
|
||||
from hermes_cli.curses_ui import curses_radiolist
|
||||
return curses_radiolist(question, choices, selected=default, cancel_returns=default)
|
||||
|
||||
|
||||
# ─── Token Estimation ────────────────────────────────────────────────────────
|
||||
|
||||
@@ -189,6 +189,33 @@ def is_wsl() -> bool:
|
||||
return _wsl_detected
|
||||
|
||||
|
||||
# ─── Well-Known Paths ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_config_path() -> Path:
|
||||
"""Return the path to ``config.yaml`` under HERMES_HOME.
|
||||
|
||||
Replaces the ``get_hermes_home() / "config.yaml"`` pattern repeated
|
||||
in 7+ files (skill_utils.py, hermes_logging.py, hermes_time.py, etc.).
|
||||
"""
|
||||
return get_hermes_home() / "config.yaml"
|
||||
|
||||
|
||||
def get_skills_dir() -> Path:
|
||||
"""Return the path to the skills directory under HERMES_HOME."""
|
||||
return get_hermes_home() / "skills"
|
||||
|
||||
|
||||
def get_logs_dir() -> Path:
|
||||
"""Return the path to the logs directory under HERMES_HOME."""
|
||||
return get_hermes_home() / "logs"
|
||||
|
||||
|
||||
def get_env_path() -> Path:
|
||||
"""Return the path to the ``.env`` file under HERMES_HOME."""
|
||||
return get_hermes_home() / ".env"
|
||||
|
||||
|
||||
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
|
||||
OPENROUTER_MODELS_URL = f"{OPENROUTER_BASE_URL}/models"
|
||||
|
||||
|
||||
+141
-9
@@ -7,27 +7,44 @@ gateway call early in their startup path. All log files live under
|
||||
Log files produced:
|
||||
agent.log — INFO+, all agent/tool/session activity (the main log)
|
||||
errors.log — WARNING+, errors and warnings only (quick triage)
|
||||
gateway.log — INFO+, gateway-only events (created when mode="gateway")
|
||||
|
||||
Both files use ``RotatingFileHandler`` with ``RedactingFormatter`` so
|
||||
All files use ``RotatingFileHandler`` with ``RedactingFormatter`` so
|
||||
secrets are never written to disk.
|
||||
|
||||
Component separation:
|
||||
gateway.log only receives records from ``gateway.*`` loggers —
|
||||
platform adapters, session management, slash commands, delivery.
|
||||
agent.log remains the catch-all (everything goes there).
|
||||
|
||||
Session context:
|
||||
Call ``set_session_context(session_id)`` at the start of a conversation
|
||||
and ``clear_session_context()`` when done. All log lines emitted on
|
||||
that thread will include ``[session_id]`` for filtering/correlation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
from hermes_constants import get_config_path, get_hermes_home
|
||||
|
||||
# Sentinel to track whether setup_logging() has already run. The function
|
||||
# is idempotent — calling it twice is safe but the second call is a no-op
|
||||
# unless ``force=True``.
|
||||
_logging_initialized = False
|
||||
|
||||
# Default log format — includes timestamp, level, logger name, and message.
|
||||
_LOG_FORMAT = "%(asctime)s %(levelname)s %(name)s: %(message)s"
|
||||
_LOG_FORMAT_VERBOSE = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
# Thread-local storage for per-conversation session context.
|
||||
_session_context = threading.local()
|
||||
|
||||
# Default log format — includes timestamp, level, optional session tag,
|
||||
# logger name, and message. The ``%(session_tag)s`` field is guaranteed to
|
||||
# exist on every LogRecord via _install_session_record_factory() below.
|
||||
_LOG_FORMAT = "%(asctime)s %(levelname)s%(session_tag)s %(name)s: %(message)s"
|
||||
_LOG_FORMAT_VERBOSE = "%(asctime)s - %(name)s - %(levelname)s%(session_tag)s - %(message)s"
|
||||
|
||||
# Third-party loggers that are noisy at DEBUG/INFO level.
|
||||
_NOISY_LOGGERS = (
|
||||
@@ -48,6 +65,99 @@ _NOISY_LOGGERS = (
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public session context API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def set_session_context(session_id: str) -> None:
|
||||
"""Set the session ID for the current thread.
|
||||
|
||||
All subsequent log records on this thread will include ``[session_id]``
|
||||
in the formatted output. Call at the start of ``run_conversation()``.
|
||||
"""
|
||||
_session_context.session_id = session_id
|
||||
|
||||
|
||||
def clear_session_context() -> None:
|
||||
"""Clear the session ID for the current thread.
|
||||
|
||||
Optional — ``set_session_context()`` overwrites the previous value,
|
||||
so explicit clearing is only needed if the thread is reused for
|
||||
non-conversation work after ``run_conversation()`` returns.
|
||||
"""
|
||||
_session_context.session_id = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Record factory — injects session_tag into every LogRecord at creation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _install_session_record_factory() -> None:
|
||||
"""Replace the global LogRecord factory with one that adds ``session_tag``.
|
||||
|
||||
Unlike a ``logging.Filter`` on a handler or logger, the record factory
|
||||
runs for EVERY record in the process — including records that propagate
|
||||
from child loggers and records handled by third-party handlers. This
|
||||
guarantees ``%(session_tag)s`` is always available in format strings,
|
||||
eliminating the KeyError that would occur if a handler used our format
|
||||
without having a ``_SessionFilter`` attached.
|
||||
|
||||
Idempotent — checks for a marker attribute to avoid double-wrapping if
|
||||
the module is reloaded.
|
||||
"""
|
||||
current_factory = logging.getLogRecordFactory()
|
||||
if getattr(current_factory, "_hermes_session_injector", False):
|
||||
return # already installed
|
||||
|
||||
def _session_record_factory(*args, **kwargs):
|
||||
record = current_factory(*args, **kwargs)
|
||||
sid = getattr(_session_context, "session_id", None)
|
||||
record.session_tag = f" [{sid}]" if sid else "" # type: ignore[attr-defined]
|
||||
return record
|
||||
|
||||
_session_record_factory._hermes_session_injector = True # type: ignore[attr-defined]
|
||||
logging.setLogRecordFactory(_session_record_factory)
|
||||
|
||||
|
||||
# Install immediately on import — session_tag is available on all records
|
||||
# from this point forward, even before setup_logging() is called.
|
||||
_install_session_record_factory()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Filters
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _ComponentFilter(logging.Filter):
|
||||
"""Only pass records whose logger name starts with one of *prefixes*.
|
||||
|
||||
Used to route gateway-specific records to ``gateway.log`` while
|
||||
keeping ``agent.log`` as the catch-all.
|
||||
"""
|
||||
|
||||
def __init__(self, prefixes: Sequence[str]) -> None:
|
||||
super().__init__()
|
||||
self._prefixes = tuple(prefixes)
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
return record.name.startswith(self._prefixes)
|
||||
|
||||
|
||||
# Logger name prefixes that belong to each component.
|
||||
# Used by _ComponentFilter and exposed for ``hermes logs --component``.
|
||||
COMPONENT_PREFIXES = {
|
||||
"gateway": ("gateway",),
|
||||
"agent": ("agent", "run_agent", "model_tools", "batch_runner"),
|
||||
"tools": ("tools",),
|
||||
"cli": ("hermes_cli", "cli"),
|
||||
"cron": ("cron",),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main setup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def setup_logging(
|
||||
*,
|
||||
hermes_home: Optional[Path] = None,
|
||||
@@ -78,8 +188,9 @@ def setup_logging(
|
||||
Number of rotated backup files to keep.
|
||||
Defaults to 3 or the value from config.yaml ``logging.backup_count``.
|
||||
mode
|
||||
Hint for the caller context: ``"cli"``, ``"gateway"``, ``"cron"``.
|
||||
Currently used only for log format tuning (gateway includes PID).
|
||||
Caller context: ``"cli"``, ``"gateway"``, ``"cron"``.
|
||||
When ``"gateway"``, an additional ``gateway.log`` file is created
|
||||
that receives only gateway-component records.
|
||||
force
|
||||
Re-run setup even if it has already been called.
|
||||
|
||||
@@ -130,6 +241,18 @@ def setup_logging(
|
||||
formatter=RedactingFormatter(_LOG_FORMAT),
|
||||
)
|
||||
|
||||
# --- gateway.log (INFO+, gateway component only) ------------------------
|
||||
if mode == "gateway":
|
||||
_add_rotating_handler(
|
||||
root,
|
||||
log_dir / "gateway.log",
|
||||
level=logging.INFO,
|
||||
max_bytes=5 * 1024 * 1024,
|
||||
backup_count=3,
|
||||
formatter=RedactingFormatter(_LOG_FORMAT),
|
||||
log_filter=_ComponentFilter(COMPONENT_PREFIXES["gateway"]),
|
||||
)
|
||||
|
||||
# Ensure root logger level is low enough for the handlers to fire.
|
||||
if root.level == logging.NOTSET or root.level > level:
|
||||
root.setLevel(level)
|
||||
@@ -218,9 +341,16 @@ def _add_rotating_handler(
|
||||
max_bytes: int,
|
||||
backup_count: int,
|
||||
formatter: logging.Formatter,
|
||||
log_filter: Optional[logging.Filter] = None,
|
||||
) -> None:
|
||||
"""Add a ``RotatingFileHandler`` to *logger*, skipping if one already
|
||||
exists for the same resolved file path (idempotent).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
log_filter
|
||||
Optional filter to attach to the handler (e.g. ``_ComponentFilter``
|
||||
for gateway.log).
|
||||
"""
|
||||
resolved = path.resolve()
|
||||
for existing in logger.handlers:
|
||||
@@ -236,6 +366,8 @@ def _add_rotating_handler(
|
||||
)
|
||||
handler.setLevel(level)
|
||||
handler.setFormatter(formatter)
|
||||
if log_filter is not None:
|
||||
handler.addFilter(log_filter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
|
||||
@@ -246,7 +378,7 @@ def _read_logging_config():
|
||||
"""
|
||||
try:
|
||||
import yaml
|
||||
config_path = get_hermes_home() / "config.yaml"
|
||||
config_path = get_config_path()
|
||||
if config_path.exists():
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
|
||||
+2
-3
@@ -16,7 +16,7 @@ crashes due to a bad timezone string.
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from hermes_constants import get_hermes_home
|
||||
from hermes_constants import get_config_path
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -48,8 +48,7 @@ def _resolve_timezone_name() -> str:
|
||||
# 2. config.yaml ``timezone`` key
|
||||
try:
|
||||
import yaml
|
||||
hermes_home = get_hermes_home()
|
||||
config_path = hermes_home / "config.yaml"
|
||||
config_path = get_config_path()
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
|
||||
@@ -499,6 +499,16 @@
|
||||
default = "ubuntu:24.04";
|
||||
description = "OCI container image. The container pulls this at runtime via Docker/Podman.";
|
||||
};
|
||||
|
||||
hostUsers = mkOption {
|
||||
type = types.listOf types.str;
|
||||
default = [ ];
|
||||
description = ''
|
||||
Interactive users who get a ~/.hermes symlink to the service
|
||||
stateDir. These users are automatically added to the hermes group.
|
||||
'';
|
||||
example = [ "sidbin" ];
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
@@ -557,6 +567,25 @@
|
||||
environment.variables.HERMES_HOME = "${cfg.stateDir}/.hermes";
|
||||
})
|
||||
|
||||
# ── Host user group membership ─────────────────────────────────────
|
||||
(lib.mkIf (cfg.container.enable && cfg.container.hostUsers != []) {
|
||||
users.users = lib.genAttrs cfg.container.hostUsers (user: {
|
||||
extraGroups = [ cfg.group ];
|
||||
});
|
||||
})
|
||||
|
||||
# ── Warnings ──────────────────────────────────────────────────────
|
||||
(lib.mkIf (cfg.container.enable && !cfg.addToSystemPackages && cfg.container.hostUsers != []) {
|
||||
warnings = [
|
||||
''
|
||||
services.hermes-agent: container.enable is true and container.hostUsers
|
||||
is set, but addToSystemPackages is false. Without a host-installed hermes
|
||||
binary, container routing will not work for interactive users.
|
||||
Set addToSystemPackages = true or ensure hermes is on PATH.
|
||||
''
|
||||
];
|
||||
})
|
||||
|
||||
# ── Directories ───────────────────────────────────────────────────
|
||||
{
|
||||
systemd.tmpfiles.rules = [
|
||||
@@ -611,6 +640,59 @@
|
||||
chown ${cfg.user}:${cfg.group} ${cfg.stateDir}/.hermes/.managed
|
||||
chmod 0644 ${cfg.stateDir}/.hermes/.managed
|
||||
|
||||
# Container mode metadata — tells the host CLI to exec into the
|
||||
# container instead of running locally. Removed when container mode
|
||||
# is disabled so the host CLI falls back to native execution.
|
||||
${if cfg.container.enable then ''
|
||||
cat > ${cfg.stateDir}/.hermes/.container-mode <<'HERMES_CONTAINER_MODE_EOF'
|
||||
# Written by NixOS activation script. Do not edit manually.
|
||||
backend=${cfg.container.backend}
|
||||
container_name=${containerName}
|
||||
exec_user=${cfg.user}
|
||||
hermes_bin=${containerDataDir}/current-package/bin/hermes
|
||||
HERMES_CONTAINER_MODE_EOF
|
||||
chown ${cfg.user}:${cfg.group} ${cfg.stateDir}/.hermes/.container-mode
|
||||
chmod 0644 ${cfg.stateDir}/.hermes/.container-mode
|
||||
'' else ''
|
||||
rm -f ${cfg.stateDir}/.hermes/.container-mode
|
||||
|
||||
# Remove symlink bridge for hostUsers
|
||||
${lib.concatStringsSep "\n" (map (user:
|
||||
let
|
||||
userHome = config.users.users.${user}.home;
|
||||
symlinkPath = "${userHome}/.hermes";
|
||||
in ''
|
||||
if [ -L "${symlinkPath}" ] && [ "$(readlink "${symlinkPath}")" = "${cfg.stateDir}/.hermes" ]; then
|
||||
rm -f "${symlinkPath}"
|
||||
echo "hermes-agent: removed symlink ${symlinkPath}"
|
||||
fi
|
||||
'') cfg.container.hostUsers)}
|
||||
''}
|
||||
|
||||
# ── Symlink bridge for interactive users ───────────────────────
|
||||
# Create ~/.hermes -> stateDir/.hermes for each hostUser so the
|
||||
# host CLI shares state with the container service.
|
||||
# Only runs when container mode is enabled.
|
||||
${lib.optionalString cfg.container.enable
|
||||
(lib.concatStringsSep "\n" (map (user:
|
||||
let
|
||||
userHome = config.users.users.${user}.home;
|
||||
symlinkPath = "${userHome}/.hermes";
|
||||
target = "${cfg.stateDir}/.hermes";
|
||||
in ''
|
||||
if [ -d "${symlinkPath}" ] && [ ! -L "${symlinkPath}" ]; then
|
||||
# Real directory — back it up, then create symlink.
|
||||
# (ln -sfn cannot atomically replace a directory.)
|
||||
_backup="${symlinkPath}.bak.$(date +%s)"
|
||||
echo "hermes-agent: backing up existing ${symlinkPath} to $_backup"
|
||||
mv "${symlinkPath}" "$_backup"
|
||||
fi
|
||||
# For everything else (existing symlink, doesn't exist, etc.)
|
||||
# ln -sfn handles it: replaces symlinks, creates new ones.
|
||||
ln -sfn "${target}" "${symlinkPath}"
|
||||
chown -h ${user}:${cfg.group} "${symlinkPath}"
|
||||
'') cfg.container.hostUsers))}
|
||||
|
||||
# Seed auth file if provided
|
||||
${lib.optionalString (cfg.authFile != null) ''
|
||||
${if cfg.authFileForceOverwrite then ''
|
||||
|
||||
@@ -617,6 +617,19 @@ class Migrator:
|
||||
candidate = self.source_root / rel
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
# OpenClaw renamed workspace/ to workspace-main/ (and workspace-{agentId}
|
||||
# for multi-agent). Try the new path as a fallback.
|
||||
if rel.startswith("workspace/"):
|
||||
suffix = rel[len("workspace/"):]
|
||||
for variant in ("workspace-main", "workspace-assistant"):
|
||||
alt = self.source_root / variant / suffix
|
||||
if alt.exists():
|
||||
return alt
|
||||
elif rel.startswith("workspace.default/"):
|
||||
suffix = rel[len("workspace.default/"):]
|
||||
alt = self.source_root / "workspace-main" / suffix
|
||||
if alt.exists():
|
||||
return alt
|
||||
return None
|
||||
|
||||
def resolve_skill_destination(self, destination: Path) -> Path:
|
||||
@@ -1033,11 +1046,8 @@ class Migrator:
|
||||
def migrate_secret_settings(self, config: Dict[str, Any]) -> None:
|
||||
secret_additions: Dict[str, str] = {}
|
||||
|
||||
telegram_token = (
|
||||
config.get("channels", {})
|
||||
.get("telegram", {})
|
||||
.get("botToken")
|
||||
)
|
||||
tg_cfg = config.get("channels", {}).get("telegram", {})
|
||||
telegram_token = self._get_channel_field(tg_cfg, "botToken") if isinstance(tg_cfg, dict) else None
|
||||
if isinstance(telegram_token, str) and telegram_token.strip():
|
||||
secret_additions["TELEGRAM_BOT_TOKEN"] = telegram_token.strip()
|
||||
|
||||
@@ -1057,15 +1067,28 @@ class Migrator:
|
||||
"""Resolve a channel config value that may be a SecretRef."""
|
||||
return resolve_secret_input(value, self.load_openclaw_env())
|
||||
|
||||
@staticmethod
|
||||
def _get_channel_field(ch_cfg: Dict[str, Any], field: str) -> Any:
|
||||
"""Get a field from channel config, checking both flat and accounts.default layout."""
|
||||
val = ch_cfg.get(field)
|
||||
if val is not None:
|
||||
return val
|
||||
accounts = ch_cfg.get("accounts")
|
||||
if isinstance(accounts, dict):
|
||||
default = accounts.get("default")
|
||||
if isinstance(default, dict):
|
||||
return default.get(field)
|
||||
return None
|
||||
|
||||
def migrate_discord_settings(self, config: Optional[Dict[str, Any]] = None) -> None:
|
||||
config = config or self.load_openclaw_config()
|
||||
additions: Dict[str, str] = {}
|
||||
discord = config.get("channels", {}).get("discord", {})
|
||||
if isinstance(discord, dict):
|
||||
token = discord.get("token")
|
||||
token = self._get_channel_field(discord, "token")
|
||||
if isinstance(token, str) and token.strip():
|
||||
additions["DISCORD_BOT_TOKEN"] = token.strip()
|
||||
allow_from = discord.get("allowFrom", [])
|
||||
allow_from = self._get_channel_field(discord, "allowFrom") or []
|
||||
if isinstance(allow_from, list):
|
||||
users = [str(u).strip() for u in allow_from if str(u).strip()]
|
||||
if users:
|
||||
@@ -1080,13 +1103,13 @@ class Migrator:
|
||||
additions: Dict[str, str] = {}
|
||||
slack = config.get("channels", {}).get("slack", {})
|
||||
if isinstance(slack, dict):
|
||||
bot_token = slack.get("botToken")
|
||||
bot_token = self._get_channel_field(slack, "botToken")
|
||||
if isinstance(bot_token, str) and bot_token.strip():
|
||||
additions["SLACK_BOT_TOKEN"] = bot_token.strip()
|
||||
app_token = slack.get("appToken")
|
||||
app_token = self._get_channel_field(slack, "appToken")
|
||||
if isinstance(app_token, str) and app_token.strip():
|
||||
additions["SLACK_APP_TOKEN"] = app_token.strip()
|
||||
allow_from = slack.get("allowFrom", [])
|
||||
allow_from = self._get_channel_field(slack, "allowFrom") or []
|
||||
if isinstance(allow_from, list):
|
||||
users = [str(u).strip() for u in allow_from if str(u).strip()]
|
||||
if users:
|
||||
@@ -1101,7 +1124,7 @@ class Migrator:
|
||||
additions: Dict[str, str] = {}
|
||||
whatsapp = config.get("channels", {}).get("whatsapp", {})
|
||||
if isinstance(whatsapp, dict):
|
||||
allow_from = whatsapp.get("allowFrom", [])
|
||||
allow_from = self._get_channel_field(whatsapp, "allowFrom") or []
|
||||
if isinstance(allow_from, list):
|
||||
users = [str(u).strip() for u in allow_from if str(u).strip()]
|
||||
if users:
|
||||
@@ -1116,13 +1139,13 @@ class Migrator:
|
||||
additions: Dict[str, str] = {}
|
||||
signal = config.get("channels", {}).get("signal", {})
|
||||
if isinstance(signal, dict):
|
||||
account = signal.get("account")
|
||||
account = self._get_channel_field(signal, "account")
|
||||
if isinstance(account, str) and account.strip():
|
||||
additions["SIGNAL_ACCOUNT"] = account.strip()
|
||||
http_url = signal.get("httpUrl")
|
||||
http_url = self._get_channel_field(signal, "httpUrl")
|
||||
if isinstance(http_url, str) and http_url.strip():
|
||||
additions["SIGNAL_HTTP_URL"] = http_url.strip()
|
||||
allow_from = signal.get("allowFrom", [])
|
||||
allow_from = self._get_channel_field(signal, "allowFrom") or []
|
||||
if isinstance(allow_from, list):
|
||||
users = [str(u).strip() for u in allow_from if str(u).strip()]
|
||||
if users:
|
||||
@@ -1161,6 +1184,16 @@ class Migrator:
|
||||
raw_key = provider_cfg.get("apiKey")
|
||||
api_key = resolve_secret_input(raw_key, openclaw_env)
|
||||
if not api_key:
|
||||
# Warn if a SecretRef with file/exec source was silently unresolvable
|
||||
if isinstance(raw_key, dict) and raw_key.get("source") in ("file", "exec"):
|
||||
self.record(
|
||||
"provider-keys",
|
||||
self.source_root / "openclaw.json",
|
||||
None,
|
||||
"skipped",
|
||||
f"Provider '{provider_name}' uses a {raw_key['source']}-backed SecretRef "
|
||||
f"that cannot be auto-migrated. Add this key manually via: hermes config set",
|
||||
)
|
||||
continue
|
||||
|
||||
base_url = provider_cfg.get("baseUrl", "")
|
||||
@@ -1224,6 +1257,21 @@ class Migrator:
|
||||
if val and hermes_key not in secret_additions:
|
||||
secret_additions[hermes_key] = val
|
||||
|
||||
# Check the openclaw.json "env" sub-object — some OpenClaw setups
|
||||
# store API keys here instead of in a separate .env file.
|
||||
# Keys can be at env.<KEY> or env.vars.<KEY>.
|
||||
json_env = config.get("env")
|
||||
if isinstance(json_env, dict):
|
||||
env_vars = json_env.get("vars")
|
||||
sources = [json_env]
|
||||
if isinstance(env_vars, dict):
|
||||
sources.append(env_vars)
|
||||
for src in sources:
|
||||
for oc_key, hermes_key in env_key_mapping.items():
|
||||
val = src.get(oc_key)
|
||||
if isinstance(val, str) and val.strip() and hermes_key not in secret_additions:
|
||||
secret_additions[hermes_key] = val.strip()
|
||||
|
||||
# Check per-agent auth-profiles.json for additional credentials
|
||||
auth_profiles_path = self.source_root / "agents" / "main" / "agent" / "auth-profiles.json"
|
||||
if auth_profiles_path.exists():
|
||||
@@ -1324,8 +1372,9 @@ class Migrator:
|
||||
tts_data: Dict[str, Any] = {}
|
||||
|
||||
provider = tts.get("provider")
|
||||
if isinstance(provider, str) and provider in ("elevenlabs", "openai", "edge"):
|
||||
tts_data["provider"] = provider
|
||||
if isinstance(provider, str) and provider in ("elevenlabs", "openai", "edge", "microsoft"):
|
||||
# OpenClaw renamed "edge" to "microsoft"; Hermes still uses "edge"
|
||||
tts_data["provider"] = "edge" if provider == "microsoft" else provider
|
||||
|
||||
# TTS provider settings live under messages.tts.providers.{provider}
|
||||
# in OpenClaw (not messages.tts.elevenlabs directly)
|
||||
@@ -1374,9 +1423,9 @@ class Migrator:
|
||||
tts_data["openai"] = oai_settings
|
||||
|
||||
edge_tts = (
|
||||
(providers.get("edge") or {})
|
||||
if isinstance(providers.get("edge"), dict) else
|
||||
(tts.get("edge") or {})
|
||||
(providers.get("edge") or providers.get("microsoft") or {})
|
||||
if isinstance(providers.get("edge"), dict) or isinstance(providers.get("microsoft"), dict) else
|
||||
(tts.get("edge") or tts.get("microsoft") or {})
|
||||
)
|
||||
if isinstance(edge_tts, dict):
|
||||
edge_voice = edge_tts.get("voice")
|
||||
@@ -1890,11 +1939,11 @@ class Migrator:
|
||||
if defaults.get("thinkingDefault"):
|
||||
# Map OpenClaw thinking -> Hermes reasoning_effort
|
||||
thinking = defaults["thinkingDefault"]
|
||||
if thinking in ("always", "high"):
|
||||
if thinking in ("always", "high", "xhigh"):
|
||||
agent_cfg["reasoning_effort"] = "high"
|
||||
elif thinking in ("auto", "medium"):
|
||||
elif thinking in ("auto", "medium", "adaptive"):
|
||||
agent_cfg["reasoning_effort"] = "medium"
|
||||
elif thinking in ("off", "low", "none"):
|
||||
elif thinking in ("off", "low", "none", "minimal"):
|
||||
agent_cfg["reasoning_effort"] = "low"
|
||||
changes = True
|
||||
|
||||
@@ -2099,10 +2148,14 @@ class Migrator:
|
||||
f"Provider '{prov_name}' already exists")
|
||||
continue
|
||||
|
||||
api_type = prov_cfg.get("apiType") or prov_cfg.get("type") or "openai"
|
||||
api_type = prov_cfg.get("apiType") or prov_cfg.get("api") or prov_cfg.get("type") or "openai"
|
||||
api_mode_map = {
|
||||
"openai": "chat_completions",
|
||||
"openai-completions": "chat_completions",
|
||||
"openai-responses": "chat_completions",
|
||||
"anthropic": "anthropic_messages",
|
||||
"anthropic-messages": "anthropic_messages",
|
||||
"google-generative-ai": "chat_completions",
|
||||
"cohere": "chat_completions",
|
||||
}
|
||||
entry = {
|
||||
@@ -2142,7 +2195,7 @@ class Migrator:
|
||||
|
||||
# Extended channel token/allowlist mapping
|
||||
CHANNEL_ENV_MAP = {
|
||||
"matrix": {"token": "MATRIX_ACCESS_TOKEN", "allowFrom": "MATRIX_ALLOWED_USERS",
|
||||
"matrix": {"token": "MATRIX...OKEN", "tokenField": "accessToken", "allowFrom": "MATRIX_ALLOWED_USERS",
|
||||
"extras": {"homeserverUrl": "MATRIX_HOMESERVER_URL", "userId": "MATRIX_USER_ID"}},
|
||||
"mattermost": {"token": "MATTERMOST_BOT_TOKEN", "allowFrom": "MATTERMOST_ALLOWED_USERS",
|
||||
"extras": {"url": "MATTERMOST_URL", "teamId": "MATTERMOST_TEAM_ID"}},
|
||||
@@ -2160,19 +2213,21 @@ class Migrator:
|
||||
if not ch_cfg:
|
||||
continue
|
||||
|
||||
# Extract tokens
|
||||
if ch_mapping.get("token") and ch_cfg.get("botToken") and self.migrate_secrets:
|
||||
self._set_env_var(ch_mapping["token"], ch_cfg["botToken"],
|
||||
f"channels.{ch_name}.botToken")
|
||||
if ch_mapping.get("allowFrom") and ch_cfg.get("allowFrom"):
|
||||
allow_val = ch_cfg["allowFrom"]
|
||||
# Extract tokens (check flat path, then accounts.default)
|
||||
token_field = ch_mapping.get("tokenField", "botToken")
|
||||
bot_token = self._get_channel_field(ch_cfg, token_field)
|
||||
if ch_mapping.get("token") and bot_token and self.migrate_secrets:
|
||||
self._set_env_var(ch_mapping["token"], str(bot_token),
|
||||
f"channels.{ch_name}.{token_field}")
|
||||
allow_val = self._get_channel_field(ch_cfg, "allowFrom")
|
||||
if ch_mapping.get("allowFrom") and allow_val:
|
||||
if isinstance(allow_val, list):
|
||||
allow_val = ",".join(str(x) for x in allow_val)
|
||||
self._set_env_var(ch_mapping["allowFrom"], str(allow_val),
|
||||
f"channels.{ch_name}.allowFrom")
|
||||
# Extra fields
|
||||
for oc_key, env_key in (ch_mapping.get("extras") or {}).items():
|
||||
val = ch_cfg.get(oc_key)
|
||||
val = self._get_channel_field(ch_cfg, oc_key)
|
||||
if val:
|
||||
if isinstance(val, list):
|
||||
val = ",".join(str(x) for x in val)
|
||||
@@ -2495,6 +2550,33 @@ class Migrator:
|
||||
elif has_cron_store_archive:
|
||||
notes.append("- Run `hermes cron` to recreate scheduled tasks (see archived cron-store)")
|
||||
|
||||
# Check if skills were imported
|
||||
has_skills = any(i.kind == "skills" and i.status == "migrated" for i in self.items)
|
||||
if has_skills:
|
||||
notes.extend([
|
||||
"",
|
||||
"## Imported Skills",
|
||||
"",
|
||||
"Imported skills require a new session to take effect. After migration,",
|
||||
"restart your agent or start a new chat session, then run `/skills`",
|
||||
"to verify they loaded correctly.",
|
||||
"",
|
||||
])
|
||||
|
||||
# Check if WhatsApp was detected
|
||||
has_whatsapp = any(i.kind == "whatsapp-settings" and i.status == "migrated" for i in self.items)
|
||||
if has_whatsapp:
|
||||
notes.extend([
|
||||
"",
|
||||
"## WhatsApp Requires Re-Pairing",
|
||||
"",
|
||||
"WhatsApp uses QR-code pairing, not token-based auth. Your allowlist",
|
||||
"was migrated, but you must re-pair the device by running:",
|
||||
"",
|
||||
" hermes whatsapp",
|
||||
"",
|
||||
])
|
||||
|
||||
notes.extend([
|
||||
"- Run `hermes gateway install` if you need the gateway service",
|
||||
"- Review `~/.hermes/config.yaml` for any adjustments",
|
||||
|
||||
+402
-148
@@ -339,10 +339,7 @@ def _paths_overlap(left: Path, right: Path) -> bool:
|
||||
|
||||
_SURROGATE_RE = re.compile(r'[\ud800-\udfff]')
|
||||
|
||||
_BUDGET_WARNING_RE = re.compile(
|
||||
r"\[BUDGET(?:\s+WARNING)?:\s+Iteration\s+\d+/\d+\..*?\]",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
|
||||
|
||||
def _sanitize_surrogates(text: str) -> str:
|
||||
@@ -463,34 +460,7 @@ def _sanitize_messages_non_ascii(messages: list) -> bool:
|
||||
return found
|
||||
|
||||
|
||||
def _strip_budget_warnings_from_history(messages: list) -> None:
|
||||
"""Remove budget pressure warnings from tool-result messages in-place.
|
||||
|
||||
Budget warnings are turn-scoped signals that must not leak into replayed
|
||||
history. They live in tool-result ``content`` either as a JSON key
|
||||
(``_budget_warning``) or appended plain text.
|
||||
"""
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict) or msg.get("role") != "tool":
|
||||
continue
|
||||
content = msg.get("content")
|
||||
if not isinstance(content, str) or "_budget_warning" not in content and "[BUDGET" not in content:
|
||||
continue
|
||||
|
||||
# Try JSON first (the common case: _budget_warning key in a dict)
|
||||
try:
|
||||
parsed = json.loads(content)
|
||||
if isinstance(parsed, dict) and "_budget_warning" in parsed:
|
||||
del parsed["_budget_warning"]
|
||||
msg["content"] = json.dumps(parsed, ensure_ascii=False)
|
||||
continue
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# Fallback: strip the text pattern from plain-text tool results
|
||||
cleaned = _BUDGET_WARNING_RE.sub("", content).strip()
|
||||
if cleaned != content:
|
||||
msg["content"] = cleaned
|
||||
|
||||
|
||||
# =========================================================================
|
||||
@@ -579,6 +549,7 @@ class AIAgent:
|
||||
clarify_callback: callable = None,
|
||||
step_callback: callable = None,
|
||||
stream_delta_callback: callable = None,
|
||||
interim_assistant_callback: callable = None,
|
||||
tool_gen_callback: callable = None,
|
||||
status_callback: callable = None,
|
||||
max_tokens: int = None,
|
||||
@@ -700,10 +671,14 @@ class AIAgent:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Direct OpenAI sessions use the Responses API path. GPT-5.x tool
|
||||
# calls with reasoning are rejected on /v1/chat/completions, and
|
||||
# Hermes is a tool-using client by default.
|
||||
if self.api_mode == "chat_completions" and self._is_direct_openai_url():
|
||||
# GPT-5.x models require the Responses API path — they are rejected
|
||||
# on /v1/chat/completions by both OpenAI and OpenRouter. Also
|
||||
# auto-upgrade for direct OpenAI URLs (api.openai.com) since all
|
||||
# newer tool-calling models prefer Responses there.
|
||||
if self.api_mode == "chat_completions" and (
|
||||
self._is_direct_openai_url()
|
||||
or self._model_requires_responses_api(self.model)
|
||||
):
|
||||
self.api_mode = "codex_responses"
|
||||
|
||||
# Pre-warm OpenRouter model metadata cache in a background thread.
|
||||
@@ -724,6 +699,7 @@ class AIAgent:
|
||||
self.clarify_callback = clarify_callback
|
||||
self.step_callback = step_callback
|
||||
self.stream_delta_callback = stream_delta_callback
|
||||
self.interim_assistant_callback = interim_assistant_callback
|
||||
self.status_callback = status_callback
|
||||
self.tool_gen_callback = tool_gen_callback
|
||||
|
||||
@@ -735,6 +711,7 @@ class AIAgent:
|
||||
# Interrupt mechanism for breaking out of tool loops
|
||||
self._interrupt_requested = False
|
||||
self._interrupt_message = None # Optional message that triggered interrupt
|
||||
self._execution_thread_id: int | None = None # Set at run_conversation() start
|
||||
self._client_lock = threading.RLock()
|
||||
|
||||
# Subagent delegation state
|
||||
@@ -770,12 +747,14 @@ class AIAgent:
|
||||
self._use_prompt_caching = (is_openrouter and is_claude) or is_native_anthropic
|
||||
self._cache_ttl = "5m" # Default 5-minute TTL (1.25x write cost)
|
||||
|
||||
# Iteration budget pressure: warn the LLM as it approaches max_iterations.
|
||||
# Warnings are injected into the last tool result JSON (not as separate
|
||||
# messages) so they don't break message structure or invalidate caching.
|
||||
self._budget_caution_threshold = 0.7 # 70% — nudge to start wrapping up
|
||||
self._budget_warning_threshold = 0.9 # 90% — urgent, respond now
|
||||
self._budget_pressure_enabled = True
|
||||
# Iteration budget: the LLM is only notified when it actually exhausts
|
||||
# the iteration budget (api_call_count >= max_iterations). At that
|
||||
# point we inject ONE message, allow one final API call, and if the
|
||||
# model doesn't produce a text response, force a user-message asking
|
||||
# it to summarise. No intermediate pressure warnings — they caused
|
||||
# models to "give up" prematurely on complex tasks (#7915).
|
||||
self._budget_exhausted_injected = False
|
||||
self._budget_grace_call = False
|
||||
|
||||
# Context pressure warnings: notify the USER (not the LLM) as context
|
||||
# fills up. Purely informational — displayed in CLI output and sent via
|
||||
@@ -826,6 +805,11 @@ class AIAgent:
|
||||
# Deferred paragraph break flag — set after tool iterations so a
|
||||
# single "\n\n" is prepended to the next real text delta.
|
||||
self._stream_needs_break = False
|
||||
# Visible assistant text already delivered through live token callbacks
|
||||
# during the current model response. Used to avoid re-sending the same
|
||||
# commentary when the provider later returns it as a completed interim
|
||||
# assistant message.
|
||||
self._current_streamed_assistant_text = ""
|
||||
|
||||
# Optional current-turn user-message override used when the API-facing
|
||||
# user message intentionally differs from the persisted transcript
|
||||
@@ -1326,6 +1310,19 @@ class AIAgent:
|
||||
)
|
||||
self.compression_enabled = compression_enabled
|
||||
|
||||
# Reject models whose context window is below the minimum required
|
||||
# for reliable tool-calling workflows (64K tokens).
|
||||
from agent.model_metadata import MINIMUM_CONTEXT_LENGTH
|
||||
_ctx = getattr(self.context_compressor, "context_length", 0)
|
||||
if _ctx and _ctx < MINIMUM_CONTEXT_LENGTH:
|
||||
raise ValueError(
|
||||
f"Model {self.model} has a context window of {_ctx:,} tokens, "
|
||||
f"which is below the minimum {MINIMUM_CONTEXT_LENGTH:,} required "
|
||||
f"by Hermes Agent. Choose a model with at least "
|
||||
f"{MINIMUM_CONTEXT_LENGTH // 1000}K context, or set "
|
||||
f"model.context_length in config.yaml to override."
|
||||
)
|
||||
|
||||
# Inject context engine tool schemas (e.g. lcm_grep, lcm_describe, lcm_expand)
|
||||
self._context_engine_tool_names: set = set()
|
||||
if hasattr(self, "context_compressor") and self.context_compressor and self.tools is not None:
|
||||
@@ -1402,6 +1399,12 @@ class AIAgent:
|
||||
else:
|
||||
print(f"📊 Context limit: {self.context_compressor.context_length:,} tokens (auto-compression disabled)")
|
||||
|
||||
# Check immediately so CLI users see the warning at startup.
|
||||
# Gateway status_callback is not yet wired, so any warning is stored
|
||||
# in _compression_warning and replayed in the first run_conversation().
|
||||
self._compression_warning = None
|
||||
self._check_compression_model_feasibility()
|
||||
|
||||
# Snapshot primary runtime for per-turn restoration. When fallback
|
||||
# activates during a turn, the next turn restores these values so the
|
||||
# preferred model gets a fresh attempt each time. Uses a single dict
|
||||
@@ -1693,6 +1696,104 @@ class AIAgent:
|
||||
except Exception:
|
||||
logger.debug("status_callback error in _emit_status", exc_info=True)
|
||||
|
||||
def _check_compression_model_feasibility(self) -> None:
|
||||
"""Warn at session start if the auxiliary compression model's context
|
||||
window is smaller than the main model's compression threshold.
|
||||
|
||||
When the auxiliary model cannot fit the content that needs summarising,
|
||||
compression will either fail outright (the LLM call errors) or produce
|
||||
a severely truncated summary.
|
||||
|
||||
Called during ``__init__`` so CLI users see the warning immediately
|
||||
(via ``_vprint``). The gateway sets ``status_callback`` *after*
|
||||
construction, so ``_replay_compression_warning()`` re-sends the
|
||||
stored warning through the callback on the first
|
||||
``run_conversation()`` call.
|
||||
"""
|
||||
if not self.compression_enabled:
|
||||
return
|
||||
try:
|
||||
from agent.auxiliary_client import get_text_auxiliary_client
|
||||
from agent.model_metadata import get_model_context_length
|
||||
|
||||
client, aux_model = get_text_auxiliary_client("compression")
|
||||
if client is None or not aux_model:
|
||||
msg = (
|
||||
"⚠ No auxiliary LLM provider configured — context "
|
||||
"compression will drop middle turns without a summary. "
|
||||
"Run `hermes setup` or set OPENROUTER_API_KEY."
|
||||
)
|
||||
self._compression_warning = msg
|
||||
self._emit_status(msg)
|
||||
logger.warning(
|
||||
"No auxiliary LLM provider for compression — "
|
||||
"summaries will be unavailable."
|
||||
)
|
||||
return
|
||||
|
||||
aux_base_url = str(getattr(client, "base_url", ""))
|
||||
aux_api_key = str(getattr(client, "api_key", ""))
|
||||
aux_context = get_model_context_length(
|
||||
aux_model,
|
||||
base_url=aux_base_url,
|
||||
api_key=aux_api_key,
|
||||
)
|
||||
|
||||
threshold = self.context_compressor.threshold_tokens
|
||||
if aux_context < threshold:
|
||||
# Suggest a threshold that would fit the aux model,
|
||||
# rounded down to a clean percentage.
|
||||
safe_pct = int((aux_context / self.context_compressor.context_length) * 100)
|
||||
msg = (
|
||||
f"⚠ Compression model ({aux_model}) context "
|
||||
f"is {aux_context:,} tokens, but the main model's "
|
||||
f"compression threshold is {threshold:,} tokens. "
|
||||
f"Context compression will not be possible — the "
|
||||
f"content to summarise will exceed the auxiliary "
|
||||
f"model's context window.\n"
|
||||
f" Fix options (config.yaml):\n"
|
||||
f" 1. Use a larger compression model:\n"
|
||||
f" auxiliary:\n"
|
||||
f" compression:\n"
|
||||
f" model: <model-with-{threshold:,}+-context>\n"
|
||||
f" 2. Lower the compression threshold to fit "
|
||||
f"the current model:\n"
|
||||
f" compression:\n"
|
||||
f" threshold: 0.{safe_pct:02d}"
|
||||
)
|
||||
self._compression_warning = msg
|
||||
self._emit_status(msg)
|
||||
logger.warning(
|
||||
"Auxiliary compression model %s has %d token context, "
|
||||
"below the main model's compression threshold of %d "
|
||||
"tokens — compression summaries will fail or be "
|
||||
"severely truncated.",
|
||||
aux_model,
|
||||
aux_context,
|
||||
threshold,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"Compression feasibility check failed (non-fatal): %s", exc
|
||||
)
|
||||
|
||||
def _replay_compression_warning(self) -> None:
|
||||
"""Re-send the compression warning through ``status_callback``.
|
||||
|
||||
During ``__init__`` the gateway's ``status_callback`` is not yet
|
||||
wired, so ``_emit_status`` only reaches ``_vprint`` (CLI). This
|
||||
method is called once at the start of the first
|
||||
``run_conversation()`` — by then the gateway has set the callback,
|
||||
so every platform (Telegram, Discord, Slack, etc.) receives the
|
||||
warning.
|
||||
"""
|
||||
msg = getattr(self, "_compression_warning", None)
|
||||
if msg and self.status_callback:
|
||||
try:
|
||||
self.status_callback("lifecycle", msg)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _is_direct_openai_url(self, base_url: str = None) -> bool:
|
||||
"""Return True when a base URL targets OpenAI's native API."""
|
||||
url = (base_url or self._base_url_lower).lower()
|
||||
@@ -1702,6 +1803,21 @@ class AIAgent:
|
||||
"""Return True when the base URL targets OpenRouter."""
|
||||
return "openrouter" in self._base_url_lower
|
||||
|
||||
@staticmethod
|
||||
def _model_requires_responses_api(model: str) -> bool:
|
||||
"""Return True for models that require the Responses API path.
|
||||
|
||||
GPT-5.x models are rejected on /v1/chat/completions by both
|
||||
OpenAI and OpenRouter (error: ``unsupported_api_for_model``).
|
||||
Detect these so the correct api_mode is set regardless of
|
||||
which provider is serving the model.
|
||||
"""
|
||||
m = model.lower()
|
||||
# Strip vendor prefix (e.g. "openai/gpt-5.4" → "gpt-5.4")
|
||||
if "/" in m:
|
||||
m = m.rsplit("/", 1)[-1]
|
||||
return m.startswith("gpt-5")
|
||||
|
||||
def _max_tokens_param(self, value: int) -> dict:
|
||||
"""Return the correct max tokens kwarg for the current provider.
|
||||
|
||||
@@ -2709,8 +2825,10 @@ class AIAgent:
|
||||
"""
|
||||
self._interrupt_requested = True
|
||||
self._interrupt_message = message
|
||||
# Signal all tools to abort any in-flight operations immediately
|
||||
_set_interrupt(True)
|
||||
# Signal all tools to abort any in-flight operations immediately.
|
||||
# Scope the interrupt to this agent's execution thread so other
|
||||
# agents running in the same process (gateway) are not affected.
|
||||
_set_interrupt(True, self._execution_thread_id)
|
||||
# Propagate interrupt to any running child agents (subagent delegation)
|
||||
with self._active_children_lock:
|
||||
children_copy = list(self._active_children)
|
||||
@@ -2723,10 +2841,10 @@ class AIAgent:
|
||||
print("\n⚡ Interrupt requested" + (f": '{message[:40]}...'" if message and len(message) > 40 else f": '{message}'" if message else ""))
|
||||
|
||||
def clear_interrupt(self) -> None:
|
||||
"""Clear any pending interrupt request and the global tool interrupt signal."""
|
||||
"""Clear any pending interrupt request and the per-thread tool interrupt signal."""
|
||||
self._interrupt_requested = False
|
||||
self._interrupt_message = None
|
||||
_set_interrupt(False)
|
||||
_set_interrupt(False, self._execution_thread_id)
|
||||
|
||||
def _touch_activity(self, desc: str) -> None:
|
||||
"""Update the last-activity timestamp and description (thread-safe)."""
|
||||
@@ -3064,7 +3182,7 @@ class AIAgent:
|
||||
if platform_key in PLATFORM_HINTS:
|
||||
prompt_parts.append(PLATFORM_HINTS[platform_key])
|
||||
|
||||
return "\n\n".join(prompt_parts)
|
||||
return "\n\n".join(p.strip() for p in prompt_parts if p.strip())
|
||||
|
||||
# =========================================================================
|
||||
# Pre/post-call guardrails (inspired by PR #1321 — @alireza78a)
|
||||
@@ -3320,6 +3438,7 @@ class AIAgent:
|
||||
def _chat_messages_to_responses_input(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Convert internal chat-style messages to Responses input items."""
|
||||
items: List[Dict[str, Any]] = []
|
||||
seen_item_ids: set = set()
|
||||
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
@@ -3340,7 +3459,12 @@ class AIAgent:
|
||||
if isinstance(codex_reasoning, list):
|
||||
for ri in codex_reasoning:
|
||||
if isinstance(ri, dict) and ri.get("encrypted_content"):
|
||||
item_id = ri.get("id")
|
||||
if item_id and item_id in seen_item_ids:
|
||||
continue
|
||||
items.append(ri)
|
||||
if item_id:
|
||||
seen_item_ids.add(item_id)
|
||||
has_codex_reasoning = True
|
||||
|
||||
if content_text.strip():
|
||||
@@ -3420,6 +3544,7 @@ class AIAgent:
|
||||
raise ValueError("Codex Responses input must be a list of input items.")
|
||||
|
||||
normalized: List[Dict[str, Any]] = []
|
||||
seen_ids: set = set()
|
||||
for idx, item in enumerate(raw_items):
|
||||
if not isinstance(item, dict):
|
||||
raise ValueError(f"Codex Responses input[{idx}] must be an object.")
|
||||
@@ -3472,8 +3597,12 @@ class AIAgent:
|
||||
if item_type == "reasoning":
|
||||
encrypted = item.get("encrypted_content")
|
||||
if isinstance(encrypted, str) and encrypted:
|
||||
reasoning_item = {"type": "reasoning", "encrypted_content": encrypted}
|
||||
item_id = item.get("id")
|
||||
if isinstance(item_id, str) and item_id:
|
||||
if item_id in seen_ids:
|
||||
continue
|
||||
seen_ids.add(item_id)
|
||||
reasoning_item = {"type": "reasoning", "encrypted_content": encrypted}
|
||||
if isinstance(item_id, str) and item_id:
|
||||
reasoning_item["id"] = item_id
|
||||
summary = item.get("summary")
|
||||
@@ -4593,6 +4722,49 @@ class AIAgent:
|
||||
|
||||
# ── Unified streaming API call ─────────────────────────────────────────
|
||||
|
||||
def _reset_stream_delivery_tracking(self) -> None:
|
||||
"""Reset tracking for text delivered during the current model response."""
|
||||
self._current_streamed_assistant_text = ""
|
||||
|
||||
def _record_streamed_assistant_text(self, text: str) -> None:
|
||||
"""Accumulate visible assistant text emitted through stream callbacks."""
|
||||
if isinstance(text, str) and text:
|
||||
self._current_streamed_assistant_text = (
|
||||
getattr(self, "_current_streamed_assistant_text", "") + text
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_interim_visible_text(text: str) -> str:
|
||||
if not isinstance(text, str):
|
||||
return ""
|
||||
return re.sub(r"\s+", " ", text).strip()
|
||||
|
||||
def _interim_content_was_streamed(self, content: str) -> bool:
|
||||
visible_content = self._normalize_interim_visible_text(
|
||||
self._strip_think_blocks(content or "")
|
||||
)
|
||||
if not visible_content:
|
||||
return False
|
||||
streamed = self._normalize_interim_visible_text(
|
||||
self._strip_think_blocks(getattr(self, "_current_streamed_assistant_text", "") or "")
|
||||
)
|
||||
return bool(streamed) and streamed == visible_content
|
||||
|
||||
def _emit_interim_assistant_message(self, assistant_msg: Dict[str, Any]) -> None:
|
||||
"""Surface a real mid-turn assistant commentary message to the UI layer."""
|
||||
cb = getattr(self, "interim_assistant_callback", None)
|
||||
if cb is None or not isinstance(assistant_msg, dict):
|
||||
return
|
||||
content = assistant_msg.get("content")
|
||||
visible = self._strip_think_blocks(content or "").strip()
|
||||
if not visible or visible == "(empty)":
|
||||
return
|
||||
already_streamed = self._interim_content_was_streamed(visible)
|
||||
try:
|
||||
cb(visible, already_streamed=already_streamed)
|
||||
except Exception:
|
||||
logger.debug("interim_assistant_callback error", exc_info=True)
|
||||
|
||||
def _fire_stream_delta(self, text: str) -> None:
|
||||
"""Fire all registered stream delta callbacks (display + TTS)."""
|
||||
# If a tool iteration set the break flag, prepend a single paragraph
|
||||
@@ -4602,12 +4774,16 @@ class AIAgent:
|
||||
if getattr(self, "_stream_needs_break", False) and text and text.strip():
|
||||
self._stream_needs_break = False
|
||||
text = "\n\n" + text
|
||||
for cb in (self.stream_delta_callback, self._stream_callback):
|
||||
if cb is not None:
|
||||
try:
|
||||
cb(text)
|
||||
except Exception:
|
||||
pass
|
||||
callbacks = [cb for cb in (self.stream_delta_callback, self._stream_callback) if cb is not None]
|
||||
delivered = False
|
||||
for cb in callbacks:
|
||||
try:
|
||||
cb(text)
|
||||
delivered = True
|
||||
except Exception:
|
||||
pass
|
||||
if delivered:
|
||||
self._record_streamed_assistant_text(text)
|
||||
|
||||
def _fire_reasoning_delta(self, text: str) -> None:
|
||||
"""Fire reasoning callback if registered."""
|
||||
@@ -4791,6 +4967,7 @@ class AIAgent:
|
||||
if self.stream_delta_callback:
|
||||
try:
|
||||
self.stream_delta_callback(delta.content)
|
||||
self._record_streamed_assistant_text(delta.content)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -5251,7 +5428,7 @@ class AIAgent:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Determine api_mode from provider / base URL
|
||||
# Determine api_mode from provider / base URL / model
|
||||
fb_api_mode = "chat_completions"
|
||||
fb_base_url = str(fb_client.base_url)
|
||||
if fb_provider == "openai-codex":
|
||||
@@ -5260,6 +5437,10 @@ class AIAgent:
|
||||
fb_api_mode = "anthropic_messages"
|
||||
elif self._is_direct_openai_url(fb_base_url):
|
||||
fb_api_mode = "codex_responses"
|
||||
elif self._model_requires_responses_api(fb_model):
|
||||
# GPT-5.x models need Responses API on every provider
|
||||
# (OpenRouter, Copilot, direct OpenAI, etc.)
|
||||
fb_api_mode = "codex_responses"
|
||||
|
||||
old_model = self.model
|
||||
self.model = fb_model
|
||||
@@ -5348,8 +5529,8 @@ class AIAgent:
|
||||
to the fallback provider for every subsequent turn. Calling this at
|
||||
the top of ``run_conversation()`` makes fallback turn-scoped.
|
||||
|
||||
The gateway creates a fresh agent per message so this is a no-op
|
||||
there (``_fallback_activated`` is always False at turn start).
|
||||
The gateway caches agents across messages (``_agent_cache`` in
|
||||
``gateway/run.py``), so this restoration IS needed there too.
|
||||
"""
|
||||
if not self._fallback_activated:
|
||||
return False
|
||||
@@ -5888,8 +6069,16 @@ class AIAgent:
|
||||
api_kwargs["tools"] = self.tools
|
||||
|
||||
if self.max_tokens is not None:
|
||||
if not self._is_qwen_portal():
|
||||
api_kwargs.update(self._max_tokens_param(self.max_tokens))
|
||||
api_kwargs.update(self._max_tokens_param(self.max_tokens))
|
||||
elif self._is_qwen_portal():
|
||||
# Qwen Portal defaults to a very low max_tokens when omitted.
|
||||
# Reasoning models (qwen3-coder-plus) exhaust that budget on
|
||||
# thinking tokens alone, causing the portal to return
|
||||
# finish_reason="stop" with truncated output — the agent sees
|
||||
# this as an intentional stop and exits the loop. Send 65536
|
||||
# (the documented max output for qwen3-coder models) so the
|
||||
# model has adequate output budget for tool calls.
|
||||
api_kwargs.update(self._max_tokens_param(65536))
|
||||
elif (self._is_openrouter_url() or "nousresearch" in self._base_url_lower) and "claude" in (self.model or "").lower():
|
||||
# OpenRouter and Nous Portal translate requests to Anthropic's
|
||||
# Messages API, which requires max_tokens as a mandatory field.
|
||||
@@ -6770,24 +6959,6 @@ class AIAgent:
|
||||
turn_tool_msgs = messages[-num_tools:]
|
||||
enforce_turn_budget(turn_tool_msgs, env=get_active_env(effective_task_id))
|
||||
|
||||
# ── Budget pressure injection ────────────────────────────────────
|
||||
budget_warning = self._get_budget_warning(api_call_count)
|
||||
if budget_warning and messages and messages[-1].get("role") == "tool":
|
||||
last_content = messages[-1]["content"]
|
||||
try:
|
||||
parsed = json.loads(last_content)
|
||||
if isinstance(parsed, dict):
|
||||
parsed["_budget_warning"] = budget_warning
|
||||
messages[-1]["content"] = json.dumps(parsed, ensure_ascii=False)
|
||||
else:
|
||||
messages[-1]["content"] = last_content + f"\n\n{budget_warning}"
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
messages[-1]["content"] = last_content + f"\n\n{budget_warning}"
|
||||
if not self.quiet_mode:
|
||||
remaining = self.max_iterations - api_call_count
|
||||
tier = "⚠️ WARNING" if remaining <= self.max_iterations * 0.1 else "💡 CAUTION"
|
||||
print(f"{self.log_prefix}{tier}: {remaining} iterations remaining")
|
||||
|
||||
def _execute_tool_calls_sequential(self, assistant_message, messages: list, effective_task_id: str, api_call_count: int = 0) -> None:
|
||||
"""Execute tool calls sequentially (original behavior). Used for single calls or interactive tools."""
|
||||
for i, tool_call in enumerate(assistant_message.tool_calls, 1):
|
||||
@@ -6836,6 +7007,15 @@ class AIAgent:
|
||||
self._current_tool = function_name
|
||||
self._touch_activity(f"executing tool: {function_name}")
|
||||
|
||||
# Set activity callback for long-running tool execution (terminal
|
||||
# commands, etc.) so the gateway's inactivity monitor doesn't kill
|
||||
# the agent while a command is running.
|
||||
try:
|
||||
from tools.environments.base import set_activity_callback
|
||||
set_activity_callback(self._touch_activity)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if self.tool_progress_callback:
|
||||
try:
|
||||
preview = _build_tool_preview(function_name, function_args)
|
||||
@@ -7125,50 +7305,7 @@ class AIAgent:
|
||||
if num_tools_seq > 0:
|
||||
enforce_turn_budget(messages[-num_tools_seq:], env=get_active_env(effective_task_id))
|
||||
|
||||
# ── Budget pressure injection ─────────────────────────────────
|
||||
# After all tool calls in this turn are processed, check if we're
|
||||
# approaching max_iterations. If so, inject a warning into the LAST
|
||||
# tool result's JSON so the LLM sees it naturally when reading results.
|
||||
budget_warning = self._get_budget_warning(api_call_count)
|
||||
if budget_warning and messages and messages[-1].get("role") == "tool":
|
||||
last_content = messages[-1]["content"]
|
||||
try:
|
||||
parsed = json.loads(last_content)
|
||||
if isinstance(parsed, dict):
|
||||
parsed["_budget_warning"] = budget_warning
|
||||
messages[-1]["content"] = json.dumps(parsed, ensure_ascii=False)
|
||||
else:
|
||||
messages[-1]["content"] = last_content + f"\n\n{budget_warning}"
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
messages[-1]["content"] = last_content + f"\n\n{budget_warning}"
|
||||
if not self.quiet_mode:
|
||||
remaining = self.max_iterations - api_call_count
|
||||
tier = "⚠️ WARNING" if remaining <= self.max_iterations * 0.1 else "💡 CAUTION"
|
||||
print(f"{self.log_prefix}{tier}: {remaining} iterations remaining")
|
||||
|
||||
def _get_budget_warning(self, api_call_count: int) -> Optional[str]:
|
||||
"""Return a budget pressure string, or None if not yet needed.
|
||||
|
||||
Two-tier system:
|
||||
- Caution (70%): nudge to consolidate work
|
||||
- Warning (90%): urgent, must respond now
|
||||
"""
|
||||
if not self._budget_pressure_enabled or self.max_iterations <= 0:
|
||||
return None
|
||||
progress = api_call_count / self.max_iterations
|
||||
remaining = self.max_iterations - api_call_count
|
||||
if progress >= self._budget_warning_threshold:
|
||||
return (
|
||||
f"[BUDGET WARNING: Iteration {api_call_count}/{self.max_iterations}. "
|
||||
f"Only {remaining} iteration(s) left. "
|
||||
"Provide your final response NOW. No more tool calls unless absolutely critical.]"
|
||||
)
|
||||
if progress >= self._budget_caution_threshold:
|
||||
return (
|
||||
f"[BUDGET: Iteration {api_call_count}/{self.max_iterations}. "
|
||||
f"{remaining} iterations left. Start consolidating your work.]"
|
||||
)
|
||||
return None
|
||||
|
||||
def _emit_context_pressure(self, compaction_progress: float, compressor) -> None:
|
||||
"""Notify the user that context is approaching the compaction threshold.
|
||||
@@ -7392,6 +7529,11 @@ class AIAgent:
|
||||
# Installed once, transparent when streams are healthy, prevents crash on write.
|
||||
_install_safe_stdio()
|
||||
|
||||
# Tag all log records on this thread with the session ID so
|
||||
# ``hermes logs --session <id>`` can filter a single conversation.
|
||||
from hermes_logging import set_session_context
|
||||
set_session_context(self.session_id)
|
||||
|
||||
# If the previous turn activated fallback, restore the primary
|
||||
# runtime so this turn gets a fresh attempt with the preferred model.
|
||||
# No-op when _fallback_activated is False (gateway, first turn, etc.).
|
||||
@@ -7437,6 +7579,12 @@ class AIAgent:
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
# Replay compression warning through status_callback for gateway
|
||||
# platforms (the callback was not wired during __init__).
|
||||
if self._compression_warning:
|
||||
self._replay_compression_warning()
|
||||
self._compression_warning = None # send once
|
||||
|
||||
# NOTE: _turns_since_memory and _iters_since_skill are NOT reset here.
|
||||
# They are initialized in __init__ and must persist across run_conversation
|
||||
# calls so that nudge logic accumulates correctly in CLI mode.
|
||||
@@ -7455,14 +7603,6 @@ class AIAgent:
|
||||
# Initialize conversation (copy to avoid mutating the caller's list)
|
||||
messages = list(conversation_history) if conversation_history else []
|
||||
|
||||
# Strip budget pressure warnings from previous turns. These are
|
||||
# turn-scoped signals injected by _get_budget_warning() into tool
|
||||
# result content. If left in the replayed history, models (especially
|
||||
# GPT-family) interpret them as still-active instructions and avoid
|
||||
# making tool calls in ALL subsequent turns.
|
||||
if messages:
|
||||
_strip_budget_warnings_from_history(messages)
|
||||
|
||||
# Hydrate todo store from conversation history (gateway creates a fresh
|
||||
# AIAgent per message, so the in-memory store is empty -- we need to
|
||||
# recover the todo state from the most recent todo tool response in history)
|
||||
@@ -7658,6 +7798,11 @@ class AIAgent:
|
||||
compression_attempts = 0
|
||||
_turn_exit_reason = "unknown" # Diagnostic: why the loop ended
|
||||
|
||||
# Record the execution thread so interrupt()/clear_interrupt() can
|
||||
# scope the tool-level interrupt signal to THIS agent's thread only.
|
||||
# Must be set before clear_interrupt() which uses it.
|
||||
self._execution_thread_id = threading.current_thread().ident
|
||||
|
||||
# Clear any stale interrupt state at start
|
||||
self.clear_interrupt()
|
||||
|
||||
@@ -7674,7 +7819,7 @@ class AIAgent:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
while api_call_count < self.max_iterations and self.iteration_budget.remaining > 0:
|
||||
while (api_call_count < self.max_iterations and self.iteration_budget.remaining > 0) or self._budget_grace_call:
|
||||
# Reset per-turn checkpoint dedup so each iteration can take one snapshot
|
||||
self._checkpoint_mgr.new_turn()
|
||||
|
||||
@@ -7689,7 +7834,13 @@ class AIAgent:
|
||||
api_call_count += 1
|
||||
self._api_call_count = api_call_count
|
||||
self._touch_activity(f"starting API call #{api_call_count}")
|
||||
if not self.iteration_budget.consume():
|
||||
|
||||
# Grace call: the budget is exhausted but we gave the model one
|
||||
# more chance. Consume the grace flag so the loop exits after
|
||||
# this iteration regardless of outcome.
|
||||
if self._budget_grace_call:
|
||||
self._budget_grace_call = False
|
||||
elif not self.iteration_budget.consume():
|
||||
_turn_exit_reason = "budget_exhausted"
|
||||
if not self.quiet_mode:
|
||||
self._safe_print(f"\n⚠️ Iteration budget exhausted ({self.iteration_budget.used}/{self.iteration_budget.max_total} iterations used)")
|
||||
@@ -7816,9 +7967,39 @@ class AIAgent:
|
||||
# manual message manipulation are always caught.
|
||||
api_messages = self._sanitize_api_messages(api_messages)
|
||||
|
||||
# Normalize message whitespace and tool-call JSON for consistent
|
||||
# prefix matching. Ensures bit-perfect prefixes across turns,
|
||||
# which enables KV cache reuse on local inference servers
|
||||
# (llama.cpp, vLLM, Ollama) and improves cache hit rates for
|
||||
# cloud providers. Operates on api_messages (the API copy) so
|
||||
# the original conversation history in `messages` is untouched.
|
||||
for am in api_messages:
|
||||
if isinstance(am.get("content"), str):
|
||||
am["content"] = am["content"].strip()
|
||||
for am in api_messages:
|
||||
tcs = am.get("tool_calls")
|
||||
if not tcs:
|
||||
continue
|
||||
new_tcs = []
|
||||
for tc in tcs:
|
||||
if isinstance(tc, dict) and "function" in tc:
|
||||
try:
|
||||
args_obj = json.loads(tc["function"]["arguments"])
|
||||
tc = {**tc, "function": {
|
||||
**tc["function"],
|
||||
"arguments": json.dumps(
|
||||
args_obj, separators=(",", ":"),
|
||||
sort_keys=True,
|
||||
),
|
||||
}}
|
||||
except Exception:
|
||||
pass
|
||||
new_tcs.append(tc)
|
||||
am["tool_calls"] = new_tcs
|
||||
|
||||
# Calculate approximate request size for logging
|
||||
total_chars = sum(len(str(msg)) for msg in api_messages)
|
||||
approx_tokens = total_chars // 4 # Rough estimate: 4 chars per token
|
||||
approx_tokens = estimate_messages_tokens_rough(api_messages)
|
||||
|
||||
# Thinking spinner for quiet mode (animated during API call)
|
||||
thinking_spinner = None
|
||||
@@ -7867,6 +8048,7 @@ class AIAgent:
|
||||
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
self._reset_stream_delivery_tracking()
|
||||
api_kwargs = self._build_api_kwargs(api_messages)
|
||||
if self.api_mode == "codex_responses":
|
||||
api_kwargs = self._preflight_codex_api_kwargs(api_kwargs, allow_stream=False)
|
||||
@@ -8025,6 +8207,8 @@ class AIAgent:
|
||||
self._emit_status("⚠️ Empty/malformed response — switching to fallback...")
|
||||
if self._try_activate_fallback():
|
||||
retry_count = 0
|
||||
compression_attempts = 0
|
||||
primary_recovery_attempted = False
|
||||
continue
|
||||
|
||||
# Check for error field in response (some providers include this)
|
||||
@@ -8060,6 +8244,8 @@ class AIAgent:
|
||||
self._emit_status(f"⚠️ Max retries ({max_retries}) for invalid responses — trying fallback...")
|
||||
if self._try_activate_fallback():
|
||||
retry_count = 0
|
||||
compression_attempts = 0
|
||||
primary_recovery_attempted = False
|
||||
continue
|
||||
self._emit_status(f"❌ Max retries ({max_retries}) exceeded for invalid responses. Giving up.")
|
||||
logging.error(f"{self.log_prefix}Invalid API response after {max_retries} retries.")
|
||||
@@ -8136,8 +8322,24 @@ class AIAgent:
|
||||
_text_parts.append(getattr(_blk, "text", ""))
|
||||
_trunc_content = "\n".join(_text_parts) if _text_parts else None
|
||||
|
||||
# A response is "thinking exhausted" only when the model
|
||||
# actually produced reasoning blocks but no visible text after
|
||||
# them. Models that do not use <think> tags (e.g. GLM-4.7 on
|
||||
# NVIDIA Build, minimax) may return content=None or an empty
|
||||
# string for unrelated reasons — treat those as normal
|
||||
# truncations that deserve continuation retries, not as
|
||||
# thinking-budget exhaustion.
|
||||
_has_think_tags = bool(
|
||||
_trunc_content and re.search(
|
||||
r'<(?:think|thinking|reasoning|REASONING_SCRATCHPAD)[^>]*>',
|
||||
_trunc_content,
|
||||
re.IGNORECASE,
|
||||
)
|
||||
)
|
||||
_thinking_exhausted = (
|
||||
not _trunc_has_tool_calls and (
|
||||
not _trunc_has_tool_calls
|
||||
and _has_think_tags
|
||||
and (
|
||||
(_trunc_content is not None and not self._has_content_after_think_block(_trunc_content))
|
||||
or _trunc_content is None
|
||||
)
|
||||
@@ -8698,6 +8900,8 @@ class AIAgent:
|
||||
self._emit_status("⚠️ Rate limited — switching to fallback provider...")
|
||||
if self._try_activate_fallback():
|
||||
retry_count = 0
|
||||
compression_attempts = 0
|
||||
primary_recovery_attempted = False
|
||||
continue
|
||||
|
||||
is_payload_too_large = (
|
||||
@@ -8911,6 +9115,8 @@ class AIAgent:
|
||||
self._emit_status(f"⚠️ Non-retryable error (HTTP {status_code}) — trying fallback...")
|
||||
if self._try_activate_fallback():
|
||||
retry_count = 0
|
||||
compression_attempts = 0
|
||||
primary_recovery_attempted = False
|
||||
continue
|
||||
if api_kwargs is not None:
|
||||
self._dump_api_request_debug(
|
||||
@@ -8976,6 +9182,8 @@ class AIAgent:
|
||||
self._emit_status(f"⚠️ Max retries ({max_retries}) exhausted — trying fallback...")
|
||||
if self._try_activate_fallback():
|
||||
retry_count = 0
|
||||
compression_attempts = 0
|
||||
primary_recovery_attempted = False
|
||||
continue
|
||||
_final_summary = self._summarize_api_error(api_error)
|
||||
if is_rate_limited:
|
||||
@@ -9199,8 +9407,6 @@ class AIAgent:
|
||||
# Check for incomplete <REASONING_SCRATCHPAD> (opened but never closed)
|
||||
# This means the model ran out of output tokens mid-reasoning — retry up to 2 times
|
||||
if has_incomplete_scratchpad(assistant_message.content or ""):
|
||||
if not hasattr(self, '_incomplete_scratchpad_retries'):
|
||||
self._incomplete_scratchpad_retries = 0
|
||||
self._incomplete_scratchpad_retries += 1
|
||||
|
||||
self._vprint(f"{self.log_prefix}⚠️ Incomplete <REASONING_SCRATCHPAD> detected (opened but never closed)")
|
||||
@@ -9228,12 +9434,9 @@ class AIAgent:
|
||||
}
|
||||
|
||||
# Reset incomplete scratchpad counter on clean response
|
||||
if hasattr(self, '_incomplete_scratchpad_retries'):
|
||||
self._incomplete_scratchpad_retries = 0
|
||||
self._incomplete_scratchpad_retries = 0
|
||||
|
||||
if self.api_mode == "codex_responses" and finish_reason == "incomplete":
|
||||
if not hasattr(self, "_codex_incomplete_retries"):
|
||||
self._codex_incomplete_retries = 0
|
||||
self._codex_incomplete_retries += 1
|
||||
|
||||
interim_msg = self._build_assistant_message(assistant_message, finish_reason)
|
||||
@@ -9260,6 +9463,7 @@ class AIAgent:
|
||||
)
|
||||
if not duplicate_interim:
|
||||
messages.append(interim_msg)
|
||||
self._emit_interim_assistant_message(interim_msg)
|
||||
|
||||
if self._codex_incomplete_retries < 3:
|
||||
if not self.quiet_mode:
|
||||
@@ -9304,8 +9508,6 @@ class AIAgent:
|
||||
]
|
||||
if invalid_tool_calls:
|
||||
# Track retries for invalid tool calls
|
||||
if not hasattr(self, '_invalid_tool_retries'):
|
||||
self._invalid_tool_retries = 0
|
||||
self._invalid_tool_retries += 1
|
||||
|
||||
# Return helpful error to model — model can self-correct next turn
|
||||
@@ -9341,8 +9543,7 @@ class AIAgent:
|
||||
})
|
||||
continue
|
||||
# Reset retry counter on successful tool call validation
|
||||
if hasattr(self, '_invalid_tool_retries'):
|
||||
self._invalid_tool_retries = 0
|
||||
self._invalid_tool_retries = 0
|
||||
|
||||
# Validate tool call arguments are valid JSON
|
||||
# Handle empty strings as empty objects (common model quirk)
|
||||
@@ -9365,12 +9566,41 @@ class AIAgent:
|
||||
invalid_json_args.append((tc.function.name, str(e)))
|
||||
|
||||
if invalid_json_args:
|
||||
# Check if the invalid JSON is due to truncation rather
|
||||
# than a model formatting mistake. Routers sometimes
|
||||
# rewrite finish_reason from "length" to "tool_calls",
|
||||
# hiding the truncation from the length handler above.
|
||||
# Detect truncation: args that don't end with } or ]
|
||||
# (after stripping whitespace) are cut off mid-stream.
|
||||
_truncated = any(
|
||||
not (tc.function.arguments or "").rstrip().endswith(("}", "]"))
|
||||
for tc in assistant_message.tool_calls
|
||||
if tc.function.name in {n for n, _ in invalid_json_args}
|
||||
)
|
||||
if _truncated:
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ Truncated tool call arguments detected "
|
||||
f"(finish_reason={finish_reason!r}) — refusing to execute.",
|
||||
force=True,
|
||||
)
|
||||
self._invalid_json_retries = 0
|
||||
self._cleanup_task_resources(effective_task_id)
|
||||
self._persist_session(messages, conversation_history)
|
||||
return {
|
||||
"final_response": None,
|
||||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": False,
|
||||
"partial": True,
|
||||
"error": "Response truncated due to output length limit",
|
||||
}
|
||||
|
||||
# Track retries for invalid JSON arguments
|
||||
self._invalid_json_retries += 1
|
||||
|
||||
|
||||
tool_name, error_msg = invalid_json_args[0]
|
||||
self._vprint(f"{self.log_prefix}⚠️ Invalid JSON in tool call arguments for '{tool_name}': {error_msg}")
|
||||
|
||||
|
||||
if self._invalid_json_retries < 3:
|
||||
self._vprint(f"{self.log_prefix}🔄 Retrying API call ({self._invalid_json_retries}/3)...")
|
||||
# Don't add anything to messages, just retry the API call
|
||||
@@ -9453,6 +9683,7 @@ class AIAgent:
|
||||
messages.pop()
|
||||
|
||||
messages.append(assistant_msg)
|
||||
self._emit_interim_assistant_message(assistant_msg)
|
||||
|
||||
# Close any open streaming display (response box, reasoning
|
||||
# box) before tool execution begins. Intermediate turns may
|
||||
@@ -9715,8 +9946,7 @@ class AIAgent:
|
||||
break
|
||||
|
||||
# Reset retry counter/signature on successful content
|
||||
if hasattr(self, '_empty_content_retries'):
|
||||
self._empty_content_retries = 0
|
||||
self._empty_content_retries = 0
|
||||
self._thinking_prefill_retries = 0
|
||||
|
||||
if (
|
||||
@@ -9732,6 +9962,7 @@ class AIAgent:
|
||||
codex_ack_continuations += 1
|
||||
interim_msg = self._build_assistant_message(assistant_message, "incomplete")
|
||||
messages.append(interim_msg)
|
||||
self._emit_interim_assistant_message(interim_msg)
|
||||
|
||||
continue_msg = {
|
||||
"role": "user",
|
||||
@@ -9782,8 +10013,7 @@ class AIAgent:
|
||||
except (OSError, ValueError):
|
||||
logger.error(error_msg)
|
||||
|
||||
if self.verbose_logging:
|
||||
logging.exception("Detailed error information:")
|
||||
logger.debug("Outer loop error in API call #%d", api_call_count, exc_info=True)
|
||||
|
||||
# If an assistant message with tool_calls was already appended,
|
||||
# the API expects a role="tool" result for every tool_call_id.
|
||||
@@ -9829,7 +10059,31 @@ class AIAgent:
|
||||
if final_response is None and (
|
||||
api_call_count >= self.max_iterations
|
||||
or self.iteration_budget.remaining <= 0
|
||||
):
|
||||
) and not self._budget_exhausted_injected:
|
||||
# Budget exhausted but we haven't tried asking the model to
|
||||
# summarise yet. Inject a user message and give it one grace
|
||||
# API call to produce a text response.
|
||||
self._budget_exhausted_injected = True
|
||||
self._budget_grace_call = True
|
||||
_grace_msg = (
|
||||
"Your tool budget ran out. Please give me the information "
|
||||
"or actions you've completed so far."
|
||||
)
|
||||
messages.append({"role": "user", "content": _grace_msg})
|
||||
self._emit_status(
|
||||
f"⚠️ Iteration budget exhausted ({api_call_count}/{self.max_iterations}) "
|
||||
"— asking model to summarise"
|
||||
)
|
||||
if not self.quiet_mode:
|
||||
self._safe_print(
|
||||
f"\n⚠️ Iteration budget exhausted ({api_call_count}/{self.max_iterations}) "
|
||||
"— requesting summary..."
|
||||
)
|
||||
|
||||
if final_response is None and (
|
||||
api_call_count >= self.max_iterations
|
||||
or self.iteration_budget.remaining <= 0
|
||||
) and not self._budget_grace_call:
|
||||
_turn_exit_reason = f"max_iterations_reached({api_call_count}/{self.max_iterations})"
|
||||
if self.iteration_budget.remaining <= 0 and not self.quiet_mode:
|
||||
print(f"\n⚠️ Iteration budget exhausted ({self.iteration_budget.used}/{self.iteration_budget.max_total} iterations used)")
|
||||
|
||||
@@ -249,8 +249,12 @@ def check_config(groq_key, eleven_key):
|
||||
|
||||
if stt_provider == "groq" and not groq_key:
|
||||
warn("STT config says groq but GROQ_API_KEY is missing")
|
||||
if stt_provider == "mistral" and not os.getenv("MISTRAL_API_KEY"):
|
||||
warn("STT config says mistral but MISTRAL_API_KEY is missing")
|
||||
if tts_provider == "elevenlabs" and not eleven_key:
|
||||
warn("TTS config says elevenlabs but ELEVENLABS_API_KEY is missing")
|
||||
if tts_provider == "mistral" and not os.getenv("MISTRAL_API_KEY"):
|
||||
warn("TTS config says mistral but MISTRAL_API_KEY is missing")
|
||||
except Exception as e:
|
||||
warn("config.yaml", f"parse error: {e}")
|
||||
else:
|
||||
|
||||
+11
-4
@@ -8,7 +8,7 @@
|
||||
"name": "hermes-whatsapp-bridge",
|
||||
"version": "1.0.0",
|
||||
"dependencies": {
|
||||
"@whiskeysockets/baileys": "7.0.0-rc.9",
|
||||
"@whiskeysockets/baileys": "WhiskeySockets/Baileys#fix/abprops-abt-fetch",
|
||||
"express": "^4.21.0",
|
||||
"pino": "^9.0.0",
|
||||
"qrcode-terminal": "^0.12.0"
|
||||
@@ -730,21 +730,22 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@whiskeysockets/baileys": {
|
||||
"name": "baileys",
|
||||
"version": "7.0.0-rc.9",
|
||||
"resolved": "https://registry.npmjs.org/@whiskeysockets/baileys/-/baileys-7.0.0-rc.9.tgz",
|
||||
"integrity": "sha512-YFm5gKXfDP9byCXCW3OPHKXLzrAKzolzgVUlRosHHgwbnf2YOO3XknkMm6J7+F0ns8OA0uuSBhgkRHTDtqkacw==",
|
||||
"resolved": "git+ssh://git@github.com/WhiskeySockets/Baileys.git#01047debd81beb20da7b7779b08edcb06aa03770",
|
||||
"hasInstallScript": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@cacheable/node-cache": "^1.4.0",
|
||||
"@hapi/boom": "^9.1.3",
|
||||
"async-mutex": "^0.5.0",
|
||||
"libsignal": "git+https://github.com/whiskeysockets/libsignal-node.git",
|
||||
"libsignal": "git+https://github.com/whiskeysockets/libsignal-node",
|
||||
"lru-cache": "^11.1.0",
|
||||
"music-metadata": "^11.7.0",
|
||||
"p-queue": "^9.0.0",
|
||||
"pino": "^9.6",
|
||||
"protobufjs": "^7.2.4",
|
||||
"whatsapp-rust-bridge": "0.5.2",
|
||||
"ws": "^8.13.0"
|
||||
},
|
||||
"engines": {
|
||||
@@ -2125,6 +2126,12 @@
|
||||
"node": ">= 0.8"
|
||||
}
|
||||
},
|
||||
"node_modules/whatsapp-rust-bridge": {
|
||||
"version": "0.5.2",
|
||||
"resolved": "https://registry.npmjs.org/whatsapp-rust-bridge/-/whatsapp-rust-bridge-0.5.2.tgz",
|
||||
"integrity": "sha512-6KBRNvxg6WMIwZ/euA8qVzj16qxMBzLllfmaJIP1JGAAfSvwn6nr8JDOMXeqpXPEOl71UfOG+79JwKEoT2b1Fw==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/win-guid": {
|
||||
"version": "0.2.1",
|
||||
"resolved": "https://registry.npmjs.org/win-guid/-/win-guid-0.2.1.tgz",
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
"start": "node bridge.js"
|
||||
},
|
||||
"dependencies": {
|
||||
"@whiskeysockets/baileys": "7.0.0-rc.9",
|
||||
"@whiskeysockets/baileys": "WhiskeySockets/Baileys#fix/abprops-abt-fetch",
|
||||
"express": "^4.21.0",
|
||||
"qrcode-terminal": "^0.12.0",
|
||||
"pino": "^9.0.0"
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"""Tests for agent.auxiliary_client resolution chain, provider overrides, and model overrides."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -14,6 +15,7 @@ from agent.auxiliary_client import (
|
||||
resolve_provider_client,
|
||||
auxiliary_max_tokens_param,
|
||||
call_llm,
|
||||
async_call_llm,
|
||||
_read_codex_access_token,
|
||||
_get_auxiliary_provider,
|
||||
_get_provider_chain,
|
||||
@@ -1122,8 +1124,8 @@ class TestCallLlmPaymentFallback:
|
||||
exc.status_code = 402
|
||||
return exc
|
||||
|
||||
def test_402_triggers_fallback(self, monkeypatch):
|
||||
"""When the primary provider returns 402, call_llm tries the next one."""
|
||||
def test_402_triggers_fallback_when_auto(self, monkeypatch):
|
||||
"""When provider is auto and returns 402, call_llm tries the next one."""
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
|
||||
primary_client = MagicMock()
|
||||
@@ -1136,7 +1138,7 @@ class TestCallLlmPaymentFallback:
|
||||
with patch("agent.auxiliary_client._get_cached_client",
|
||||
return_value=(primary_client, "google/gemini-3-flash-preview")), \
|
||||
patch("agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("openrouter", "google/gemini-3-flash-preview", None, None)), \
|
||||
return_value=("auto", "google/gemini-3-flash-preview", None, None, None)), \
|
||||
patch("agent.auxiliary_client._try_payment_fallback",
|
||||
return_value=(fallback_client, "gpt-5.2-codex", "openai-codex")) as mock_fb:
|
||||
result = call_llm(
|
||||
@@ -1145,13 +1147,62 @@ class TestCallLlmPaymentFallback:
|
||||
)
|
||||
|
||||
assert result is fallback_response
|
||||
mock_fb.assert_called_once_with("openrouter", "compression")
|
||||
mock_fb.assert_called_once_with("auto", "compression", reason="payment error")
|
||||
# Fallback call should use the fallback model
|
||||
fb_kwargs = fallback_client.chat.completions.create.call_args.kwargs
|
||||
assert fb_kwargs["model"] == "gpt-5.2-codex"
|
||||
|
||||
def test_402_no_fallback_when_explicit_provider(self, monkeypatch):
|
||||
"""When provider is explicitly configured (not auto), 402 should NOT fallback (#7559)."""
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
|
||||
primary_client = MagicMock()
|
||||
primary_client.chat.completions.create.side_effect = self._make_402_error()
|
||||
|
||||
with patch("agent.auxiliary_client._get_cached_client",
|
||||
return_value=(primary_client, "local-model")), \
|
||||
patch("agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("custom", "local-model", None, None, None)), \
|
||||
patch("agent.auxiliary_client._try_payment_fallback") as mock_fb:
|
||||
with pytest.raises(Exception, match="insufficient credits"):
|
||||
call_llm(
|
||||
task="compression",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
)
|
||||
|
||||
# Fallback should NOT be attempted when provider is explicit
|
||||
mock_fb.assert_not_called()
|
||||
|
||||
def test_connection_error_triggers_fallback_when_auto(self, monkeypatch):
|
||||
"""Connection errors also trigger fallback when provider is auto."""
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
|
||||
primary_client = MagicMock()
|
||||
conn_err = Exception("Connection refused")
|
||||
conn_err.status_code = None
|
||||
primary_client.chat.completions.create.side_effect = conn_err
|
||||
|
||||
fallback_client = MagicMock()
|
||||
fallback_response = MagicMock()
|
||||
fallback_client.chat.completions.create.return_value = fallback_response
|
||||
|
||||
with patch("agent.auxiliary_client._get_cached_client",
|
||||
return_value=(primary_client, "model")), \
|
||||
patch("agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("auto", "model", None, None, None)), \
|
||||
patch("agent.auxiliary_client._is_connection_error", return_value=True), \
|
||||
patch("agent.auxiliary_client._try_payment_fallback",
|
||||
return_value=(fallback_client, "fb-model", "nous")) as mock_fb:
|
||||
result = call_llm(
|
||||
task="compression",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
)
|
||||
|
||||
assert result is fallback_response
|
||||
mock_fb.assert_called_once_with("auto", "compression", reason="connection error")
|
||||
|
||||
def test_non_payment_error_not_caught(self, monkeypatch):
|
||||
"""Non-payment errors (500, connection, etc.) should NOT trigger fallback."""
|
||||
"""Non-payment/non-connection errors (500) should NOT trigger fallback."""
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
|
||||
primary_client = MagicMock()
|
||||
@@ -1162,7 +1213,7 @@ class TestCallLlmPaymentFallback:
|
||||
with patch("agent.auxiliary_client._get_cached_client",
|
||||
return_value=(primary_client, "google/gemini-3-flash-preview")), \
|
||||
patch("agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("openrouter", "google/gemini-3-flash-preview", None, None)):
|
||||
return_value=("auto", "google/gemini-3-flash-preview", None, None, None)):
|
||||
with pytest.raises(Exception, match="Internal Server Error"):
|
||||
call_llm(
|
||||
task="compression",
|
||||
@@ -1179,7 +1230,7 @@ class TestCallLlmPaymentFallback:
|
||||
with patch("agent.auxiliary_client._get_cached_client",
|
||||
return_value=(primary_client, "google/gemini-3-flash-preview")), \
|
||||
patch("agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("openrouter", "google/gemini-3-flash-preview", None, None)), \
|
||||
return_value=("auto", "google/gemini-3-flash-preview", None, None, None)), \
|
||||
patch("agent.auxiliary_client._try_payment_fallback",
|
||||
return_value=(None, None, "")):
|
||||
with pytest.raises(Exception, match="insufficient credits"):
|
||||
@@ -1229,3 +1280,283 @@ def test_resolve_api_key_provider_skips_unconfigured_anthropic(monkeypatch):
|
||||
|
||||
assert "anthropic" not in called, \
|
||||
"_try_anthropic() should not be called when anthropic is not explicitly configured"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# model="default" elimination (#7512)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestModelDefaultElimination:
|
||||
"""_resolve_api_key_provider must skip providers without known aux models."""
|
||||
|
||||
def test_unknown_provider_skipped(self, monkeypatch):
|
||||
"""Providers not in _API_KEY_PROVIDER_AUX_MODELS are skipped, not sent model='default'."""
|
||||
from agent.auxiliary_client import _API_KEY_PROVIDER_AUX_MODELS
|
||||
|
||||
# Verify our known providers have entries
|
||||
assert "gemini" in _API_KEY_PROVIDER_AUX_MODELS
|
||||
assert "kimi-coding" in _API_KEY_PROVIDER_AUX_MODELS
|
||||
|
||||
# A random provider_id not in the dict should return None
|
||||
assert _API_KEY_PROVIDER_AUX_MODELS.get("totally-unknown-provider") is None
|
||||
|
||||
def test_known_provider_gets_real_model(self):
|
||||
"""Known providers get a real model name, not 'default'."""
|
||||
from agent.auxiliary_client import _API_KEY_PROVIDER_AUX_MODELS
|
||||
|
||||
for provider_id, model in _API_KEY_PROVIDER_AUX_MODELS.items():
|
||||
assert model != "default", f"{provider_id} should not map to 'default'"
|
||||
assert isinstance(model, str) and model.strip(), \
|
||||
f"{provider_id} should have a non-empty model string"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _try_payment_fallback reason parameter (#7512 bug 3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTryPaymentFallbackReason:
|
||||
"""_try_payment_fallback uses the reason parameter in log messages."""
|
||||
|
||||
def test_reason_parameter_passed_through(self, monkeypatch):
|
||||
"""The reason= parameter is accepted without error."""
|
||||
from agent.auxiliary_client import _try_payment_fallback
|
||||
|
||||
# Mock the provider chain to return nothing
|
||||
monkeypatch.setattr(
|
||||
"agent.auxiliary_client._get_provider_chain",
|
||||
lambda: [],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"agent.auxiliary_client._read_main_provider",
|
||||
lambda: "",
|
||||
)
|
||||
|
||||
client, model, label = _try_payment_fallback(
|
||||
"openrouter", task="compression", reason="connection error"
|
||||
)
|
||||
assert client is None
|
||||
assert label == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_connection_error coverage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsConnectionError:
|
||||
"""Tests for _is_connection_error detection."""
|
||||
|
||||
def test_connection_refused(self):
|
||||
from agent.auxiliary_client import _is_connection_error
|
||||
err = Exception("Connection refused")
|
||||
assert _is_connection_error(err) is True
|
||||
|
||||
def test_timeout(self):
|
||||
from agent.auxiliary_client import _is_connection_error
|
||||
err = Exception("Request timed out.")
|
||||
assert _is_connection_error(err) is True
|
||||
|
||||
def test_dns_failure(self):
|
||||
from agent.auxiliary_client import _is_connection_error
|
||||
err = Exception("Name or service not known")
|
||||
assert _is_connection_error(err) is True
|
||||
|
||||
def test_normal_api_error_not_connection(self):
|
||||
from agent.auxiliary_client import _is_connection_error
|
||||
err = Exception("Bad Request: invalid model")
|
||||
err.status_code = 400
|
||||
assert _is_connection_error(err) is False
|
||||
|
||||
def test_500_not_connection(self):
|
||||
from agent.auxiliary_client import _is_connection_error
|
||||
err = Exception("Internal Server Error")
|
||||
err.status_code = 500
|
||||
assert _is_connection_error(err) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# async_call_llm payment / connection fallback (#7512 bug 2)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAsyncCallLlmFallback:
|
||||
"""async_call_llm mirrors call_llm fallback behavior."""
|
||||
|
||||
def _make_402_error(self, msg="Payment Required: insufficient credits"):
|
||||
exc = Exception(msg)
|
||||
exc.status_code = 402
|
||||
return exc
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_402_triggers_async_fallback_when_auto(self, monkeypatch):
|
||||
"""When provider is auto and returns 402, async_call_llm tries fallback."""
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
|
||||
primary_client = MagicMock()
|
||||
primary_client.chat.completions.create = AsyncMock(
|
||||
side_effect=self._make_402_error())
|
||||
|
||||
# Fallback client (sync) returned by _try_payment_fallback
|
||||
fb_sync_client = MagicMock()
|
||||
fb_async_client = MagicMock()
|
||||
fb_response = MagicMock()
|
||||
fb_async_client.chat.completions.create = AsyncMock(return_value=fb_response)
|
||||
|
||||
with patch("agent.auxiliary_client._get_cached_client",
|
||||
return_value=(primary_client, "google/gemini-3-flash-preview")), \
|
||||
patch("agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("auto", "google/gemini-3-flash-preview", None, None, None)), \
|
||||
patch("agent.auxiliary_client._try_payment_fallback",
|
||||
return_value=(fb_sync_client, "gpt-5.2-codex", "openai-codex")) as mock_fb, \
|
||||
patch("agent.auxiliary_client._to_async_client",
|
||||
return_value=(fb_async_client, "gpt-5.2-codex")):
|
||||
result = await async_call_llm(
|
||||
task="compression",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
)
|
||||
|
||||
assert result is fb_response
|
||||
mock_fb.assert_called_once_with("auto", "compression", reason="payment error")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_402_no_async_fallback_when_explicit(self, monkeypatch):
|
||||
"""When provider is explicit, 402 should NOT trigger async fallback."""
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
|
||||
primary_client = MagicMock()
|
||||
primary_client.chat.completions.create = AsyncMock(
|
||||
side_effect=self._make_402_error())
|
||||
|
||||
with patch("agent.auxiliary_client._get_cached_client",
|
||||
return_value=(primary_client, "local-model")), \
|
||||
patch("agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("custom", "local-model", None, None, None)), \
|
||||
patch("agent.auxiliary_client._try_payment_fallback") as mock_fb:
|
||||
with pytest.raises(Exception, match="insufficient credits"):
|
||||
await async_call_llm(
|
||||
task="compression",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
)
|
||||
|
||||
mock_fb.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_error_triggers_async_fallback(self, monkeypatch):
|
||||
"""Connection errors trigger async fallback when provider is auto."""
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
|
||||
primary_client = MagicMock()
|
||||
conn_err = Exception("Connection refused")
|
||||
conn_err.status_code = None
|
||||
primary_client.chat.completions.create = AsyncMock(side_effect=conn_err)
|
||||
|
||||
fb_sync_client = MagicMock()
|
||||
fb_async_client = MagicMock()
|
||||
fb_response = MagicMock()
|
||||
fb_async_client.chat.completions.create = AsyncMock(return_value=fb_response)
|
||||
|
||||
with patch("agent.auxiliary_client._get_cached_client",
|
||||
return_value=(primary_client, "model")), \
|
||||
patch("agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("auto", "model", None, None, None)), \
|
||||
patch("agent.auxiliary_client._is_connection_error", return_value=True), \
|
||||
patch("agent.auxiliary_client._try_payment_fallback",
|
||||
return_value=(fb_sync_client, "fb-model", "nous")) as mock_fb, \
|
||||
patch("agent.auxiliary_client._to_async_client",
|
||||
return_value=(fb_async_client, "fb-model")):
|
||||
result = await async_call_llm(
|
||||
task="compression",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
)
|
||||
|
||||
assert result is fb_response
|
||||
mock_fb.assert_called_once_with("auto", "compression", reason="connection error")
|
||||
class TestStaleBaseUrlWarning:
|
||||
"""_resolve_auto() warns when OPENAI_BASE_URL conflicts with config provider (#5161)."""
|
||||
|
||||
def test_warns_when_openai_base_url_set_with_named_provider(self, monkeypatch, caplog):
|
||||
"""Warning fires when OPENAI_BASE_URL is set but provider is a named provider."""
|
||||
import agent.auxiliary_client as mod
|
||||
# Reset the module-level flag so the warning fires
|
||||
monkeypatch.setattr(mod, "_stale_base_url_warned", False)
|
||||
monkeypatch.setenv("OPENAI_BASE_URL", "http://localhost:11434/v1")
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "sk-or-test")
|
||||
|
||||
with patch("agent.auxiliary_client._read_main_provider", return_value="openrouter"), \
|
||||
patch("agent.auxiliary_client._read_main_model", return_value="google/gemini-flash"), \
|
||||
caplog.at_level(logging.WARNING, logger="agent.auxiliary_client"):
|
||||
_resolve_auto()
|
||||
|
||||
assert any("OPENAI_BASE_URL is set" in rec.message for rec in caplog.records), \
|
||||
"Expected a warning about stale OPENAI_BASE_URL"
|
||||
assert mod._stale_base_url_warned is True
|
||||
|
||||
def test_no_warning_when_provider_is_custom(self, monkeypatch, caplog):
|
||||
"""No warning when the provider is 'custom' — OPENAI_BASE_URL is expected."""
|
||||
import agent.auxiliary_client as mod
|
||||
monkeypatch.setattr(mod, "_stale_base_url_warned", False)
|
||||
monkeypatch.setenv("OPENAI_BASE_URL", "http://localhost:11434/v1")
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
|
||||
|
||||
with patch("agent.auxiliary_client._read_main_provider", return_value="custom"), \
|
||||
patch("agent.auxiliary_client._read_main_model", return_value="llama3"), \
|
||||
patch("agent.auxiliary_client._resolve_custom_runtime",
|
||||
return_value=("http://localhost:11434/v1", "test-key", None)), \
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai, \
|
||||
caplog.at_level(logging.WARNING, logger="agent.auxiliary_client"):
|
||||
mock_openai.return_value = MagicMock()
|
||||
_resolve_auto()
|
||||
|
||||
assert not any("OPENAI_BASE_URL is set" in rec.message for rec in caplog.records), \
|
||||
"Should NOT warn when provider is 'custom'"
|
||||
|
||||
def test_no_warning_when_provider_is_named_custom(self, monkeypatch, caplog):
|
||||
"""No warning when the provider is 'custom:myname' — base_url comes from config."""
|
||||
import agent.auxiliary_client as mod
|
||||
monkeypatch.setattr(mod, "_stale_base_url_warned", False)
|
||||
monkeypatch.setenv("OPENAI_BASE_URL", "http://localhost:11434/v1")
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
|
||||
|
||||
with patch("agent.auxiliary_client._read_main_provider", return_value="custom:ollama-local"), \
|
||||
patch("agent.auxiliary_client._read_main_model", return_value="llama3"), \
|
||||
patch("agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(MagicMock(), "llama3")), \
|
||||
caplog.at_level(logging.WARNING, logger="agent.auxiliary_client"):
|
||||
_resolve_auto()
|
||||
|
||||
assert not any("OPENAI_BASE_URL is set" in rec.message for rec in caplog.records), \
|
||||
"Should NOT warn when provider is 'custom:*'"
|
||||
|
||||
def test_no_warning_when_openai_base_url_not_set(self, monkeypatch, caplog):
|
||||
"""No warning when OPENAI_BASE_URL is absent."""
|
||||
import agent.auxiliary_client as mod
|
||||
monkeypatch.setattr(mod, "_stale_base_url_warned", False)
|
||||
monkeypatch.delenv("OPENAI_BASE_URL", raising=False)
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "sk-or-test")
|
||||
|
||||
with patch("agent.auxiliary_client._read_main_provider", return_value="openrouter"), \
|
||||
patch("agent.auxiliary_client._read_main_model", return_value="google/gemini-flash"), \
|
||||
caplog.at_level(logging.WARNING, logger="agent.auxiliary_client"):
|
||||
_resolve_auto()
|
||||
|
||||
assert not any("OPENAI_BASE_URL is set" in rec.message for rec in caplog.records), \
|
||||
"Should NOT warn when OPENAI_BASE_URL is not set"
|
||||
|
||||
def test_warning_only_fires_once(self, monkeypatch, caplog):
|
||||
"""Warning is suppressed after the first invocation."""
|
||||
import agent.auxiliary_client as mod
|
||||
monkeypatch.setattr(mod, "_stale_base_url_warned", False)
|
||||
monkeypatch.setenv("OPENAI_BASE_URL", "http://localhost:11434/v1")
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "sk-or-test")
|
||||
|
||||
with patch("agent.auxiliary_client._read_main_provider", return_value="openrouter"), \
|
||||
patch("agent.auxiliary_client._read_main_model", return_value="google/gemini-flash"), \
|
||||
caplog.at_level(logging.WARNING, logger="agent.auxiliary_client"):
|
||||
_resolve_auto()
|
||||
caplog.clear()
|
||||
_resolve_auto()
|
||||
|
||||
assert not any("OPENAI_BASE_URL is set" in rec.message for rec in caplog.records), \
|
||||
"Warning should not fire a second time"
|
||||
|
||||
@@ -576,11 +576,19 @@ class TestSummaryTargetRatio:
|
||||
assert c.summary_target_ratio == 0.80
|
||||
|
||||
def test_default_threshold_is_50_percent(self):
|
||||
"""Default compression threshold should be 50%."""
|
||||
"""Default compression threshold should be 50%, with a 64K floor."""
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100_000):
|
||||
c = ContextCompressor(model="test", quiet_mode=True)
|
||||
assert c.threshold_percent == 0.50
|
||||
assert c.threshold_tokens == 50_000
|
||||
# 50% of 100K = 50K, but the floor is 64K
|
||||
assert c.threshold_tokens == 64_000
|
||||
|
||||
def test_threshold_floor_does_not_apply_above_128k(self):
|
||||
"""On large-context models the 50% percentage is used directly."""
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=200_000):
|
||||
c = ContextCompressor(model="test", quiet_mode=True)
|
||||
# 50% of 200K = 100K, which is above the 64K floor
|
||||
assert c.threshold_tokens == 100_000
|
||||
|
||||
def test_default_protect_last_n_is_20(self):
|
||||
"""Default protect_last_n should be 20."""
|
||||
|
||||
@@ -22,6 +22,9 @@ class TestLocalStreamReadTimeout:
|
||||
"http://0.0.0.0:5000",
|
||||
"http://192.168.1.100:8000",
|
||||
"http://10.0.0.5:1234",
|
||||
"http://host.docker.internal:11434",
|
||||
"http://host.containers.internal:11434",
|
||||
"http://host.lima.internal:11434",
|
||||
])
|
||||
def test_local_endpoint_bumps_read_timeout(self, base_url):
|
||||
"""Local endpoint + default timeout -> bumps to base_timeout."""
|
||||
@@ -68,3 +71,38 @@ class TestLocalStreamReadTimeout:
|
||||
if _stream_read_timeout == 120.0 and base_url and is_local_endpoint(base_url):
|
||||
_stream_read_timeout = _base_timeout
|
||||
assert _stream_read_timeout == 120.0
|
||||
|
||||
|
||||
class TestIsLocalEndpoint:
|
||||
"""Direct unit tests for is_local_endpoint."""
|
||||
|
||||
@pytest.mark.parametrize("url", [
|
||||
"http://localhost:11434",
|
||||
"http://127.0.0.1:8080",
|
||||
"http://0.0.0.0:5000",
|
||||
"http://[::1]:11434",
|
||||
"http://192.168.1.100:8000",
|
||||
"http://10.0.0.5:1234",
|
||||
"http://172.17.0.1:11434",
|
||||
])
|
||||
def test_classic_local_addresses(self, url):
|
||||
assert is_local_endpoint(url) is True
|
||||
|
||||
@pytest.mark.parametrize("url", [
|
||||
"http://host.docker.internal:11434",
|
||||
"http://host.docker.internal:8080/v1",
|
||||
"http://gateway.docker.internal:11434",
|
||||
"http://host.containers.internal:11434",
|
||||
"http://host.lima.internal:11434",
|
||||
])
|
||||
def test_container_dns_names(self, url):
|
||||
assert is_local_endpoint(url) is True
|
||||
|
||||
@pytest.mark.parametrize("url", [
|
||||
"https://api.openai.com",
|
||||
"https://openrouter.ai/api",
|
||||
"https://api.anthropic.com",
|
||||
"https://evil.docker.internal.example.com",
|
||||
])
|
||||
def test_remote_endpoints(self, url):
|
||||
assert is_local_endpoint(url) is False
|
||||
|
||||
@@ -50,7 +50,8 @@ class TestEstimateTokensRough:
|
||||
assert estimate_tokens_rough("a" * 400) == 100
|
||||
|
||||
def test_short_text(self):
|
||||
assert estimate_tokens_rough("hello") == 1
|
||||
# "hello" = 5 chars → ceil(5/4) = 2
|
||||
assert estimate_tokens_rough("hello") == 2
|
||||
|
||||
def test_proportional(self):
|
||||
short = estimate_tokens_rough("hello world")
|
||||
@@ -68,10 +69,11 @@ class TestEstimateMessagesTokensRough:
|
||||
assert estimate_messages_tokens_rough([]) == 0
|
||||
|
||||
def test_single_message_concrete_value(self):
|
||||
"""Verify against known str(msg) length."""
|
||||
"""Verify against known str(msg) length (ceiling division)."""
|
||||
msg = {"role": "user", "content": "a" * 400}
|
||||
result = estimate_messages_tokens_rough([msg])
|
||||
expected = len(str(msg)) // 4
|
||||
n = len(str(msg))
|
||||
expected = (n + 3) // 4
|
||||
assert result == expected
|
||||
|
||||
def test_multiple_messages_additive(self):
|
||||
@@ -80,7 +82,8 @@ class TestEstimateMessagesTokensRough:
|
||||
{"role": "assistant", "content": "Hi there, how can I help?"},
|
||||
]
|
||||
result = estimate_messages_tokens_rough(msgs)
|
||||
expected = sum(len(str(m)) for m in msgs) // 4
|
||||
n = sum(len(str(m)) for m in msgs)
|
||||
expected = (n + 3) // 4
|
||||
assert result == expected
|
||||
|
||||
def test_tool_call_message(self):
|
||||
@@ -89,7 +92,7 @@ class TestEstimateMessagesTokensRough:
|
||||
"tool_calls": [{"id": "1", "function": {"name": "terminal", "arguments": "{}"}}]}
|
||||
result = estimate_messages_tokens_rough([msg])
|
||||
assert result > 0
|
||||
assert result == len(str(msg)) // 4
|
||||
assert result == (len(str(msg)) + 3) // 4
|
||||
|
||||
def test_message_with_list_content(self):
|
||||
"""Vision messages with multimodal content arrays."""
|
||||
@@ -98,7 +101,7 @@ class TestEstimateMessagesTokensRough:
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,AAAA"}}
|
||||
]}
|
||||
result = estimate_messages_tokens_rough([msg])
|
||||
assert result == len(str(msg)) // 4
|
||||
assert result == (len(str(msg)) + 3) // 4
|
||||
|
||||
|
||||
# =========================================================================
|
||||
@@ -222,6 +225,24 @@ class TestGetModelContextLength:
|
||||
mock_fetch.return_value = {}
|
||||
assert get_model_context_length("openai/gpt-4o") == 128000
|
||||
|
||||
@patch("agent.model_metadata.fetch_model_metadata")
|
||||
def test_qwen3_coder_plus_context_length(self, mock_fetch):
|
||||
"""qwen3-coder-plus has a 1M context window, not the generic 128K Qwen default."""
|
||||
mock_fetch.return_value = {}
|
||||
assert get_model_context_length("qwen3-coder-plus") == 1000000
|
||||
|
||||
@patch("agent.model_metadata.fetch_model_metadata")
|
||||
def test_qwen3_coder_context_length(self, mock_fetch):
|
||||
"""qwen3-coder has a 256K context window, not the generic 128K Qwen default."""
|
||||
mock_fetch.return_value = {}
|
||||
assert get_model_context_length("qwen3-coder") == 262144
|
||||
|
||||
@patch("agent.model_metadata.fetch_model_metadata")
|
||||
def test_qwen_generic_context_length(self, mock_fetch):
|
||||
"""Generic qwen models still get the 128K default."""
|
||||
mock_fetch.return_value = {}
|
||||
assert get_model_context_length("qwen3-plus") == 131072
|
||||
|
||||
@patch("agent.model_metadata.fetch_model_metadata")
|
||||
def test_api_missing_context_length_key(self, mock_fetch):
|
||||
"""Model in API but without context_length → defaults to 128000."""
|
||||
|
||||
@@ -7,6 +7,7 @@ from agent.models_dev import (
|
||||
PROVIDER_TO_MODELS_DEV,
|
||||
_extract_context,
|
||||
fetch_models_dev,
|
||||
get_model_capabilities,
|
||||
lookup_models_dev_context,
|
||||
)
|
||||
|
||||
@@ -195,3 +196,88 @@ class TestFetchModelsDev:
|
||||
result = fetch_models_dev()
|
||||
mock_get.assert_not_called()
|
||||
assert result == SAMPLE_REGISTRY
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_model_capabilities — vision via modalities.input
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
CAPS_REGISTRY = {
|
||||
"google": {
|
||||
"id": "google",
|
||||
"models": {
|
||||
"gemma-4-31b-it": {
|
||||
"id": "gemma-4-31b-it",
|
||||
"attachment": False,
|
||||
"tool_call": True,
|
||||
"modalities": {"input": ["text", "image"]},
|
||||
"limit": {"context": 128000, "output": 8192},
|
||||
},
|
||||
"gemma-3-1b": {
|
||||
"id": "gemma-3-1b",
|
||||
"tool_call": True,
|
||||
"limit": {"context": 32000, "output": 8192},
|
||||
},
|
||||
},
|
||||
},
|
||||
"anthropic": {
|
||||
"id": "anthropic",
|
||||
"models": {
|
||||
"claude-sonnet-4": {
|
||||
"id": "claude-sonnet-4",
|
||||
"attachment": True,
|
||||
"tool_call": True,
|
||||
"limit": {"context": 200000, "output": 64000},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestGetModelCapabilities:
|
||||
"""Tests for get_model_capabilities vision detection."""
|
||||
|
||||
def test_vision_from_attachment_flag(self):
|
||||
"""Models with attachment=True should report supports_vision=True."""
|
||||
with patch("agent.models_dev.fetch_models_dev", return_value=CAPS_REGISTRY):
|
||||
caps = get_model_capabilities("anthropic", "claude-sonnet-4")
|
||||
assert caps is not None
|
||||
assert caps.supports_vision is True
|
||||
|
||||
def test_vision_from_modalities_input_image(self):
|
||||
"""Models with 'image' in modalities.input but attachment=False should
|
||||
still report supports_vision=True (the core fix in this PR)."""
|
||||
with patch("agent.models_dev.fetch_models_dev", return_value=CAPS_REGISTRY):
|
||||
caps = get_model_capabilities("google", "gemma-4-31b-it")
|
||||
assert caps is not None
|
||||
assert caps.supports_vision is True
|
||||
|
||||
def test_no_vision_without_attachment_or_modalities(self):
|
||||
"""Models with neither attachment nor image modality should be non-vision."""
|
||||
with patch("agent.models_dev.fetch_models_dev", return_value=CAPS_REGISTRY):
|
||||
caps = get_model_capabilities("google", "gemma-3-1b")
|
||||
assert caps is not None
|
||||
assert caps.supports_vision is False
|
||||
|
||||
def test_modalities_non_dict_handled(self):
|
||||
"""Non-dict modalities field should not crash."""
|
||||
registry = {
|
||||
"google": {"id": "google", "models": {
|
||||
"weird-model": {
|
||||
"id": "weird-model",
|
||||
"modalities": "text", # not a dict
|
||||
"limit": {"context": 200000, "output": 8192},
|
||||
},
|
||||
}},
|
||||
}
|
||||
with patch("agent.models_dev.fetch_models_dev", return_value=registry):
|
||||
caps = get_model_capabilities("gemini", "weird-model")
|
||||
assert caps is not None
|
||||
assert caps.supports_vision is False
|
||||
|
||||
def test_model_not_found_returns_none(self):
|
||||
"""Unknown model should return None."""
|
||||
with patch("agent.models_dev.fetch_models_dev", return_value=CAPS_REGISTRY):
|
||||
caps = get_model_capabilities("anthropic", "nonexistent-model")
|
||||
assert caps is None
|
||||
|
||||
@@ -1009,65 +1009,4 @@ class TestOpenAIModelExecutionGuidance:
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestStripBudgetWarningsFromHistory:
|
||||
def test_strips_json_budget_warning_key(self):
|
||||
import json
|
||||
from run_agent import _strip_budget_warnings_from_history
|
||||
|
||||
messages = [
|
||||
{"role": "tool", "tool_call_id": "c1", "content": json.dumps({
|
||||
"output": "hello",
|
||||
"exit_code": 0,
|
||||
"_budget_warning": "[BUDGET: Iteration 55/60. 5 iterations left. Start consolidating your work.]",
|
||||
})},
|
||||
]
|
||||
_strip_budget_warnings_from_history(messages)
|
||||
parsed = json.loads(messages[0]["content"])
|
||||
assert "_budget_warning" not in parsed
|
||||
assert parsed["output"] == "hello"
|
||||
assert parsed["exit_code"] == 0
|
||||
|
||||
def test_strips_text_budget_warning(self):
|
||||
from run_agent import _strip_budget_warnings_from_history
|
||||
|
||||
messages = [
|
||||
{"role": "tool", "tool_call_id": "c1",
|
||||
"content": "some result\n\n[BUDGET WARNING: Iteration 58/60. Only 2 iteration(s) left. Provide your final response NOW. No more tool calls unless absolutely critical.]"},
|
||||
]
|
||||
_strip_budget_warnings_from_history(messages)
|
||||
assert messages[0]["content"] == "some result"
|
||||
|
||||
def test_leaves_non_tool_messages_unchanged(self):
|
||||
from run_agent import _strip_budget_warnings_from_history
|
||||
|
||||
messages = [
|
||||
{"role": "assistant", "content": "[BUDGET WARNING: Iteration 58/60. Only 2 iteration(s) left. Provide your final response NOW. No more tool calls unless absolutely critical.]"},
|
||||
{"role": "user", "content": "hello"},
|
||||
]
|
||||
original_contents = [m["content"] for m in messages]
|
||||
_strip_budget_warnings_from_history(messages)
|
||||
assert [m["content"] for m in messages] == original_contents
|
||||
|
||||
def test_handles_empty_and_missing_content(self):
|
||||
from run_agent import _strip_budget_warnings_from_history
|
||||
|
||||
messages = [
|
||||
{"role": "tool", "tool_call_id": "c1", "content": ""},
|
||||
{"role": "tool", "tool_call_id": "c2"},
|
||||
]
|
||||
_strip_budget_warnings_from_history(messages)
|
||||
assert messages[0]["content"] == ""
|
||||
|
||||
def test_strips_caution_variant(self):
|
||||
import json
|
||||
from run_agent import _strip_budget_warnings_from_history
|
||||
|
||||
messages = [
|
||||
{"role": "tool", "tool_call_id": "c1", "content": json.dumps({
|
||||
"output": "ok",
|
||||
"_budget_warning": "[BUDGET: Iteration 42/60. 18 iterations left. Start consolidating your work.]",
|
||||
})},
|
||||
]
|
||||
_strip_budget_warnings_from_history(messages)
|
||||
parsed = json.loads(messages[0]["content"])
|
||||
assert "_budget_warning" not in parsed
|
||||
|
||||
@@ -211,7 +211,8 @@ def make_adapter(platform: Platform, runner=None):
|
||||
config = PlatformConfig(enabled=True, token="e2e-test-token")
|
||||
|
||||
if platform == Platform.DISCORD:
|
||||
with patch.object(DiscordAdapter, "_load_participated_threads", return_value=set()):
|
||||
from gateway.platforms.helpers import ThreadParticipationTracker
|
||||
with patch.object(ThreadParticipationTracker, "_load", return_value=set()):
|
||||
adapter = DiscordAdapter(config)
|
||||
platform_key = Platform.DISCORD
|
||||
elif platform == Platform.SLACK:
|
||||
|
||||
@@ -409,11 +409,50 @@ class TestChatCompletionsEndpoint:
|
||||
)
|
||||
assert resp.status == 200
|
||||
assert "text/event-stream" in resp.headers.get("Content-Type", "")
|
||||
assert resp.headers.get("X-Accel-Buffering") == "no"
|
||||
body = await resp.text()
|
||||
assert "data: " in body
|
||||
assert "[DONE]" in body
|
||||
assert "Hello!" in body
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_sends_keepalive_during_quiet_tool_gap(self, adapter):
|
||||
"""Idle SSE streams should send keepalive comments while tools run silently."""
|
||||
import asyncio
|
||||
import gateway.platforms.api_server as api_server_mod
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
async def _mock_run_agent(**kwargs):
|
||||
cb = kwargs.get("stream_delta_callback")
|
||||
if cb:
|
||||
cb("Working")
|
||||
await asyncio.sleep(0.65)
|
||||
cb("...done")
|
||||
return (
|
||||
{"final_response": "Working...done", "messages": [], "api_calls": 1},
|
||||
{"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(api_server_mod, "CHAT_COMPLETIONS_SSE_KEEPALIVE_SECONDS", 0.01),
|
||||
patch.object(adapter, "_run_agent", side_effect=_mock_run_agent),
|
||||
):
|
||||
resp = await cli.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "test",
|
||||
"messages": [{"role": "user", "content": "do the thing"}],
|
||||
"stream": True,
|
||||
},
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.text()
|
||||
assert ": keepalive" in body
|
||||
assert "Working" in body
|
||||
assert "...done" in body
|
||||
assert "[DONE]" in body
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_survives_tool_call_none_sentinel(self, adapter):
|
||||
"""stream_delta_callback(None) mid-stream (tool calls) must NOT kill the SSE stream.
|
||||
|
||||
@@ -345,6 +345,11 @@ class TestBlockingApprovalE2E:
|
||||
|
||||
def setup_method(self):
|
||||
_clear_approval_state()
|
||||
os.environ.pop("HERMES_YOLO_MODE", None)
|
||||
os.environ.pop("HERMES_INTERACTIVE", None)
|
||||
os.environ.pop("HERMES_GATEWAY_SESSION", None)
|
||||
os.environ.pop("HERMES_EXEC_ASK", None)
|
||||
os.environ.pop("HERMES_SESSION_KEY", None)
|
||||
|
||||
def test_blocking_approval_approve_once(self):
|
||||
"""check_all_command_guards blocks until resolve_gateway_approval is called."""
|
||||
@@ -364,6 +369,7 @@ class TestBlockingApprovalE2E:
|
||||
from tools.approval import reset_current_session_key, set_current_session_key
|
||||
|
||||
token = set_current_session_key(session_key)
|
||||
os.environ["HERMES_GATEWAY_SESSION"] = "1"
|
||||
os.environ["HERMES_EXEC_ASK"] = "1"
|
||||
os.environ["HERMES_SESSION_KEY"] = session_key
|
||||
try:
|
||||
@@ -371,6 +377,7 @@ class TestBlockingApprovalE2E:
|
||||
"rm -rf /important", "local"
|
||||
)
|
||||
finally:
|
||||
os.environ.pop("HERMES_GATEWAY_SESSION", None)
|
||||
os.environ.pop("HERMES_EXEC_ASK", None)
|
||||
os.environ.pop("HERMES_SESSION_KEY", None)
|
||||
reset_current_session_key(token)
|
||||
@@ -410,6 +417,7 @@ class TestBlockingApprovalE2E:
|
||||
from tools.approval import reset_current_session_key, set_current_session_key
|
||||
|
||||
token = set_current_session_key(session_key)
|
||||
os.environ["HERMES_GATEWAY_SESSION"] = "1"
|
||||
os.environ["HERMES_EXEC_ASK"] = "1"
|
||||
os.environ["HERMES_SESSION_KEY"] = session_key
|
||||
try:
|
||||
@@ -417,6 +425,7 @@ class TestBlockingApprovalE2E:
|
||||
"rm -rf /important", "local"
|
||||
)
|
||||
finally:
|
||||
os.environ.pop("HERMES_GATEWAY_SESSION", None)
|
||||
os.environ.pop("HERMES_EXEC_ASK", None)
|
||||
os.environ.pop("HERMES_SESSION_KEY", None)
|
||||
reset_current_session_key(token)
|
||||
@@ -451,6 +460,7 @@ class TestBlockingApprovalE2E:
|
||||
from tools.approval import reset_current_session_key, set_current_session_key
|
||||
|
||||
token = set_current_session_key(session_key)
|
||||
os.environ["HERMES_GATEWAY_SESSION"] = "1"
|
||||
os.environ["HERMES_EXEC_ASK"] = "1"
|
||||
os.environ["HERMES_SESSION_KEY"] = session_key
|
||||
try:
|
||||
@@ -460,6 +470,7 @@ class TestBlockingApprovalE2E:
|
||||
"rm -rf /important", "local"
|
||||
)
|
||||
finally:
|
||||
os.environ.pop("HERMES_GATEWAY_SESSION", None)
|
||||
os.environ.pop("HERMES_EXEC_ASK", None)
|
||||
os.environ.pop("HERMES_SESSION_KEY", None)
|
||||
reset_current_session_key(token)
|
||||
@@ -491,11 +502,13 @@ class TestBlockingApprovalE2E:
|
||||
from tools.approval import reset_current_session_key, set_current_session_key
|
||||
|
||||
token = set_current_session_key(session_key)
|
||||
os.environ["HERMES_GATEWAY_SESSION"] = "1"
|
||||
os.environ["HERMES_EXEC_ASK"] = "1"
|
||||
os.environ["HERMES_SESSION_KEY"] = session_key
|
||||
try:
|
||||
results[idx] = check_all_command_guards(cmd, "local")
|
||||
finally:
|
||||
os.environ.pop("HERMES_GATEWAY_SESSION", None)
|
||||
os.environ.pop("HERMES_EXEC_ASK", None)
|
||||
os.environ.pop("HERMES_SESSION_KEY", None)
|
||||
reset_current_session_key(token)
|
||||
@@ -546,11 +559,13 @@ class TestBlockingApprovalE2E:
|
||||
from tools.approval import reset_current_session_key, set_current_session_key
|
||||
|
||||
token = set_current_session_key(session_key)
|
||||
os.environ["HERMES_GATEWAY_SESSION"] = "1"
|
||||
os.environ["HERMES_EXEC_ASK"] = "1"
|
||||
os.environ["HERMES_SESSION_KEY"] = session_key
|
||||
try:
|
||||
results[idx] = check_all_command_guards(cmd, "local")
|
||||
finally:
|
||||
os.environ.pop("HERMES_GATEWAY_SESSION", None)
|
||||
os.environ.pop("HERMES_EXEC_ASK", None)
|
||||
os.environ.pop("HERMES_SESSION_KEY", None)
|
||||
reset_current_session_key(token)
|
||||
|
||||
@@ -119,28 +119,29 @@ class TestDeduplication:
|
||||
def test_first_message_not_duplicate(self):
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
|
||||
assert adapter._is_duplicate("msg-1") is False
|
||||
assert adapter._dedup.is_duplicate("msg-1") is False
|
||||
|
||||
def test_second_same_message_is_duplicate(self):
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
|
||||
adapter._is_duplicate("msg-1")
|
||||
assert adapter._is_duplicate("msg-1") is True
|
||||
adapter._dedup.is_duplicate("msg-1")
|
||||
assert adapter._dedup.is_duplicate("msg-1") is True
|
||||
|
||||
def test_different_messages_not_duplicate(self):
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
|
||||
adapter._is_duplicate("msg-1")
|
||||
assert adapter._is_duplicate("msg-2") is False
|
||||
adapter._dedup.is_duplicate("msg-1")
|
||||
assert adapter._dedup.is_duplicate("msg-2") is False
|
||||
|
||||
def test_cache_cleanup_on_overflow(self):
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter, DEDUP_MAX_SIZE
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
|
||||
max_size = adapter._dedup._max_size
|
||||
# Fill beyond max
|
||||
for i in range(DEDUP_MAX_SIZE + 10):
|
||||
adapter._is_duplicate(f"msg-{i}")
|
||||
for i in range(max_size + 10):
|
||||
adapter._dedup.is_duplicate(f"msg-{i}")
|
||||
# Cache should have been pruned
|
||||
assert len(adapter._seen_messages) <= DEDUP_MAX_SIZE + 10
|
||||
assert len(adapter._dedup._seen) <= max_size + 10
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -253,13 +254,13 @@ class TestConnect:
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
|
||||
adapter._session_webhooks["a"] = "http://x"
|
||||
adapter._seen_messages["b"] = 1.0
|
||||
adapter._dedup._seen["b"] = 1.0
|
||||
adapter._http_client = AsyncMock()
|
||||
adapter._stream_task = None
|
||||
|
||||
await adapter.disconnect()
|
||||
assert len(adapter._session_webhooks) == 0
|
||||
assert len(adapter._seen_messages) == 0
|
||||
assert len(adapter._dedup._seen) == 0
|
||||
assert adapter._http_client is None
|
||||
|
||||
|
||||
|
||||
@@ -137,4 +137,4 @@ async def test_connect_releases_token_lock_on_timeout(monkeypatch):
|
||||
|
||||
assert ok is False
|
||||
assert released == [("discord-bot-token", "test-token")]
|
||||
assert adapter._token_lock_identity is None
|
||||
assert adapter._platform_lock_identity is None
|
||||
|
||||
@@ -302,7 +302,7 @@ async def test_discord_bot_thread_skips_mention_requirement(adapter, monkeypatch
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
|
||||
|
||||
# Simulate bot having previously participated in thread 456
|
||||
adapter._bot_participated_threads.add("456")
|
||||
adapter._threads.mark("456")
|
||||
|
||||
thread = FakeThread(channel_id=456, name="existing thread")
|
||||
message = make_message(channel=thread, content="follow-up without mention")
|
||||
@@ -344,7 +344,7 @@ async def test_discord_auto_thread_tracks_participation(adapter, monkeypatch):
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
assert "555" in adapter._bot_participated_threads
|
||||
assert "555" in adapter._threads
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -358,4 +358,4 @@ async def test_discord_thread_participation_tracked_on_dispatch(adapter, monkeyp
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
assert "777" in adapter._bot_participated_threads
|
||||
assert "777" in adapter._threads
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Tests for Discord thread participation persistence.
|
||||
|
||||
Verifies that _bot_participated_threads survives adapter restarts by
|
||||
Verifies that _threads (ThreadParticipationTracker) survives adapter restarts by
|
||||
being persisted to ~/.hermes/discord_threads.json.
|
||||
"""
|
||||
|
||||
@@ -25,13 +25,13 @@ class TestDiscordThreadPersistence:
|
||||
|
||||
def test_starts_empty_when_no_state_file(self, tmp_path):
|
||||
adapter = self._make_adapter(tmp_path)
|
||||
assert adapter._bot_participated_threads == set()
|
||||
assert "$nonexistent" not in adapter._threads
|
||||
|
||||
def test_track_thread_persists_to_disk(self, tmp_path):
|
||||
adapter = self._make_adapter(tmp_path)
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
adapter._track_thread("111")
|
||||
adapter._track_thread("222")
|
||||
adapter._threads.mark("111")
|
||||
adapter._threads.mark("222")
|
||||
|
||||
state_file = tmp_path / "discord_threads.json"
|
||||
assert state_file.exists()
|
||||
@@ -42,42 +42,43 @@ class TestDiscordThreadPersistence:
|
||||
"""Threads tracked by one adapter instance are visible to the next."""
|
||||
adapter1 = self._make_adapter(tmp_path)
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
adapter1._track_thread("aaa")
|
||||
adapter1._track_thread("bbb")
|
||||
adapter1._threads.mark("aaa")
|
||||
adapter1._threads.mark("bbb")
|
||||
|
||||
adapter2 = self._make_adapter(tmp_path)
|
||||
assert "aaa" in adapter2._bot_participated_threads
|
||||
assert "bbb" in adapter2._bot_participated_threads
|
||||
assert "aaa" in adapter2._threads
|
||||
assert "bbb" in adapter2._threads
|
||||
|
||||
def test_duplicate_track_does_not_double_save(self, tmp_path):
|
||||
adapter = self._make_adapter(tmp_path)
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
adapter._track_thread("111")
|
||||
adapter._track_thread("111") # no-op
|
||||
adapter._threads.mark("111")
|
||||
adapter._threads.mark("111") # no-op
|
||||
|
||||
saved = json.loads((tmp_path / "discord_threads.json").read_text())
|
||||
assert saved.count("111") == 1
|
||||
|
||||
def test_caps_at_max_tracked_threads(self, tmp_path):
|
||||
adapter = self._make_adapter(tmp_path)
|
||||
adapter._MAX_TRACKED_THREADS = 5
|
||||
adapter._threads._max_tracked = 5
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
for i in range(10):
|
||||
adapter._track_thread(str(i))
|
||||
adapter._threads.mark(str(i))
|
||||
|
||||
assert len(adapter._bot_participated_threads) == 5
|
||||
saved = json.loads((tmp_path / "discord_threads.json").read_text())
|
||||
assert len(saved) == 5
|
||||
|
||||
def test_corrupted_state_file_falls_back_to_empty(self, tmp_path):
|
||||
state_file = tmp_path / "discord_threads.json"
|
||||
state_file.write_text("not valid json{{{")
|
||||
adapter = self._make_adapter(tmp_path)
|
||||
assert adapter._bot_participated_threads == set()
|
||||
assert "$nonexistent" not in adapter._threads
|
||||
|
||||
def test_missing_hermes_home_does_not_crash(self, tmp_path):
|
||||
"""Load/save tolerate missing directories."""
|
||||
fake_home = tmp_path / "nonexistent" / "deep"
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(fake_home)}):
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
# _load should return empty set, not crash
|
||||
threads = DiscordAdapter._load_participated_threads()
|
||||
assert threads == set()
|
||||
from gateway.platforms.helpers import ThreadParticipationTracker
|
||||
# ThreadParticipationTracker should return empty set, not crash
|
||||
tracker = ThreadParticipationTracker("discord")
|
||||
assert "$test" not in tracker
|
||||
|
||||
@@ -0,0 +1,355 @@
|
||||
"""Tests for gateway.display_config — per-platform display/verbosity resolver."""
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Resolver: resolution order
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestResolveDisplaySetting:
|
||||
"""resolve_display_setting() resolves with correct priority."""
|
||||
|
||||
def test_explicit_platform_override_wins(self):
|
||||
"""display.platforms.<plat>.<key> takes top priority."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"tool_progress": "all",
|
||||
"platforms": {
|
||||
"telegram": {"tool_progress": "verbose"},
|
||||
},
|
||||
}
|
||||
}
|
||||
assert resolve_display_setting(config, "telegram", "tool_progress") == "verbose"
|
||||
|
||||
def test_global_setting_when_no_platform_override(self):
|
||||
"""Falls back to display.<key> when no platform override exists."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"tool_progress": "new",
|
||||
"platforms": {},
|
||||
}
|
||||
}
|
||||
assert resolve_display_setting(config, "telegram", "tool_progress") == "new"
|
||||
|
||||
def test_platform_default_when_no_user_config(self):
|
||||
"""Falls back to built-in platform default."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
# Empty config — should get built-in defaults
|
||||
config = {}
|
||||
# Telegram defaults to tier_high → "all"
|
||||
assert resolve_display_setting(config, "telegram", "tool_progress") == "all"
|
||||
# Email defaults to tier_minimal → "off"
|
||||
assert resolve_display_setting(config, "email", "tool_progress") == "off"
|
||||
|
||||
def test_global_default_for_unknown_platform(self):
|
||||
"""Unknown platforms get the global defaults."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {}
|
||||
# Unknown platform, no config → global default "all"
|
||||
assert resolve_display_setting(config, "unknown_platform", "tool_progress") == "all"
|
||||
|
||||
def test_fallback_parameter_used_last(self):
|
||||
"""Explicit fallback is used when nothing else matches."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {}
|
||||
# "nonexistent_key" isn't in any defaults
|
||||
result = resolve_display_setting(config, "telegram", "nonexistent_key", "my_fallback")
|
||||
assert result == "my_fallback"
|
||||
|
||||
def test_platform_override_only_affects_that_platform(self):
|
||||
"""Other platforms are unaffected by a specific platform override."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"tool_progress": "all",
|
||||
"platforms": {
|
||||
"slack": {"tool_progress": "off"},
|
||||
},
|
||||
}
|
||||
}
|
||||
assert resolve_display_setting(config, "slack", "tool_progress") == "off"
|
||||
assert resolve_display_setting(config, "telegram", "tool_progress") == "all"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backward compatibility: tool_progress_overrides
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBackwardCompat:
|
||||
"""Legacy tool_progress_overrides is still respected as a fallback."""
|
||||
|
||||
def test_legacy_overrides_read(self):
|
||||
"""tool_progress_overrides is read when no platforms entry exists."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"tool_progress": "all",
|
||||
"tool_progress_overrides": {
|
||||
"signal": "off",
|
||||
"telegram": "verbose",
|
||||
},
|
||||
}
|
||||
}
|
||||
assert resolve_display_setting(config, "signal", "tool_progress") == "off"
|
||||
assert resolve_display_setting(config, "telegram", "tool_progress") == "verbose"
|
||||
|
||||
def test_new_platforms_takes_precedence_over_legacy(self):
|
||||
"""display.platforms beats tool_progress_overrides."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"tool_progress": "all",
|
||||
"tool_progress_overrides": {"telegram": "verbose"},
|
||||
"platforms": {"telegram": {"tool_progress": "new"}},
|
||||
}
|
||||
}
|
||||
assert resolve_display_setting(config, "telegram", "tool_progress") == "new"
|
||||
|
||||
def test_legacy_overrides_only_for_tool_progress(self):
|
||||
"""Legacy overrides don't affect other settings."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"tool_progress_overrides": {"telegram": "verbose"},
|
||||
}
|
||||
}
|
||||
# show_reasoning should NOT read from tool_progress_overrides
|
||||
assert resolve_display_setting(config, "telegram", "show_reasoning") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# YAML normalisation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestYAMLNormalisation:
|
||||
"""YAML 1.1 quirks (bare off → False, on → True) are handled."""
|
||||
|
||||
def test_tool_progress_false_normalised_to_off(self):
|
||||
"""YAML's bare `off` parses as False — normalised to 'off' string."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {"display": {"tool_progress": False}}
|
||||
assert resolve_display_setting(config, "telegram", "tool_progress") == "off"
|
||||
|
||||
def test_tool_progress_true_normalised_to_all(self):
|
||||
"""YAML's bare `on` parses as True — normalised to 'all'."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {"display": {"tool_progress": True}}
|
||||
assert resolve_display_setting(config, "telegram", "tool_progress") == "all"
|
||||
|
||||
def test_show_reasoning_string_true(self):
|
||||
"""String 'true' is normalised to bool True."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {"display": {"platforms": {"telegram": {"show_reasoning": "true"}}}}
|
||||
assert resolve_display_setting(config, "telegram", "show_reasoning") is True
|
||||
|
||||
def test_tool_preview_length_string(self):
|
||||
"""String numbers are normalised to int."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {"display": {"platforms": {"slack": {"tool_preview_length": "80"}}}}
|
||||
assert resolve_display_setting(config, "slack", "tool_preview_length") == 80
|
||||
|
||||
def test_platform_override_false_tool_progress(self):
|
||||
"""Per-platform bare off → normalised."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {"display": {"platforms": {"slack": {"tool_progress": False}}}}
|
||||
assert resolve_display_setting(config, "slack", "tool_progress") == "off"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Built-in platform defaults (tier system)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPlatformDefaults:
|
||||
"""Built-in defaults reflect platform capability tiers."""
|
||||
|
||||
def test_high_tier_platforms(self):
|
||||
"""Telegram and Discord default to 'all' tool progress."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
for plat in ("telegram", "discord"):
|
||||
assert resolve_display_setting({}, plat, "tool_progress") == "all", plat
|
||||
|
||||
def test_medium_tier_platforms(self):
|
||||
"""Slack, Mattermost, Matrix default to 'new' tool progress."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
for plat in ("slack", "mattermost", "matrix", "feishu"):
|
||||
assert resolve_display_setting({}, plat, "tool_progress") == "new", plat
|
||||
|
||||
def test_low_tier_platforms(self):
|
||||
"""Signal, WhatsApp, etc. default to 'off' tool progress."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
for plat in ("signal", "whatsapp", "bluebubbles", "weixin", "wecom", "dingtalk"):
|
||||
assert resolve_display_setting({}, plat, "tool_progress") == "off", plat
|
||||
|
||||
def test_minimal_tier_platforms(self):
|
||||
"""Email, SMS, webhook default to 'off' tool progress."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
for plat in ("email", "sms", "webhook", "homeassistant"):
|
||||
assert resolve_display_setting({}, plat, "tool_progress") == "off", plat
|
||||
|
||||
def test_low_tier_streaming_defaults_to_false(self):
|
||||
"""Low-tier platforms default streaming to False."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
assert resolve_display_setting({}, "signal", "streaming") is False
|
||||
assert resolve_display_setting({}, "email", "streaming") is False
|
||||
|
||||
def test_high_tier_streaming_defaults_to_none(self):
|
||||
"""High-tier platforms default streaming to None (follow global)."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
assert resolve_display_setting({}, "telegram", "streaming") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_effective_display / get_platform_defaults
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestHelpers:
|
||||
"""Helper functions return correct composite results."""
|
||||
|
||||
def test_get_effective_display_merges_correctly(self):
|
||||
from gateway.display_config import get_effective_display
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"tool_progress": "new",
|
||||
"show_reasoning": True,
|
||||
"platforms": {
|
||||
"telegram": {"tool_progress": "verbose"},
|
||||
},
|
||||
}
|
||||
}
|
||||
eff = get_effective_display(config, "telegram")
|
||||
assert eff["tool_progress"] == "verbose" # platform override
|
||||
assert eff["show_reasoning"] is True # global
|
||||
assert "tool_preview_length" in eff # default filled in
|
||||
|
||||
def test_get_platform_defaults_returns_dict(self):
|
||||
from gateway.display_config import get_platform_defaults
|
||||
|
||||
defaults = get_platform_defaults("telegram")
|
||||
assert "tool_progress" in defaults
|
||||
assert "show_reasoning" in defaults
|
||||
# Returns a new dict (not the shared tier dict)
|
||||
defaults["tool_progress"] = "changed"
|
||||
assert get_platform_defaults("telegram")["tool_progress"] != "changed"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config migration: tool_progress_overrides → display.platforms
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestConfigMigration:
|
||||
"""Version 16 migration moves tool_progress_overrides into display.platforms."""
|
||||
|
||||
def test_migration_creates_platforms_entries(self, tmp_path, monkeypatch):
|
||||
"""Old overrides are migrated into display.platforms.<plat>.tool_progress."""
|
||||
import yaml
|
||||
|
||||
config_path = tmp_path / "config.yaml"
|
||||
config = {
|
||||
"_config_version": 15,
|
||||
"display": {
|
||||
"tool_progress_overrides": {
|
||||
"signal": "off",
|
||||
"telegram": "all",
|
||||
},
|
||||
},
|
||||
}
|
||||
config_path.write_text(yaml.dump(config))
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
# Re-import to pick up the new HERMES_HOME
|
||||
import importlib
|
||||
import hermes_cli.config as cfg_mod
|
||||
importlib.reload(cfg_mod)
|
||||
|
||||
result = cfg_mod.migrate_config(interactive=False, quiet=True)
|
||||
# Re-read config
|
||||
updated = yaml.safe_load(config_path.read_text())
|
||||
platforms = updated.get("display", {}).get("platforms", {})
|
||||
assert platforms.get("signal", {}).get("tool_progress") == "off"
|
||||
assert platforms.get("telegram", {}).get("tool_progress") == "all"
|
||||
|
||||
def test_migration_preserves_existing_platforms_entries(self, tmp_path, monkeypatch):
|
||||
"""Existing display.platforms entries are NOT overwritten by migration."""
|
||||
import yaml
|
||||
|
||||
config_path = tmp_path / "config.yaml"
|
||||
config = {
|
||||
"_config_version": 15,
|
||||
"display": {
|
||||
"tool_progress_overrides": {"telegram": "off"},
|
||||
"platforms": {"telegram": {"tool_progress": "verbose"}},
|
||||
},
|
||||
}
|
||||
config_path.write_text(yaml.dump(config))
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
import importlib
|
||||
import hermes_cli.config as cfg_mod
|
||||
importlib.reload(cfg_mod)
|
||||
|
||||
cfg_mod.migrate_config(interactive=False, quiet=True)
|
||||
updated = yaml.safe_load(config_path.read_text())
|
||||
# Existing "verbose" should NOT be overwritten by legacy "off"
|
||||
assert updated["display"]["platforms"]["telegram"]["tool_progress"] == "verbose"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Streaming per-platform (None = follow global)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestStreamingPerPlatform:
|
||||
"""Streaming per-platform override semantics."""
|
||||
|
||||
def test_none_means_follow_global(self):
|
||||
"""When streaming is None, the caller should use global config."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {}
|
||||
# Telegram has no streaming override in defaults → None
|
||||
result = resolve_display_setting(config, "telegram", "streaming")
|
||||
assert result is None # caller should check global StreamingConfig
|
||||
|
||||
def test_explicit_false_disables(self):
|
||||
"""Explicit False disables streaming for that platform."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"platforms": {"telegram": {"streaming": False}},
|
||||
}
|
||||
}
|
||||
assert resolve_display_setting(config, "telegram", "streaming") is False
|
||||
|
||||
def test_explicit_true_enables(self):
|
||||
"""Explicit True enables streaming for that platform."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"platforms": {"email": {"streaming": True}},
|
||||
}
|
||||
}
|
||||
assert resolve_display_setting(config, "email", "streaming") is True
|
||||
@@ -195,6 +195,105 @@ async def test_internal_event_does_not_trigger_pairing(monkeypatch, tmp_path):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notify_on_complete_preserves_user_identity(monkeypatch, tmp_path):
|
||||
"""Synthetic completion event should carry user_id and user_name from the watcher."""
|
||||
import tools.process_registry as pr_module
|
||||
|
||||
sessions = [
|
||||
SimpleNamespace(
|
||||
output_buffer="done\n", exited=True, exit_code=0, command="echo test"
|
||||
),
|
||||
]
|
||||
monkeypatch.setattr(pr_module, "process_registry", _FakeRegistry(sessions))
|
||||
|
||||
async def _instant_sleep(*_a, **_kw):
|
||||
pass
|
||||
monkeypatch.setattr(asyncio, "sleep", _instant_sleep)
|
||||
|
||||
runner = _build_runner(monkeypatch, tmp_path)
|
||||
adapter = runner.adapters[Platform.DISCORD]
|
||||
|
||||
watcher = _watcher_dict_with_notify()
|
||||
watcher["user_id"] = "user-42"
|
||||
watcher["user_name"] = "alice"
|
||||
|
||||
await runner._run_process_watcher(watcher)
|
||||
|
||||
assert adapter.handle_message.await_count == 1
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.source.user_id == "user-42"
|
||||
assert event.source.user_name == "alice"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_user_id_skips_pairing(monkeypatch, tmp_path):
|
||||
"""A non-internal event with user_id=None should be silently dropped."""
|
||||
import gateway.run as gateway_run
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
(tmp_path / "config.yaml").write_text("", encoding="utf-8")
|
||||
|
||||
runner = GatewayRunner(GatewayConfig())
|
||||
adapter = SimpleNamespace(send=AsyncMock())
|
||||
runner.adapters[Platform.TELEGRAM] = adapter
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="123",
|
||||
chat_type="dm",
|
||||
user_id=None,
|
||||
)
|
||||
event = MessageEvent(
|
||||
text="service message",
|
||||
source=source,
|
||||
internal=False,
|
||||
)
|
||||
|
||||
result = await runner._handle_message(event)
|
||||
|
||||
# Should return None (dropped) and NOT send any pairing message
|
||||
assert result is None
|
||||
assert adapter.send.await_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_user_id_does_not_generate_pairing_code(monkeypatch, tmp_path):
|
||||
"""A message with user_id=None must never call generate_code."""
|
||||
import gateway.run as gateway_run
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
(tmp_path / "config.yaml").write_text("", encoding="utf-8")
|
||||
|
||||
runner = GatewayRunner(GatewayConfig())
|
||||
adapter = SimpleNamespace(send=AsyncMock())
|
||||
runner.adapters[Platform.DISCORD] = adapter
|
||||
|
||||
generate_called = False
|
||||
original_generate = runner.pairing_store.generate_code
|
||||
|
||||
def tracking_generate(*args, **kwargs):
|
||||
nonlocal generate_called
|
||||
generate_called = True
|
||||
return original_generate(*args, **kwargs)
|
||||
|
||||
runner.pairing_store.generate_code = tracking_generate
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="456",
|
||||
chat_type="dm",
|
||||
user_id=None,
|
||||
)
|
||||
event = MessageEvent(text="anonymous", source=source, internal=False)
|
||||
|
||||
await runner._handle_message(event)
|
||||
|
||||
assert not generate_called, (
|
||||
"Pairing code should NOT be generated for messages with user_id=None"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_internal_event_without_user_triggers_pairing(monkeypatch, tmp_path):
|
||||
"""Verify the normal (non-internal) path still triggers pairing for unknown users."""
|
||||
|
||||
@@ -157,7 +157,9 @@ def _make_fake_mautrix():
|
||||
mautrix_crypto_store = types.ModuleType("mautrix.crypto.store")
|
||||
|
||||
class MemoryCryptoStore:
|
||||
pass
|
||||
def __init__(self, account_id="", pickle_key=""):
|
||||
self.account_id = account_id
|
||||
self.pickle_key = pickle_key
|
||||
|
||||
mautrix_crypto_store.MemoryCryptoStore = MemoryCryptoStore
|
||||
|
||||
@@ -1041,20 +1043,28 @@ class TestMatrixSyncLoop:
|
||||
call_count += 1
|
||||
if call_count >= 1:
|
||||
adapter._closing = True
|
||||
return {"rooms": {"join": {"!room:example.org": {}}}}
|
||||
return {"rooms": {"join": {"!room:example.org": {}}}, "next_batch": "s1234"}
|
||||
|
||||
mock_crypto = MagicMock()
|
||||
mock_crypto.share_keys = AsyncMock()
|
||||
|
||||
mock_sync_store = MagicMock()
|
||||
mock_sync_store.get_next_batch = AsyncMock(return_value=None)
|
||||
mock_sync_store.put_next_batch = AsyncMock()
|
||||
|
||||
fake_client = MagicMock()
|
||||
fake_client.sync = AsyncMock(side_effect=_sync_once)
|
||||
fake_client.crypto = mock_crypto
|
||||
fake_client.sync_store = mock_sync_store
|
||||
fake_client.handle_sync = MagicMock(return_value=[])
|
||||
adapter._client = fake_client
|
||||
|
||||
await adapter._sync_loop()
|
||||
|
||||
fake_client.sync.assert_awaited_once()
|
||||
mock_crypto.share_keys.assert_awaited_once()
|
||||
fake_client.handle_sync.assert_called_once()
|
||||
mock_sync_store.put_next_batch.assert_awaited_once_with("s1234")
|
||||
|
||||
|
||||
class TestMatrixEncryptedSendFallback:
|
||||
|
||||
@@ -247,7 +247,7 @@ async def test_require_mention_bot_participated_thread(monkeypatch):
|
||||
monkeypatch.setenv("MATRIX_AUTO_THREAD", "false")
|
||||
|
||||
adapter = _make_adapter()
|
||||
adapter._bot_participated_threads.add("$thread1")
|
||||
adapter._threads.mark("$thread1")
|
||||
|
||||
event = _make_event("hello without mention", thread_id="$thread1")
|
||||
|
||||
@@ -298,7 +298,7 @@ async def test_auto_thread_preserves_existing_thread(monkeypatch):
|
||||
monkeypatch.delenv("MATRIX_AUTO_THREAD", raising=False)
|
||||
|
||||
adapter = _make_adapter()
|
||||
adapter._bot_participated_threads.add("$thread_root")
|
||||
adapter._threads.mark("$thread_root")
|
||||
event = _make_event("reply in thread", thread_id="$thread_root")
|
||||
|
||||
await adapter._on_room_message(event)
|
||||
@@ -340,17 +340,17 @@ async def test_auto_thread_disabled(monkeypatch):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_thread_tracks_participation(monkeypatch):
|
||||
"""Auto-created threads are tracked in _bot_participated_threads."""
|
||||
"""Auto-created threads are tracked in _threads."""
|
||||
monkeypatch.setenv("MATRIX_REQUIRE_MENTION", "false")
|
||||
monkeypatch.delenv("MATRIX_AUTO_THREAD", raising=False)
|
||||
|
||||
adapter = _make_adapter()
|
||||
event = _make_event("hello", event_id="$msg1")
|
||||
|
||||
with patch.object(adapter, "_save_participated_threads"):
|
||||
with patch.object(adapter._threads, "_save"):
|
||||
await adapter._on_room_message(event)
|
||||
|
||||
assert "$msg1" in adapter._bot_participated_threads
|
||||
assert "$msg1" in adapter._threads
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -361,56 +361,54 @@ async def test_auto_thread_tracks_participation(monkeypatch):
|
||||
class TestThreadPersistence:
|
||||
def test_empty_state_file(self, tmp_path, monkeypatch):
|
||||
"""No state file → empty set."""
|
||||
from gateway.platforms.matrix import MatrixAdapter
|
||||
from gateway.platforms.helpers import ThreadParticipationTracker
|
||||
monkeypatch.setattr(
|
||||
MatrixAdapter, "_thread_state_path",
|
||||
staticmethod(lambda: tmp_path / "matrix_threads.json"),
|
||||
ThreadParticipationTracker, "_state_path",
|
||||
lambda self: tmp_path / "matrix_threads.json",
|
||||
)
|
||||
adapter = _make_adapter()
|
||||
loaded = adapter._load_participated_threads()
|
||||
assert loaded == set()
|
||||
assert "$nonexistent" not in adapter._threads
|
||||
|
||||
def test_track_thread_persists(self, tmp_path, monkeypatch):
|
||||
"""_track_thread writes to disk."""
|
||||
from gateway.platforms.matrix import MatrixAdapter
|
||||
"""mark() writes to disk."""
|
||||
from gateway.platforms.helpers import ThreadParticipationTracker
|
||||
state_path = tmp_path / "matrix_threads.json"
|
||||
monkeypatch.setattr(
|
||||
MatrixAdapter, "_thread_state_path",
|
||||
staticmethod(lambda: state_path),
|
||||
ThreadParticipationTracker, "_state_path",
|
||||
lambda self: state_path,
|
||||
)
|
||||
adapter = _make_adapter()
|
||||
adapter._track_thread("$thread_abc")
|
||||
adapter._threads.mark("$thread_abc")
|
||||
|
||||
data = json.loads(state_path.read_text())
|
||||
assert "$thread_abc" in data
|
||||
|
||||
def test_threads_survive_reload(self, tmp_path, monkeypatch):
|
||||
"""Persisted threads are loaded by a new adapter instance."""
|
||||
from gateway.platforms.matrix import MatrixAdapter
|
||||
from gateway.platforms.helpers import ThreadParticipationTracker
|
||||
state_path = tmp_path / "matrix_threads.json"
|
||||
state_path.write_text(json.dumps(["$t1", "$t2"]))
|
||||
monkeypatch.setattr(
|
||||
MatrixAdapter, "_thread_state_path",
|
||||
staticmethod(lambda: state_path),
|
||||
ThreadParticipationTracker, "_state_path",
|
||||
lambda self: state_path,
|
||||
)
|
||||
adapter = _make_adapter()
|
||||
assert "$t1" in adapter._bot_participated_threads
|
||||
assert "$t2" in adapter._bot_participated_threads
|
||||
assert "$t1" in adapter._threads
|
||||
assert "$t2" in adapter._threads
|
||||
|
||||
def test_cap_max_tracked_threads(self, tmp_path, monkeypatch):
|
||||
"""Thread set is trimmed to _MAX_TRACKED_THREADS."""
|
||||
from gateway.platforms.matrix import MatrixAdapter
|
||||
"""Thread set is trimmed to max_tracked."""
|
||||
from gateway.platforms.helpers import ThreadParticipationTracker
|
||||
state_path = tmp_path / "matrix_threads.json"
|
||||
monkeypatch.setattr(
|
||||
MatrixAdapter, "_thread_state_path",
|
||||
staticmethod(lambda: state_path),
|
||||
ThreadParticipationTracker, "_state_path",
|
||||
lambda self: state_path,
|
||||
)
|
||||
adapter = _make_adapter()
|
||||
adapter._MAX_TRACKED_THREADS = 5
|
||||
adapter._threads._max_tracked = 5
|
||||
|
||||
for i in range(10):
|
||||
adapter._bot_participated_threads.add(f"$t{i}")
|
||||
adapter._save_participated_threads()
|
||||
adapter._threads.mark(f"$t{i}")
|
||||
|
||||
data = json.loads(state_path.read_text())
|
||||
assert len(data) == 5
|
||||
@@ -447,7 +445,7 @@ async def test_dm_mention_thread_creates_thread(monkeypatch):
|
||||
_set_dm(adapter)
|
||||
event = _make_event("@hermes:example.org help me", event_id="$dm1")
|
||||
|
||||
with patch.object(adapter, "_save_participated_threads"):
|
||||
with patch.object(adapter._threads, "_save"):
|
||||
await adapter._on_room_message(event)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
@@ -480,7 +478,7 @@ async def test_dm_mention_thread_preserves_existing_thread(monkeypatch):
|
||||
|
||||
adapter = _make_adapter()
|
||||
_set_dm(adapter)
|
||||
adapter._bot_participated_threads.add("$existing_thread")
|
||||
adapter._threads.mark("$existing_thread")
|
||||
event = _make_event("@hermes:example.org help me", thread_id="$existing_thread")
|
||||
|
||||
await adapter._on_room_message(event)
|
||||
@@ -491,7 +489,7 @@ async def test_dm_mention_thread_preserves_existing_thread(monkeypatch):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dm_mention_thread_tracks_participation(monkeypatch):
|
||||
"""DM mention-thread tracks the thread in _bot_participated_threads."""
|
||||
"""DM mention-thread tracks the thread in _threads."""
|
||||
monkeypatch.setenv("MATRIX_DM_MENTION_THREADS", "true")
|
||||
monkeypatch.setenv("MATRIX_AUTO_THREAD", "false")
|
||||
|
||||
@@ -499,10 +497,10 @@ async def test_dm_mention_thread_tracks_participation(monkeypatch):
|
||||
_set_dm(adapter)
|
||||
event = _make_event("@hermes:example.org help", event_id="$dm1")
|
||||
|
||||
with patch.object(adapter, "_save_participated_threads"):
|
||||
with patch.object(adapter._threads, "_save"):
|
||||
await adapter._on_room_message(event)
|
||||
|
||||
assert "$dm1" in adapter._bot_participated_threads
|
||||
assert "$dm1" in adapter._threads
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -614,25 +614,27 @@ class TestMattermostDedup:
|
||||
assert self.adapter.handle_message.call_count == 2
|
||||
|
||||
def test_prune_seen_clears_expired(self):
|
||||
"""_prune_seen should remove entries older than _SEEN_TTL."""
|
||||
"""Dedup cache should remove entries older than TTL on overflow."""
|
||||
now = time.time()
|
||||
dedup = self.adapter._dedup
|
||||
# Fill with enough expired entries to trigger pruning
|
||||
for i in range(self.adapter._SEEN_MAX + 10):
|
||||
self.adapter._seen_posts[f"old_{i}"] = now - 600 # 10 min ago
|
||||
for i in range(dedup._max_size + 10):
|
||||
dedup._seen[f"old_{i}"] = now - 600 # 10 min ago (older than default TTL)
|
||||
|
||||
# Add a fresh one
|
||||
self.adapter._seen_posts["fresh"] = now
|
||||
dedup._seen["fresh"] = now
|
||||
|
||||
self.adapter._prune_seen()
|
||||
# Trigger pruning by calling is_duplicate with a new entry (over max_size)
|
||||
dedup.is_duplicate("trigger_prune")
|
||||
|
||||
# Old entries should be pruned, fresh one kept
|
||||
assert "fresh" in self.adapter._seen_posts
|
||||
assert len(self.adapter._seen_posts) < self.adapter._SEEN_MAX
|
||||
assert "fresh" in dedup._seen
|
||||
assert len(dedup._seen) < dedup._max_size + 10
|
||||
|
||||
def test_seen_cache_tracks_post_ids(self):
|
||||
"""Posts are tracked in _seen_posts dict."""
|
||||
self.adapter._seen_posts["test_post"] = time.time()
|
||||
assert "test_post" in self.adapter._seen_posts
|
||||
"""Posts are tracked in the dedup cache."""
|
||||
self.adapter._dedup._seen["test_post"] = time.time()
|
||||
assert "test_post" in self.adapter._dedup._seen
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.run import _dequeue_pending_event
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
@@ -79,6 +80,26 @@ class TestQueueMessageStorage:
|
||||
# Should be consumed (cleared)
|
||||
assert adapter.get_pending_message(session_key) is None
|
||||
|
||||
def test_dequeue_pending_event_preserves_voice_media_metadata(self):
|
||||
adapter = _StubAdapter()
|
||||
session_key = "telegram:user:voice"
|
||||
event = MessageEvent(
|
||||
text="",
|
||||
message_type=MessageType.VOICE,
|
||||
source=MagicMock(chat_id="123", platform=Platform.TELEGRAM),
|
||||
message_id="voice-q1",
|
||||
media_urls=["/tmp/voice.ogg"],
|
||||
media_types=["audio/ogg"],
|
||||
)
|
||||
adapter._pending_messages[session_key] = event
|
||||
|
||||
retrieved = _dequeue_pending_event(adapter, session_key)
|
||||
|
||||
assert retrieved is event
|
||||
assert retrieved.media_urls == ["/tmp/voice.ogg"]
|
||||
assert retrieved.media_types == ["audio/ogg"]
|
||||
assert adapter.get_pending_message(session_key) is None
|
||||
|
||||
def test_queue_does_not_set_interrupt_event(self):
|
||||
"""The whole point of /queue — no interrupt signal."""
|
||||
adapter = _StubAdapter()
|
||||
|
||||
@@ -221,5 +221,6 @@ class TestHandleResumeCommand:
|
||||
|
||||
runner._async_flush_memories.assert_called_once_with(
|
||||
"current_session_001",
|
||||
"agent:main:telegram:dm:67890",
|
||||
)
|
||||
db.close()
|
||||
|
||||
@@ -8,8 +8,8 @@ from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, SendResult
|
||||
from gateway.config import Platform, PlatformConfig, StreamingConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType, SendResult
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
@@ -104,6 +104,11 @@ def _make_runner(adapter):
|
||||
runner._session_db = None
|
||||
runner._running_agents = {}
|
||||
runner.hooks = SimpleNamespace(loaded_hooks=False)
|
||||
runner.config = SimpleNamespace(
|
||||
thread_sessions_per_user=False,
|
||||
group_sessions_per_user=False,
|
||||
stt_enabled=False,
|
||||
)
|
||||
return runner
|
||||
|
||||
|
||||
@@ -118,6 +123,7 @@ async def test_run_agent_progress_stays_in_originating_topic(monkeypatch, tmp_pa
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = FakeAgent
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
import tools.terminal_tool # noqa: F401 - register terminal emoji for this fake-agent test
|
||||
|
||||
adapter = ProgressCaptureAdapter()
|
||||
runner = _make_runner(adapter)
|
||||
@@ -144,7 +150,7 @@ async def test_run_agent_progress_stays_in_originating_topic(monkeypatch, tmp_pa
|
||||
assert adapter.sent == [
|
||||
{
|
||||
"chat_id": "-1001",
|
||||
"content": '⚙️ terminal: "pwd"',
|
||||
"content": '💻 terminal: "pwd"',
|
||||
"reply_to": None,
|
||||
"metadata": {"thread_id": "17585"},
|
||||
}
|
||||
@@ -334,3 +340,238 @@ def test_all_mode_no_truncation_when_preview_fits(monkeypatch, tmp_path):
|
||||
content = adapter.sent[0]["content"]
|
||||
# With a 200-char cap, the 165-char command should NOT be truncated
|
||||
assert "..." not in content, f"Preview was truncated when it shouldn't be: {content}"
|
||||
|
||||
|
||||
class CommentaryAgent:
|
||||
def __init__(self, **kwargs):
|
||||
self.tool_progress_callback = kwargs.get("tool_progress_callback")
|
||||
self.interim_assistant_callback = kwargs.get("interim_assistant_callback")
|
||||
self.stream_delta_callback = kwargs.get("stream_delta_callback")
|
||||
self.tools = []
|
||||
|
||||
def run_conversation(self, message, conversation_history=None, task_id=None):
|
||||
if self.interim_assistant_callback:
|
||||
self.interim_assistant_callback("I'll inspect the repo first.", already_streamed=False)
|
||||
time.sleep(0.1)
|
||||
if self.stream_delta_callback:
|
||||
self.stream_delta_callback("done")
|
||||
return {
|
||||
"final_response": "done",
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
}
|
||||
|
||||
|
||||
class PreviewedResponseAgent:
|
||||
def __init__(self, **kwargs):
|
||||
self.interim_assistant_callback = kwargs.get("interim_assistant_callback")
|
||||
self.tools = []
|
||||
|
||||
def run_conversation(self, message, conversation_history=None, task_id=None):
|
||||
if self.interim_assistant_callback:
|
||||
self.interim_assistant_callback("You're welcome.", already_streamed=False)
|
||||
return {
|
||||
"final_response": "You're welcome.",
|
||||
"response_previewed": True,
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
}
|
||||
|
||||
|
||||
class QueuedCommentaryAgent:
|
||||
calls = 0
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.interim_assistant_callback = kwargs.get("interim_assistant_callback")
|
||||
self.tools = []
|
||||
|
||||
def run_conversation(self, message, conversation_history=None, task_id=None):
|
||||
type(self).calls += 1
|
||||
if type(self).calls == 1 and self.interim_assistant_callback:
|
||||
self.interim_assistant_callback("I'll inspect the repo first.", already_streamed=False)
|
||||
return {
|
||||
"final_response": f"final response {type(self).calls}",
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
}
|
||||
|
||||
|
||||
async def _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
agent_cls,
|
||||
*,
|
||||
session_id,
|
||||
pending_text=None,
|
||||
config_data=None,
|
||||
):
|
||||
if config_data:
|
||||
import yaml
|
||||
|
||||
(tmp_path / "config.yaml").write_text(yaml.dump(config_data), encoding="utf-8")
|
||||
|
||||
fake_dotenv = types.ModuleType("dotenv")
|
||||
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
|
||||
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
|
||||
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = agent_cls
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
|
||||
adapter = ProgressCaptureAdapter()
|
||||
runner = _make_runner(adapter)
|
||||
gateway_run = importlib.import_module("gateway.run")
|
||||
if config_data and "streaming" in config_data:
|
||||
runner.config.streaming = StreamingConfig.from_dict(config_data["streaming"])
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="-1001",
|
||||
chat_type="group",
|
||||
thread_id="17585",
|
||||
)
|
||||
session_key = "agent:main:telegram:group:-1001:17585"
|
||||
if pending_text is not None:
|
||||
adapter._pending_messages[session_key] = MessageEvent(
|
||||
text=pending_text,
|
||||
message_type=MessageType.TEXT,
|
||||
source=source,
|
||||
message_id="queued-1",
|
||||
)
|
||||
|
||||
result = await runner._run_agent(
|
||||
message="hello",
|
||||
context_prompt="",
|
||||
history=[],
|
||||
source=source,
|
||||
session_id=session_id,
|
||||
session_key=session_key,
|
||||
)
|
||||
return adapter, result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_surfaces_real_interim_commentary(monkeypatch, tmp_path):
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
CommentaryAgent,
|
||||
session_id="sess-commentary",
|
||||
config_data={"display": {"interim_assistant_messages": True}},
|
||||
)
|
||||
|
||||
assert result.get("already_sent") is not True
|
||||
assert any(call["content"] == "I'll inspect the repo first." for call in adapter.sent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_surfaces_interim_commentary_by_default(monkeypatch, tmp_path):
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
CommentaryAgent,
|
||||
session_id="sess-commentary-default-on",
|
||||
)
|
||||
|
||||
assert any(call["content"] == "I'll inspect the repo first." for call in adapter.sent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_suppresses_interim_commentary_when_disabled(monkeypatch, tmp_path):
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
CommentaryAgent,
|
||||
session_id="sess-commentary-disabled",
|
||||
config_data={"display": {"interim_assistant_messages": False}},
|
||||
)
|
||||
|
||||
assert result.get("already_sent") is not True
|
||||
assert not any(call["content"] == "I'll inspect the repo first." for call in adapter.sent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_tool_progress_does_not_control_interim_commentary(monkeypatch, tmp_path):
|
||||
"""tool_progress=all with interim_assistant_messages=false should not surface commentary."""
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
CommentaryAgent,
|
||||
session_id="sess-commentary-tool-progress",
|
||||
config_data={"display": {"tool_progress": "all", "interim_assistant_messages": False}},
|
||||
)
|
||||
|
||||
assert result.get("already_sent") is not True
|
||||
assert not any(call["content"] == "I'll inspect the repo first." for call in adapter.sent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_streaming_does_not_enable_completed_interim_commentary(
|
||||
monkeypatch, tmp_path
|
||||
):
|
||||
"""Streaming alone with interim_assistant_messages=false should not surface commentary."""
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
CommentaryAgent,
|
||||
session_id="sess-commentary-streaming",
|
||||
config_data={
|
||||
"display": {"tool_progress": "off", "interim_assistant_messages": False},
|
||||
"streaming": {"enabled": True},
|
||||
},
|
||||
)
|
||||
|
||||
assert result.get("already_sent") is True
|
||||
assert not any(call["content"] == "I'll inspect the repo first." for call in adapter.sent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_interim_commentary_works_with_tool_progress_off(monkeypatch, tmp_path):
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
CommentaryAgent,
|
||||
session_id="sess-commentary-explicit-on",
|
||||
config_data={
|
||||
"display": {
|
||||
"tool_progress": "off",
|
||||
"interim_assistant_messages": True,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert result.get("already_sent") is not True
|
||||
assert any(call["content"] == "I'll inspect the repo first." for call in adapter.sent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_previewed_final_marks_already_sent(monkeypatch, tmp_path):
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
PreviewedResponseAgent,
|
||||
session_id="sess-previewed",
|
||||
config_data={"display": {"interim_assistant_messages": True}},
|
||||
)
|
||||
|
||||
assert result.get("already_sent") is True
|
||||
assert [call["content"] for call in adapter.sent] == ["You're welcome."]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_queued_message_does_not_treat_commentary_as_final(monkeypatch, tmp_path):
|
||||
QueuedCommentaryAgent.calls = 0
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
QueuedCommentaryAgent,
|
||||
session_id="sess-queued-commentary",
|
||||
pending_text="queued follow-up",
|
||||
config_data={"display": {"interim_assistant_messages": True}},
|
||||
)
|
||||
|
||||
sent_texts = [call["content"] for call in adapter.sent]
|
||||
assert result["final_response"] == "final response 2"
|
||||
assert "I'll inspect the repo first." in sent_texts
|
||||
assert "final response 1" in sent_texts
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from gateway.config import Platform
|
||||
@@ -18,6 +19,8 @@ def test_set_session_env_sets_contextvars(monkeypatch):
|
||||
chat_id="-1001",
|
||||
chat_name="Group",
|
||||
chat_type="group",
|
||||
user_id="123456",
|
||||
user_name="alice",
|
||||
thread_id="17585",
|
||||
)
|
||||
context = SessionContext(source=source, connected_platforms=[], home_channels={})
|
||||
@@ -25,6 +28,8 @@ def test_set_session_env_sets_contextvars(monkeypatch):
|
||||
monkeypatch.delenv("HERMES_SESSION_PLATFORM", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_CHAT_ID", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_CHAT_NAME", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_USER_ID", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_USER_NAME", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_THREAD_ID", raising=False)
|
||||
|
||||
tokens = runner._set_session_env(context)
|
||||
@@ -33,6 +38,8 @@ def test_set_session_env_sets_contextvars(monkeypatch):
|
||||
assert get_session_env("HERMES_SESSION_PLATFORM") == "telegram"
|
||||
assert get_session_env("HERMES_SESSION_CHAT_ID") == "-1001"
|
||||
assert get_session_env("HERMES_SESSION_CHAT_NAME") == "Group"
|
||||
assert get_session_env("HERMES_SESSION_USER_ID") == "123456"
|
||||
assert get_session_env("HERMES_SESSION_USER_NAME") == "alice"
|
||||
assert get_session_env("HERMES_SESSION_THREAD_ID") == "17585"
|
||||
|
||||
# os.environ should NOT be touched
|
||||
@@ -50,6 +57,8 @@ def test_clear_session_env_restores_previous_state(monkeypatch):
|
||||
monkeypatch.delenv("HERMES_SESSION_PLATFORM", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_CHAT_ID", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_CHAT_NAME", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_USER_ID", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_USER_NAME", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_THREAD_ID", raising=False)
|
||||
|
||||
source = SessionSource(
|
||||
@@ -57,12 +66,15 @@ def test_clear_session_env_restores_previous_state(monkeypatch):
|
||||
chat_id="-1001",
|
||||
chat_name="Group",
|
||||
chat_type="group",
|
||||
user_id="123456",
|
||||
user_name="alice",
|
||||
thread_id="17585",
|
||||
)
|
||||
context = SessionContext(source=source, connected_platforms=[], home_channels={})
|
||||
|
||||
tokens = runner._set_session_env(context)
|
||||
assert get_session_env("HERMES_SESSION_PLATFORM") == "telegram"
|
||||
assert get_session_env("HERMES_SESSION_USER_ID") == "123456"
|
||||
|
||||
runner._clear_session_env(tokens)
|
||||
|
||||
@@ -70,6 +82,8 @@ def test_clear_session_env_restores_previous_state(monkeypatch):
|
||||
assert get_session_env("HERMES_SESSION_PLATFORM") == ""
|
||||
assert get_session_env("HERMES_SESSION_CHAT_ID") == ""
|
||||
assert get_session_env("HERMES_SESSION_CHAT_NAME") == ""
|
||||
assert get_session_env("HERMES_SESSION_USER_ID") == ""
|
||||
assert get_session_env("HERMES_SESSION_USER_NAME") == ""
|
||||
assert get_session_env("HERMES_SESSION_THREAD_ID") == ""
|
||||
|
||||
|
||||
@@ -117,3 +131,99 @@ def test_set_session_env_handles_missing_optional_fields():
|
||||
assert get_session_env("HERMES_SESSION_THREAD_ID") == ""
|
||||
|
||||
runner._clear_session_env(tokens)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SESSION_KEY contextvars tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_session_key_set_via_contextvars(monkeypatch):
|
||||
"""set_session_vars should set HERMES_SESSION_KEY via contextvars."""
|
||||
monkeypatch.delenv("HERMES_SESSION_KEY", raising=False)
|
||||
|
||||
tokens = set_session_vars(
|
||||
platform="telegram",
|
||||
chat_id="-1001",
|
||||
session_key="tg:-1001:17585",
|
||||
)
|
||||
assert get_session_env("HERMES_SESSION_KEY") == "tg:-1001:17585"
|
||||
|
||||
clear_session_vars(tokens)
|
||||
assert get_session_env("HERMES_SESSION_KEY") == ""
|
||||
|
||||
|
||||
def test_session_key_falls_back_to_os_environ(monkeypatch):
|
||||
"""get_session_env for SESSION_KEY should fall back to os.environ."""
|
||||
monkeypatch.setenv("HERMES_SESSION_KEY", "env-session-123")
|
||||
|
||||
# No contextvar set — should read from os.environ
|
||||
assert get_session_env("HERMES_SESSION_KEY") == "env-session-123"
|
||||
|
||||
# Set contextvar — should prefer it
|
||||
tokens = set_session_vars(session_key="ctx-session-456")
|
||||
assert get_session_env("HERMES_SESSION_KEY") == "ctx-session-456"
|
||||
|
||||
# Restore — should fall back to os.environ
|
||||
clear_session_vars(tokens)
|
||||
assert get_session_env("HERMES_SESSION_KEY") == "env-session-123"
|
||||
|
||||
|
||||
def test_set_session_env_includes_session_key():
|
||||
"""_set_session_env should propagate session_key from SessionContext."""
|
||||
runner = object.__new__(GatewayRunner)
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="-1001",
|
||||
chat_name="Group",
|
||||
chat_type="group",
|
||||
thread_id="17585",
|
||||
)
|
||||
context = SessionContext(
|
||||
source=source,
|
||||
connected_platforms=[],
|
||||
home_channels={},
|
||||
session_key="tg:-1001:17585",
|
||||
)
|
||||
|
||||
tokens = runner._set_session_env(context)
|
||||
assert get_session_env("HERMES_SESSION_KEY") == "tg:-1001:17585"
|
||||
runner._clear_session_env(tokens)
|
||||
assert get_session_env("HERMES_SESSION_KEY") == ""
|
||||
|
||||
|
||||
def test_session_key_no_race_condition_with_contextvars(monkeypatch):
|
||||
"""Prove contextvars isolates SESSION_KEY across concurrent async tasks.
|
||||
|
||||
Two tasks set different session keys. With contextvars each task
|
||||
reads back its own value. With os.environ the second task would
|
||||
overwrite the first (the old bug).
|
||||
"""
|
||||
monkeypatch.delenv("HERMES_SESSION_KEY", raising=False)
|
||||
|
||||
results = {}
|
||||
|
||||
async def handler(key: str, delay: float):
|
||||
tokens = set_session_vars(session_key=key)
|
||||
try:
|
||||
await asyncio.sleep(delay)
|
||||
read_back = get_session_env("HERMES_SESSION_KEY")
|
||||
results[key] = read_back
|
||||
finally:
|
||||
clear_session_vars(tokens)
|
||||
|
||||
async def run():
|
||||
task_a = asyncio.create_task(handler("session-A", 0.15))
|
||||
await asyncio.sleep(0.05)
|
||||
task_b = asyncio.create_task(handler("session-B", 0.05))
|
||||
await asyncio.gather(task_a, task_b)
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
# Both tasks must read back their own session key
|
||||
assert results["session-A"] == "session-A", (
|
||||
f"Session A got '{results['session-A']}' instead of 'session-A' — race condition!"
|
||||
)
|
||||
assert results["session-B"] == "session-B", (
|
||||
f"Session B got '{results['session-B']}' instead of 'session-B' — race condition!"
|
||||
)
|
||||
|
||||
@@ -0,0 +1,160 @@
|
||||
"""Regression tests for session-scoped model/provider overrides in gateway agents.
|
||||
|
||||
These cover the bug where `/model ...` stored a session override, but fresh
|
||||
agent constructions still resolved model/provider from global config/runtime.
|
||||
That let helper agents (and cache-miss main agents) route GPT-5.4 to the wrong
|
||||
provider, e.g. Nous instead of OpenAI Codex.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import threading
|
||||
import types
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import gateway.run as gateway_run
|
||||
from gateway.config import Platform
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
class _CapturingAgent:
|
||||
"""Fake agent that records init kwargs for assertions."""
|
||||
|
||||
last_init = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
type(self).last_init = dict(kwargs)
|
||||
self.tools = []
|
||||
|
||||
def run_conversation(self, user_message: str, conversation_history=None, task_id=None):
|
||||
return {
|
||||
"final_response": "ok",
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
}
|
||||
|
||||
|
||||
def _make_runner():
|
||||
runner = object.__new__(gateway_run.GatewayRunner)
|
||||
runner.adapters = {}
|
||||
runner.session_store = None
|
||||
runner.config = None
|
||||
runner._voice_mode = {}
|
||||
runner._ephemeral_system_prompt = ""
|
||||
runner._prefill_messages = []
|
||||
runner._reasoning_config = None
|
||||
runner._show_reasoning = False
|
||||
runner._provider_routing = {}
|
||||
runner._fallback_model = None
|
||||
runner._service_tier = None
|
||||
runner._running_agents = {}
|
||||
runner._running_agents_ts = {}
|
||||
runner._background_tasks = set()
|
||||
runner._session_db = None
|
||||
runner._session_model_overrides = {}
|
||||
runner._pending_model_notes = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._agent_cache = {}
|
||||
runner._agent_cache_lock = threading.Lock()
|
||||
runner._get_or_create_gateway_honcho = lambda session_key: (None, None)
|
||||
runner.hooks = MagicMock()
|
||||
runner.hooks.emit = AsyncMock()
|
||||
runner.hooks.loaded_hooks = []
|
||||
return runner
|
||||
|
||||
|
||||
def _codex_override():
|
||||
return {
|
||||
"model": "gpt-5.4",
|
||||
"provider": "openai-codex",
|
||||
"api_key": "***",
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
"api_mode": "codex_responses",
|
||||
}
|
||||
|
||||
|
||||
def _explode_runtime_resolution():
|
||||
raise AssertionError(
|
||||
"global runtime resolution should not run when a complete session override exists"
|
||||
)
|
||||
|
||||
|
||||
def test_run_agent_prefers_session_override_over_global_runtime(monkeypatch):
|
||||
monkeypatch.setattr(gateway_run, "_load_gateway_config", lambda: {})
|
||||
monkeypatch.setattr(gateway_run, "load_dotenv", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", _explode_runtime_resolution)
|
||||
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = _CapturingAgent
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
|
||||
_CapturingAgent.last_init = None
|
||||
runner = _make_runner()
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.LOCAL,
|
||||
chat_id="cli",
|
||||
chat_name="CLI",
|
||||
chat_type="dm",
|
||||
user_id="user-1",
|
||||
)
|
||||
session_key = "agent:main:local:dm"
|
||||
runner._session_model_overrides[session_key] = _codex_override()
|
||||
|
||||
result = asyncio.run(
|
||||
runner._run_agent(
|
||||
message="ping",
|
||||
context_prompt="",
|
||||
history=[],
|
||||
source=source,
|
||||
session_id="session-1",
|
||||
session_key=session_key,
|
||||
)
|
||||
)
|
||||
|
||||
assert result["final_response"] == "ok"
|
||||
assert _CapturingAgent.last_init is not None
|
||||
assert _CapturingAgent.last_init["model"] == "gpt-5.4"
|
||||
assert _CapturingAgent.last_init["provider"] == "openai-codex"
|
||||
assert _CapturingAgent.last_init["api_mode"] == "codex_responses"
|
||||
assert _CapturingAgent.last_init["base_url"] == "https://chatgpt.com/backend-api/codex"
|
||||
assert _CapturingAgent.last_init["api_key"] == "***"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_background_task_prefers_session_override_over_global_runtime(monkeypatch):
|
||||
monkeypatch.setattr(gateway_run, "_load_gateway_config", lambda: {})
|
||||
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", _explode_runtime_resolution)
|
||||
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = _CapturingAgent
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
|
||||
_CapturingAgent.last_init = None
|
||||
runner = _make_runner()
|
||||
|
||||
adapter = AsyncMock()
|
||||
adapter.send = AsyncMock()
|
||||
adapter.extract_media = MagicMock(return_value=([], "ok"))
|
||||
adapter.extract_images = MagicMock(return_value=([], "ok"))
|
||||
runner.adapters[Platform.TELEGRAM] = adapter
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
user_id="12345",
|
||||
chat_id="67890",
|
||||
user_name="testuser",
|
||||
)
|
||||
session_key = runner._session_key_for_source(source)
|
||||
runner._session_model_overrides[session_key] = _codex_override()
|
||||
|
||||
await runner._run_background_task("say hello", source, "bg_test")
|
||||
|
||||
assert _CapturingAgent.last_init is not None
|
||||
assert _CapturingAgent.last_init["model"] == "gpt-5.4"
|
||||
assert _CapturingAgent.last_init["provider"] == "openai-codex"
|
||||
assert _CapturingAgent.last_init["api_mode"] == "codex_responses"
|
||||
assert _CapturingAgent.last_init["base_url"] == "https://chatgpt.com/backend-api/codex"
|
||||
assert _CapturingAgent.last_init["api_key"] == "***"
|
||||
@@ -114,16 +114,16 @@ class TestSignalAdapterInit:
|
||||
|
||||
class TestSignalHelpers:
|
||||
def test_redact_phone_long(self):
|
||||
from gateway.platforms.signal import _redact_phone
|
||||
assert _redact_phone("+15551234567") == "+155****4567"
|
||||
from gateway.platforms.helpers import redact_phone
|
||||
assert redact_phone("+155****4567") == "+155****4567"
|
||||
|
||||
def test_redact_phone_short(self):
|
||||
from gateway.platforms.signal import _redact_phone
|
||||
assert _redact_phone("+12345") == "+1****45"
|
||||
from gateway.platforms.helpers import redact_phone
|
||||
assert redact_phone("+12345") == "+1****45"
|
||||
|
||||
def test_redact_phone_empty(self):
|
||||
from gateway.platforms.signal import _redact_phone
|
||||
assert _redact_phone("") == "<none>"
|
||||
from gateway.platforms.helpers import redact_phone
|
||||
assert redact_phone("") == "<none>"
|
||||
|
||||
def test_parse_comma_list(self):
|
||||
from gateway.platforms.signal import _parse_comma_list
|
||||
|
||||
+337
-2
@@ -1,11 +1,14 @@
|
||||
"""Tests for SMS (Twilio) platform integration.
|
||||
|
||||
Covers config loading, format/truncate, echo prevention,
|
||||
requirements check, and toolset verification.
|
||||
requirements check, toolset verification, and Twilio signature validation.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -213,3 +216,335 @@ class TestSmsToolset:
|
||||
from tools.cronjob_tools import CRONJOB_SCHEMA
|
||||
deliver_desc = CRONJOB_SCHEMA["parameters"]["properties"]["deliver"]["description"]
|
||||
assert "sms" in deliver_desc.lower()
|
||||
|
||||
|
||||
# ── Webhook host configuration ─────────────────────────────────────
|
||||
|
||||
class TestWebhookHostConfig:
|
||||
"""Verify SMS_WEBHOOK_HOST env var and default."""
|
||||
|
||||
def test_default_host_is_all_interfaces(self):
|
||||
from gateway.platforms.sms import DEFAULT_WEBHOOK_HOST
|
||||
assert DEFAULT_WEBHOOK_HOST == "0.0.0.0"
|
||||
|
||||
def test_host_from_env(self):
|
||||
from gateway.platforms.sms import SmsAdapter
|
||||
|
||||
env = {
|
||||
"TWILIO_ACCOUNT_SID": "ACtest",
|
||||
"TWILIO_AUTH_TOKEN": "tok",
|
||||
"TWILIO_PHONE_NUMBER": "+15550001111",
|
||||
"SMS_WEBHOOK_HOST": "127.0.0.1",
|
||||
}
|
||||
with patch.dict(os.environ, env):
|
||||
pc = PlatformConfig(enabled=True, api_key="tok")
|
||||
adapter = SmsAdapter(pc)
|
||||
assert adapter._webhook_host == "127.0.0.1"
|
||||
|
||||
def test_webhook_url_from_env(self):
|
||||
from gateway.platforms.sms import SmsAdapter
|
||||
|
||||
env = {
|
||||
"TWILIO_ACCOUNT_SID": "ACtest",
|
||||
"TWILIO_AUTH_TOKEN": "tok",
|
||||
"TWILIO_PHONE_NUMBER": "+15550001111",
|
||||
"SMS_WEBHOOK_URL": "https://example.com/webhooks/twilio",
|
||||
}
|
||||
with patch.dict(os.environ, env):
|
||||
pc = PlatformConfig(enabled=True, api_key="tok")
|
||||
adapter = SmsAdapter(pc)
|
||||
assert adapter._webhook_url == "https://example.com/webhooks/twilio"
|
||||
|
||||
def test_webhook_url_stripped(self):
|
||||
from gateway.platforms.sms import SmsAdapter
|
||||
|
||||
env = {
|
||||
"TWILIO_ACCOUNT_SID": "ACtest",
|
||||
"TWILIO_AUTH_TOKEN": "tok",
|
||||
"TWILIO_PHONE_NUMBER": "+15550001111",
|
||||
"SMS_WEBHOOK_URL": " https://example.com/webhooks/twilio ",
|
||||
}
|
||||
with patch.dict(os.environ, env):
|
||||
pc = PlatformConfig(enabled=True, api_key="tok")
|
||||
adapter = SmsAdapter(pc)
|
||||
assert adapter._webhook_url == "https://example.com/webhooks/twilio"
|
||||
|
||||
|
||||
# ── Startup guard (fail-closed) ────────────────────────────────────
|
||||
|
||||
class TestStartupGuard:
|
||||
"""Adapter must refuse to start without SMS_WEBHOOK_URL."""
|
||||
|
||||
def _make_adapter(self, extra_env=None):
|
||||
from gateway.platforms.sms import SmsAdapter
|
||||
|
||||
env = {
|
||||
"TWILIO_ACCOUNT_SID": "ACtest",
|
||||
"TWILIO_AUTH_TOKEN": "tok",
|
||||
"TWILIO_PHONE_NUMBER": "+15550001111",
|
||||
}
|
||||
if extra_env:
|
||||
env.update(extra_env)
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
pc = PlatformConfig(enabled=True, api_key="tok")
|
||||
adapter = SmsAdapter(pc)
|
||||
return adapter
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refuses_start_without_webhook_url(self):
|
||||
adapter = self._make_adapter()
|
||||
result = await adapter.connect()
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insecure_flag_allows_start_without_url(self):
|
||||
mock_session = AsyncMock()
|
||||
with patch.dict(os.environ, {"SMS_INSECURE_NO_SIGNATURE": "true"}), \
|
||||
patch("aiohttp.web.AppRunner") as mock_runner_cls, \
|
||||
patch("aiohttp.web.TCPSite") as mock_site_cls, \
|
||||
patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
mock_runner_cls.return_value.setup = AsyncMock()
|
||||
mock_runner_cls.return_value.cleanup = AsyncMock()
|
||||
mock_site_cls.return_value.start = AsyncMock()
|
||||
adapter = self._make_adapter()
|
||||
result = await adapter.connect()
|
||||
assert result is True
|
||||
await adapter.disconnect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_url_allows_start(self):
|
||||
mock_session = AsyncMock()
|
||||
with patch("aiohttp.web.AppRunner") as mock_runner_cls, \
|
||||
patch("aiohttp.web.TCPSite") as mock_site_cls, \
|
||||
patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
mock_runner_cls.return_value.setup = AsyncMock()
|
||||
mock_runner_cls.return_value.cleanup = AsyncMock()
|
||||
mock_site_cls.return_value.start = AsyncMock()
|
||||
adapter = self._make_adapter(
|
||||
extra_env={"SMS_WEBHOOK_URL": "https://example.com/webhooks/twilio"}
|
||||
)
|
||||
result = await adapter.connect()
|
||||
assert result is True
|
||||
await adapter.disconnect()
|
||||
|
||||
|
||||
# ── Twilio signature validation ────────────────────────────────────
|
||||
|
||||
def _compute_twilio_signature(auth_token, url, params):
|
||||
"""Reference implementation of Twilio's signature algorithm."""
|
||||
data_to_sign = url
|
||||
for key in sorted(params.keys()):
|
||||
data_to_sign += key + params[key]
|
||||
mac = hmac.new(
|
||||
auth_token.encode("utf-8"),
|
||||
data_to_sign.encode("utf-8"),
|
||||
hashlib.sha1,
|
||||
)
|
||||
return base64.b64encode(mac.digest()).decode("utf-8")
|
||||
|
||||
|
||||
class TestTwilioSignatureValidation:
|
||||
"""Unit tests for SmsAdapter._validate_twilio_signature."""
|
||||
|
||||
def _make_adapter(self, auth_token="test_token_secret"):
|
||||
from gateway.platforms.sms import SmsAdapter
|
||||
|
||||
env = {
|
||||
"TWILIO_ACCOUNT_SID": "ACtest",
|
||||
"TWILIO_AUTH_TOKEN": auth_token,
|
||||
"TWILIO_PHONE_NUMBER": "+15550001111",
|
||||
}
|
||||
with patch.dict(os.environ, env):
|
||||
pc = PlatformConfig(enabled=True, api_key=auth_token)
|
||||
adapter = SmsAdapter(pc)
|
||||
return adapter
|
||||
|
||||
def test_valid_signature_accepted(self):
|
||||
adapter = self._make_adapter()
|
||||
url = "https://example.com/webhooks/twilio"
|
||||
params = {"From": "+15551234567", "Body": "hello", "To": "+15550001111"}
|
||||
sig = _compute_twilio_signature("test_token_secret", url, params)
|
||||
assert adapter._validate_twilio_signature(url, params, sig) is True
|
||||
|
||||
def test_invalid_signature_rejected(self):
|
||||
adapter = self._make_adapter()
|
||||
url = "https://example.com/webhooks/twilio"
|
||||
params = {"From": "+15551234567", "Body": "hello"}
|
||||
assert adapter._validate_twilio_signature(url, params, "badsig") is False
|
||||
|
||||
def test_wrong_token_rejected(self):
|
||||
adapter = self._make_adapter(auth_token="correct_token")
|
||||
url = "https://example.com/webhooks/twilio"
|
||||
params = {"From": "+15551234567", "Body": "hello"}
|
||||
sig = _compute_twilio_signature("wrong_token", url, params)
|
||||
assert adapter._validate_twilio_signature(url, params, sig) is False
|
||||
|
||||
def test_params_sorted_by_key(self):
|
||||
"""Signature must be computed with params sorted alphabetically."""
|
||||
adapter = self._make_adapter()
|
||||
url = "https://example.com/webhooks/twilio"
|
||||
params = {"Zebra": "last", "Alpha": "first", "Middle": "mid"}
|
||||
sig = _compute_twilio_signature("test_token_secret", url, params)
|
||||
assert adapter._validate_twilio_signature(url, params, sig) is True
|
||||
|
||||
def test_empty_param_values_included(self):
|
||||
"""Blank values must be included in signature computation."""
|
||||
adapter = self._make_adapter()
|
||||
url = "https://example.com/webhooks/twilio"
|
||||
params = {"From": "+15551234567", "Body": "", "SmsStatus": "received"}
|
||||
sig = _compute_twilio_signature("test_token_secret", url, params)
|
||||
assert adapter._validate_twilio_signature(url, params, sig) is True
|
||||
|
||||
def test_url_matters(self):
|
||||
"""Different URLs produce different signatures."""
|
||||
adapter = self._make_adapter()
|
||||
params = {"Body": "hello"}
|
||||
sig = _compute_twilio_signature(
|
||||
"test_token_secret", "https://a.com/webhooks/twilio", params
|
||||
)
|
||||
assert adapter._validate_twilio_signature(
|
||||
"https://b.com/webhooks/twilio", params, sig
|
||||
) is False
|
||||
|
||||
def test_port_variant_443_matches_without_port(self):
|
||||
"""Signature for https URL with :443 validates against URL without port."""
|
||||
adapter = self._make_adapter()
|
||||
params = {"From": "+15551234567", "Body": "hello"}
|
||||
sig = _compute_twilio_signature(
|
||||
"test_token_secret", "https://example.com:443/webhooks/twilio", params
|
||||
)
|
||||
assert adapter._validate_twilio_signature(
|
||||
"https://example.com/webhooks/twilio", params, sig
|
||||
) is True
|
||||
|
||||
def test_port_variant_without_port_matches_443(self):
|
||||
"""Signature for https URL without port validates against URL with :443."""
|
||||
adapter = self._make_adapter()
|
||||
params = {"From": "+15551234567", "Body": "hello"}
|
||||
sig = _compute_twilio_signature(
|
||||
"test_token_secret", "https://example.com/webhooks/twilio", params
|
||||
)
|
||||
assert adapter._validate_twilio_signature(
|
||||
"https://example.com:443/webhooks/twilio", params, sig
|
||||
) is True
|
||||
|
||||
def test_non_standard_port_no_variant(self):
|
||||
"""Non-standard port must NOT match URL without port."""
|
||||
adapter = self._make_adapter()
|
||||
params = {"From": "+15551234567", "Body": "hello"}
|
||||
sig = _compute_twilio_signature(
|
||||
"test_token_secret", "https://example.com/webhooks/twilio", params
|
||||
)
|
||||
assert adapter._validate_twilio_signature(
|
||||
"https://example.com:8080/webhooks/twilio", params, sig
|
||||
) is False
|
||||
|
||||
def test_port_variant_http_80(self):
|
||||
"""Port variant also works for http with port 80."""
|
||||
adapter = self._make_adapter()
|
||||
params = {"From": "+15551234567", "Body": "hello"}
|
||||
sig = _compute_twilio_signature(
|
||||
"test_token_secret", "http://example.com:80/webhooks/twilio", params
|
||||
)
|
||||
assert adapter._validate_twilio_signature(
|
||||
"http://example.com/webhooks/twilio", params, sig
|
||||
) is True
|
||||
|
||||
|
||||
# ── Webhook signature enforcement (handler-level) ──────────────────
|
||||
|
||||
class TestWebhookSignatureEnforcement:
|
||||
"""Integration tests for signature validation in _handle_webhook."""
|
||||
|
||||
def _make_adapter(self, webhook_url=""):
|
||||
from gateway.platforms.sms import SmsAdapter
|
||||
|
||||
env = {
|
||||
"TWILIO_ACCOUNT_SID": "ACtest",
|
||||
"TWILIO_AUTH_TOKEN": "test_token_secret",
|
||||
"TWILIO_PHONE_NUMBER": "+15550001111",
|
||||
"SMS_WEBHOOK_URL": webhook_url,
|
||||
}
|
||||
with patch.dict(os.environ, env):
|
||||
pc = PlatformConfig(enabled=True, api_key="test_token_secret")
|
||||
adapter = SmsAdapter(pc)
|
||||
adapter._message_handler = AsyncMock()
|
||||
return adapter
|
||||
|
||||
def _mock_request(self, body, headers=None):
|
||||
request = MagicMock()
|
||||
request.read = AsyncMock(return_value=body)
|
||||
request.headers = headers or {}
|
||||
return request
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insecure_flag_skips_validation(self):
|
||||
"""With SMS_INSECURE_NO_SIGNATURE=true and no URL, requests are accepted."""
|
||||
env = {"SMS_INSECURE_NO_SIGNATURE": "true"}
|
||||
with patch.dict(os.environ, env):
|
||||
adapter = self._make_adapter(webhook_url="")
|
||||
body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123"
|
||||
request = self._mock_request(body)
|
||||
resp = await adapter._handle_webhook(request)
|
||||
assert resp.status == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insecure_flag_with_url_still_validates(self):
|
||||
"""When both SMS_WEBHOOK_URL and SMS_INSECURE_NO_SIGNATURE are set,
|
||||
validation stays active (URL takes precedence)."""
|
||||
adapter = self._make_adapter(webhook_url="https://example.com/webhooks/twilio")
|
||||
body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123"
|
||||
request = self._mock_request(body, headers={})
|
||||
resp = await adapter._handle_webhook(request)
|
||||
assert resp.status == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_signature_returns_403(self):
|
||||
adapter = self._make_adapter(webhook_url="https://example.com/webhooks/twilio")
|
||||
body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123"
|
||||
request = self._mock_request(body, headers={})
|
||||
resp = await adapter._handle_webhook(request)
|
||||
assert resp.status == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_signature_returns_403(self):
|
||||
adapter = self._make_adapter(webhook_url="https://example.com/webhooks/twilio")
|
||||
body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123"
|
||||
request = self._mock_request(body, headers={"X-Twilio-Signature": "invalid"})
|
||||
resp = await adapter._handle_webhook(request)
|
||||
assert resp.status == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_signature_returns_200(self):
|
||||
webhook_url = "https://example.com/webhooks/twilio"
|
||||
adapter = self._make_adapter(webhook_url=webhook_url)
|
||||
params = {
|
||||
"From": "+15551234567",
|
||||
"To": "+15550001111",
|
||||
"Body": "hello",
|
||||
"MessageSid": "SM123",
|
||||
}
|
||||
sig = _compute_twilio_signature("test_token_secret", webhook_url, params)
|
||||
body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123"
|
||||
request = self._mock_request(body, headers={"X-Twilio-Signature": sig})
|
||||
resp = await adapter._handle_webhook(request)
|
||||
assert resp.status == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_port_variant_signature_returns_200(self):
|
||||
"""Signature computed with :443 should pass when URL configured without port."""
|
||||
webhook_url = "https://example.com/webhooks/twilio"
|
||||
adapter = self._make_adapter(webhook_url=webhook_url)
|
||||
params = {
|
||||
"From": "+15551234567",
|
||||
"To": "+15550001111",
|
||||
"Body": "hello",
|
||||
"MessageSid": "SM123",
|
||||
}
|
||||
sig = _compute_twilio_signature(
|
||||
"test_token_secret", "https://example.com:443/webhooks/twilio", params
|
||||
)
|
||||
body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123"
|
||||
request = self._mock_request(body, headers={"X-Twilio-Signature": sig})
|
||||
resp = await adapter._handle_webhook(request)
|
||||
assert resp.status == 200
|
||||
|
||||
@@ -505,3 +505,81 @@ class TestSegmentBreakOnToolBoundary:
|
||||
assert len(sent_texts) == 3
|
||||
assert sent_texts[0].startswith(prefix)
|
||||
assert sum(len(t) for t in sent_texts[1:]) == len(tail)
|
||||
|
||||
|
||||
class TestInterimCommentaryMessages:
|
||||
@pytest.mark.asyncio
|
||||
async def test_commentary_message_stays_separate_from_final_stream(self):
|
||||
adapter = MagicMock()
|
||||
adapter.send = AsyncMock(side_effect=[
|
||||
SimpleNamespace(success=True, message_id="msg_1"),
|
||||
SimpleNamespace(success=True, message_id="msg_2"),
|
||||
])
|
||||
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True))
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
consumer = GatewayStreamConsumer(
|
||||
adapter,
|
||||
"chat_123",
|
||||
StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5),
|
||||
)
|
||||
|
||||
consumer.on_commentary("I'll inspect the repository first.")
|
||||
consumer.on_delta("Done.")
|
||||
consumer.finish()
|
||||
|
||||
await consumer.run()
|
||||
|
||||
sent_texts = [call[1]["content"] for call in adapter.send.call_args_list]
|
||||
assert sent_texts == ["I'll inspect the repository first.", "Done."]
|
||||
assert consumer.final_response_sent is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failed_final_send_does_not_mark_final_response_sent(self):
|
||||
adapter = MagicMock()
|
||||
adapter.send = AsyncMock(return_value=SimpleNamespace(success=False, message_id=None))
|
||||
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True))
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
consumer = GatewayStreamConsumer(
|
||||
adapter,
|
||||
"chat_123",
|
||||
StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5),
|
||||
)
|
||||
|
||||
consumer.on_delta("Done.")
|
||||
consumer.finish()
|
||||
|
||||
await consumer.run()
|
||||
|
||||
assert consumer.final_response_sent is False
|
||||
assert consumer.already_sent is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_without_message_id_marks_visible_and_sends_only_tail(self):
|
||||
adapter = MagicMock()
|
||||
adapter.send = AsyncMock(side_effect=[
|
||||
SimpleNamespace(success=True, message_id=None),
|
||||
SimpleNamespace(success=True, message_id=None),
|
||||
])
|
||||
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True))
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
consumer = GatewayStreamConsumer(
|
||||
adapter,
|
||||
"chat_123",
|
||||
StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5, cursor=" ▉"),
|
||||
)
|
||||
|
||||
consumer.on_delta("Hello")
|
||||
task = asyncio.create_task(consumer.run())
|
||||
await asyncio.sleep(0.08)
|
||||
consumer.on_delta(" world")
|
||||
await asyncio.sleep(0.08)
|
||||
consumer.finish()
|
||||
await task
|
||||
|
||||
sent_texts = [call[1]["content"] for call in adapter.send.call_args_list]
|
||||
assert sent_texts == ["Hello ▉", "world"]
|
||||
assert consumer.already_sent is True
|
||||
assert consumer.final_response_sent is True
|
||||
|
||||
@@ -6,7 +6,9 @@ from unittest.mock import AsyncMock, patch
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from gateway.config import GatewayConfig, load_gateway_config
|
||||
from gateway.config import GatewayConfig, Platform, load_gateway_config
|
||||
from gateway.platforms.base import MessageEvent, MessageType
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
def test_gateway_config_stt_disabled_from_dict_nested():
|
||||
@@ -69,3 +71,46 @@ async def test_enrich_message_with_transcription_avoids_bogus_no_provider_messag
|
||||
assert "No STT provider is configured" not in result
|
||||
assert "trouble transcribing" in result
|
||||
assert "caption" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_inbound_message_text_transcribes_queued_voice_event():
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = GatewayRunner.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(stt_enabled=True)
|
||||
runner.adapters = {}
|
||||
runner._model = "test-model"
|
||||
runner._base_url = ""
|
||||
runner._has_setup_skill = lambda: False
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="123",
|
||||
chat_type="dm",
|
||||
)
|
||||
event = MessageEvent(
|
||||
text="",
|
||||
message_type=MessageType.VOICE,
|
||||
source=source,
|
||||
media_urls=["/tmp/queued-voice.ogg"],
|
||||
media_types=["audio/ogg"],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"tools.transcription_tools.transcribe_audio",
|
||||
return_value={
|
||||
"success": True,
|
||||
"transcript": "queued voice transcript",
|
||||
"provider": "local_command",
|
||||
},
|
||||
):
|
||||
result = await runner._prepare_inbound_message_text(
|
||||
event=event,
|
||||
source=source,
|
||||
history=[],
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert "queued voice transcript" in result
|
||||
assert "voice message" in result.lower()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user