Compare commits
188 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c45d18265c | |||
| 1c6d144a10 | |||
| 496e378b10 | |||
| 03f23f10e1 | |||
| 2b4abf8d9c | |||
| 8bcb8b8e87 | |||
| f07b35acba | |||
| 363d5d57be | |||
| 7ccdb74364 | |||
| 6c115440fd | |||
| 4fb42d0193 | |||
| f83e86d826 | |||
| 0bea603510 | |||
| 360b21ce95 | |||
| 37a1c75716 | |||
| c6e1add6f1 | |||
| 2c99b4e79b | |||
| 71036a7a75 | |||
| 7e28b7b5d5 | |||
| a093eb47f7 | |||
| f72faf191c | |||
| f8dbe0ffd1 | |||
| 7e60b09274 | |||
| 970192f183 | |||
| 5b8beb0ead | |||
| 7cec784b64 | |||
| be4f049f46 | |||
| 5b63bf7f9a | |||
| 4a65c9cd08 | |||
| 916fbf362c | |||
| b730c2955a | |||
| fd5cc6e1b4 | |||
| 1662b7f82a | |||
| e3b395e17d | |||
| 0cdf5232ae | |||
| 49bba1096e | |||
| fd3e855d58 | |||
| 5fc5ced972 | |||
| 0e315a6f02 | |||
| 6d2fa03837 | |||
| f3ae1d765d | |||
| 49da1ff1b1 | |||
| 76a1e6e0fe | |||
| 21bb2547c6 | |||
| 58413c411f | |||
| cc12ab8290 | |||
| 74e883ca37 | |||
| e376a9b2c9 | |||
| 2629927032 | |||
| aedf6c7964 | |||
| 5a1cce53e4 | |||
| 419b719c2b | |||
| f3fb3eded4 | |||
| d7164603da | |||
| e683c9db90 | |||
| 7663c98c1e | |||
| 714809634f | |||
| f4c7086035 | |||
| 0b143f2ea3 | |||
| c8e4dcf412 | |||
| 00dd5cc491 | |||
| 9bb8cb8d83 | |||
| 5dea7e1ebc | |||
| b1e2b5ea74 | |||
| 96f9b91489 | |||
| bb3a4fc68e | |||
| 429da6cbce | |||
| 4f2f09affa | |||
| af7d809354 | |||
| fbfa7c27d5 | |||
| 1bcc87a153 | |||
| 437feabb74 | |||
| 957485876b | |||
| c6c769772f | |||
| f63cc3c0c7 | |||
| cff9b7ffab | |||
| 96c060018a | |||
| 04baab5422 | |||
| 9a0dfb5a6d | |||
| 68528068ec | |||
| 8dd738c2e6 | |||
| 0f597dd127 | |||
| 5a8b5f149d | |||
| f4f8b9579e | |||
| c6ff5e5d30 | |||
| 9aedab00f4 | |||
| 19292eb8bf | |||
| 6d5f607e48 | |||
| 52bd3bd200 | |||
| 568be71003 | |||
| a2f46e4665 | |||
| 7d426e6536 | |||
| 30ae68dd33 | |||
| 9afe1784bd | |||
| 94f5979cc2 | |||
| 738f0bac13 | |||
| 37bb4f807b | |||
| b577697189 | |||
| 5b22e61cfa | |||
| b39ea46488 | |||
| aad40f6d0c | |||
| 41c233cb99 | |||
| 1f1f297528 | |||
| 1495647636 | |||
| 4e78963fe8 | |||
| f92298fe95 | |||
| eaa21a8275 | |||
| a420235b66 | |||
| 6c3565df57 | |||
| 51d826f889 | |||
| a04854800f | |||
| 940237c6fd | |||
| 95ee453bc0 | |||
| 38cce22e2c | |||
| 7368854398 | |||
| 38ccd9eb95 | |||
| 45034b746f | |||
| a7588830d4 | |||
| 9431f82aff | |||
| 6da952bc50 | |||
| 8779a268a7 | |||
| 0848a79476 | |||
| 871313ae2d | |||
| 13d7ff3420 | |||
| d5023d36d8 | |||
| 0602ff8f58 | |||
| 8104f400f8 | |||
| 1ed00496f2 | |||
| f92a0b8596 | |||
| 1723e8e998 | |||
| 07148cac9a | |||
| 0fc0c1c83b | |||
| 5075717949 | |||
| f783986f5a | |||
| bda9aa17cb | |||
| 8394b5ddd2 | |||
| d416a69288 | |||
| 4caa635803 | |||
| a64d8a83e1 | |||
| dfde4058cf | |||
| 13b3ea6484 | |||
| b87d00288d | |||
| 08e2a1a51e | |||
| 42e7755d4c | |||
| 68954b7c03 | |||
| 9634e20e15 | |||
| 2d0d05a337 | |||
| 3b554bf839 | |||
| 69a0092c38 | |||
| c3141429b7 | |||
| 769ec1ee1a | |||
| 3237733ca5 | |||
| 54d5138a54 | |||
| 6dcb3c4774 | |||
| 096b3f9f12 | |||
| a3aed1bd26 | |||
| 4970705ed3 | |||
| 2194425918 | |||
| 3878495972 | |||
| 4e40e93b98 | |||
| 122925a6f2 | |||
| e79cc88985 | |||
| e053433c84 | |||
| 1789c2699a | |||
| aed9b90ae3 | |||
| 6b437f7934 | |||
| f91fffbe33 | |||
| 49d8c9557f | |||
| c3854e0f85 | |||
| 97308707e9 | |||
| e9168f917e | |||
| c8bbd29aae | |||
| 73eb59db8d | |||
| 127b4caf0d | |||
| 1780ad24b1 | |||
| 775a46ce75 | |||
| 6f8e426275 | |||
| 88dbbfe982 | |||
| 88845b99d2 | |||
| 18d8e91a5a | |||
| 1773e3d647 | |||
| 7f7b02b764 | |||
| 7d499c75db | |||
| 997e219c14 | |||
| ab7b407224 | |||
| 95220facdf | |||
| 5ea9bf70de | |||
| 67e4d43ea1 |
+2
-1
@@ -13,7 +13,8 @@ COPY . /opt/hermes
|
||||
WORKDIR /opt/hermes
|
||||
|
||||
# Install Python and Node dependencies in one layer, no cache
|
||||
RUN pip install --no-cache-dir -e ".[all]" --break-system-packages && \
|
||||
RUN pip install --no-cache-dir uv --break-system-packages && \
|
||||
uv pip install --system --break-system-packages --no-cache -e ".[all]" && \
|
||||
npm install --prefer-offline --no-audit && \
|
||||
npx playwright install --with-deps chromium --only-shell && \
|
||||
cd /opt/hermes/scripts/whatsapp-bridge && \
|
||||
|
||||
@@ -33,8 +33,10 @@ Use any model you want — [Nous Portal](https://portal.nousresearch.com), [Open
|
||||
curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash
|
||||
```
|
||||
|
||||
Works on Linux, macOS, and WSL2. The installer handles everything — Python, Node.js, dependencies, and the `hermes` command. No prerequisites except git.
|
||||
Works on Linux, macOS, WSL2, and Android via Termux. The installer handles the platform-specific setup for you.
|
||||
|
||||
> **Android / Termux:** The tested manual path is documented in the [Termux guide](https://hermes-agent.nousresearch.com/docs/getting-started/termux). On Termux, Hermes installs a curated `.[termux]` extra because the full `.[all]` extra currently pulls Android-incompatible voice dependencies.
|
||||
>
|
||||
> **Windows:** Native Windows is not supported. Please install [WSL2](https://learn.microsoft.com/en-us/windows/wsl/install) and run the command above.
|
||||
|
||||
After installation:
|
||||
|
||||
@@ -36,6 +36,7 @@ from acp.schema import (
|
||||
SessionCapabilities,
|
||||
SessionForkCapabilities,
|
||||
SessionListCapabilities,
|
||||
SessionResumeCapabilities,
|
||||
SessionInfo,
|
||||
TextContentBlock,
|
||||
UnstructuredCommandInput,
|
||||
@@ -245,9 +246,11 @@ class HermesACPAgent(acp.Agent):
|
||||
protocol_version=acp.PROTOCOL_VERSION,
|
||||
agent_info=Implementation(name="hermes-agent", version=HERMES_VERSION),
|
||||
agent_capabilities=AgentCapabilities(
|
||||
load_session=True,
|
||||
session_capabilities=SessionCapabilities(
|
||||
fork=SessionForkCapabilities(),
|
||||
list=SessionListCapabilities(),
|
||||
resume=SessionResumeCapabilities(),
|
||||
),
|
||||
),
|
||||
auth_methods=auth_methods,
|
||||
@@ -451,14 +454,13 @@ class HermesACPAgent(acp.Agent):
|
||||
await conn.session_update(session_id, update)
|
||||
|
||||
usage = None
|
||||
usage_data = result.get("usage")
|
||||
if usage_data and isinstance(usage_data, dict):
|
||||
if any(result.get(key) is not None for key in ("prompt_tokens", "completion_tokens", "total_tokens")):
|
||||
usage = Usage(
|
||||
input_tokens=usage_data.get("prompt_tokens", 0),
|
||||
output_tokens=usage_data.get("completion_tokens", 0),
|
||||
total_tokens=usage_data.get("total_tokens", 0),
|
||||
thought_tokens=usage_data.get("reasoning_tokens"),
|
||||
cached_read_tokens=usage_data.get("cached_tokens"),
|
||||
input_tokens=result.get("prompt_tokens", 0),
|
||||
output_tokens=result.get("completion_tokens", 0),
|
||||
total_tokens=result.get("total_tokens", 0),
|
||||
thought_tokens=result.get("reasoning_tokens"),
|
||||
cached_read_tokens=result.get("cache_read_tokens"),
|
||||
)
|
||||
|
||||
stop_reason = "cancelled" if state.cancel_event and state.cancel_event.is_set() else "end_turn"
|
||||
|
||||
+54
-86
@@ -74,8 +74,11 @@ def _get_anthropic_max_output(model: str) -> int:
|
||||
model IDs (claude-sonnet-4-5-20250929) and variant suffixes (:1m, :fast)
|
||||
resolve correctly. Longest-prefix match wins to avoid e.g. "claude-3-5"
|
||||
matching before "claude-3-5-sonnet".
|
||||
|
||||
Normalizes dots to hyphens so that model names like
|
||||
``anthropic/claude-opus-4.6`` match the ``claude-opus-4-6`` table key.
|
||||
"""
|
||||
m = model.lower()
|
||||
m = model.lower().replace(".", "-")
|
||||
best_key = ""
|
||||
best_val = _ANTHROPIC_DEFAULT_OUTPUT_LIMIT
|
||||
for key, val in _ANTHROPIC_OUTPUT_LIMITS.items():
|
||||
@@ -95,6 +98,15 @@ _COMMON_BETAS = [
|
||||
"interleaved-thinking-2025-05-14",
|
||||
"fine-grained-tool-streaming-2025-05-14",
|
||||
]
|
||||
# MiniMax's Anthropic-compatible endpoints fail tool-use requests when
|
||||
# the fine-grained tool streaming beta is present. Omit it so tool calls
|
||||
# fall back to the provider's default response path.
|
||||
_TOOL_STREAMING_BETA = "fine-grained-tool-streaming-2025-05-14"
|
||||
|
||||
# Fast mode beta — enables the ``speed: "fast"`` request parameter for
|
||||
# significantly higher output token throughput on Opus 4.6 (~2.5x).
|
||||
# See https://platform.claude.com/docs/en/build-with-claude/fast-mode
|
||||
_FAST_MODE_BETA = "fast-mode-2026-02-01"
|
||||
|
||||
# Additional beta headers required for OAuth/subscription auth.
|
||||
# Matches what Claude Code (and pi-ai / OpenCode) send.
|
||||
@@ -204,6 +216,19 @@ def _requires_bearer_auth(base_url: str | None) -> bool:
|
||||
return normalized.startswith(("https://api.minimax.io/anthropic", "https://api.minimaxi.com/anthropic"))
|
||||
|
||||
|
||||
def _common_betas_for_base_url(base_url: str | None) -> list[str]:
|
||||
"""Return the beta headers that are safe for the configured endpoint.
|
||||
|
||||
MiniMax's Anthropic-compatible endpoints (Bearer-auth) reject requests
|
||||
that include Anthropic's ``fine-grained-tool-streaming`` beta — every
|
||||
tool-use message triggers a connection error. Strip that beta for
|
||||
Bearer-auth endpoints while keeping all other betas intact.
|
||||
"""
|
||||
if _requires_bearer_auth(base_url):
|
||||
return [b for b in _COMMON_BETAS if b != _TOOL_STREAMING_BETA]
|
||||
return _COMMON_BETAS
|
||||
|
||||
|
||||
def build_anthropic_client(api_key: str, base_url: str = None):
|
||||
"""Create an Anthropic client, auto-detecting setup-tokens vs API keys.
|
||||
|
||||
@@ -222,6 +247,7 @@ def build_anthropic_client(api_key: str, base_url: str = None):
|
||||
}
|
||||
if normalized_base_url:
|
||||
kwargs["base_url"] = normalized_base_url
|
||||
common_betas = _common_betas_for_base_url(normalized_base_url)
|
||||
|
||||
if _requires_bearer_auth(normalized_base_url):
|
||||
# Some Anthropic-compatible providers (e.g. MiniMax) expect the API key in
|
||||
@@ -231,21 +257,21 @@ def build_anthropic_client(api_key: str, base_url: str = None):
|
||||
# not use Anthropic's sk-ant-api prefix and would otherwise be misread as
|
||||
# Anthropic OAuth/setup tokens.
|
||||
kwargs["auth_token"] = api_key
|
||||
if _COMMON_BETAS:
|
||||
kwargs["default_headers"] = {"anthropic-beta": ",".join(_COMMON_BETAS)}
|
||||
if common_betas:
|
||||
kwargs["default_headers"] = {"anthropic-beta": ",".join(common_betas)}
|
||||
elif _is_third_party_anthropic_endpoint(base_url):
|
||||
# Third-party proxies (Azure AI Foundry, AWS Bedrock, etc.) use their
|
||||
# own API keys with x-api-key auth. Skip OAuth detection — their keys
|
||||
# don't follow Anthropic's sk-ant-* prefix convention and would be
|
||||
# misclassified as OAuth tokens.
|
||||
kwargs["api_key"] = api_key
|
||||
if _COMMON_BETAS:
|
||||
kwargs["default_headers"] = {"anthropic-beta": ",".join(_COMMON_BETAS)}
|
||||
if common_betas:
|
||||
kwargs["default_headers"] = {"anthropic-beta": ",".join(common_betas)}
|
||||
elif _is_oauth_token(api_key):
|
||||
# OAuth access token / setup-token → Bearer auth + Claude Code identity.
|
||||
# Anthropic routes OAuth requests based on user-agent and headers;
|
||||
# without Claude Code's fingerprint, requests get intermittent 500s.
|
||||
all_betas = _COMMON_BETAS + _OAUTH_ONLY_BETAS
|
||||
all_betas = common_betas + _OAUTH_ONLY_BETAS
|
||||
kwargs["auth_token"] = api_key
|
||||
kwargs["default_headers"] = {
|
||||
"anthropic-beta": ",".join(all_betas),
|
||||
@@ -255,8 +281,8 @@ def build_anthropic_client(api_key: str, base_url: str = None):
|
||||
else:
|
||||
# Regular API key → x-api-key header + common betas
|
||||
kwargs["api_key"] = api_key
|
||||
if _COMMON_BETAS:
|
||||
kwargs["default_headers"] = {"anthropic-beta": ",".join(_COMMON_BETAS)}
|
||||
if common_betas:
|
||||
kwargs["default_headers"] = {"anthropic-beta": ",".join(common_betas)}
|
||||
|
||||
return _anthropic_sdk.Anthropic(**kwargs)
|
||||
|
||||
@@ -485,35 +511,6 @@ def _prefer_refreshable_claude_code_token(env_token: str, creds: Optional[Dict[s
|
||||
return None
|
||||
|
||||
|
||||
def get_anthropic_token_source(token: Optional[str] = None) -> str:
|
||||
"""Best-effort source classification for an Anthropic credential token."""
|
||||
token = (token or "").strip()
|
||||
if not token:
|
||||
return "none"
|
||||
|
||||
env_token = os.getenv("ANTHROPIC_TOKEN", "").strip()
|
||||
if env_token and env_token == token:
|
||||
return "anthropic_token_env"
|
||||
|
||||
cc_env_token = os.getenv("CLAUDE_CODE_OAUTH_TOKEN", "").strip()
|
||||
if cc_env_token and cc_env_token == token:
|
||||
return "claude_code_oauth_token_env"
|
||||
|
||||
creds = read_claude_code_credentials()
|
||||
if creds and creds.get("accessToken") == token:
|
||||
return str(creds.get("source") or "claude_code_credentials")
|
||||
|
||||
managed_key = read_claude_managed_key()
|
||||
if managed_key and managed_key == token:
|
||||
return "claude_json_primary_api_key"
|
||||
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY", "").strip()
|
||||
if api_key and api_key == token:
|
||||
return "anthropic_api_key_env"
|
||||
|
||||
return "unknown"
|
||||
|
||||
|
||||
def resolve_anthropic_token() -> Optional[str]:
|
||||
"""Resolve an Anthropic token from all available sources.
|
||||
|
||||
@@ -720,21 +717,6 @@ def run_hermes_oauth_login_pure() -> Optional[Dict[str, Any]]:
|
||||
}
|
||||
|
||||
|
||||
def _save_hermes_oauth_credentials(access_token: str, refresh_token: str, expires_at_ms: int) -> None:
|
||||
"""Save OAuth credentials to ~/.hermes/.anthropic_oauth.json."""
|
||||
data = {
|
||||
"accessToken": access_token,
|
||||
"refreshToken": refresh_token,
|
||||
"expiresAt": expires_at_ms,
|
||||
}
|
||||
try:
|
||||
_HERMES_OAUTH_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
_HERMES_OAUTH_FILE.write_text(json.dumps(data, indent=2), encoding="utf-8")
|
||||
_HERMES_OAUTH_FILE.chmod(0o600)
|
||||
except (OSError, IOError) as e:
|
||||
logger.debug("Failed to save Hermes OAuth credentials: %s", e)
|
||||
|
||||
|
||||
def read_hermes_oauth_credentials() -> Optional[Dict[str, Any]]:
|
||||
"""Read Hermes-managed OAuth credentials from ~/.hermes/.anthropic_oauth.json."""
|
||||
if _HERMES_OAUTH_FILE.exists():
|
||||
@@ -783,39 +765,6 @@ def _sanitize_tool_id(tool_id: str) -> str:
|
||||
return sanitized or "tool_0"
|
||||
|
||||
|
||||
def _convert_openai_image_part_to_anthropic(part: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""Convert an OpenAI-style image block to Anthropic's image source format."""
|
||||
image_data = part.get("image_url", {})
|
||||
url = image_data.get("url", "") if isinstance(image_data, dict) else str(image_data)
|
||||
if not isinstance(url, str) or not url.strip():
|
||||
return None
|
||||
url = url.strip()
|
||||
|
||||
if url.startswith("data:"):
|
||||
header, sep, data = url.partition(",")
|
||||
if sep and ";base64" in header:
|
||||
media_type = header[5:].split(";", 1)[0] or "image/png"
|
||||
return {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": data,
|
||||
},
|
||||
}
|
||||
|
||||
if url.startswith(("http://", "https://")):
|
||||
return {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "url",
|
||||
"url": url,
|
||||
},
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def convert_tools_to_anthropic(tools: List[Dict]) -> List[Dict]:
|
||||
"""Convert OpenAI tool definitions to Anthropic format."""
|
||||
if not tools:
|
||||
@@ -1235,6 +1184,7 @@ def build_anthropic_kwargs(
|
||||
preserve_dots: bool = False,
|
||||
context_length: Optional[int] = None,
|
||||
base_url: str | None = None,
|
||||
fast_mode: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build kwargs for anthropic.messages.create().
|
||||
|
||||
@@ -1268,6 +1218,10 @@ def build_anthropic_kwargs(
|
||||
|
||||
When *base_url* points to a third-party Anthropic-compatible endpoint,
|
||||
thinking block signatures are stripped (they are Anthropic-proprietary).
|
||||
|
||||
When *fast_mode* is True, adds ``speed: "fast"`` and the fast-mode beta
|
||||
header for ~2.5x faster output throughput on Opus 4.6. Currently only
|
||||
supported on native Anthropic endpoints (not third-party compatible ones).
|
||||
"""
|
||||
system, anthropic_messages = convert_messages_to_anthropic(messages, base_url=base_url)
|
||||
anthropic_tools = convert_tools_to_anthropic(tools) if tools else []
|
||||
@@ -1366,6 +1320,20 @@ def build_anthropic_kwargs(
|
||||
kwargs["temperature"] = 1
|
||||
kwargs["max_tokens"] = max(effective_max_tokens, budget + 4096)
|
||||
|
||||
# ── Fast mode (Opus 4.6 only) ────────────────────────────────────
|
||||
# Adds speed:"fast" + the fast-mode beta header for ~2.5x output speed.
|
||||
# Only for native Anthropic endpoints — third-party providers would
|
||||
# reject the unknown beta header and speed parameter.
|
||||
if fast_mode and not _is_third_party_anthropic_endpoint(base_url):
|
||||
kwargs["speed"] = "fast"
|
||||
# Build extra_headers with ALL applicable betas (the per-request
|
||||
# extra_headers override the client-level anthropic-beta header).
|
||||
betas = list(_common_betas_for_base_url(base_url))
|
||||
if is_oauth:
|
||||
betas.extend(_OAUTH_ONLY_BETAS)
|
||||
betas.append(_FAST_MODE_BETA)
|
||||
kwargs["extra_headers"] = {"anthropic-beta": ",".join(betas)}
|
||||
|
||||
return kwargs
|
||||
|
||||
|
||||
@@ -1427,4 +1395,4 @@ def normalize_anthropic_response(
|
||||
reasoning_details=reasoning_details or None,
|
||||
),
|
||||
finish_reason,
|
||||
)
|
||||
)
|
||||
|
||||
+41
-76
@@ -687,6 +687,15 @@ def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
if pconfig.auth_type != "api_key":
|
||||
continue
|
||||
if provider_id == "anthropic":
|
||||
# Only try anthropic when the user has explicitly configured it.
|
||||
# Without this gate, Claude Code credentials get silently used
|
||||
# as auxiliary fallback when the user's primary provider fails.
|
||||
try:
|
||||
from hermes_cli.auth import is_provider_explicitly_configured
|
||||
if not is_provider_explicitly_configured("anthropic"):
|
||||
continue
|
||||
except ImportError:
|
||||
pass
|
||||
return _try_anthropic()
|
||||
|
||||
pool_present, entry = _select_pool_entry(provider_id)
|
||||
@@ -702,7 +711,7 @@ def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
logger.debug("Auxiliary text client: %s (%s) via pool", pconfig.name, model)
|
||||
extra = {}
|
||||
if "api.kimi.com" in base_url.lower():
|
||||
extra["default_headers"] = {"User-Agent": "KimiCLI/1.3"}
|
||||
extra["default_headers"] = {"User-Agent": "KimiCLI/1.30.0"}
|
||||
elif "api.githubcopilot.com" in base_url.lower():
|
||||
from hermes_cli.models import copilot_default_headers
|
||||
|
||||
@@ -721,7 +730,7 @@ def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
logger.debug("Auxiliary text client: %s (%s)", pconfig.name, model)
|
||||
extra = {}
|
||||
if "api.kimi.com" in base_url.lower():
|
||||
extra["default_headers"] = {"User-Agent": "KimiCLI/1.3"}
|
||||
extra["default_headers"] = {"User-Agent": "KimiCLI/1.30.0"}
|
||||
elif "api.githubcopilot.com" in base_url.lower():
|
||||
from hermes_cli.models import copilot_default_headers
|
||||
|
||||
@@ -967,40 +976,6 @@ def _try_anthropic() -> Tuple[Optional[Any], Optional[str]]:
|
||||
return AnthropicAuxiliaryClient(real_client, model, token, base_url, is_oauth=is_oauth), model
|
||||
|
||||
|
||||
def _resolve_forced_provider(forced: str) -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
"""Resolve a specific forced provider. Returns (None, None) if creds missing."""
|
||||
if forced == "openrouter":
|
||||
client, model = _try_openrouter()
|
||||
if client is None:
|
||||
logger.warning("auxiliary.provider=openrouter but OPENROUTER_API_KEY not set")
|
||||
return client, model
|
||||
|
||||
if forced == "nous":
|
||||
client, model = _try_nous()
|
||||
if client is None:
|
||||
logger.warning("auxiliary.provider=nous but Nous Portal not configured (run: hermes auth)")
|
||||
return client, model
|
||||
|
||||
if forced == "codex":
|
||||
client, model = _try_codex()
|
||||
if client is None:
|
||||
logger.warning("auxiliary.provider=codex but no Codex OAuth token found (run: hermes model)")
|
||||
return client, model
|
||||
|
||||
if forced == "main":
|
||||
# "main" = skip OpenRouter/Nous, use the main chat model's credentials.
|
||||
for try_fn in (_try_custom_endpoint, _try_codex, _resolve_api_key_provider):
|
||||
client, model = try_fn()
|
||||
if client is not None:
|
||||
return client, model
|
||||
logger.warning("auxiliary.provider=main but no main endpoint credentials found")
|
||||
return None, None
|
||||
|
||||
# Unknown provider name — fall through to auto
|
||||
logger.warning("Unknown auxiliary.provider=%r, falling back to auto", forced)
|
||||
return None, None
|
||||
|
||||
|
||||
_AUTO_PROVIDER_LABELS = {
|
||||
"_try_openrouter": "openrouter",
|
||||
"_try_nous": "nous",
|
||||
@@ -1195,10 +1170,22 @@ def _to_async_client(sync_client, model: str):
|
||||
|
||||
async_kwargs["default_headers"] = copilot_default_headers()
|
||||
elif "api.kimi.com" in base_lower:
|
||||
async_kwargs["default_headers"] = {"User-Agent": "KimiCLI/1.3"}
|
||||
async_kwargs["default_headers"] = {"User-Agent": "KimiCLI/1.30.0"}
|
||||
return AsyncOpenAI(**async_kwargs), model
|
||||
|
||||
|
||||
def _normalize_resolved_model(model_name: Optional[str], provider: str) -> Optional[str]:
|
||||
"""Normalize a resolved model for the provider that will receive it."""
|
||||
if not model_name:
|
||||
return model_name
|
||||
try:
|
||||
from hermes_cli.model_normalize import normalize_model_for_provider
|
||||
|
||||
return normalize_model_for_provider(model_name, provider)
|
||||
except Exception:
|
||||
return model_name
|
||||
|
||||
|
||||
def resolve_provider_client(
|
||||
provider: str,
|
||||
model: str = None,
|
||||
@@ -1261,7 +1248,7 @@ def resolve_provider_client(
|
||||
logger.warning("resolve_provider_client: openrouter requested "
|
||||
"but OPENROUTER_API_KEY not set")
|
||||
return None, None
|
||||
final_model = model or default
|
||||
final_model = _normalize_resolved_model(model or default, provider)
|
||||
return (_to_async_client(client, final_model) if async_mode
|
||||
else (client, final_model))
|
||||
|
||||
@@ -1272,7 +1259,7 @@ def resolve_provider_client(
|
||||
logger.warning("resolve_provider_client: nous requested "
|
||||
"but Nous Portal not configured (run: hermes auth)")
|
||||
return None, None
|
||||
final_model = model or default
|
||||
final_model = _normalize_resolved_model(model or default, provider)
|
||||
return (_to_async_client(client, final_model) if async_mode
|
||||
else (client, final_model))
|
||||
|
||||
@@ -1286,7 +1273,7 @@ def resolve_provider_client(
|
||||
logger.warning("resolve_provider_client: openai-codex requested "
|
||||
"but no Codex OAuth token found (run: hermes model)")
|
||||
return None, None
|
||||
final_model = model or _CODEX_AUX_MODEL
|
||||
final_model = _normalize_resolved_model(model or _CODEX_AUX_MODEL, provider)
|
||||
raw_client = OpenAI(api_key=codex_token, base_url=_CODEX_AUX_BASE_URL)
|
||||
return (raw_client, final_model)
|
||||
# Standard path: wrap in CodexAuxiliaryClient adapter
|
||||
@@ -1295,7 +1282,7 @@ def resolve_provider_client(
|
||||
logger.warning("resolve_provider_client: openai-codex requested "
|
||||
"but no Codex OAuth token found (run: hermes model)")
|
||||
return None, None
|
||||
final_model = model or default
|
||||
final_model = _normalize_resolved_model(model or default, provider)
|
||||
return (_to_async_client(client, final_model) if async_mode
|
||||
else (client, final_model))
|
||||
|
||||
@@ -1314,10 +1301,13 @@ def resolve_provider_client(
|
||||
"but base_url is empty"
|
||||
)
|
||||
return None, None
|
||||
final_model = model or _read_main_model() or "gpt-4o-mini"
|
||||
final_model = _normalize_resolved_model(
|
||||
model or _read_main_model() or "gpt-4o-mini",
|
||||
provider,
|
||||
)
|
||||
extra = {}
|
||||
if "api.kimi.com" in custom_base.lower():
|
||||
extra["default_headers"] = {"User-Agent": "KimiCLI/1.3"}
|
||||
extra["default_headers"] = {"User-Agent": "KimiCLI/1.30.0"}
|
||||
elif "api.githubcopilot.com" in custom_base.lower():
|
||||
from hermes_cli.models import copilot_default_headers
|
||||
extra["default_headers"] = copilot_default_headers()
|
||||
@@ -1329,7 +1319,7 @@ def resolve_provider_client(
|
||||
_resolve_api_key_provider):
|
||||
client, default = try_fn()
|
||||
if client is not None:
|
||||
final_model = model or default
|
||||
final_model = _normalize_resolved_model(model or default, provider)
|
||||
return (_to_async_client(client, final_model) if async_mode
|
||||
else (client, final_model))
|
||||
logger.warning("resolve_provider_client: custom/main requested "
|
||||
@@ -1344,7 +1334,10 @@ def resolve_provider_client(
|
||||
custom_base = custom_entry.get("base_url", "").strip()
|
||||
custom_key = custom_entry.get("api_key", "").strip() or "no-key-required"
|
||||
if custom_base:
|
||||
final_model = model or _read_main_model() or "gpt-4o-mini"
|
||||
final_model = _normalize_resolved_model(
|
||||
model or _read_main_model() or "gpt-4o-mini",
|
||||
provider,
|
||||
)
|
||||
client = OpenAI(api_key=custom_key, base_url=custom_base)
|
||||
logger.debug(
|
||||
"resolve_provider_client: named custom provider %r (%s)",
|
||||
@@ -1376,7 +1369,7 @@ def resolve_provider_client(
|
||||
if client is None:
|
||||
logger.warning("resolve_provider_client: anthropic requested but no Anthropic credentials found")
|
||||
return None, None
|
||||
final_model = model or default_model
|
||||
final_model = _normalize_resolved_model(model or default_model, provider)
|
||||
return (_to_async_client(client, final_model) if async_mode else (client, final_model))
|
||||
|
||||
creds = resolve_api_key_provider_credentials(provider)
|
||||
@@ -1395,12 +1388,12 @@ def resolve_provider_client(
|
||||
)
|
||||
|
||||
default_model = _API_KEY_PROVIDER_AUX_MODELS.get(provider, "")
|
||||
final_model = model or default_model
|
||||
final_model = _normalize_resolved_model(model or default_model, provider)
|
||||
|
||||
# Provider-specific headers
|
||||
headers = {}
|
||||
if "api.kimi.com" in base_url.lower():
|
||||
headers["User-Agent"] = "KimiCLI/1.3"
|
||||
headers["User-Agent"] = "KimiCLI/1.30.0"
|
||||
elif "api.githubcopilot.com" in base_url.lower():
|
||||
from hermes_cli.models import copilot_default_headers
|
||||
|
||||
@@ -1495,22 +1488,6 @@ def _strict_vision_backend_available(provider: str) -> bool:
|
||||
return _resolve_strict_vision_backend(provider)[0] is not None
|
||||
|
||||
|
||||
def _preferred_main_vision_provider() -> Optional[str]:
|
||||
"""Return the selected main provider when it is also a supported vision backend."""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
|
||||
config = load_config()
|
||||
model_cfg = config.get("model", {})
|
||||
if isinstance(model_cfg, dict):
|
||||
provider = _normalize_vision_provider(model_cfg.get("provider", ""))
|
||||
if provider in _VISION_AUTO_PROVIDER_ORDER:
|
||||
return provider
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def get_available_vision_backends() -> List[str]:
|
||||
"""Return the currently available vision backends in auto-selection order.
|
||||
|
||||
@@ -1624,18 +1601,6 @@ def resolve_vision_provider_client(
|
||||
return requested, client, final_model
|
||||
|
||||
|
||||
def get_vision_auxiliary_client() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
"""Return (client, default_model_slug) for vision/multimodal auxiliary tasks."""
|
||||
_, client, final_model = resolve_vision_provider_client(async_mode=False)
|
||||
return client, final_model
|
||||
|
||||
|
||||
def get_async_vision_auxiliary_client():
|
||||
"""Return (async_client, model_slug) for async vision consumers."""
|
||||
_, client, final_model = resolve_vision_provider_client(async_mode=True)
|
||||
return client, final_model
|
||||
|
||||
|
||||
def get_auxiliary_extra_body() -> dict:
|
||||
"""Return extra_body kwargs for auxiliary API calls.
|
||||
|
||||
|
||||
@@ -1,114 +0,0 @@
|
||||
"""BuiltinMemoryProvider — wraps MEMORY.md / USER.md as a MemoryProvider.
|
||||
|
||||
Always registered as the first provider. Cannot be disabled or removed.
|
||||
This is the existing Hermes memory system exposed through the provider
|
||||
interface for compatibility with the MemoryManager.
|
||||
|
||||
The actual storage logic lives in tools/memory_tool.py (MemoryStore).
|
||||
This provider is a thin adapter that delegates to MemoryStore and
|
||||
exposes the memory tool schema.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from agent.memory_provider import MemoryProvider
|
||||
from tools.registry import tool_error
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BuiltinMemoryProvider(MemoryProvider):
|
||||
"""Built-in file-backed memory (MEMORY.md + USER.md).
|
||||
|
||||
Always active, never disabled by other providers. The `memory` tool
|
||||
is handled by run_agent.py's agent-level tool interception (not through
|
||||
the normal registry), so get_tool_schemas() returns an empty list —
|
||||
the memory tool is already wired separately.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
memory_store=None,
|
||||
memory_enabled: bool = False,
|
||||
user_profile_enabled: bool = False,
|
||||
):
|
||||
self._store = memory_store
|
||||
self._memory_enabled = memory_enabled
|
||||
self._user_profile_enabled = user_profile_enabled
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "builtin"
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Built-in memory is always available."""
|
||||
return True
|
||||
|
||||
def initialize(self, session_id: str, **kwargs) -> None:
|
||||
"""Load memory from disk if not already loaded."""
|
||||
if self._store is not None:
|
||||
self._store.load_from_disk()
|
||||
|
||||
def system_prompt_block(self) -> str:
|
||||
"""Return MEMORY.md and USER.md content for the system prompt.
|
||||
|
||||
Uses the frozen snapshot captured at load time. This ensures the
|
||||
system prompt stays stable throughout a session (preserving the
|
||||
prompt cache), even though the live entries may change via tool calls.
|
||||
"""
|
||||
if not self._store:
|
||||
return ""
|
||||
|
||||
parts = []
|
||||
if self._memory_enabled:
|
||||
mem_block = self._store.format_for_system_prompt("memory")
|
||||
if mem_block:
|
||||
parts.append(mem_block)
|
||||
if self._user_profile_enabled:
|
||||
user_block = self._store.format_for_system_prompt("user")
|
||||
if user_block:
|
||||
parts.append(user_block)
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def prefetch(self, query: str, *, session_id: str = "") -> str:
|
||||
"""Built-in memory doesn't do query-based recall — it's injected via system_prompt_block."""
|
||||
return ""
|
||||
|
||||
def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None:
|
||||
"""Built-in memory doesn't auto-sync turns — writes happen via the memory tool."""
|
||||
|
||||
def get_tool_schemas(self) -> List[Dict[str, Any]]:
|
||||
"""Return empty list.
|
||||
|
||||
The `memory` tool is an agent-level intercepted tool, handled
|
||||
specially in run_agent.py before normal tool dispatch. It's not
|
||||
part of the standard tool registry. We don't duplicate it here.
|
||||
"""
|
||||
return []
|
||||
|
||||
def handle_tool_call(self, tool_name: str, args: Dict[str, Any], **kwargs) -> str:
|
||||
"""Not used — the memory tool is intercepted in run_agent.py."""
|
||||
return tool_error("Built-in memory tool is handled by the agent loop")
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""No cleanup needed — files are saved on every write."""
|
||||
|
||||
# -- Property access for backward compatibility --------------------------
|
||||
|
||||
@property
|
||||
def store(self):
|
||||
"""Access the underlying MemoryStore for legacy code paths."""
|
||||
return self._store
|
||||
|
||||
@property
|
||||
def memory_enabled(self) -> bool:
|
||||
return self._memory_enabled
|
||||
|
||||
@property
|
||||
def user_profile_enabled(self) -> bool:
|
||||
return self._user_profile_enabled
|
||||
+36
-43
@@ -114,7 +114,6 @@ class ContextCompressor:
|
||||
|
||||
self.last_prompt_tokens = 0
|
||||
self.last_completion_tokens = 0
|
||||
self.last_total_tokens = 0
|
||||
|
||||
self.summary_model = summary_model_override or ""
|
||||
|
||||
@@ -126,28 +125,12 @@ class ContextCompressor:
|
||||
"""Update tracked token usage from API response."""
|
||||
self.last_prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
self.last_completion_tokens = usage.get("completion_tokens", 0)
|
||||
self.last_total_tokens = usage.get("total_tokens", 0)
|
||||
|
||||
def should_compress(self, prompt_tokens: int = None) -> bool:
|
||||
"""Check if context exceeds the compression threshold."""
|
||||
tokens = prompt_tokens if prompt_tokens is not None else self.last_prompt_tokens
|
||||
return tokens >= self.threshold_tokens
|
||||
|
||||
def should_compress_preflight(self, messages: List[Dict[str, Any]]) -> bool:
|
||||
"""Quick pre-flight check using rough estimate (before API call)."""
|
||||
rough_estimate = estimate_messages_tokens_rough(messages)
|
||||
return rough_estimate >= self.threshold_tokens
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""Get current compression status for display/logging."""
|
||||
return {
|
||||
"last_prompt_tokens": self.last_prompt_tokens,
|
||||
"threshold_tokens": self.threshold_tokens,
|
||||
"context_length": self.context_length,
|
||||
"usage_percent": min(100, (self.last_prompt_tokens / self.context_length * 100)) if self.context_length else 0,
|
||||
"compression_count": self.compression_count,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool output pruning (cheap pre-pass, no LLM call)
|
||||
# ------------------------------------------------------------------
|
||||
@@ -691,33 +674,43 @@ Write only the summary body. Do not include any preamble or prefix."""
|
||||
)
|
||||
compressed.append(msg)
|
||||
|
||||
_merge_summary_into_tail = False
|
||||
if summary:
|
||||
last_head_role = messages[compress_start - 1].get("role", "user") if compress_start > 0 else "user"
|
||||
first_tail_role = messages[compress_end].get("role", "user") if compress_end < n_messages else "user"
|
||||
# Pick a role that avoids consecutive same-role with both neighbors.
|
||||
# Priority: avoid colliding with head (already committed), then tail.
|
||||
if last_head_role in ("assistant", "tool"):
|
||||
summary_role = "user"
|
||||
else:
|
||||
summary_role = "assistant"
|
||||
# If the chosen role collides with the tail AND flipping wouldn't
|
||||
# collide with the head, flip it.
|
||||
if summary_role == first_tail_role:
|
||||
flipped = "assistant" if summary_role == "user" else "user"
|
||||
if flipped != last_head_role:
|
||||
summary_role = flipped
|
||||
else:
|
||||
# Both roles would create consecutive same-role messages
|
||||
# (e.g. head=assistant, tail=user — neither role works).
|
||||
# Merge the summary into the first tail message instead
|
||||
# of inserting a standalone message that breaks alternation.
|
||||
_merge_summary_into_tail = True
|
||||
if not _merge_summary_into_tail:
|
||||
compressed.append({"role": summary_role, "content": summary})
|
||||
else:
|
||||
# If LLM summary failed, insert a static fallback so the model
|
||||
# knows context was lost rather than silently dropping everything.
|
||||
if not summary:
|
||||
if not self.quiet_mode:
|
||||
logger.debug("No summary model available — middle turns dropped without summary")
|
||||
logger.warning("Summary generation failed — inserting static fallback context marker")
|
||||
n_dropped = compress_end - compress_start
|
||||
summary = (
|
||||
f"{SUMMARY_PREFIX}\n"
|
||||
f"Summary generation was unavailable. {n_dropped} conversation turns were "
|
||||
f"removed to free context space but could not be summarized. The removed "
|
||||
f"turns contained earlier work in this session. Continue based on the "
|
||||
f"recent messages below and the current state of any files or resources."
|
||||
)
|
||||
|
||||
_merge_summary_into_tail = False
|
||||
last_head_role = messages[compress_start - 1].get("role", "user") if compress_start > 0 else "user"
|
||||
first_tail_role = messages[compress_end].get("role", "user") if compress_end < n_messages else "user"
|
||||
# Pick a role that avoids consecutive same-role with both neighbors.
|
||||
# Priority: avoid colliding with head (already committed), then tail.
|
||||
if last_head_role in ("assistant", "tool"):
|
||||
summary_role = "user"
|
||||
else:
|
||||
summary_role = "assistant"
|
||||
# If the chosen role collides with the tail AND flipping wouldn't
|
||||
# collide with the head, flip it.
|
||||
if summary_role == first_tail_role:
|
||||
flipped = "assistant" if summary_role == "user" else "user"
|
||||
if flipped != last_head_role:
|
||||
summary_role = flipped
|
||||
else:
|
||||
# Both roles would create consecutive same-role messages
|
||||
# (e.g. head=assistant, tail=user — neither role works).
|
||||
# Merge the summary into the first tail message instead
|
||||
# of inserting a standalone message that breaks alternation.
|
||||
_merge_summary_into_tail = True
|
||||
if not _merge_summary_into_tail:
|
||||
compressed.append({"role": summary_role, "content": summary})
|
||||
|
||||
for i in range(compress_end, n_messages):
|
||||
msg = messages[i].copy()
|
||||
|
||||
@@ -13,8 +13,9 @@ from typing import Awaitable, Callable
|
||||
|
||||
from agent.model_metadata import estimate_tokens_rough
|
||||
|
||||
_QUOTED_REFERENCE_VALUE = r'(?:`[^`\n]+`|"[^"\n]+"|\'[^\'\n]+\')'
|
||||
REFERENCE_PATTERN = re.compile(
|
||||
r"(?<![\w/])@(?:(?P<simple>diff|staged)\b|(?P<kind>file|folder|git|url):(?P<value>\S+))"
|
||||
rf"(?<![\w/])@(?:(?P<simple>diff|staged)\b|(?P<kind>file|folder|git|url):(?P<value>{_QUOTED_REFERENCE_VALUE}(?::\d+(?:-\d+)?)?|\S+))"
|
||||
)
|
||||
TRAILING_PUNCTUATION = ",.;!?"
|
||||
_SENSITIVE_HOME_DIRS = (".ssh", ".aws", ".gnupg", ".kube", ".docker", ".azure", ".config/gh")
|
||||
@@ -81,14 +82,10 @@ def parse_context_references(message: str) -> list[ContextReference]:
|
||||
value = _strip_trailing_punctuation(match.group("value") or "")
|
||||
line_start = None
|
||||
line_end = None
|
||||
target = value
|
||||
target = _strip_reference_wrappers(value)
|
||||
|
||||
if kind == "file":
|
||||
range_match = re.match(r"^(?P<path>.+?):(?P<start>\d+)(?:-(?P<end>\d+))?$", value)
|
||||
if range_match:
|
||||
target = range_match.group("path")
|
||||
line_start = int(range_match.group("start"))
|
||||
line_end = int(range_match.group("end") or range_match.group("start"))
|
||||
target, line_start, line_end = _parse_file_reference_value(value)
|
||||
|
||||
refs.append(
|
||||
ContextReference(
|
||||
@@ -375,6 +372,38 @@ def _strip_trailing_punctuation(value: str) -> str:
|
||||
return stripped
|
||||
|
||||
|
||||
def _strip_reference_wrappers(value: str) -> str:
|
||||
if len(value) >= 2 and value[0] == value[-1] and value[0] in "`\"'":
|
||||
return value[1:-1]
|
||||
return value
|
||||
|
||||
|
||||
def _parse_file_reference_value(value: str) -> tuple[str, int | None, int | None]:
|
||||
quoted_match = re.match(
|
||||
r'^(?P<quote>`|"|\')(?P<path>.+?)(?P=quote)(?::(?P<start>\d+)(?:-(?P<end>\d+))?)?$',
|
||||
value,
|
||||
)
|
||||
if quoted_match:
|
||||
line_start = quoted_match.group("start")
|
||||
line_end = quoted_match.group("end")
|
||||
return (
|
||||
quoted_match.group("path"),
|
||||
int(line_start) if line_start is not None else None,
|
||||
int(line_end or line_start) if line_start is not None else None,
|
||||
)
|
||||
|
||||
range_match = re.match(r"^(?P<path>.+?):(?P<start>\d+)(?:-(?P<end>\d+))?$", value)
|
||||
if range_match:
|
||||
line_start = int(range_match.group("start"))
|
||||
return (
|
||||
range_match.group("path"),
|
||||
line_start,
|
||||
int(range_match.group("end") or range_match.group("start")),
|
||||
)
|
||||
|
||||
return _strip_reference_wrappers(value), None, None
|
||||
|
||||
|
||||
def _remove_reference_tokens(message: str, refs: list[ContextReference]) -> str:
|
||||
pieces: list[str] = []
|
||||
cursor = 0
|
||||
|
||||
+124
-16
@@ -20,6 +20,7 @@ from hermes_cli.auth import (
|
||||
DEFAULT_AGENT_KEY_MIN_TTL_SECONDS,
|
||||
KIMI_CODE_BASE_URL,
|
||||
PROVIDER_REGISTRY,
|
||||
_auth_store_lock,
|
||||
_codex_access_token_is_expiring,
|
||||
_decode_jwt_claims,
|
||||
_import_codex_cli_tokens,
|
||||
@@ -27,6 +28,8 @@ from hermes_cli.auth import (
|
||||
_load_provider_state,
|
||||
_resolve_kimi_base_url,
|
||||
_resolve_zai_base_url,
|
||||
_save_auth_store,
|
||||
_save_provider_state,
|
||||
read_credential_pool,
|
||||
write_credential_pool,
|
||||
)
|
||||
@@ -479,6 +482,67 @@ class CredentialPool:
|
||||
logger.debug("Failed to sync from ~/.codex/auth.json: %s", exc)
|
||||
return entry
|
||||
|
||||
def _sync_device_code_entry_to_auth_store(self, entry: PooledCredential) -> None:
|
||||
"""Write refreshed pool entry tokens back to auth.json providers.
|
||||
|
||||
After a pool-level refresh, the pool entry has fresh tokens but
|
||||
auth.json's ``providers.<id>`` still holds the pre-refresh state.
|
||||
On the next ``load_pool()``, ``_seed_from_singletons()`` reads that
|
||||
stale state and can overwrite the fresh pool entry — potentially
|
||||
re-seeding a consumed single-use refresh token.
|
||||
|
||||
Applies to any OAuth provider whose singleton lives in auth.json
|
||||
(currently Nous and OpenAI Codex).
|
||||
"""
|
||||
if entry.source != "device_code":
|
||||
return
|
||||
try:
|
||||
with _auth_store_lock():
|
||||
auth_store = _load_auth_store()
|
||||
if self.provider == "nous":
|
||||
state = _load_provider_state(auth_store, "nous")
|
||||
if state is None:
|
||||
return
|
||||
state["access_token"] = entry.access_token
|
||||
if entry.refresh_token:
|
||||
state["refresh_token"] = entry.refresh_token
|
||||
if entry.expires_at:
|
||||
state["expires_at"] = entry.expires_at
|
||||
if entry.agent_key:
|
||||
state["agent_key"] = entry.agent_key
|
||||
if entry.agent_key_expires_at:
|
||||
state["agent_key_expires_at"] = entry.agent_key_expires_at
|
||||
for extra_key in ("obtained_at", "expires_in", "agent_key_id",
|
||||
"agent_key_expires_in", "agent_key_reused",
|
||||
"agent_key_obtained_at"):
|
||||
val = entry.extra.get(extra_key)
|
||||
if val is not None:
|
||||
state[extra_key] = val
|
||||
if entry.inference_base_url:
|
||||
state["inference_base_url"] = entry.inference_base_url
|
||||
_save_provider_state(auth_store, "nous", state)
|
||||
|
||||
elif self.provider == "openai-codex":
|
||||
state = _load_provider_state(auth_store, "openai-codex")
|
||||
if not isinstance(state, dict):
|
||||
return
|
||||
tokens = state.get("tokens")
|
||||
if not isinstance(tokens, dict):
|
||||
return
|
||||
tokens["access_token"] = entry.access_token
|
||||
if entry.refresh_token:
|
||||
tokens["refresh_token"] = entry.refresh_token
|
||||
if entry.last_refresh:
|
||||
state["last_refresh"] = entry.last_refresh
|
||||
_save_provider_state(auth_store, "openai-codex", state)
|
||||
|
||||
else:
|
||||
return
|
||||
|
||||
_save_auth_store(auth_store)
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to sync %s pool entry back to auth store: %s", self.provider, exc)
|
||||
|
||||
def _refresh_entry(self, entry: PooledCredential, *, force: bool) -> Optional[PooledCredential]:
|
||||
if entry.auth_type != AUTH_TYPE_OAUTH or not entry.refresh_token:
|
||||
if force:
|
||||
@@ -513,6 +577,13 @@ class CredentialPool:
|
||||
except Exception as wexc:
|
||||
logger.debug("Failed to write refreshed token to credentials file: %s", wexc)
|
||||
elif self.provider == "openai-codex":
|
||||
# Proactively sync from ~/.codex/auth.json before refresh.
|
||||
# The Codex CLI (or another Hermes profile) may have already
|
||||
# consumed our refresh_token. Syncing first avoids a
|
||||
# "refresh_token_reused" error when the CLI has a newer pair.
|
||||
synced = self._sync_codex_entry_from_cli(entry)
|
||||
if synced is not entry:
|
||||
entry = synced
|
||||
refreshed = auth_mod.refresh_codex_oauth_pure(
|
||||
entry.access_token,
|
||||
entry.refresh_token,
|
||||
@@ -598,6 +669,37 @@ class CredentialPool:
|
||||
# Credentials file had a valid (non-expired) token — use it directly
|
||||
logger.debug("Credentials file has valid token, using without refresh")
|
||||
return synced
|
||||
# For openai-codex: the refresh_token may have been consumed by
|
||||
# the Codex CLI between our proactive sync and the refresh call.
|
||||
# Re-sync and retry once.
|
||||
if self.provider == "openai-codex":
|
||||
synced = self._sync_codex_entry_from_cli(entry)
|
||||
if synced.refresh_token != entry.refresh_token:
|
||||
logger.debug("Retrying Codex refresh with synced token from ~/.codex/auth.json")
|
||||
try:
|
||||
refreshed = auth_mod.refresh_codex_oauth_pure(
|
||||
synced.access_token,
|
||||
synced.refresh_token,
|
||||
)
|
||||
updated = replace(
|
||||
synced,
|
||||
access_token=refreshed["access_token"],
|
||||
refresh_token=refreshed["refresh_token"],
|
||||
last_refresh=refreshed.get("last_refresh"),
|
||||
last_status=STATUS_OK,
|
||||
last_status_at=None,
|
||||
last_error_code=None,
|
||||
)
|
||||
self._replace_entry(synced, updated)
|
||||
self._persist()
|
||||
self._sync_device_code_entry_to_auth_store(updated)
|
||||
return updated
|
||||
except Exception as retry_exc:
|
||||
logger.debug("Codex retry refresh also failed: %s", retry_exc)
|
||||
elif not self._entry_needs_refresh(synced):
|
||||
logger.debug("Codex CLI has valid token, using without refresh")
|
||||
self._sync_device_code_entry_to_auth_store(synced)
|
||||
return synced
|
||||
self._mark_exhausted(entry, None)
|
||||
return None
|
||||
|
||||
@@ -612,6 +714,10 @@ class CredentialPool:
|
||||
)
|
||||
self._replace_entry(entry, updated)
|
||||
self._persist()
|
||||
# Sync refreshed tokens back to auth.json providers so that
|
||||
# _seed_from_singletons() on the next load_pool() sees fresh state
|
||||
# instead of re-seeding stale/consumed tokens.
|
||||
self._sync_device_code_entry_to_auth_store(updated)
|
||||
return updated
|
||||
|
||||
def _entry_needs_refresh(self, entry: PooledCredential) -> bool:
|
||||
@@ -633,17 +739,6 @@ class CredentialPool:
|
||||
return False
|
||||
return False
|
||||
|
||||
def mark_used(self, entry_id: Optional[str] = None) -> None:
|
||||
"""Increment request_count for tracking. Used by least_used strategy."""
|
||||
target_id = entry_id or self._current_id
|
||||
if not target_id:
|
||||
return
|
||||
with self._lock:
|
||||
for idx, entry in enumerate(self._entries):
|
||||
if entry.id == target_id:
|
||||
self._entries[idx] = replace(entry, request_count=entry.request_count + 1)
|
||||
return
|
||||
|
||||
def select(self) -> Optional[PooledCredential]:
|
||||
with self._lock:
|
||||
return self._select_unlocked()
|
||||
@@ -805,11 +900,6 @@ class CredentialPool:
|
||||
else:
|
||||
self._active_leases[credential_id] = count - 1
|
||||
|
||||
def active_lease_count(self, credential_id: str) -> int:
|
||||
"""Return the number of active leases for a credential."""
|
||||
with self._lock:
|
||||
return self._active_leases.get(credential_id, 0)
|
||||
|
||||
def try_refresh_current(self) -> Optional[PooledCredential]:
|
||||
with self._lock:
|
||||
return self._try_refresh_current_unlocked()
|
||||
@@ -969,6 +1059,17 @@ def _seed_from_singletons(provider: str, entries: List[PooledCredential]) -> Tup
|
||||
auth_store = _load_auth_store()
|
||||
|
||||
if provider == "anthropic":
|
||||
# Only auto-discover external credentials (Claude Code, Hermes PKCE)
|
||||
# when the user has explicitly configured anthropic as their provider.
|
||||
# Without this gate, auxiliary client fallback chains silently read
|
||||
# ~/.claude/.credentials.json without user consent. See PR #4210.
|
||||
try:
|
||||
from hermes_cli.auth import is_provider_explicitly_configured
|
||||
if not is_provider_explicitly_configured("anthropic"):
|
||||
return changed, active_sources
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from agent.anthropic_adapter import read_claude_code_credentials, read_hermes_oauth_credentials
|
||||
|
||||
for source_name, creds in (
|
||||
@@ -976,6 +1077,13 @@ def _seed_from_singletons(provider: str, entries: List[PooledCredential]) -> Tup
|
||||
("claude_code", read_claude_code_credentials()),
|
||||
):
|
||||
if creds and creds.get("accessToken"):
|
||||
# Check if user explicitly removed this source
|
||||
try:
|
||||
from hermes_cli.auth import is_source_suppressed
|
||||
if is_source_suppressed(provider, source_name):
|
||||
continue
|
||||
except ImportError:
|
||||
pass
|
||||
active_sources.add(source_name)
|
||||
changed |= _upsert_entry(
|
||||
entries,
|
||||
|
||||
@@ -67,26 +67,6 @@ def _get_skin():
|
||||
return None
|
||||
|
||||
|
||||
def get_skin_faces(key: str, default: list) -> list:
|
||||
"""Get spinner face list from active skin, falling back to default."""
|
||||
skin = _get_skin()
|
||||
if skin:
|
||||
faces = skin.get_spinner_list(key)
|
||||
if faces:
|
||||
return faces
|
||||
return default
|
||||
|
||||
|
||||
def get_skin_verbs() -> list:
|
||||
"""Get thinking verbs from active skin."""
|
||||
skin = _get_skin()
|
||||
if skin:
|
||||
verbs = skin.get_spinner_list("thinking_verbs")
|
||||
if verbs:
|
||||
return verbs
|
||||
return KawaiiSpinner.THINKING_VERBS
|
||||
|
||||
|
||||
def get_skin_tool_prefix() -> str:
|
||||
"""Get tool output prefix character from active skin."""
|
||||
skin = _get_skin()
|
||||
@@ -723,46 +703,6 @@ class KawaiiSpinner:
|
||||
return False
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Kawaii face arrays (used by AIAgent._execute_tool_calls for spinner text)
|
||||
# =========================================================================
|
||||
|
||||
KAWAII_SEARCH = [
|
||||
"♪(´ε` )", "(。◕‿◕。)", "ヾ(^∇^)", "(◕ᴗ◕✿)", "( ˘▽˘)っ",
|
||||
"٩(◕‿◕。)۶", "(✿◠‿◠)", "♪~(´ε` )", "(ノ´ヮ`)ノ*:・゚✧", "\(◎o◎)/",
|
||||
]
|
||||
KAWAII_READ = [
|
||||
"φ(゜▽゜*)♪", "( ˘▽˘)っ", "(⌐■_■)", "٩(。•́‿•̀。)۶", "(◕‿◕✿)",
|
||||
"ヾ(@⌒ー⌒@)ノ", "(✧ω✧)", "♪(๑ᴖ◡ᴖ๑)♪", "(≧◡≦)", "( ´ ▽ ` )ノ",
|
||||
]
|
||||
KAWAII_TERMINAL = [
|
||||
"ヽ(>∀<☆)ノ", "(ノ°∀°)ノ", "٩(^ᴗ^)۶", "ヾ(⌐■_■)ノ♪", "(•̀ᴗ•́)و",
|
||||
"┗(^0^)┓", "(`・ω・´)", "\( ̄▽ ̄)/", "(ง •̀_•́)ง", "ヽ(´▽`)/",
|
||||
]
|
||||
KAWAII_BROWSER = [
|
||||
"(ノ°∀°)ノ", "(☞゚ヮ゚)☞", "( ͡° ͜ʖ ͡°)", "┌( ಠ_ಠ)┘", "(⊙_⊙)?",
|
||||
"ヾ(•ω•`)o", "( ̄ω ̄)", "( ˇωˇ )", "(ᵔᴥᵔ)", "\(◎o◎)/",
|
||||
]
|
||||
KAWAII_CREATE = [
|
||||
"✧*。٩(ˊᗜˋ*)و✧", "(ノ◕ヮ◕)ノ*:・゚✧", "ヽ(>∀<☆)ノ", "٩(♡ε♡)۶", "(◕‿◕)♡",
|
||||
"✿◕ ‿ ◕✿", "(*≧▽≦)", "ヾ(^-^)ノ", "(☆▽☆)", "°˖✧◝(⁰▿⁰)◜✧˖°",
|
||||
]
|
||||
KAWAII_SKILL = [
|
||||
"ヾ(@⌒ー⌒@)ノ", "(๑˃ᴗ˂)ﻭ", "٩(◕‿◕。)۶", "(✿╹◡╹)", "ヽ(・∀・)ノ",
|
||||
"(ノ´ヮ`)ノ*:・゚✧", "♪(๑ᴖ◡ᴖ๑)♪", "(◠‿◠)", "٩(ˊᗜˋ*)و", "(^▽^)",
|
||||
"ヾ(^∇^)", "(★ω★)/", "٩(。•́‿•̀。)۶", "(◕ᴗ◕✿)", "\(◎o◎)/",
|
||||
"(✧ω✧)", "ヽ(>∀<☆)ノ", "( ˘▽˘)っ", "(≧◡≦) ♡", "ヾ( ̄▽ ̄)",
|
||||
]
|
||||
KAWAII_THINK = [
|
||||
"(っ°Д°;)っ", "(;′⌒`)", "(・_・ヾ", "( ´_ゝ`)", "( ̄ヘ ̄)",
|
||||
"(。-`ω´-)", "( ˘︹˘ )", "(¬_¬)", "ヽ(ー_ー )ノ", "(;一_一)",
|
||||
]
|
||||
KAWAII_GENERIC = [
|
||||
"♪(´ε` )", "(◕‿◕✿)", "ヾ(^∇^)", "٩(◕‿◕。)۶", "(✿◠‿◠)",
|
||||
"(ノ´ヮ`)ノ*:・゚✧", "ヽ(>∀<☆)ノ", "(☆▽☆)", "( ˘▽˘)っ", "(≧◡≦)",
|
||||
]
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Cute tool message (completion line that replaces the spinner)
|
||||
# =========================================================================
|
||||
@@ -970,22 +910,6 @@ _SKY_BLUE = "\033[38;5;117m"
|
||||
_ANSI_RESET = "\033[0m"
|
||||
|
||||
|
||||
def honcho_session_url(workspace: str, session_name: str) -> str:
|
||||
"""Build a Honcho app URL for a session."""
|
||||
from urllib.parse import quote
|
||||
return (
|
||||
f"https://app.honcho.dev/explore"
|
||||
f"?workspace={quote(workspace, safe='')}"
|
||||
f"&view=sessions"
|
||||
f"&session={quote(session_name, safe='')}"
|
||||
)
|
||||
|
||||
|
||||
def _osc8_link(url: str, text: str) -> str:
|
||||
"""OSC 8 terminal hyperlink (clickable in iTerm2, Ghostty, WezTerm, etc.)."""
|
||||
return f"\033]8;;{url}\033\\{text}\033]8;;\033\\"
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Context pressure display (CLI user-facing warnings)
|
||||
# =========================================================================
|
||||
|
||||
+28
-11
@@ -82,16 +82,6 @@ class ClassifiedError:
|
||||
def is_auth(self) -> bool:
|
||||
return self.reason in (FailoverReason.auth, FailoverReason.auth_permanent)
|
||||
|
||||
@property
|
||||
def is_transient(self) -> bool:
|
||||
"""Error is expected to resolve on retry (with or without backoff)."""
|
||||
return self.reason in (
|
||||
FailoverReason.rate_limit,
|
||||
FailoverReason.overloaded,
|
||||
FailoverReason.server_error,
|
||||
FailoverReason.timeout,
|
||||
FailoverReason.unknown,
|
||||
)
|
||||
|
||||
|
||||
# ── Provider-specific patterns ──────────────────────────────────────────
|
||||
@@ -122,6 +112,7 @@ _RATE_LIMIT_PATTERNS = [
|
||||
"try again in",
|
||||
"please retry after",
|
||||
"resource_exhausted",
|
||||
"rate increased too quickly", # Alibaba/DashScope throttling
|
||||
]
|
||||
|
||||
# Usage-limit patterns that need disambiguation (could be billing OR rate_limit)
|
||||
@@ -677,6 +668,27 @@ def _classify_by_message(
|
||||
should_compress=True,
|
||||
)
|
||||
|
||||
# Usage-limit patterns need the same disambiguation as 402: some providers
|
||||
# surface "usage limit" errors without an HTTP status code. A transient
|
||||
# signal ("try again", "resets at", …) means it's a periodic quota, not
|
||||
# billing exhaustion.
|
||||
has_usage_limit = any(p in error_msg for p in _USAGE_LIMIT_PATTERNS)
|
||||
if has_usage_limit:
|
||||
has_transient_signal = any(p in error_msg for p in _USAGE_LIMIT_TRANSIENT_SIGNALS)
|
||||
if has_transient_signal:
|
||||
return result_fn(
|
||||
FailoverReason.rate_limit,
|
||||
retryable=True,
|
||||
should_rotate_credential=True,
|
||||
should_fallback=True,
|
||||
)
|
||||
return result_fn(
|
||||
FailoverReason.billing,
|
||||
retryable=False,
|
||||
should_rotate_credential=True,
|
||||
should_fallback=True,
|
||||
)
|
||||
|
||||
# Billing patterns
|
||||
if any(p in error_msg for p in _BILLING_PATTERNS):
|
||||
return result_fn(
|
||||
@@ -704,11 +716,16 @@ def _classify_by_message(
|
||||
)
|
||||
|
||||
# Auth patterns
|
||||
# Auth errors should NOT be retried directly — the credential is invalid and
|
||||
# retrying with the same key will always fail. Set retryable=False so the
|
||||
# caller triggers credential rotation (should_rotate_credential=True) or
|
||||
# provider fallback rather than an immediate retry loop.
|
||||
if any(p in error_msg for p in _AUTH_PATTERNS):
|
||||
return result_fn(
|
||||
FailoverReason.auth,
|
||||
retryable=True,
|
||||
retryable=False,
|
||||
should_rotate_credential=True,
|
||||
should_fallback=True,
|
||||
)
|
||||
|
||||
# Model not found patterns
|
||||
|
||||
@@ -39,15 +39,6 @@ def _has_known_pricing(model_name: str, provider: str = None, base_url: str = No
|
||||
return has_known_pricing(model_name, provider=provider, base_url=base_url)
|
||||
|
||||
|
||||
def _get_pricing(model_name: str) -> Dict[str, float]:
|
||||
"""Look up pricing for a model. Uses fuzzy matching on model name.
|
||||
|
||||
Returns _DEFAULT_PRICING (zero cost) for unknown/custom models —
|
||||
we can't assume costs for self-hosted endpoints, local inference, etc.
|
||||
"""
|
||||
return get_pricing(model_name)
|
||||
|
||||
|
||||
def _estimate_cost(
|
||||
session_or_model: Dict[str, Any] | str,
|
||||
input_tokens: int = 0,
|
||||
|
||||
@@ -134,11 +134,6 @@ class MemoryManager:
|
||||
"""All registered providers in order."""
|
||||
return list(self._providers)
|
||||
|
||||
@property
|
||||
def provider_names(self) -> List[str]:
|
||||
"""Names of all registered providers."""
|
||||
return [p.name for p in self._providers]
|
||||
|
||||
def get_provider(self, name: str) -> Optional[MemoryProvider]:
|
||||
"""Get a provider by name, or None if not registered."""
|
||||
for p in self._providers:
|
||||
|
||||
@@ -126,6 +126,21 @@ DEFAULT_CONTEXT_LENGTHS = {
|
||||
"minimax": 1048576,
|
||||
# GLM
|
||||
"glm": 202752,
|
||||
# xAI Grok — xAI /v1/models does not return context_length metadata,
|
||||
# so these hardcoded fallbacks prevent Hermes from probing-down to
|
||||
# the default 128k when the user points at https://api.x.ai/v1
|
||||
# via a custom provider. Values sourced from models.dev (2026-04).
|
||||
# Keys use substring matching (longest-first), so e.g. "grok-4.20"
|
||||
# matches "grok-4.20-0309-reasoning" / "-non-reasoning" / "-multi-agent-0309".
|
||||
"grok-code-fast": 256000, # grok-code-fast-1
|
||||
"grok-4-1-fast": 2000000, # grok-4-1-fast-(non-)reasoning
|
||||
"grok-2-vision": 8192, # grok-2-vision, -1212, -latest
|
||||
"grok-4-fast": 2000000, # grok-4-fast-(non-)reasoning
|
||||
"grok-4.20": 2000000, # grok-4.20-0309-(non-)reasoning, -multi-agent-0309
|
||||
"grok-4": 256000, # grok-4, grok-4-0709
|
||||
"grok-3": 131072, # grok-3, grok-3-mini, grok-3-fast, grok-3-mini-fast
|
||||
"grok-2": 131072, # grok-2, grok-2-1212, grok-2-latest
|
||||
"grok": 131072, # catch-all (grok-beta, unknown grok-*)
|
||||
# Kimi
|
||||
"kimi": 262144,
|
||||
# Arcee
|
||||
@@ -198,6 +213,7 @@ _URL_TO_PROVIDER: Dict[str, str] = {
|
||||
"models.github.ai": "copilot",
|
||||
"api.fireworks.ai": "fireworks",
|
||||
"opencode.ai": "opencode-go",
|
||||
"api.x.ai": "xai",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -135,9 +135,6 @@ class ProviderInfo:
|
||||
doc: str = "" # documentation URL
|
||||
model_count: int = 0
|
||||
|
||||
def has_api_url(self) -> bool:
|
||||
return bool(self.api)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider ID mapping: Hermes ↔ models.dev
|
||||
@@ -634,43 +631,6 @@ def get_provider_info(provider_id: str) -> Optional[ProviderInfo]:
|
||||
return _parse_provider_info(mdev_id, raw)
|
||||
|
||||
|
||||
def list_all_providers() -> Dict[str, ProviderInfo]:
|
||||
"""Return all providers from models.dev as {provider_id: ProviderInfo}.
|
||||
|
||||
Returns the full catalog — 109+ providers. For providers that have
|
||||
a Hermes alias, both the models.dev ID and the Hermes ID are included.
|
||||
"""
|
||||
data = fetch_models_dev()
|
||||
result: Dict[str, ProviderInfo] = {}
|
||||
|
||||
for pid, pdata in data.items():
|
||||
if isinstance(pdata, dict):
|
||||
info = _parse_provider_info(pid, pdata)
|
||||
result[pid] = info
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_providers_for_env_var(env_var: str) -> List[str]:
|
||||
"""Reverse lookup: find all providers that use a given env var.
|
||||
|
||||
Useful for auto-detection: "user has ANTHROPIC_API_KEY set, which
|
||||
providers does that enable?"
|
||||
|
||||
Returns list of models.dev provider IDs.
|
||||
"""
|
||||
data = fetch_models_dev()
|
||||
matches: List[str] = []
|
||||
|
||||
for pid, pdata in data.items():
|
||||
if isinstance(pdata, dict):
|
||||
env = pdata.get("env", [])
|
||||
if isinstance(env, list) and env_var in env:
|
||||
matches.append(pid)
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model-level queries (rich ModelInfo)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -708,74 +668,3 @@ def get_model_info(
|
||||
return None
|
||||
|
||||
|
||||
def get_model_info_any_provider(model_id: str) -> Optional[ModelInfo]:
|
||||
"""Search all providers for a model by ID.
|
||||
|
||||
Useful when you have a full slug like "anthropic/claude-sonnet-4.6" or
|
||||
a bare name and want to find it anywhere. Checks Hermes-mapped providers
|
||||
first, then falls back to all models.dev providers.
|
||||
"""
|
||||
data = fetch_models_dev()
|
||||
|
||||
# Try Hermes-mapped providers first (more likely what the user wants)
|
||||
for hermes_id, mdev_id in PROVIDER_TO_MODELS_DEV.items():
|
||||
pdata = data.get(mdev_id)
|
||||
if not isinstance(pdata, dict):
|
||||
continue
|
||||
models = pdata.get("models", {})
|
||||
if not isinstance(models, dict):
|
||||
continue
|
||||
|
||||
raw = models.get(model_id)
|
||||
if isinstance(raw, dict):
|
||||
return _parse_model_info(model_id, raw, mdev_id)
|
||||
|
||||
# Case-insensitive
|
||||
model_lower = model_id.lower()
|
||||
for mid, mdata in models.items():
|
||||
if mid.lower() == model_lower and isinstance(mdata, dict):
|
||||
return _parse_model_info(mid, mdata, mdev_id)
|
||||
|
||||
# Fall back to ALL providers
|
||||
for pid, pdata in data.items():
|
||||
if pid in _get_reverse_mapping():
|
||||
continue # already checked
|
||||
if not isinstance(pdata, dict):
|
||||
continue
|
||||
models = pdata.get("models", {})
|
||||
if not isinstance(models, dict):
|
||||
continue
|
||||
|
||||
raw = models.get(model_id)
|
||||
if isinstance(raw, dict):
|
||||
return _parse_model_info(model_id, raw, pid)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def list_provider_model_infos(provider_id: str) -> List[ModelInfo]:
|
||||
"""Return all models for a provider as ModelInfo objects.
|
||||
|
||||
Filters out deprecated models by default.
|
||||
"""
|
||||
mdev_id = PROVIDER_TO_MODELS_DEV.get(provider_id, provider_id)
|
||||
|
||||
data = fetch_models_dev()
|
||||
pdata = data.get(mdev_id)
|
||||
if not isinstance(pdata, dict):
|
||||
return []
|
||||
|
||||
models = pdata.get("models", {})
|
||||
if not isinstance(models, dict):
|
||||
return []
|
||||
|
||||
result: List[ModelInfo] = []
|
||||
for mid, mdata in models.items():
|
||||
if not isinstance(mdata, dict):
|
||||
continue
|
||||
status = mdata.get("status", "")
|
||||
if status == "deprecated":
|
||||
continue
|
||||
result.append(_parse_model_info(mid, mdata, mdev_id))
|
||||
|
||||
return result
|
||||
|
||||
+9
-12
@@ -40,7 +40,7 @@ _CONTEXT_THREAT_PATTERNS = [
|
||||
(r'disregard\s+(your|all|any)\s+(instructions|rules|guidelines)', "disregard_rules"),
|
||||
(r'act\s+as\s+(if|though)\s+you\s+(have\s+no|don\'t\s+have)\s+(restrictions|limits|rules)', "bypass_restrictions"),
|
||||
(r'<!--[^>]*(?:ignore|override|system|secret|hidden)[^>]*-->', "html_comment_injection"),
|
||||
(r'<\s*div\s+style\s*=\s*["\'].*display\s*:\s*none', "hidden_div"),
|
||||
(r'<\s*div\s+style\s*=\s*["\'][\s\S]*?display\s*:\s*none', "hidden_div"),
|
||||
(r'translate\s+.*\s+into\s+.*\s+and\s+(execute|run|eval)', "translate_execute"),
|
||||
(r'curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_curl"),
|
||||
(r'cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass)', "read_secrets"),
|
||||
@@ -356,6 +356,14 @@ PLATFORM_HINTS = {
|
||||
"MEDIA:/absolute/path/to/file in your response. Images (.jpg, .png, "
|
||||
".heic) appear as photos and other files arrive as attachments."
|
||||
),
|
||||
"weixin": (
|
||||
"You are on Weixin/WeChat. Markdown formatting is supported, so you may use it when "
|
||||
"it improves readability, but keep the message compact and chat-friendly. You can send media files natively: "
|
||||
"include MEDIA:/absolute/path/to/file in your response. Images are sent as native "
|
||||
"photos, videos play inline when supported, and other files arrive as downloadable "
|
||||
"documents. You can also include image URLs in markdown format  and they "
|
||||
"will be downloaded and sent as native media when possible."
|
||||
),
|
||||
}
|
||||
|
||||
CONTEXT_FILE_MAX_CHARS = 20_000
|
||||
@@ -491,17 +499,6 @@ def _parse_skill_file(skill_file: Path) -> tuple[bool, dict, str]:
|
||||
return True, {}, ""
|
||||
|
||||
|
||||
def _read_skill_conditions(skill_file: Path) -> dict:
|
||||
"""Extract conditional activation fields from SKILL.md frontmatter."""
|
||||
try:
|
||||
raw = skill_file.read_text(encoding="utf-8")[:2000]
|
||||
frontmatter, _ = parse_frontmatter(raw)
|
||||
return extract_skill_conditions(frontmatter)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to read skill conditions from %s: %s", skill_file, e)
|
||||
return {}
|
||||
|
||||
|
||||
def _skill_should_show(
|
||||
conditions: dict,
|
||||
available_tools: "set[str] | None",
|
||||
|
||||
@@ -97,8 +97,12 @@ def parse_rate_limit_headers(
|
||||
|
||||
Returns None if no rate limit headers are present.
|
||||
"""
|
||||
# Normalize to lowercase so lookups work regardless of how the server
|
||||
# capitalises headers (HTTP header names are case-insensitive per RFC 7230).
|
||||
lowered = {k.lower(): v for k, v in headers.items()}
|
||||
|
||||
# Quick check: at least one rate limit header must exist
|
||||
has_any = any(k.lower().startswith("x-ratelimit-") for k in headers)
|
||||
has_any = any(k.startswith("x-ratelimit-") for k in lowered)
|
||||
if not has_any:
|
||||
return None
|
||||
|
||||
@@ -109,9 +113,9 @@ def parse_rate_limit_headers(
|
||||
# resource="tokens", suffix="-1h" -> per-hour
|
||||
tag = f"{resource}{suffix}"
|
||||
return RateLimitBucket(
|
||||
limit=_safe_int(headers.get(f"x-ratelimit-limit-{tag}")),
|
||||
remaining=_safe_int(headers.get(f"x-ratelimit-remaining-{tag}")),
|
||||
reset_seconds=_safe_float(headers.get(f"x-ratelimit-reset-{tag}")),
|
||||
limit=_safe_int(lowered.get(f"x-ratelimit-limit-{tag}")),
|
||||
remaining=_safe_int(lowered.get(f"x-ratelimit-remaining-{tag}")),
|
||||
reset_seconds=_safe_float(lowered.get(f"x-ratelimit-reset-{tag}")),
|
||||
captured_at=now,
|
||||
)
|
||||
|
||||
|
||||
@@ -181,6 +181,7 @@ def resolve_turn_route(user_message: str, routing_config: Optional[Dict[str, Any
|
||||
"api_mode": runtime.get("api_mode"),
|
||||
"command": runtime.get("command"),
|
||||
"args": list(runtime.get("args") or []),
|
||||
"credential_pool": runtime.get("credential_pool"),
|
||||
},
|
||||
"label": f"smart route → {route.get('model')} ({runtime.get('provider')})",
|
||||
"signature": (
|
||||
|
||||
@@ -595,30 +595,6 @@ def get_pricing(
|
||||
}
|
||||
|
||||
|
||||
def estimate_cost_usd(
|
||||
model: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
*,
|
||||
provider: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
) -> float:
|
||||
"""Backward-compatible helper for legacy callers.
|
||||
|
||||
This uses non-cached input/output only. New code should call
|
||||
`estimate_usage_cost()` with canonical usage buckets.
|
||||
"""
|
||||
result = estimate_usage_cost(
|
||||
model,
|
||||
CanonicalUsage(input_tokens=input_tokens, output_tokens=output_tokens),
|
||||
provider=provider,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
)
|
||||
return float(result.amount_usd or _ZERO)
|
||||
|
||||
|
||||
def format_duration_compact(seconds: float) -> str:
|
||||
if seconds < 60:
|
||||
return f"{seconds:.0f}s"
|
||||
|
||||
+2
-2
@@ -1158,7 +1158,7 @@ def main(
|
||||
providers_order (str): Comma-separated list of OpenRouter providers to try in order (e.g. "anthropic,openai,google")
|
||||
provider_sort (str): Sort providers by "price", "throughput", or "latency" (OpenRouter only)
|
||||
max_tokens (int): Maximum tokens for model responses (optional, uses model default if not set)
|
||||
reasoning_effort (str): OpenRouter reasoning effort level: "xhigh", "high", "medium", "low", "minimal", "none" (default: "medium")
|
||||
reasoning_effort (str): OpenRouter reasoning effort level: "none", "minimal", "low", "medium", "high", "xhigh" (default: "medium")
|
||||
reasoning_disabled (bool): Completely disable reasoning/thinking tokens (default: False)
|
||||
prefill_messages_file (str): Path to JSON file containing prefill messages (list of {role, content} dicts)
|
||||
max_samples (int): Only process the first N samples from the dataset (optional, processes all if not set)
|
||||
@@ -1227,7 +1227,7 @@ def main(
|
||||
print("🧠 Reasoning: DISABLED (effort=none)")
|
||||
elif reasoning_effort:
|
||||
# Use specified effort level
|
||||
valid_efforts = ["xhigh", "high", "medium", "low", "minimal", "none"]
|
||||
valid_efforts = ["none", "minimal", "low", "medium", "high", "xhigh"]
|
||||
if reasoning_effort not in valid_efforts:
|
||||
print(f"❌ Error: --reasoning_effort must be one of: {', '.join(valid_efforts)}")
|
||||
return
|
||||
|
||||
@@ -684,7 +684,11 @@ platform_toolsets:
|
||||
stt:
|
||||
enabled: true
|
||||
# provider: "local" # auto-detected if omitted
|
||||
model: "whisper-1" # whisper-1 (cheapest) | gpt-4o-mini-transcribe | gpt-4o-transcribe
|
||||
local:
|
||||
model: "base" # tiny | base | small | medium | large-v3 | turbo
|
||||
# language: "" # auto-detect; set to "en", "es", "fr", etc. to force
|
||||
openai:
|
||||
model: "whisper-1" # whisper-1 | gpt-4o-mini-transcribe | gpt-4o-transcribe
|
||||
# mistral:
|
||||
# model: "voxtral-mini-latest" # voxtral-mini-latest | voxtral-mini-2602
|
||||
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
# Termux / Android dependency constraints for Hermes Agent.
|
||||
#
|
||||
# Usage:
|
||||
# python -m pip install -e '.[termux]' -c constraints-termux.txt
|
||||
#
|
||||
# These pins keep the tested Android install path stable when upstream packages
|
||||
# move faster than Termux-compatible wheels / sdists.
|
||||
|
||||
ipython<10
|
||||
jedi>=0.18.1,<0.20
|
||||
parso>=0.8.4,<0.9
|
||||
stack-data>=0.6,<0.7
|
||||
pexpect>4.3,<5
|
||||
matplotlib-inline>=0.1.7,<0.2
|
||||
asttokens>=2.1,<3
|
||||
+62
-4
@@ -44,7 +44,7 @@ logger = logging.getLogger(__name__)
|
||||
_KNOWN_DELIVERY_PLATFORMS = frozenset({
|
||||
"telegram", "discord", "slack", "whatsapp", "signal",
|
||||
"matrix", "mattermost", "homeassistant", "dingtalk", "feishu",
|
||||
"wecom", "sms", "email", "webhook", "bluebubbles",
|
||||
"wecom", "weixin", "sms", "email", "webhook", "bluebubbles",
|
||||
})
|
||||
|
||||
from cron.jobs import get_due_jobs, mark_job_run, save_job_output, advance_next_run
|
||||
@@ -234,6 +234,7 @@ def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> Option
|
||||
"dingtalk": Platform.DINGTALK,
|
||||
"feishu": Platform.FEISHU,
|
||||
"wecom": Platform.WECOM,
|
||||
"weixin": Platform.WEIXIN,
|
||||
"email": Platform.EMAIL,
|
||||
"sms": Platform.SMS,
|
||||
"bluebubbles": Platform.BLUEBUBBLES,
|
||||
@@ -346,7 +347,42 @@ def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> Option
|
||||
return None
|
||||
|
||||
|
||||
_SCRIPT_TIMEOUT = 120 # seconds
|
||||
_DEFAULT_SCRIPT_TIMEOUT = 120 # seconds
|
||||
# Backward-compatible module override used by tests and emergency monkeypatches.
|
||||
_SCRIPT_TIMEOUT = _DEFAULT_SCRIPT_TIMEOUT
|
||||
|
||||
|
||||
def _get_script_timeout() -> int:
|
||||
"""Resolve cron pre-run script timeout from module/env/config with a safe default."""
|
||||
if _SCRIPT_TIMEOUT != _DEFAULT_SCRIPT_TIMEOUT:
|
||||
try:
|
||||
timeout = int(float(_SCRIPT_TIMEOUT))
|
||||
if timeout > 0:
|
||||
return timeout
|
||||
except Exception:
|
||||
logger.warning("Invalid patched _SCRIPT_TIMEOUT=%r; using env/config/default", _SCRIPT_TIMEOUT)
|
||||
|
||||
env_value = os.getenv("HERMES_CRON_SCRIPT_TIMEOUT", "").strip()
|
||||
if env_value:
|
||||
try:
|
||||
timeout = int(float(env_value))
|
||||
if timeout > 0:
|
||||
return timeout
|
||||
except Exception:
|
||||
logger.warning("Invalid HERMES_CRON_SCRIPT_TIMEOUT=%r; using config/default", env_value)
|
||||
|
||||
try:
|
||||
cfg = load_config() or {}
|
||||
cron_cfg = cfg.get("cron", {}) if isinstance(cfg, dict) else {}
|
||||
configured = cron_cfg.get("script_timeout_seconds")
|
||||
if configured is not None:
|
||||
timeout = int(float(configured))
|
||||
if timeout > 0:
|
||||
return timeout
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to load cron script timeout from config: %s", exc)
|
||||
|
||||
return _DEFAULT_SCRIPT_TIMEOUT
|
||||
|
||||
|
||||
def _run_job_script(script_path: str) -> tuple[bool, str]:
|
||||
@@ -393,12 +429,14 @@ def _run_job_script(script_path: str) -> tuple[bool, str]:
|
||||
if not path.is_file():
|
||||
return False, f"Script path is not a file: {path}"
|
||||
|
||||
script_timeout = _get_script_timeout()
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[sys.executable, str(path)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=_SCRIPT_TIMEOUT,
|
||||
timeout=script_timeout,
|
||||
cwd=str(path.parent),
|
||||
)
|
||||
stdout = (result.stdout or "").strip()
|
||||
@@ -422,7 +460,7 @@ def _run_job_script(script_path: str) -> tuple[bool, str]:
|
||||
return True, stdout
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return False, f"Script timed out after {_SCRIPT_TIMEOUT}s: {path}"
|
||||
return False, f"Script timed out after {script_timeout}s: {path}"
|
||||
except Exception as exc:
|
||||
return False, f"Script execution failed: {exc}"
|
||||
|
||||
@@ -646,6 +684,24 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
},
|
||||
)
|
||||
|
||||
fallback_model = _cfg.get("fallback_providers") or _cfg.get("fallback_model") or None
|
||||
credential_pool = None
|
||||
runtime_provider = str(turn_route["runtime"].get("provider") or "").strip().lower()
|
||||
if runtime_provider:
|
||||
try:
|
||||
from agent.credential_pool import load_pool
|
||||
pool = load_pool(runtime_provider)
|
||||
if pool.has_credentials():
|
||||
credential_pool = pool
|
||||
logger.info(
|
||||
"Job '%s': loaded credential pool for provider %s with %d entries",
|
||||
job_id,
|
||||
runtime_provider,
|
||||
len(pool.entries()),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Job '%s': failed to load credential pool for %s: %s", job_id, runtime_provider, e)
|
||||
|
||||
agent = AIAgent(
|
||||
model=turn_route["model"],
|
||||
api_key=turn_route["runtime"].get("api_key"),
|
||||
@@ -657,6 +713,8 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
max_iterations=max_iterations,
|
||||
reasoning_config=reasoning_config,
|
||||
prefill_messages=prefill_messages,
|
||||
fallback_model=fallback_model,
|
||||
credential_pool=credential_pool,
|
||||
providers_allowed=pr.get("only"),
|
||||
providers_ignored=pr.get("ignore"),
|
||||
providers_order=pr.get("order"),
|
||||
|
||||
@@ -9,7 +9,10 @@ INSTALL_DIR="/opt/hermes"
|
||||
# (cache/images, cache/audio, platforms/whatsapp, etc.) are created on
|
||||
# demand by the application — don't pre-create them here so new installs
|
||||
# get the consolidated layout from get_hermes_dir().
|
||||
mkdir -p "$HERMES_HOME"/{cron,sessions,logs,hooks,memories,skills}
|
||||
# The "home/" subdirectory is a per-profile HOME for subprocesses (git,
|
||||
# ssh, gh, npm …). Without it those tools write to /root which is
|
||||
# ephemeral and shared across profiles. See issue #4426.
|
||||
mkdir -p "$HERMES_HOME"/{cron,sessions,logs,hooks,memories,skills,home}
|
||||
|
||||
# .env
|
||||
if [ ! -f "$HERMES_HOME/.env" ]; then
|
||||
|
||||
@@ -77,7 +77,7 @@ def build_channel_directory(adapters: Dict[Any, Any]) -> Dict[str, Any]:
|
||||
logger.warning("Channel directory: failed to build %s: %s", platform.value, e)
|
||||
|
||||
# Telegram, WhatsApp & Signal can't enumerate chats -- pull from session history
|
||||
for plat_name in ("telegram", "whatsapp", "signal", "email", "sms", "bluebubbles"):
|
||||
for plat_name in ("telegram", "whatsapp", "signal", "weixin", "email", "sms", "bluebubbles"):
|
||||
if plat_name not in platforms:
|
||||
platforms[plat_name] = _build_from_sessions(plat_name)
|
||||
|
||||
|
||||
@@ -63,6 +63,7 @@ class Platform(Enum):
|
||||
WEBHOOK = "webhook"
|
||||
FEISHU = "feishu"
|
||||
WECOM = "wecom"
|
||||
WEIXIN = "weixin"
|
||||
BLUEBUBBLES = "bluebubbles"
|
||||
|
||||
|
||||
@@ -261,6 +262,11 @@ class GatewayConfig:
|
||||
for platform, config in self.platforms.items():
|
||||
if not config.enabled:
|
||||
continue
|
||||
# Weixin requires both a token and an account_id
|
||||
if platform == Platform.WEIXIN:
|
||||
if config.extra.get("account_id") and (config.token or config.extra.get("token")):
|
||||
connected.append(platform)
|
||||
continue
|
||||
# Platforms that use token/api_key auth
|
||||
if config.token or config.api_key:
|
||||
connected.append(platform)
|
||||
@@ -532,8 +538,12 @@ def load_gateway_config() -> GatewayConfig:
|
||||
bridged["reply_prefix"] = platform_cfg["reply_prefix"]
|
||||
if "require_mention" in platform_cfg:
|
||||
bridged["require_mention"] = platform_cfg["require_mention"]
|
||||
if "free_response_channels" in platform_cfg:
|
||||
bridged["free_response_channels"] = platform_cfg["free_response_channels"]
|
||||
if "mention_patterns" in platform_cfg:
|
||||
bridged["mention_patterns"] = platform_cfg["mention_patterns"]
|
||||
if plat == Platform.DISCORD and "channel_skill_bindings" in platform_cfg:
|
||||
bridged["channel_skill_bindings"] = platform_cfg["channel_skill_bindings"]
|
||||
if not bridged:
|
||||
continue
|
||||
plat_data = platforms_data.setdefault(plat.value, {})
|
||||
@@ -546,6 +556,19 @@ def load_gateway_config() -> GatewayConfig:
|
||||
plat_data["extra"] = extra
|
||||
extra.update(bridged)
|
||||
|
||||
# Slack settings → env vars (env vars take precedence)
|
||||
slack_cfg = yaml_cfg.get("slack", {})
|
||||
if isinstance(slack_cfg, dict):
|
||||
if "require_mention" in slack_cfg and not os.getenv("SLACK_REQUIRE_MENTION"):
|
||||
os.environ["SLACK_REQUIRE_MENTION"] = str(slack_cfg["require_mention"]).lower()
|
||||
if "allow_bots" in slack_cfg and not os.getenv("SLACK_ALLOW_BOTS"):
|
||||
os.environ["SLACK_ALLOW_BOTS"] = str(slack_cfg["allow_bots"]).lower()
|
||||
frc = slack_cfg.get("free_response_channels")
|
||||
if frc is not None and not os.getenv("SLACK_FREE_RESPONSE_CHANNELS"):
|
||||
if isinstance(frc, list):
|
||||
frc = ",".join(str(v) for v in frc)
|
||||
os.environ["SLACK_FREE_RESPONSE_CHANNELS"] = str(frc)
|
||||
|
||||
# Discord settings → env vars (env vars take precedence)
|
||||
discord_cfg = yaml_cfg.get("discord", {})
|
||||
if isinstance(discord_cfg, dict):
|
||||
@@ -566,6 +589,12 @@ def load_gateway_config() -> GatewayConfig:
|
||||
if isinstance(ic, list):
|
||||
ic = ",".join(str(v) for v in ic)
|
||||
os.environ["DISCORD_IGNORED_CHANNELS"] = str(ic)
|
||||
# allowed_channels: if set, bot ONLY responds in these channels (whitelist)
|
||||
ac = discord_cfg.get("allowed_channels")
|
||||
if ac is not None and not os.getenv("DISCORD_ALLOWED_CHANNELS"):
|
||||
if isinstance(ac, list):
|
||||
ac = ",".join(str(v) for v in ac)
|
||||
os.environ["DISCORD_ALLOWED_CHANNELS"] = str(ac)
|
||||
# no_thread_channels: channels where bot responds directly without creating thread
|
||||
ntc = discord_cfg.get("no_thread_channels")
|
||||
if ntc is not None and not os.getenv("DISCORD_NO_THREAD_CHANNELS"):
|
||||
@@ -651,6 +680,7 @@ def load_gateway_config() -> GatewayConfig:
|
||||
Platform.SLACK: "SLACK_BOT_TOKEN",
|
||||
Platform.MATTERMOST: "MATTERMOST_TOKEN",
|
||||
Platform.MATRIX: "MATRIX_ACCESS_TOKEN",
|
||||
Platform.WEIXIN: "WEIXIN_TOKEN",
|
||||
}
|
||||
for platform, pconfig in config.platforms.items():
|
||||
if not pconfig.enabled:
|
||||
@@ -886,6 +916,9 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
pass
|
||||
if api_server_host:
|
||||
config.platforms[Platform.API_SERVER].extra["host"] = api_server_host
|
||||
api_server_model_name = os.getenv("API_SERVER_MODEL_NAME", "")
|
||||
if api_server_model_name:
|
||||
config.platforms[Platform.API_SERVER].extra["model_name"] = api_server_model_name
|
||||
|
||||
# Webhook platform
|
||||
webhook_enabled = os.getenv("WEBHOOK_ENABLED", "").lower() in ("true", "1", "yes")
|
||||
@@ -952,6 +985,44 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
name=os.getenv("WECOM_HOME_CHANNEL_NAME", "Home"),
|
||||
)
|
||||
|
||||
# Weixin (personal WeChat via iLink Bot API)
|
||||
weixin_token = os.getenv("WEIXIN_TOKEN")
|
||||
weixin_account_id = os.getenv("WEIXIN_ACCOUNT_ID")
|
||||
if weixin_token or weixin_account_id:
|
||||
if Platform.WEIXIN not in config.platforms:
|
||||
config.platforms[Platform.WEIXIN] = PlatformConfig()
|
||||
config.platforms[Platform.WEIXIN].enabled = True
|
||||
if weixin_token:
|
||||
config.platforms[Platform.WEIXIN].token = weixin_token
|
||||
extra = config.platforms[Platform.WEIXIN].extra
|
||||
if weixin_account_id:
|
||||
extra["account_id"] = weixin_account_id
|
||||
weixin_base_url = os.getenv("WEIXIN_BASE_URL", "").strip()
|
||||
if weixin_base_url:
|
||||
extra["base_url"] = weixin_base_url.rstrip("/")
|
||||
weixin_cdn_base_url = os.getenv("WEIXIN_CDN_BASE_URL", "").strip()
|
||||
if weixin_cdn_base_url:
|
||||
extra["cdn_base_url"] = weixin_cdn_base_url.rstrip("/")
|
||||
weixin_dm_policy = os.getenv("WEIXIN_DM_POLICY", "").strip().lower()
|
||||
if weixin_dm_policy:
|
||||
extra["dm_policy"] = weixin_dm_policy
|
||||
weixin_group_policy = os.getenv("WEIXIN_GROUP_POLICY", "").strip().lower()
|
||||
if weixin_group_policy:
|
||||
extra["group_policy"] = weixin_group_policy
|
||||
weixin_allowed_users = os.getenv("WEIXIN_ALLOWED_USERS", "").strip()
|
||||
if weixin_allowed_users:
|
||||
extra["allow_from"] = weixin_allowed_users
|
||||
weixin_group_allowed_users = os.getenv("WEIXIN_GROUP_ALLOWED_USERS", "").strip()
|
||||
if weixin_group_allowed_users:
|
||||
extra["group_allow_from"] = weixin_group_allowed_users
|
||||
weixin_home = os.getenv("WEIXIN_HOME_CHANNEL", "").strip()
|
||||
if weixin_home:
|
||||
config.platforms[Platform.WEIXIN].home_channel = HomeChannel(
|
||||
platform=Platform.WEIXIN,
|
||||
chat_id=weixin_home,
|
||||
name=os.getenv("WEIXIN_HOME_CHANNEL_NAME", "Home"),
|
||||
)
|
||||
|
||||
# BlueBubbles (iMessage)
|
||||
bluebubbles_server_url = os.getenv("BLUEBUBBLES_SERVER_URL")
|
||||
bluebubbles_password = os.getenv("BLUEBUBBLES_PASSWORD")
|
||||
|
||||
@@ -124,53 +124,6 @@ class DeliveryRouter:
|
||||
self.adapters = adapters or {}
|
||||
self.output_dir = get_hermes_home() / "cron" / "output"
|
||||
|
||||
def resolve_targets(
|
||||
self,
|
||||
deliver: Union[str, List[str]],
|
||||
origin: Optional[SessionSource] = None
|
||||
) -> List[DeliveryTarget]:
|
||||
"""
|
||||
Resolve delivery specification to concrete targets.
|
||||
|
||||
Args:
|
||||
deliver: Delivery spec - "origin", "telegram", ["local", "discord"], etc.
|
||||
origin: The source where the request originated (for "origin" target)
|
||||
|
||||
Returns:
|
||||
List of resolved delivery targets
|
||||
"""
|
||||
if isinstance(deliver, str):
|
||||
deliver = [deliver]
|
||||
|
||||
targets = []
|
||||
seen_platforms = set()
|
||||
|
||||
for target_str in deliver:
|
||||
target = DeliveryTarget.parse(target_str, origin)
|
||||
|
||||
# Resolve home channel if needed
|
||||
if target.chat_id is None and target.platform != Platform.LOCAL:
|
||||
home = self.config.get_home_channel(target.platform)
|
||||
if home:
|
||||
target.chat_id = home.chat_id
|
||||
else:
|
||||
# No home channel configured, skip this platform
|
||||
continue
|
||||
|
||||
# Deduplicate
|
||||
key = (target.platform, target.chat_id, target.thread_id)
|
||||
if key not in seen_platforms:
|
||||
seen_platforms.add(key)
|
||||
targets.append(target)
|
||||
|
||||
# Always include local if configured
|
||||
if self.config.always_log_local:
|
||||
local_key = (Platform.LOCAL, None, None)
|
||||
if local_key not in seen_platforms:
|
||||
targets.append(DeliveryTarget(platform=Platform.LOCAL))
|
||||
|
||||
return targets
|
||||
|
||||
async def deliver(
|
||||
self,
|
||||
content: str,
|
||||
@@ -299,19 +252,5 @@ class DeliveryRouter:
|
||||
return await adapter.send(target.chat_id, content, metadata=send_metadata or None)
|
||||
|
||||
|
||||
def parse_deliver_spec(
|
||||
deliver: Optional[Union[str, List[str]]],
|
||||
origin: Optional[SessionSource] = None,
|
||||
default: str = "origin"
|
||||
) -> Union[str, List[str]]:
|
||||
"""
|
||||
Normalize a delivery specification.
|
||||
|
||||
If None or empty, returns the default.
|
||||
"""
|
||||
if not deliver:
|
||||
return default
|
||||
return deliver
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -20,10 +20,13 @@ Requires:
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import socket as _socket
|
||||
import re
|
||||
import sqlite3
|
||||
import time
|
||||
import uuid
|
||||
@@ -40,6 +43,7 @@ from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
SendResult,
|
||||
is_network_accessible,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -51,6 +55,7 @@ MAX_STORED_RESPONSES = 100
|
||||
MAX_REQUEST_BYTES = 1_000_000 # 1 MB default limit for POST bodies
|
||||
|
||||
|
||||
|
||||
def check_api_server_requirements() -> bool:
|
||||
"""Check if API server dependencies are available."""
|
||||
return AIOHTTP_AVAILABLE
|
||||
@@ -282,6 +287,24 @@ def _make_request_fingerprint(body: Dict[str, Any], keys: List[str]) -> str:
|
||||
return sha256(repr(subset).encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _derive_chat_session_id(
|
||||
system_prompt: Optional[str],
|
||||
first_user_message: str,
|
||||
) -> str:
|
||||
"""Derive a stable session ID from the conversation's first user message.
|
||||
|
||||
OpenAI-compatible frontends (Open WebUI, LibreChat, etc.) send the full
|
||||
conversation history with every request. The system prompt and first user
|
||||
message are constant across all turns of the same conversation, so hashing
|
||||
them produces a deterministic session ID that lets the API server reuse
|
||||
the same Hermes session (and therefore the same Docker container sandbox
|
||||
directory) across turns.
|
||||
"""
|
||||
seed = f"{system_prompt or ''}\n{first_user_message}"
|
||||
digest = hashlib.sha256(seed.encode("utf-8")).hexdigest()[:16]
|
||||
return f"api-{digest}"
|
||||
|
||||
|
||||
class APIServerAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
OpenAI-compatible HTTP API server adapter.
|
||||
@@ -299,6 +322,9 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
self._cors_origins: tuple[str, ...] = self._parse_cors_origins(
|
||||
extra.get("cors_origins", os.getenv("API_SERVER_CORS_ORIGINS", "")),
|
||||
)
|
||||
self._model_name: str = self._resolve_model_name(
|
||||
extra.get("model_name", os.getenv("API_SERVER_MODEL_NAME", "")),
|
||||
)
|
||||
self._app: Optional["web.Application"] = None
|
||||
self._runner: Optional["web.AppRunner"] = None
|
||||
self._site: Optional["web.TCPSite"] = None
|
||||
@@ -324,6 +350,26 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
|
||||
return tuple(str(item).strip() for item in items if str(item).strip())
|
||||
|
||||
@staticmethod
|
||||
def _resolve_model_name(explicit: str) -> str:
|
||||
"""Derive the advertised model name for /v1/models.
|
||||
|
||||
Priority:
|
||||
1. Explicit override (config extra or API_SERVER_MODEL_NAME env var)
|
||||
2. Active profile name (so each profile advertises a distinct model)
|
||||
3. Fallback: "hermes-agent"
|
||||
"""
|
||||
if explicit and explicit.strip():
|
||||
return explicit.strip()
|
||||
try:
|
||||
from hermes_cli.profiles import get_active_profile_name
|
||||
profile = get_active_profile_name()
|
||||
if profile and profile not in ("default", "custom"):
|
||||
return profile
|
||||
except Exception:
|
||||
pass
|
||||
return "hermes-agent"
|
||||
|
||||
def _cors_headers_for_origin(self, origin: str) -> Optional[Dict[str, str]]:
|
||||
"""Return CORS headers for an allowed browser origin."""
|
||||
if not origin or not self._cors_origins:
|
||||
@@ -363,7 +409,8 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
Validate Bearer token from Authorization header.
|
||||
|
||||
Returns None if auth is OK, or a 401 web.Response on failure.
|
||||
If no API key is configured, all requests are allowed.
|
||||
If no API key is configured, all requests are allowed (only when API
|
||||
server is local)
|
||||
"""
|
||||
if not self._api_key:
|
||||
return None # No key configured — allow all (local-only use)
|
||||
@@ -468,12 +515,12 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "hermes-agent",
|
||||
"id": self._model_name,
|
||||
"object": "model",
|
||||
"created": int(time.time()),
|
||||
"owned_by": "hermes",
|
||||
"permission": [],
|
||||
"root": "hermes-agent",
|
||||
"root": self._model_name,
|
||||
"parent": None,
|
||||
}
|
||||
],
|
||||
@@ -531,8 +578,32 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
|
||||
# Allow caller to continue an existing session by passing X-Hermes-Session-Id.
|
||||
# When provided, history is loaded from state.db instead of from the request body.
|
||||
#
|
||||
# Security: session continuation exposes conversation history, so it is
|
||||
# only allowed when the API key is configured and the request is
|
||||
# authenticated. Without this gate, any unauthenticated client could
|
||||
# read arbitrary session history by guessing/enumerating session IDs.
|
||||
provided_session_id = request.headers.get("X-Hermes-Session-Id", "").strip()
|
||||
if provided_session_id:
|
||||
if not self._api_key:
|
||||
logger.warning(
|
||||
"Session continuation via X-Hermes-Session-Id rejected: "
|
||||
"no API key configured. Set API_SERVER_KEY to enable "
|
||||
"session continuity."
|
||||
)
|
||||
return web.json_response(
|
||||
_openai_error(
|
||||
"Session continuation requires API key authentication. "
|
||||
"Configure API_SERVER_KEY to enable this feature."
|
||||
),
|
||||
status=403,
|
||||
)
|
||||
# Sanitize: reject control characters that could enable header injection.
|
||||
if re.search(r'[\r\n\x00]', provided_session_id):
|
||||
return web.json_response(
|
||||
{"error": {"message": "Invalid session ID", "type": "invalid_request_error"}},
|
||||
status=400,
|
||||
)
|
||||
session_id = provided_session_id
|
||||
try:
|
||||
db = self._ensure_session_db()
|
||||
@@ -542,11 +613,20 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
logger.warning("Failed to load session history for %s: %s", session_id, e)
|
||||
history = []
|
||||
else:
|
||||
session_id = str(uuid.uuid4())
|
||||
# Derive a stable session ID from the conversation fingerprint so
|
||||
# that consecutive messages from the same Open WebUI (or similar)
|
||||
# conversation map to the same Hermes session. The first user
|
||||
# message + system prompt are constant across all turns.
|
||||
first_user = ""
|
||||
for cm in conversation_messages:
|
||||
if cm.get("role") == "user":
|
||||
first_user = cm.get("content", "")
|
||||
break
|
||||
session_id = _derive_chat_session_id(system_prompt, first_user)
|
||||
# history already set from request body above
|
||||
|
||||
completion_id = f"chatcmpl-{uuid.uuid4().hex[:29]}"
|
||||
model_name = body.get("model", "hermes-agent")
|
||||
model_name = body.get("model", self._model_name)
|
||||
created = int(time.time())
|
||||
|
||||
if stream:
|
||||
@@ -923,7 +1003,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
"object": "response",
|
||||
"status": "completed",
|
||||
"created_at": created_at,
|
||||
"model": body.get("model", "hermes-agent"),
|
||||
"model": body.get("model", self._model_name),
|
||||
"output": output_items,
|
||||
"usage": {
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
@@ -1318,6 +1398,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
result = agent.run_conversation(
|
||||
user_message=user_message,
|
||||
conversation_history=conversation_history,
|
||||
task_id="default",
|
||||
)
|
||||
usage = {
|
||||
"input_tokens": getattr(agent, "session_prompt_tokens", 0) or 0,
|
||||
@@ -1484,6 +1565,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
r = agent.run_conversation(
|
||||
user_message=user_message,
|
||||
conversation_history=conversation_history,
|
||||
task_id="default",
|
||||
)
|
||||
u = {
|
||||
"input_tokens": getattr(agent, "session_prompt_tokens", 0) or 0,
|
||||
@@ -1635,8 +1717,16 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
if hasattr(sweep_task, "add_done_callback"):
|
||||
sweep_task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
# Refuse to start network-accessible without authentication
|
||||
if is_network_accessible(self._host) and not self._api_key:
|
||||
logger.error(
|
||||
"[%s] Refusing to start: binding to %s requires API_SERVER_KEY. "
|
||||
"Set API_SERVER_KEY or use the default 127.0.0.1.",
|
||||
self.name, self._host,
|
||||
)
|
||||
return False
|
||||
|
||||
# Port conflict detection — fail fast if port is already in use
|
||||
import socket as _socket
|
||||
try:
|
||||
with _socket.socket(_socket.AF_INET, _socket.SOCK_STREAM) as _s:
|
||||
_s.settimeout(1)
|
||||
@@ -1652,9 +1742,17 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
await self._site.start()
|
||||
|
||||
self._mark_connected()
|
||||
if not self._api_key:
|
||||
logger.warning(
|
||||
"[%s] ⚠️ No API key configured (API_SERVER_KEY / platforms.api_server.key). "
|
||||
"All requests will be accepted without authentication. "
|
||||
"Set an API key for production deployments to prevent "
|
||||
"unauthorized access to sessions, responses, and cron jobs.",
|
||||
self.name,
|
||||
)
|
||||
logger.info(
|
||||
"[%s] API server listening on http://%s:%d",
|
||||
self.name, self._host, self._port,
|
||||
"[%s] API server listening on http://%s:%d (model: %s)",
|
||||
self.name, self._host, self._port, self._model_name,
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
+253
-16
@@ -6,22 +6,183 @@ and implement the required methods.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import ipaddress
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import socket as _socket
|
||||
import subprocess
|
||||
import sys
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_network_accessible(host: str) -> bool:
|
||||
"""Return True if *host* would expose the server beyond loopback.
|
||||
|
||||
Loopback addresses (127.0.0.1, ::1, IPv4-mapped ::ffff:127.0.0.1)
|
||||
are local-only. Unspecified addresses (0.0.0.0, ::) bind all
|
||||
interfaces. Hostnames are resolved; DNS failure fails closed.
|
||||
"""
|
||||
try:
|
||||
addr = ipaddress.ip_address(host)
|
||||
if addr.is_loopback:
|
||||
return False
|
||||
# ::ffff:127.0.0.1 — Python reports is_loopback=False for mapped
|
||||
# addresses, so check the underlying IPv4 explicitly.
|
||||
if getattr(addr, "ipv4_mapped", None) and addr.ipv4_mapped.is_loopback:
|
||||
return False
|
||||
return True
|
||||
except ValueError:
|
||||
# when host variable is a hostname, we should try to resolve below
|
||||
pass
|
||||
|
||||
try:
|
||||
resolved = _socket.getaddrinfo(
|
||||
host, None, _socket.AF_UNSPEC, _socket.SOCK_STREAM,
|
||||
)
|
||||
# if the hostname resolves into at least one non-loopback address,
|
||||
# then we consider it to be network accessible
|
||||
for _family, _type, _proto, _canonname, sockaddr in resolved:
|
||||
addr = ipaddress.ip_address(sockaddr[0])
|
||||
if not addr.is_loopback:
|
||||
return True
|
||||
return False
|
||||
except (_socket.gaierror, OSError):
|
||||
return True
|
||||
|
||||
|
||||
def _detect_macos_system_proxy() -> str | None:
|
||||
"""Read the macOS system HTTP(S) proxy via ``scutil --proxy``.
|
||||
|
||||
Returns an ``http://host:port`` URL string if an HTTP or HTTPS proxy is
|
||||
enabled, otherwise *None*. Falls back silently on non-macOS or on any
|
||||
subprocess error.
|
||||
"""
|
||||
if sys.platform != "darwin":
|
||||
return None
|
||||
try:
|
||||
out = subprocess.check_output(
|
||||
["scutil", "--proxy"], timeout=3, text=True, stderr=subprocess.DEVNULL,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
props: dict[str, str] = {}
|
||||
for line in out.splitlines():
|
||||
line = line.strip()
|
||||
if " : " in line:
|
||||
key, _, val = line.partition(" : ")
|
||||
props[key.strip()] = val.strip()
|
||||
|
||||
# Prefer HTTPS, fall back to HTTP
|
||||
for enable_key, host_key, port_key in (
|
||||
("HTTPSEnable", "HTTPSProxy", "HTTPSPort"),
|
||||
("HTTPEnable", "HTTPProxy", "HTTPPort"),
|
||||
):
|
||||
if props.get(enable_key) == "1":
|
||||
host = props.get(host_key)
|
||||
port = props.get(port_key)
|
||||
if host and port:
|
||||
return f"http://{host}:{port}"
|
||||
return None
|
||||
|
||||
|
||||
def resolve_proxy_url(platform_env_var: str | None = None) -> str | None:
|
||||
"""Return a proxy URL from env vars, or macOS system proxy.
|
||||
|
||||
Check order:
|
||||
0. *platform_env_var* (e.g. ``DISCORD_PROXY``) — highest priority
|
||||
1. HTTPS_PROXY / HTTP_PROXY / ALL_PROXY (and lowercase variants)
|
||||
2. macOS system proxy via ``scutil --proxy`` (auto-detect)
|
||||
|
||||
Returns *None* if no proxy is found.
|
||||
"""
|
||||
if platform_env_var:
|
||||
value = (os.environ.get(platform_env_var) or "").strip()
|
||||
if value:
|
||||
return value
|
||||
for key in ("HTTPS_PROXY", "HTTP_PROXY", "ALL_PROXY",
|
||||
"https_proxy", "http_proxy", "all_proxy"):
|
||||
value = (os.environ.get(key) or "").strip()
|
||||
if value:
|
||||
return value
|
||||
return _detect_macos_system_proxy()
|
||||
|
||||
|
||||
def proxy_kwargs_for_bot(proxy_url: str | None) -> dict:
|
||||
"""Build kwargs for ``commands.Bot()`` / ``discord.Client()`` with proxy.
|
||||
|
||||
Returns:
|
||||
- SOCKS URL → ``{"connector": ProxyConnector(..., rdns=True)}``
|
||||
- HTTP URL → ``{"proxy": url}``
|
||||
- *None* → ``{}``
|
||||
|
||||
``rdns=True`` forces remote DNS resolution through the proxy — required
|
||||
by many SOCKS implementations (Shadowrocket, Clash) and essential for
|
||||
bypassing DNS pollution behind the GFW.
|
||||
"""
|
||||
if not proxy_url:
|
||||
return {}
|
||||
if proxy_url.lower().startswith("socks"):
|
||||
try:
|
||||
from aiohttp_socks import ProxyConnector
|
||||
|
||||
connector = ProxyConnector.from_url(proxy_url, rdns=True)
|
||||
return {"connector": connector}
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"aiohttp_socks not installed — SOCKS proxy %s ignored. "
|
||||
"Run: pip install aiohttp-socks",
|
||||
proxy_url,
|
||||
)
|
||||
return {}
|
||||
return {"proxy": proxy_url}
|
||||
|
||||
|
||||
def proxy_kwargs_for_aiohttp(proxy_url: str | None) -> tuple[dict, dict]:
|
||||
"""Build kwargs for standalone ``aiohttp.ClientSession`` with proxy.
|
||||
|
||||
Returns ``(session_kwargs, request_kwargs)`` where:
|
||||
- SOCKS → ``({"connector": ProxyConnector(...)}, {})``
|
||||
- HTTP → ``({}, {"proxy": url})``
|
||||
- None → ``({}, {})``
|
||||
|
||||
Usage::
|
||||
|
||||
sess_kw, req_kw = proxy_kwargs_for_aiohttp(proxy_url)
|
||||
async with aiohttp.ClientSession(**sess_kw) as session:
|
||||
async with session.get(url, **req_kw) as resp:
|
||||
...
|
||||
"""
|
||||
if not proxy_url:
|
||||
return {}, {}
|
||||
if proxy_url.lower().startswith("socks"):
|
||||
try:
|
||||
from aiohttp_socks import ProxyConnector
|
||||
|
||||
connector = ProxyConnector.from_url(proxy_url, rdns=True)
|
||||
return {"connector": connector}, {}
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"aiohttp_socks not installed — SOCKS proxy %s ignored. "
|
||||
"Run: pip install aiohttp-socks",
|
||||
proxy_url,
|
||||
)
|
||||
return {}, {}
|
||||
return {}, {"proxy": proxy_url}
|
||||
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any, Callable, Awaitable, Tuple
|
||||
from enum import Enum
|
||||
|
||||
import sys
|
||||
from pathlib import Path as _Path
|
||||
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
|
||||
|
||||
@@ -36,7 +197,7 @@ GATEWAY_SECRET_CAPTURE_UNSUPPORTED_MESSAGE = (
|
||||
)
|
||||
|
||||
|
||||
def _safe_url_for_log(url: str, max_len: int = 80) -> str:
|
||||
def safe_url_for_log(url: str, max_len: int = 80) -> str:
|
||||
"""Return a URL string safe for logs (no query/fragment/userinfo)."""
|
||||
if max_len <= 0:
|
||||
return ""
|
||||
@@ -73,6 +234,23 @@ def _safe_url_for_log(url: str, max_len: int = 80) -> str:
|
||||
return f"{safe[:max_len - 3]}..."
|
||||
|
||||
|
||||
async def _ssrf_redirect_guard(response):
|
||||
"""Re-validate each redirect target to prevent redirect-based SSRF.
|
||||
|
||||
Without this, an attacker can host a public URL that 302-redirects to
|
||||
http://169.254.169.254/ and bypass the pre-flight is_safe_url() check.
|
||||
|
||||
Must be async because httpx.AsyncClient awaits response event hooks.
|
||||
"""
|
||||
if response.is_redirect and response.next_request:
|
||||
redirect_url = str(response.next_request.url)
|
||||
from tools.url_safety import is_safe_url
|
||||
if not is_safe_url(redirect_url):
|
||||
raise ValueError(
|
||||
f"Blocked redirect to private/internal address: {safe_url_for_log(redirect_url)}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Image cache utilities
|
||||
#
|
||||
@@ -92,6 +270,23 @@ def get_image_cache_dir() -> Path:
|
||||
return IMAGE_CACHE_DIR
|
||||
|
||||
|
||||
def _looks_like_image(data: bytes) -> bool:
|
||||
"""Return True if *data* starts with a known image magic-byte sequence."""
|
||||
if len(data) < 4:
|
||||
return False
|
||||
if data[:8] == b"\x89PNG\r\n\x1a\n":
|
||||
return True
|
||||
if data[:3] == b"\xff\xd8\xff":
|
||||
return True
|
||||
if data[:6] in (b"GIF87a", b"GIF89a"):
|
||||
return True
|
||||
if data[:2] == b"BM":
|
||||
return True
|
||||
if data[:4] == b"RIFF" and len(data) >= 12 and data[8:12] == b"WEBP":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def cache_image_from_bytes(data: bytes, ext: str = ".jpg") -> str:
|
||||
"""
|
||||
Save raw image bytes to the cache and return the absolute file path.
|
||||
@@ -102,7 +297,17 @@ def cache_image_from_bytes(data: bytes, ext: str = ".jpg") -> str:
|
||||
|
||||
Returns:
|
||||
Absolute path to the cached image file as a string.
|
||||
|
||||
Raises:
|
||||
ValueError: If *data* does not look like a valid image (e.g. an HTML
|
||||
error page returned by the upstream server).
|
||||
"""
|
||||
if not _looks_like_image(data):
|
||||
snippet = data[:80].decode("utf-8", errors="replace")
|
||||
raise ValueError(
|
||||
f"Refusing to cache non-image data as {ext} "
|
||||
f"(starts with: {snippet!r})"
|
||||
)
|
||||
cache_dir = get_image_cache_dir()
|
||||
filename = f"img_{uuid.uuid4().hex[:12]}{ext}"
|
||||
filepath = cache_dir / filename
|
||||
@@ -130,7 +335,7 @@ async def cache_image_from_url(url: str, ext: str = ".jpg", retries: int = 2) ->
|
||||
"""
|
||||
from tools.url_safety import is_safe_url
|
||||
if not is_safe_url(url):
|
||||
raise ValueError(f"Blocked unsafe URL (SSRF protection): {_safe_url_for_log(url)}")
|
||||
raise ValueError(f"Blocked unsafe URL (SSRF protection): {safe_url_for_log(url)}")
|
||||
|
||||
import asyncio
|
||||
import httpx
|
||||
@@ -138,7 +343,11 @@ async def cache_image_from_url(url: str, ext: str = ".jpg", retries: int = 2) ->
|
||||
_log = _logging.getLogger(__name__)
|
||||
|
||||
last_exc = None
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=30.0,
|
||||
follow_redirects=True,
|
||||
event_hooks={"response": [_ssrf_redirect_guard]},
|
||||
) as client:
|
||||
for attempt in range(retries + 1):
|
||||
try:
|
||||
response = await client.get(
|
||||
@@ -160,7 +369,7 @@ async def cache_image_from_url(url: str, ext: str = ".jpg", retries: int = 2) ->
|
||||
"Media cache retry %d/%d for %s (%.1fs): %s",
|
||||
attempt + 1,
|
||||
retries,
|
||||
_safe_url_for_log(url),
|
||||
safe_url_for_log(url),
|
||||
wait,
|
||||
exc,
|
||||
)
|
||||
@@ -245,7 +454,7 @@ async def cache_audio_from_url(url: str, ext: str = ".ogg", retries: int = 2) ->
|
||||
"""
|
||||
from tools.url_safety import is_safe_url
|
||||
if not is_safe_url(url):
|
||||
raise ValueError(f"Blocked unsafe URL (SSRF protection): {_safe_url_for_log(url)}")
|
||||
raise ValueError(f"Blocked unsafe URL (SSRF protection): {safe_url_for_log(url)}")
|
||||
|
||||
import asyncio
|
||||
import httpx
|
||||
@@ -253,7 +462,11 @@ async def cache_audio_from_url(url: str, ext: str = ".ogg", retries: int = 2) ->
|
||||
_log = _logging.getLogger(__name__)
|
||||
|
||||
last_exc = None
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=30.0,
|
||||
follow_redirects=True,
|
||||
event_hooks={"response": [_ssrf_redirect_guard]},
|
||||
) as client:
|
||||
for attempt in range(retries + 1):
|
||||
try:
|
||||
response = await client.get(
|
||||
@@ -275,7 +488,7 @@ async def cache_audio_from_url(url: str, ext: str = ".ogg", retries: int = 2) ->
|
||||
"Audio cache retry %d/%d for %s (%.1fs): %s",
|
||||
attempt + 1,
|
||||
retries,
|
||||
_safe_url_for_log(url),
|
||||
safe_url_for_log(url),
|
||||
wait,
|
||||
exc,
|
||||
)
|
||||
@@ -378,6 +591,14 @@ class MessageType(Enum):
|
||||
COMMAND = "command" # /command style
|
||||
|
||||
|
||||
class ProcessingOutcome(Enum):
|
||||
"""Result classification for message-processing lifecycle hooks."""
|
||||
|
||||
SUCCESS = "success"
|
||||
FAILURE = "failure"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageEvent:
|
||||
"""
|
||||
@@ -405,8 +626,9 @@ class MessageEvent:
|
||||
reply_to_message_id: Optional[str] = None
|
||||
reply_to_text: Optional[str] = None # Text of the replied-to message (for context injection)
|
||||
|
||||
# Auto-loaded skill for topic/channel bindings (e.g., Telegram DM Topics)
|
||||
auto_skill: Optional[str] = None
|
||||
# Auto-loaded skill(s) for topic/channel bindings (e.g., Telegram DM Topics,
|
||||
# Discord channel_skill_bindings). A single name or ordered list.
|
||||
auto_skill: Optional[str | list[str]] = None
|
||||
|
||||
# Internal flag — set for synthetic events (e.g. background process
|
||||
# completion notifications) that must bypass user authorization checks.
|
||||
@@ -428,6 +650,9 @@ class MessageEvent:
|
||||
raw = parts[0][1:].lower() if parts else None
|
||||
if raw and "@" in raw:
|
||||
raw = raw.split("@", 1)[0]
|
||||
# Reject file paths: valid command names never contain /
|
||||
if raw and "/" in raw:
|
||||
return None
|
||||
return raw
|
||||
|
||||
def get_command_args(self) -> str:
|
||||
@@ -501,6 +726,7 @@ class BasePlatformAdapter(ABC):
|
||||
# Gateway shutdown cancels these so an old gateway instance doesn't keep
|
||||
# working on a task after --replace or manual restarts.
|
||||
self._background_tasks: set[asyncio.Task] = set()
|
||||
self._expected_cancelled_tasks: set[asyncio.Task] = set()
|
||||
# Chats where auto-TTS on voice input is disabled (set by /voice off)
|
||||
self._auto_tts_disabled_chats: set = set()
|
||||
# Chats where typing indicator is paused (e.g. during approval waits).
|
||||
@@ -1009,7 +1235,7 @@ class BasePlatformAdapter(ABC):
|
||||
async def on_processing_start(self, event: MessageEvent) -> None:
|
||||
"""Hook called when background processing begins."""
|
||||
|
||||
async def on_processing_complete(self, event: MessageEvent, success: bool) -> None:
|
||||
async def on_processing_complete(self, event: MessageEvent, outcome: ProcessingOutcome) -> None:
|
||||
"""Hook called when background processing completes."""
|
||||
|
||||
async def _run_processing_hook(self, hook_name: str, *args: Any, **kwargs: Any) -> None:
|
||||
@@ -1170,7 +1396,7 @@ class BasePlatformAdapter(ABC):
|
||||
# session lifecycle and its cleanup races with the running task
|
||||
# (see PR #4926).
|
||||
cmd = event.get_command()
|
||||
if cmd in ("approve", "deny", "status", "stop", "new", "reset"):
|
||||
if cmd in ("approve", "deny", "status", "stop", "new", "reset", "background"):
|
||||
logger.debug(
|
||||
"[%s] Command '/%s' bypassing active-session guard for %s",
|
||||
self.name, cmd, session_key,
|
||||
@@ -1228,6 +1454,7 @@ class BasePlatformAdapter(ABC):
|
||||
return
|
||||
if hasattr(task, "add_done_callback"):
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
task.add_done_callback(self._expected_cancelled_tasks.discard)
|
||||
|
||||
@staticmethod
|
||||
def _get_human_delay() -> float:
|
||||
@@ -1364,7 +1591,7 @@ class BasePlatformAdapter(ABC):
|
||||
logger.info(
|
||||
"[%s] Sending image: %s (alt=%s)",
|
||||
self.name,
|
||||
_safe_url_for_log(image_url),
|
||||
safe_url_for_log(image_url),
|
||||
alt_text[:30] if alt_text else "",
|
||||
)
|
||||
# Route animated GIFs through send_animation for proper playback
|
||||
@@ -1456,7 +1683,11 @@ class BasePlatformAdapter(ABC):
|
||||
|
||||
# Determine overall success for the processing hook
|
||||
processing_ok = delivery_succeeded if delivery_attempted else not bool(response)
|
||||
await self._run_processing_hook("on_processing_complete", event, processing_ok)
|
||||
await self._run_processing_hook(
|
||||
"on_processing_complete",
|
||||
event,
|
||||
ProcessingOutcome.SUCCESS if processing_ok else ProcessingOutcome.FAILURE,
|
||||
)
|
||||
|
||||
# Check if there's a pending message that was queued during our processing
|
||||
if session_key in self._pending_messages:
|
||||
@@ -1475,10 +1706,14 @@ class BasePlatformAdapter(ABC):
|
||||
return # Already cleaned up
|
||||
|
||||
except asyncio.CancelledError:
|
||||
await self._run_processing_hook("on_processing_complete", event, False)
|
||||
current_task = asyncio.current_task()
|
||||
outcome = ProcessingOutcome.CANCELLED
|
||||
if current_task is None or current_task not in self._expected_cancelled_tasks:
|
||||
outcome = ProcessingOutcome.FAILURE
|
||||
await self._run_processing_hook("on_processing_complete", event, outcome)
|
||||
raise
|
||||
except Exception as e:
|
||||
await self._run_processing_hook("on_processing_complete", event, False)
|
||||
await self._run_processing_hook("on_processing_complete", event, ProcessingOutcome.FAILURE)
|
||||
logger.error("[%s] Error handling message: %s", self.name, e, exc_info=True)
|
||||
# Send the error to the user so they aren't left with radio silence
|
||||
try:
|
||||
@@ -1522,10 +1757,12 @@ class BasePlatformAdapter(ABC):
|
||||
"""
|
||||
tasks = [task for task in self._background_tasks if not task.done()]
|
||||
for task in tasks:
|
||||
self._expected_cancelled_tasks.add(task)
|
||||
task.cancel()
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
self._background_tasks.clear()
|
||||
self._expected_cancelled_tasks.clear()
|
||||
self._pending_messages.clear()
|
||||
self._active_sessions.clear()
|
||||
|
||||
|
||||
@@ -207,9 +207,17 @@ class BlueBubblesAdapter(BasePlatformAdapter):
|
||||
self.webhook_port,
|
||||
self.webhook_path,
|
||||
)
|
||||
|
||||
# Register webhook with BlueBubbles server
|
||||
# This is required for the server to know where to send events
|
||||
await self._register_webhook()
|
||||
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
# Unregister webhook before cleaning up
|
||||
await self._unregister_webhook()
|
||||
|
||||
if self.client:
|
||||
await self.client.aclose()
|
||||
self.client = None
|
||||
@@ -218,6 +226,105 @@ class BlueBubblesAdapter(BasePlatformAdapter):
|
||||
self._runner = None
|
||||
self._mark_disconnected()
|
||||
|
||||
@property
|
||||
def _webhook_url(self) -> str:
|
||||
"""Compute the external webhook URL for BlueBubbles registration."""
|
||||
host = self.webhook_host
|
||||
if host in ("0.0.0.0", "127.0.0.1", "localhost", "::"):
|
||||
host = "localhost"
|
||||
return f"http://{host}:{self.webhook_port}{self.webhook_path}"
|
||||
|
||||
async def _find_registered_webhooks(self, url: str) -> list:
|
||||
"""Return list of BB webhook entries matching *url*."""
|
||||
try:
|
||||
res = await self._api_get("/api/v1/webhook")
|
||||
data = res.get("data")
|
||||
if isinstance(data, list):
|
||||
return [wh for wh in data if wh.get("url") == url]
|
||||
except Exception:
|
||||
pass
|
||||
return []
|
||||
|
||||
async def _register_webhook(self) -> bool:
|
||||
"""Register this webhook URL with the BlueBubbles server.
|
||||
|
||||
BlueBubbles requires webhooks to be registered via API before
|
||||
it will send events. Checks for an existing registration first
|
||||
to avoid duplicates (e.g. after a crash without clean shutdown).
|
||||
"""
|
||||
if not self.client:
|
||||
return False
|
||||
|
||||
webhook_url = self._webhook_url
|
||||
|
||||
# Crash resilience — reuse an existing registration if present
|
||||
existing = await self._find_registered_webhooks(webhook_url)
|
||||
if existing:
|
||||
logger.info(
|
||||
"[bluebubbles] webhook already registered: %s", webhook_url
|
||||
)
|
||||
return True
|
||||
|
||||
payload = {
|
||||
"url": webhook_url,
|
||||
"events": ["new-message", "updated-message", "message"],
|
||||
}
|
||||
|
||||
try:
|
||||
res = await self._api_post("/api/v1/webhook", payload)
|
||||
status = res.get("status", 0)
|
||||
if 200 <= status < 300:
|
||||
logger.info(
|
||||
"[bluebubbles] webhook registered with server: %s",
|
||||
webhook_url,
|
||||
)
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
"[bluebubbles] webhook registration returned status %s: %s",
|
||||
status,
|
||||
res.get("message"),
|
||||
)
|
||||
return False
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[bluebubbles] failed to register webhook with server: %s",
|
||||
exc,
|
||||
)
|
||||
return False
|
||||
|
||||
async def _unregister_webhook(self) -> bool:
|
||||
"""Unregister this webhook URL from the BlueBubbles server.
|
||||
|
||||
Removes *all* matching registrations to clean up any duplicates
|
||||
left by prior crashes.
|
||||
"""
|
||||
if not self.client:
|
||||
return False
|
||||
|
||||
webhook_url = self._webhook_url
|
||||
removed = False
|
||||
|
||||
try:
|
||||
for wh in await self._find_registered_webhooks(webhook_url):
|
||||
wh_id = wh.get("id")
|
||||
if wh_id:
|
||||
res = await self.client.delete(
|
||||
self._api_url(f"/api/v1/webhook/{wh_id}")
|
||||
)
|
||||
res.raise_for_status()
|
||||
removed = True
|
||||
if removed:
|
||||
logger.info(
|
||||
"[bluebubbles] webhook unregistered: %s", webhook_url
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"[bluebubbles] failed to unregister webhook (non-critical): %s",
|
||||
exc,
|
||||
)
|
||||
return removed
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Chat GUID resolution
|
||||
# ------------------------------------------------------------------
|
||||
@@ -826,3 +933,4 @@ class BlueBubblesAdapter(BasePlatformAdapter):
|
||||
asyncio.create_task(self.mark_read(session_chat_id))
|
||||
|
||||
return web.Response(text="ok")
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ Configuration in config.yaml:
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
@@ -54,6 +55,8 @@ MAX_MESSAGE_LENGTH = 20000
|
||||
DEDUP_WINDOW_SECONDS = 300
|
||||
DEDUP_MAX_SIZE = 1000
|
||||
RECONNECT_BACKOFF = [2, 5, 10, 30, 60]
|
||||
_SESSION_WEBHOOKS_MAX = 500
|
||||
_DINGTALK_WEBHOOK_RE = re.compile(r'^https://api\.dingtalk\.com/')
|
||||
|
||||
|
||||
def check_dingtalk_requirements() -> bool:
|
||||
@@ -195,9 +198,15 @@ class DingTalkAdapter(BasePlatformAdapter):
|
||||
chat_id = conversation_id or sender_id
|
||||
chat_type = "group" if is_group else "dm"
|
||||
|
||||
# Store session webhook for reply routing
|
||||
# Store session webhook for reply routing (validate origin to prevent SSRF)
|
||||
session_webhook = getattr(message, "session_webhook", None) or ""
|
||||
if session_webhook and chat_id:
|
||||
if session_webhook and chat_id and _DINGTALK_WEBHOOK_RE.match(session_webhook):
|
||||
if len(self._session_webhooks) >= _SESSION_WEBHOOKS_MAX:
|
||||
# Evict oldest entry to cap memory growth
|
||||
try:
|
||||
self._session_webhooks.pop(next(iter(self._session_webhooks)))
|
||||
except StopIteration:
|
||||
pass
|
||||
self._session_webhooks[chat_id] = session_webhook
|
||||
|
||||
source = self.build_source(
|
||||
|
||||
+202
-34
@@ -49,6 +49,7 @@ from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
ProcessingOutcome,
|
||||
SendResult,
|
||||
cache_image_from_url,
|
||||
cache_audio_from_url,
|
||||
@@ -422,6 +423,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
|
||||
# Discord message limits
|
||||
MAX_MESSAGE_LENGTH = 2000
|
||||
_SPLIT_THRESHOLD = 1900 # near the 2000-char split point
|
||||
|
||||
# Auto-disconnect from voice channel after this many seconds of inactivity
|
||||
VOICE_TIMEOUT = 300
|
||||
@@ -433,6 +435,11 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
self._allowed_user_ids: set = set() # For button approval authorization
|
||||
# Voice channel state (per-guild)
|
||||
self._voice_clients: Dict[int, Any] = {} # guild_id -> VoiceClient
|
||||
# Text batching: merge rapid successive messages (Telegram-style)
|
||||
self._text_batch_delay_seconds = float(os.getenv("HERMES_DISCORD_TEXT_BATCH_DELAY_SECONDS", "0.6"))
|
||||
self._text_batch_split_delay_seconds = float(os.getenv("HERMES_DISCORD_TEXT_BATCH_SPLIT_DELAY_SECONDS", "2.0"))
|
||||
self._pending_text_batches: Dict[str, MessageEvent] = {}
|
||||
self._pending_text_batch_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._voice_text_channels: Dict[int, int] = {} # guild_id -> text_channel_id
|
||||
self._voice_timeout_tasks: Dict[int, asyncio.Task] = {} # guild_id -> timeout task
|
||||
# Phase 2: voice listening
|
||||
@@ -529,10 +536,17 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
intents.members = any(not entry.isdigit() for entry in self._allowed_user_ids)
|
||||
intents.voice_states = True
|
||||
|
||||
# Create bot
|
||||
# Resolve proxy (DISCORD_PROXY > generic env vars > macOS system proxy)
|
||||
from gateway.platforms.base import resolve_proxy_url, proxy_kwargs_for_bot
|
||||
proxy_url = resolve_proxy_url(platform_env_var="DISCORD_PROXY")
|
||||
if proxy_url:
|
||||
logger.info("[%s] Using proxy for Discord: %s", self.name, proxy_url)
|
||||
|
||||
# Create bot — proxy= for HTTP, connector= for SOCKS
|
||||
self._client = commands.Bot(
|
||||
command_prefix="!", # Not really used, we handle raw messages
|
||||
intents=intents,
|
||||
**proxy_kwargs_for_bot(proxy_url),
|
||||
)
|
||||
adapter_self = self # capture for closure
|
||||
|
||||
@@ -592,22 +606,35 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if not self._client.user or self._client.user not in message.mentions:
|
||||
return
|
||||
# "all" falls through to handle_message
|
||||
|
||||
# If the message @mentions other users but NOT the bot, the
|
||||
# sender is talking to someone else — stay silent. Only
|
||||
# applies in server channels; in DMs the user is always
|
||||
# talking to the bot (mentions are just references).
|
||||
# Controlled by DISCORD_IGNORE_NO_MENTION (default: true).
|
||||
_ignore_no_mention = os.getenv(
|
||||
"DISCORD_IGNORE_NO_MENTION", "true"
|
||||
).lower() in ("true", "1", "yes")
|
||||
if _ignore_no_mention and message.mentions and not isinstance(message.channel, discord.DMChannel):
|
||||
_bot_mentioned = (
|
||||
|
||||
# Multi-agent filtering: if the message mentions specific bots
|
||||
# but NOT this bot, the sender is talking to another agent —
|
||||
# stay silent. Messages with no bot mentions (general chat)
|
||||
# still fall through to _handle_message for the existing
|
||||
# DISCORD_REQUIRE_MENTION check.
|
||||
#
|
||||
# This replaces the older DISCORD_IGNORE_NO_MENTION logic
|
||||
# with bot-aware filtering that works correctly when multiple
|
||||
# agents share a channel.
|
||||
if not isinstance(message.channel, discord.DMChannel) and message.mentions:
|
||||
_self_mentioned = (
|
||||
self._client.user is not None
|
||||
and self._client.user in message.mentions
|
||||
)
|
||||
if not _bot_mentioned:
|
||||
return # Talking to someone else, don't interrupt
|
||||
_other_bots_mentioned = any(
|
||||
m.bot and m != self._client.user
|
||||
for m in message.mentions
|
||||
)
|
||||
# If other bots are mentioned but we're not → not for us
|
||||
if _other_bots_mentioned and not _self_mentioned:
|
||||
return
|
||||
# If humans are mentioned but we're not → not for us
|
||||
# (preserves old DISCORD_IGNORE_NO_MENTION=true behavior)
|
||||
_ignore_no_mention = os.getenv(
|
||||
"DISCORD_IGNORE_NO_MENTION", "true"
|
||||
).lower() in ("true", "1", "yes")
|
||||
if _ignore_no_mention and not _self_mentioned and not _other_bots_mentioned:
|
||||
return
|
||||
|
||||
await self._handle_message(message)
|
||||
|
||||
@@ -741,14 +768,17 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if hasattr(message, "add_reaction"):
|
||||
await self._add_reaction(message, "👀")
|
||||
|
||||
async def on_processing_complete(self, event: MessageEvent, success: bool) -> None:
|
||||
async def on_processing_complete(self, event: MessageEvent, outcome: ProcessingOutcome) -> None:
|
||||
"""Swap the in-progress reaction for a final success/failure reaction."""
|
||||
if not self._reactions_enabled():
|
||||
return
|
||||
message = event.raw_message
|
||||
if hasattr(message, "add_reaction"):
|
||||
await self._remove_reaction(message, "👀")
|
||||
await self._add_reaction(message, "✅" if success else "❌")
|
||||
if outcome == ProcessingOutcome.SUCCESS:
|
||||
await self._add_reaction(message, "✅")
|
||||
elif outcome == ProcessingOutcome.FAILURE:
|
||||
await self._add_reaction(message, "❌")
|
||||
|
||||
async def send(
|
||||
self,
|
||||
@@ -757,18 +787,34 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> SendResult:
|
||||
"""Send a message to a Discord channel."""
|
||||
"""Send a message to a Discord channel or thread.
|
||||
|
||||
When metadata contains a thread_id, the message is sent to that
|
||||
thread instead of the parent channel identified by chat_id.
|
||||
"""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
# Get the channel
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
# Determine target channel: thread_id in metadata takes precedence.
|
||||
thread_id = None
|
||||
if metadata and metadata.get("thread_id"):
|
||||
thread_id = metadata["thread_id"]
|
||||
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
if thread_id:
|
||||
# Fetch the thread directly — threads are addressed by their own ID.
|
||||
channel = self._client.get_channel(int(thread_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(thread_id))
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Thread {thread_id} not found")
|
||||
else:
|
||||
# Get the parent channel
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
# Format and split message if needed
|
||||
formatted = self.format_message(content)
|
||||
@@ -1231,9 +1277,8 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
try:
|
||||
await asyncio.to_thread(VoiceReceiver.pcm_to_wav, pcm_data, wav_path)
|
||||
|
||||
from tools.transcription_tools import transcribe_audio, get_stt_model_from_config
|
||||
stt_model = get_stt_model_from_config()
|
||||
result = await asyncio.to_thread(transcribe_audio, wav_path, model=stt_model)
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
result = await asyncio.to_thread(transcribe_audio, wav_path)
|
||||
|
||||
if not result.get("success"):
|
||||
return
|
||||
@@ -1307,8 +1352,11 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
|
||||
# Download the image and send as a Discord file attachment
|
||||
# (Discord renders attachments inline, unlike plain URLs)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image_url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||
from gateway.platforms.base import resolve_proxy_url, proxy_kwargs_for_aiohttp
|
||||
_proxy = resolve_proxy_url(platform_env_var="DISCORD_PROXY")
|
||||
_sess_kw, _req_kw = proxy_kwargs_for_aiohttp(_proxy)
|
||||
async with aiohttp.ClientSession(**_sess_kw) as session:
|
||||
async with session.get(image_url, timeout=aiohttp.ClientTimeout(total=30), **_req_kw) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"Failed to download image: HTTP {resp.status}")
|
||||
|
||||
@@ -1585,7 +1633,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
await self._run_simple_slash(interaction, f"/model {name}".strip())
|
||||
|
||||
@tree.command(name="reasoning", description="Show or change reasoning effort")
|
||||
@discord.app_commands.describe(effort="Reasoning effort: xhigh, high, medium, low, minimal, or none.")
|
||||
@discord.app_commands.describe(effort="Reasoning effort: none, minimal, low, medium, high, or xhigh.")
|
||||
async def slash_reasoning(interaction: discord.Interaction, effort: str = ""):
|
||||
await self._run_simple_slash(interaction, f"/reasoning {effort}".strip())
|
||||
|
||||
@@ -1857,14 +1905,42 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
chat_topic=chat_topic,
|
||||
)
|
||||
|
||||
_parent_id = str(getattr(getattr(interaction, "channel", None), "parent_id", "") or "")
|
||||
_skills = self._resolve_channel_skills(thread_id, _parent_id or None)
|
||||
event = MessageEvent(
|
||||
text=text,
|
||||
message_type=MessageType.TEXT,
|
||||
source=source,
|
||||
raw_message=interaction,
|
||||
auto_skill=_skills,
|
||||
)
|
||||
await self.handle_message(event)
|
||||
|
||||
def _resolve_channel_skills(self, channel_id: str, parent_id: str | None = None) -> list[str] | None:
|
||||
"""Look up auto-skill bindings for a Discord channel/forum thread.
|
||||
|
||||
Config format (in platform extra):
|
||||
channel_skill_bindings:
|
||||
- id: "123456"
|
||||
skills: ["skill-a", "skill-b"]
|
||||
Also checks parent_id so forum threads inherit the forum's bindings.
|
||||
"""
|
||||
bindings = self.config.extra.get("channel_skill_bindings", [])
|
||||
if not bindings:
|
||||
return None
|
||||
ids_to_check = {channel_id}
|
||||
if parent_id:
|
||||
ids_to_check.add(parent_id)
|
||||
for entry in bindings:
|
||||
entry_id = str(entry.get("id", ""))
|
||||
if entry_id in ids_to_check:
|
||||
skills = entry.get("skills") or entry.get("skill")
|
||||
if isinstance(skills, str):
|
||||
return [skills]
|
||||
if isinstance(skills, list) and skills:
|
||||
return list(dict.fromkeys(skills)) # dedup, preserve order
|
||||
return None
|
||||
|
||||
def _thread_parent_channel(self, channel: Any) -> Any:
|
||||
"""Return the parent text channel when invoked from a thread."""
|
||||
return getattr(channel, "parent", None) or channel
|
||||
@@ -2218,6 +2294,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
# discord.require_mention: Require @mention in server channels (default: true)
|
||||
# discord.free_response_channels: Channel IDs where bot responds without mention
|
||||
# discord.ignored_channels: Channel IDs where bot NEVER responds (even when mentioned)
|
||||
# discord.allowed_channels: If set, bot ONLY responds in these channels (whitelist)
|
||||
# discord.no_thread_channels: Channel IDs where bot responds directly without creating thread
|
||||
# discord.auto_thread: Auto-create thread on @mention in channels (default: true)
|
||||
|
||||
@@ -2229,12 +2306,21 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
parent_channel_id = self._get_parent_channel_id(message.channel)
|
||||
|
||||
if not isinstance(message.channel, discord.DMChannel):
|
||||
# Check ignored channels first - never respond even when mentioned
|
||||
ignored_channels_raw = os.getenv("DISCORD_IGNORED_CHANNELS", "")
|
||||
ignored_channels = {ch.strip() for ch in ignored_channels_raw.split(",") if ch.strip()}
|
||||
channel_ids = {str(message.channel.id)}
|
||||
if parent_channel_id:
|
||||
channel_ids.add(parent_channel_id)
|
||||
|
||||
# Check allowed channels - if set, only respond in these channels
|
||||
allowed_channels_raw = os.getenv("DISCORD_ALLOWED_CHANNELS", "")
|
||||
if allowed_channels_raw:
|
||||
allowed_channels = {ch.strip() for ch in allowed_channels_raw.split(",") if ch.strip()}
|
||||
if not (channel_ids & allowed_channels):
|
||||
logger.debug("[%s] Ignoring message in non-allowed channel: %s", self.name, channel_ids)
|
||||
return
|
||||
|
||||
# Check ignored channels - never respond even when mentioned
|
||||
ignored_channels_raw = os.getenv("DISCORD_IGNORED_CHANNELS", "")
|
||||
ignored_channels = {ch.strip() for ch in ignored_channels_raw.split(",") if ch.strip()}
|
||||
if channel_ids & ignored_channels:
|
||||
logger.debug("[%s] Ignoring message in ignored channel: %s", self.name, channel_ids)
|
||||
return
|
||||
@@ -2391,10 +2477,14 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
else:
|
||||
try:
|
||||
import aiohttp
|
||||
async with aiohttp.ClientSession() as session:
|
||||
from gateway.platforms.base import resolve_proxy_url, proxy_kwargs_for_aiohttp
|
||||
_proxy = resolve_proxy_url(platform_env_var="DISCORD_PROXY")
|
||||
_sess_kw, _req_kw = proxy_kwargs_for_aiohttp(_proxy)
|
||||
async with aiohttp.ClientSession(**_sess_kw) as session:
|
||||
async with session.get(
|
||||
att.url,
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
**_req_kw,
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"HTTP {resp.status}")
|
||||
@@ -2435,6 +2525,10 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if not event_text or not event_text.strip():
|
||||
event_text = "(The user sent a message with no text content)"
|
||||
|
||||
_chan = message.channel
|
||||
_parent_id = str(getattr(_chan, "parent_id", "") or "")
|
||||
_chan_id = str(getattr(_chan, "id", ""))
|
||||
_skills = self._resolve_channel_skills(_chan_id, _parent_id or None)
|
||||
event = MessageEvent(
|
||||
text=event_text,
|
||||
message_type=msg_type,
|
||||
@@ -2445,6 +2539,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
media_types=media_types,
|
||||
reply_to_message_id=str(message.reference.message_id) if message.reference else None,
|
||||
timestamp=message.created_at,
|
||||
auto_skill=_skills,
|
||||
)
|
||||
|
||||
# Track thread participation so the bot won't require @mention for
|
||||
@@ -2452,7 +2547,80 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if thread_id:
|
||||
self._track_thread(thread_id)
|
||||
|
||||
await self.handle_message(event)
|
||||
# Only batch plain text messages — commands, media, etc. dispatch
|
||||
# immediately since they won't be split by the Discord client.
|
||||
if msg_type == MessageType.TEXT and self._text_batch_delay_seconds > 0:
|
||||
self._enqueue_text_event(event)
|
||||
else:
|
||||
await self.handle_message(event)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Text message aggregation (handles Discord client-side splits)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _text_batch_key(self, event: MessageEvent) -> str:
|
||||
"""Session-scoped key for text message batching."""
|
||||
from gateway.session import build_session_key
|
||||
return build_session_key(
|
||||
event.source,
|
||||
group_sessions_per_user=self.config.extra.get("group_sessions_per_user", True),
|
||||
thread_sessions_per_user=self.config.extra.get("thread_sessions_per_user", False),
|
||||
)
|
||||
|
||||
def _enqueue_text_event(self, event: MessageEvent) -> None:
|
||||
"""Buffer a text event and reset the flush timer.
|
||||
|
||||
When Discord splits a long user message at 2000 chars, the chunks
|
||||
arrive within a few hundred milliseconds. This merges them into
|
||||
a single event before dispatching.
|
||||
"""
|
||||
key = self._text_batch_key(event)
|
||||
existing = self._pending_text_batches.get(key)
|
||||
chunk_len = len(event.text or "")
|
||||
if existing is None:
|
||||
event._last_chunk_len = chunk_len # type: ignore[attr-defined]
|
||||
self._pending_text_batches[key] = event
|
||||
else:
|
||||
if event.text:
|
||||
existing.text = f"{existing.text}\n{event.text}" if existing.text else event.text
|
||||
existing._last_chunk_len = chunk_len # type: ignore[attr-defined]
|
||||
if event.media_urls:
|
||||
existing.media_urls.extend(event.media_urls)
|
||||
existing.media_types.extend(event.media_types)
|
||||
|
||||
prior_task = self._pending_text_batch_tasks.get(key)
|
||||
if prior_task and not prior_task.done():
|
||||
prior_task.cancel()
|
||||
self._pending_text_batch_tasks[key] = asyncio.create_task(
|
||||
self._flush_text_batch(key)
|
||||
)
|
||||
|
||||
async def _flush_text_batch(self, key: str) -> None:
|
||||
"""Wait for the quiet period then dispatch the aggregated text.
|
||||
|
||||
Uses a longer delay when the latest chunk is near Discord's 2000-char
|
||||
split point, since a continuation chunk is almost certain.
|
||||
"""
|
||||
current_task = asyncio.current_task()
|
||||
try:
|
||||
pending = self._pending_text_batches.get(key)
|
||||
last_len = getattr(pending, "_last_chunk_len", 0) if pending else 0
|
||||
if last_len >= self._SPLIT_THRESHOLD:
|
||||
delay = self._text_batch_split_delay_seconds
|
||||
else:
|
||||
delay = self._text_batch_delay_seconds
|
||||
await asyncio.sleep(delay)
|
||||
event = self._pending_text_batches.pop(key, None)
|
||||
if not event:
|
||||
return
|
||||
logger.info(
|
||||
"[Discord] Flushing text batch %s (%d chars)",
|
||||
key, len(event.text or ""),
|
||||
)
|
||||
await self.handle_message(event)
|
||||
finally:
|
||||
if self._pending_text_batch_tasks.get(key) is current_task:
|
||||
self._pending_text_batch_tasks.pop(key, None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -195,7 +195,11 @@ def _extract_attachments(
|
||||
|
||||
ext = Path(filename).suffix.lower()
|
||||
if ext in _IMAGE_EXTS:
|
||||
cached_path = cache_image_from_bytes(payload, ext)
|
||||
try:
|
||||
cached_path = cache_image_from_bytes(payload, ext)
|
||||
except ValueError:
|
||||
logger.debug("Skipping non-image attachment %s (invalid magic bytes)", filename)
|
||||
continue
|
||||
attachments.append({
|
||||
"path": cached_path,
|
||||
"filename": filename,
|
||||
|
||||
+40
-10
@@ -264,6 +264,7 @@ class FeishuAdapterSettings:
|
||||
bot_name: str
|
||||
dedup_cache_size: int
|
||||
text_batch_delay_seconds: float
|
||||
text_batch_split_delay_seconds: float
|
||||
text_batch_max_messages: int
|
||||
text_batch_max_chars: int
|
||||
media_batch_delay_seconds: float
|
||||
@@ -972,7 +973,8 @@ def _run_official_feishu_ws_client(ws_client: Any, adapter: Any) -> None:
|
||||
return await original_connect(*args, **kwargs)
|
||||
|
||||
def _configure_with_overrides(conf: Any) -> Any:
|
||||
assert original_configure is not None
|
||||
if original_configure is None:
|
||||
raise RuntimeError("Feishu _configure_with_overrides called but original_configure is None")
|
||||
result = original_configure(conf)
|
||||
_apply_runtime_ws_overrides()
|
||||
return result
|
||||
@@ -1014,6 +1016,10 @@ class FeishuAdapter(BasePlatformAdapter):
|
||||
"""Feishu/Lark bot adapter."""
|
||||
|
||||
MAX_MESSAGE_LENGTH = 8000
|
||||
# Threshold for detecting Feishu client-side message splits.
|
||||
# When a chunk is near the ~4096-char practical limit, a continuation
|
||||
# is almost certain.
|
||||
_SPLIT_THRESHOLD = 4000
|
||||
|
||||
# =========================================================================
|
||||
# Lifecycle — init / settings / connect / disconnect
|
||||
@@ -1105,6 +1111,9 @@ class FeishuAdapter(BasePlatformAdapter):
|
||||
text_batch_delay_seconds=float(
|
||||
os.getenv("HERMES_FEISHU_TEXT_BATCH_DELAY_SECONDS", str(_DEFAULT_TEXT_BATCH_DELAY_SECONDS))
|
||||
),
|
||||
text_batch_split_delay_seconds=float(
|
||||
os.getenv("HERMES_FEISHU_TEXT_BATCH_SPLIT_DELAY_SECONDS", "2.0")
|
||||
),
|
||||
text_batch_max_messages=max(
|
||||
1,
|
||||
int(os.getenv("HERMES_FEISHU_TEXT_BATCH_MAX_MESSAGES", str(_DEFAULT_TEXT_BATCH_MAX_MESSAGES))),
|
||||
@@ -1152,6 +1161,7 @@ class FeishuAdapter(BasePlatformAdapter):
|
||||
self._bot_name = settings.bot_name
|
||||
self._dedup_cache_size = settings.dedup_cache_size
|
||||
self._text_batch_delay_seconds = settings.text_batch_delay_seconds
|
||||
self._text_batch_split_delay_seconds = settings.text_batch_split_delay_seconds
|
||||
self._text_batch_max_messages = settings.text_batch_max_messages
|
||||
self._text_batch_max_chars = settings.text_batch_max_chars
|
||||
self._media_batch_delay_seconds = settings.media_batch_delay_seconds
|
||||
@@ -1570,13 +1580,18 @@ class FeishuAdapter(BasePlatformAdapter):
|
||||
return SendResult(success=False, error=f"Image file not found: {image_path}")
|
||||
|
||||
try:
|
||||
with open(image_path, "rb") as image_file:
|
||||
body = self._build_image_upload_body(
|
||||
image_type=_FEISHU_IMAGE_UPLOAD_TYPE,
|
||||
image=image_file,
|
||||
)
|
||||
request = self._build_image_upload_request(body)
|
||||
upload_response = await asyncio.to_thread(self._client.im.v1.image.create, request)
|
||||
import io as _io
|
||||
with open(image_path, "rb") as f:
|
||||
image_bytes = f.read()
|
||||
# Wrap in BytesIO so lark SDK's MultipartEncoder can read .name and .tell()
|
||||
image_file = _io.BytesIO(image_bytes)
|
||||
image_file.name = os.path.basename(image_path)
|
||||
body = self._build_image_upload_body(
|
||||
image_type=_FEISHU_IMAGE_UPLOAD_TYPE,
|
||||
image=image_file,
|
||||
)
|
||||
request = self._build_image_upload_request(body)
|
||||
upload_response = await asyncio.to_thread(self._client.im.v1.image.create, request)
|
||||
image_key = self._extract_response_field(upload_response, "image_key")
|
||||
if not image_key:
|
||||
return self._response_error_result(
|
||||
@@ -2478,8 +2493,10 @@ class FeishuAdapter(BasePlatformAdapter):
|
||||
async def _enqueue_text_event(self, event: MessageEvent) -> None:
|
||||
"""Debounce rapid Feishu text bursts into a single MessageEvent."""
|
||||
key = self._text_batch_key(event)
|
||||
chunk_len = len(event.text or "")
|
||||
existing = self._pending_text_batches.get(key)
|
||||
if existing is None:
|
||||
event._last_chunk_len = chunk_len # type: ignore[attr-defined]
|
||||
self._pending_text_batches[key] = event
|
||||
self._pending_text_batch_counts[key] = 1
|
||||
self._schedule_text_batch_flush(key)
|
||||
@@ -2504,6 +2521,7 @@ class FeishuAdapter(BasePlatformAdapter):
|
||||
return
|
||||
|
||||
existing.text = next_text
|
||||
existing._last_chunk_len = chunk_len # type: ignore[attr-defined]
|
||||
existing.timestamp = event.timestamp
|
||||
if event.message_id:
|
||||
existing.message_id = event.message_id
|
||||
@@ -2530,10 +2548,22 @@ class FeishuAdapter(BasePlatformAdapter):
|
||||
task_map[key] = asyncio.create_task(flush_fn(key))
|
||||
|
||||
async def _flush_text_batch(self, key: str) -> None:
|
||||
"""Flush a pending text batch after the quiet period."""
|
||||
"""Flush a pending text batch after the quiet period.
|
||||
|
||||
Uses a longer delay when the latest chunk is near Feishu's ~4096-char
|
||||
split point, since a continuation chunk is almost certain.
|
||||
"""
|
||||
current_task = asyncio.current_task()
|
||||
try:
|
||||
await asyncio.sleep(self._text_batch_delay_seconds)
|
||||
# Adaptive delay: if the latest chunk is near the split threshold,
|
||||
# a continuation is almost certain — wait longer.
|
||||
pending = self._pending_text_batches.get(key)
|
||||
last_len = getattr(pending, "_last_chunk_len", 0) if pending else 0
|
||||
if last_len >= self._SPLIT_THRESHOLD:
|
||||
delay = self._text_batch_split_delay_seconds
|
||||
else:
|
||||
delay = self._text_batch_delay_seconds
|
||||
await asyncio.sleep(delay)
|
||||
await self._flush_text_batch_now(key)
|
||||
finally:
|
||||
if self._pending_text_batch_tasks.get(key) is current_task:
|
||||
|
||||
+114
-13
@@ -40,6 +40,7 @@ from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
ProcessingOutcome,
|
||||
SendResult,
|
||||
)
|
||||
|
||||
@@ -120,6 +121,11 @@ def check_matrix_requirements() -> bool:
|
||||
class MatrixAdapter(BasePlatformAdapter):
|
||||
"""Gateway adapter for Matrix (any homeserver)."""
|
||||
|
||||
# Threshold for detecting Matrix client-side message splits.
|
||||
# When a chunk is near the ~4000-char practical limit, a continuation
|
||||
# is almost certain.
|
||||
_SPLIT_THRESHOLD = 3900
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.MATRIX)
|
||||
|
||||
@@ -171,6 +177,16 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
self._reactions_enabled: bool = os.getenv(
|
||||
"MATRIX_REACTIONS", "true"
|
||||
).lower() not in ("false", "0", "no")
|
||||
# Tracks the reaction event_id for in-progress (eyes) reactions.
|
||||
# Key: (room_id, message_event_id) → reaction_event_id (for the eyes reaction).
|
||||
self._pending_reactions: dict[tuple[str, str], str] = {}
|
||||
|
||||
# Text batching: merge rapid successive messages (Telegram-style).
|
||||
# Matrix clients split long messages around 4000 chars.
|
||||
self._text_batch_delay_seconds = float(os.getenv("HERMES_MATRIX_TEXT_BATCH_DELAY_SECONDS", "0.6"))
|
||||
self._text_batch_split_delay_seconds = float(os.getenv("HERMES_MATRIX_TEXT_BATCH_SPLIT_DELAY_SECONDS", "2.0"))
|
||||
self._pending_text_batches: Dict[str, MessageEvent] = {}
|
||||
self._pending_text_batch_tasks: Dict[str, asyncio.Task] = {}
|
||||
|
||||
def _is_duplicate_event(self, event_id) -> bool:
|
||||
"""Return True if this event was already processed. Tracks the ID otherwise."""
|
||||
@@ -1088,7 +1104,81 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
# Acknowledge receipt so the room shows as read (fire-and-forget).
|
||||
self._background_read_receipt(room.room_id, event.event_id)
|
||||
|
||||
await self.handle_message(msg_event)
|
||||
# Only batch plain text messages — commands dispatch immediately.
|
||||
if msg_type == MessageType.TEXT and self._text_batch_delay_seconds > 0:
|
||||
self._enqueue_text_event(msg_event)
|
||||
else:
|
||||
await self.handle_message(msg_event)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Text message aggregation (handles Matrix client-side splits)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _text_batch_key(self, event: MessageEvent) -> str:
|
||||
"""Session-scoped key for text message batching."""
|
||||
from gateway.session import build_session_key
|
||||
return build_session_key(
|
||||
event.source,
|
||||
group_sessions_per_user=self.config.extra.get("group_sessions_per_user", True),
|
||||
thread_sessions_per_user=self.config.extra.get("thread_sessions_per_user", False),
|
||||
)
|
||||
|
||||
def _enqueue_text_event(self, event: MessageEvent) -> None:
|
||||
"""Buffer a text event and reset the flush timer.
|
||||
|
||||
When a Matrix client splits a long message, the chunks arrive within
|
||||
a few hundred milliseconds. This merges them into a single event
|
||||
before dispatching.
|
||||
"""
|
||||
key = self._text_batch_key(event)
|
||||
existing = self._pending_text_batches.get(key)
|
||||
chunk_len = len(event.text or "")
|
||||
if existing is None:
|
||||
event._last_chunk_len = chunk_len # type: ignore[attr-defined]
|
||||
self._pending_text_batches[key] = event
|
||||
else:
|
||||
if event.text:
|
||||
existing.text = f"{existing.text}\n{event.text}" if existing.text else event.text
|
||||
existing._last_chunk_len = chunk_len # type: ignore[attr-defined]
|
||||
# Merge any media that might be attached
|
||||
if event.media_urls:
|
||||
existing.media_urls.extend(event.media_urls)
|
||||
existing.media_types.extend(event.media_types)
|
||||
|
||||
# Cancel any pending flush and restart the timer
|
||||
prior_task = self._pending_text_batch_tasks.get(key)
|
||||
if prior_task and not prior_task.done():
|
||||
prior_task.cancel()
|
||||
self._pending_text_batch_tasks[key] = asyncio.create_task(
|
||||
self._flush_text_batch(key)
|
||||
)
|
||||
|
||||
async def _flush_text_batch(self, key: str) -> None:
|
||||
"""Wait for the quiet period then dispatch the aggregated text.
|
||||
|
||||
Uses a longer delay when the latest chunk is near Matrix's ~4000-char
|
||||
split point, since a continuation chunk is almost certain.
|
||||
"""
|
||||
current_task = asyncio.current_task()
|
||||
try:
|
||||
pending = self._pending_text_batches.get(key)
|
||||
last_len = getattr(pending, "_last_chunk_len", 0) if pending else 0
|
||||
if last_len >= self._SPLIT_THRESHOLD:
|
||||
delay = self._text_batch_split_delay_seconds
|
||||
else:
|
||||
delay = self._text_batch_delay_seconds
|
||||
await asyncio.sleep(delay)
|
||||
event = self._pending_text_batches.pop(key, None)
|
||||
if not event:
|
||||
return
|
||||
logger.info(
|
||||
"[Matrix] Flushing text batch %s (%d chars)",
|
||||
key, len(event.text or ""),
|
||||
)
|
||||
await self.handle_message(event)
|
||||
finally:
|
||||
if self._pending_text_batch_tasks.get(key) is current_task:
|
||||
self._pending_text_batch_tasks.pop(key, None)
|
||||
|
||||
async def _on_room_message_media(self, room: Any, event: Any) -> None:
|
||||
"""Handle incoming media messages (images, audio, video, files)."""
|
||||
@@ -1350,12 +1440,14 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
|
||||
async def _send_reaction(
|
||||
self, room_id: str, event_id: str, emoji: str,
|
||||
) -> bool:
|
||||
"""Send an emoji reaction to a message in a room."""
|
||||
) -> Optional[str]:
|
||||
"""Send an emoji reaction to a message in a room.
|
||||
Returns the reaction event_id on success, None on failure.
|
||||
"""
|
||||
import nio
|
||||
|
||||
if not self._client:
|
||||
return False
|
||||
return None
|
||||
content = {
|
||||
"m.relates_to": {
|
||||
"rel_type": "m.annotation",
|
||||
@@ -1370,12 +1462,12 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
)
|
||||
if isinstance(resp, nio.RoomSendResponse):
|
||||
logger.debug("Matrix: sent reaction %s to %s", emoji, event_id)
|
||||
return True
|
||||
return resp.event_id
|
||||
logger.debug("Matrix: reaction send failed: %s", resp)
|
||||
return False
|
||||
return None
|
||||
except Exception as exc:
|
||||
logger.debug("Matrix: reaction send error: %s", exc)
|
||||
return False
|
||||
return None
|
||||
|
||||
async def _redact_reaction(
|
||||
self, room_id: str, reaction_event_id: str, reason: str = "",
|
||||
@@ -1390,10 +1482,12 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
msg_id = event.message_id
|
||||
room_id = event.source.chat_id
|
||||
if msg_id and room_id:
|
||||
await self._send_reaction(room_id, msg_id, "\U0001f440")
|
||||
reaction_event_id = await self._send_reaction(room_id, msg_id, "\U0001f440")
|
||||
if reaction_event_id:
|
||||
self._pending_reactions[(room_id, msg_id)] = reaction_event_id
|
||||
|
||||
async def on_processing_complete(
|
||||
self, event: MessageEvent, success: bool,
|
||||
self, event: MessageEvent, outcome: ProcessingOutcome,
|
||||
) -> None:
|
||||
"""Replace eyes with checkmark (success) or cross (failure)."""
|
||||
if not self._reactions_enabled:
|
||||
@@ -1402,11 +1496,18 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
room_id = event.source.chat_id
|
||||
if not msg_id or not room_id:
|
||||
return
|
||||
# Note: Matrix doesn't support removing a specific reaction easily
|
||||
# without tracking the reaction event_id. We send the new reaction;
|
||||
# the eyes stays (acceptable UX — both are visible).
|
||||
if outcome == ProcessingOutcome.CANCELLED:
|
||||
return
|
||||
# Remove the eyes reaction first, if we tracked its event_id.
|
||||
reaction_key = (room_id, msg_id)
|
||||
if reaction_key in self._pending_reactions:
|
||||
eyes_event_id = self._pending_reactions.pop(reaction_key)
|
||||
if not await self._redact_reaction(room_id, eyes_event_id):
|
||||
logger.debug("Matrix: failed to redact eyes reaction %s", eyes_event_id)
|
||||
await self._send_reaction(
|
||||
room_id, msg_id, "\u2705" if success else "\u274c",
|
||||
room_id,
|
||||
msg_id,
|
||||
"\u2705" if outcome == ProcessingOutcome.SUCCESS else "\u274c",
|
||||
)
|
||||
|
||||
async def _on_reaction(self, room: Any, event: Any) -> None:
|
||||
|
||||
+245
-67
@@ -14,6 +14,7 @@ import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional, Any, Tuple
|
||||
|
||||
try:
|
||||
@@ -38,6 +39,7 @@ from gateway.platforms.base import (
|
||||
MessageType,
|
||||
SendResult,
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
safe_url_for_log,
|
||||
cache_document_from_bytes,
|
||||
)
|
||||
|
||||
@@ -45,6 +47,14 @@ from gateway.platforms.base import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ThreadContextCache:
|
||||
"""Cache entry for fetched thread context."""
|
||||
content: str
|
||||
fetched_at: float = field(default_factory=time.monotonic)
|
||||
message_count: int = 0
|
||||
|
||||
|
||||
def check_slack_requirements() -> bool:
|
||||
"""Check if Slack dependencies are available."""
|
||||
return SLACK_AVAILABLE
|
||||
@@ -101,6 +111,9 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
# session + memory scoping.
|
||||
self._assistant_threads: Dict[Tuple[str, str], Dict[str, str]] = {}
|
||||
self._ASSISTANT_THREADS_MAX = 5000
|
||||
# Cache for _fetch_thread_context results: cache_key → _ThreadContextCache
|
||||
self._thread_context_cache: Dict[str, _ThreadContextCache] = {}
|
||||
self._THREAD_CACHE_TTL = 60.0
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to Slack via Socket Mode."""
|
||||
@@ -281,6 +294,7 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
kwargs = {
|
||||
"channel": chat_id,
|
||||
"text": chunk,
|
||||
"mrkdwn": True,
|
||||
}
|
||||
if thread_ts:
|
||||
kwargs["thread_ts"] = thread_ts
|
||||
@@ -323,9 +337,7 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
if not self._app:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
try:
|
||||
# Convert standard markdown → Slack mrkdwn
|
||||
formatted = self.format_message(content)
|
||||
|
||||
await self._get_client(chat_id).chat_update(
|
||||
channel=chat_id,
|
||||
ts=message_id,
|
||||
@@ -457,13 +469,36 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
text = re.sub(r'(`[^`]+`)', lambda m: _ph(m.group(0)), text)
|
||||
|
||||
# 3) Convert markdown links [text](url) → <url|text>
|
||||
def _convert_markdown_link(m):
|
||||
label = m.group(1)
|
||||
url = m.group(2).strip()
|
||||
if url.startswith('<') and url.endswith('>'):
|
||||
url = url[1:-1].strip()
|
||||
return _ph(f'<{url}|{label}>')
|
||||
|
||||
text = re.sub(
|
||||
r'\[([^\]]+)\]\(([^)]+)\)',
|
||||
lambda m: _ph(f'<{m.group(2)}|{m.group(1)}>'),
|
||||
r'\[([^\]]+)\]\(([^()]*(?:\([^()]*\)[^()]*)*)\)',
|
||||
_convert_markdown_link,
|
||||
text,
|
||||
)
|
||||
|
||||
# 4) Convert headers (## Title) → *Title* (bold)
|
||||
# 4) Protect existing Slack entities/manual links so escaping and later
|
||||
# formatting passes don't break them.
|
||||
text = re.sub(
|
||||
r'(<(?:[@#!]|(?:https?|mailto|tel):)[^>\n]+>)',
|
||||
lambda m: _ph(m.group(1)),
|
||||
text,
|
||||
)
|
||||
|
||||
# 5) Protect blockquote markers before escaping
|
||||
text = re.sub(r'^(>+\s)', lambda m: _ph(m.group(0)), text, flags=re.MULTILINE)
|
||||
|
||||
# 6) Escape Slack control characters in remaining plain text.
|
||||
# Unescape first so already-escaped input doesn't get double-escaped.
|
||||
text = text.replace('&', '&').replace('<', '<').replace('>', '>')
|
||||
text = text.replace('&', '&').replace('<', '<').replace('>', '>')
|
||||
|
||||
# 7) Convert headers (## Title) → *Title* (bold)
|
||||
def _convert_header(m):
|
||||
inner = m.group(1).strip()
|
||||
# Strip redundant bold markers inside a header
|
||||
@@ -474,34 +509,39 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
r'^#{1,6}\s+(.+)$', _convert_header, text, flags=re.MULTILINE
|
||||
)
|
||||
|
||||
# 5) Convert bold: **text** → *text* (Slack bold)
|
||||
# 8) Convert bold+italic: ***text*** → *_text_* (Slack bold wrapping italic)
|
||||
text = re.sub(
|
||||
r'\*\*\*(.+?)\*\*\*',
|
||||
lambda m: _ph(f'*_{m.group(1)}_*'),
|
||||
text,
|
||||
)
|
||||
|
||||
# 9) Convert bold: **text** → *text* (Slack bold)
|
||||
text = re.sub(
|
||||
r'\*\*(.+?)\*\*',
|
||||
lambda m: _ph(f'*{m.group(1)}*'),
|
||||
text,
|
||||
)
|
||||
|
||||
# 6) Convert italic: _text_ stays as _text_ (already Slack italic)
|
||||
# Single *text* → _text_ (Slack italic)
|
||||
# 10) Convert italic: _text_ stays as _text_ (already Slack italic)
|
||||
# Single *text* → _text_ (Slack italic)
|
||||
text = re.sub(
|
||||
r'(?<!\*)\*([^*\n]+)\*(?!\*)',
|
||||
lambda m: _ph(f'_{m.group(1)}_'),
|
||||
text,
|
||||
)
|
||||
|
||||
# 7) Convert strikethrough: ~~text~~ → ~text~
|
||||
# 11) Convert strikethrough: ~~text~~ → ~text~
|
||||
text = re.sub(
|
||||
r'~~(.+?)~~',
|
||||
lambda m: _ph(f'~{m.group(1)}~'),
|
||||
text,
|
||||
)
|
||||
|
||||
# 8) Convert blockquotes: > text → > text (same syntax, just ensure
|
||||
# no extra escaping happens to the > character)
|
||||
# Slack uses the same > prefix, so this is a no-op for content.
|
||||
# 12) Blockquotes: > prefix is already protected by step 5 above.
|
||||
|
||||
# 9) Restore placeholders in reverse order
|
||||
for key in reversed(list(placeholders.keys())):
|
||||
# 13) Restore placeholders in reverse order
|
||||
for key in reversed(placeholders):
|
||||
text = text.replace(key, placeholders[key])
|
||||
|
||||
return text
|
||||
@@ -617,8 +657,19 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
try:
|
||||
import httpx
|
||||
|
||||
async def _ssrf_redirect_guard(response):
|
||||
"""Re-check redirect targets so public URLs cannot bounce into private IPs."""
|
||||
if response.is_redirect and response.next_request:
|
||||
redirect_url = str(response.next_request.url)
|
||||
if not is_safe_url(redirect_url):
|
||||
raise ValueError("Blocked redirect to private/internal address")
|
||||
|
||||
# Download the image first
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=30.0,
|
||||
follow_redirects=True,
|
||||
event_hooks={"response": [_ssrf_redirect_guard]},
|
||||
) as client:
|
||||
response = await client.get(image_url)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -635,7 +686,7 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
except Exception as e: # pragma: no cover - defensive logging
|
||||
logger.warning(
|
||||
"[Slack] Failed to upload image from URL %s, falling back to text: %s",
|
||||
image_url,
|
||||
safe_url_for_log(image_url),
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
@@ -914,9 +965,26 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
if v > cutoff
|
||||
}
|
||||
|
||||
# Ignore bot messages (including our own)
|
||||
# Bot message filtering (SLACK_ALLOW_BOTS / config allow_bots):
|
||||
# "none" — ignore all bot messages (default, backward-compatible)
|
||||
# "mentions" — accept bot messages only when they @mention us
|
||||
# "all" — accept all bot messages (except our own)
|
||||
if event.get("bot_id") or event.get("subtype") == "bot_message":
|
||||
return
|
||||
allow_bots = self.config.extra.get("allow_bots", "")
|
||||
if not allow_bots:
|
||||
allow_bots = os.getenv("SLACK_ALLOW_BOTS", "none")
|
||||
allow_bots = str(allow_bots).lower().strip()
|
||||
if allow_bots == "none":
|
||||
return
|
||||
elif allow_bots == "mentions":
|
||||
text_check = event.get("text", "")
|
||||
if self._bot_user_id and f"<@{self._bot_user_id}>" not in text_check:
|
||||
return
|
||||
# "all" falls through to process the message
|
||||
# Always ignore our own messages to prevent echo loops
|
||||
msg_user = event.get("user", "")
|
||||
if msg_user and self._bot_user_id and msg_user == self._bot_user_id:
|
||||
return
|
||||
|
||||
# Ignore message edits and deletions
|
||||
subtype = event.get("subtype")
|
||||
@@ -948,7 +1016,7 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
channel_type = event.get("channel_type", "")
|
||||
if not channel_type and channel_id.startswith("D"):
|
||||
channel_type = "im"
|
||||
is_dm = channel_type == "im"
|
||||
is_dm = channel_type in ("im", "mpim") # Both 1:1 and group DMs
|
||||
|
||||
# Build thread_ts for session keying.
|
||||
# In channels: fall back to ts so each top-level @mention starts a
|
||||
@@ -961,6 +1029,8 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
thread_ts = event.get("thread_ts") or ts # ts fallback for channels
|
||||
|
||||
# In channels, respond if:
|
||||
# 0. Channel is in free_response_channels, OR require_mention is
|
||||
# disabled — always process regardless of mention.
|
||||
# 1. The bot is @mentioned in this message, OR
|
||||
# 2. The message is a reply in a thread the bot started/participated in, OR
|
||||
# 3. The message is in a thread where the bot was previously @mentioned, OR
|
||||
@@ -970,24 +1040,29 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
event_thread_ts = event.get("thread_ts")
|
||||
is_thread_reply = bool(event_thread_ts and event_thread_ts != ts)
|
||||
|
||||
if not is_dm and bot_uid and not is_mentioned:
|
||||
reply_to_bot_thread = (
|
||||
is_thread_reply and event_thread_ts in self._bot_message_ts
|
||||
)
|
||||
in_mentioned_thread = (
|
||||
event_thread_ts is not None
|
||||
and event_thread_ts in self._mentioned_threads
|
||||
)
|
||||
has_session = (
|
||||
is_thread_reply
|
||||
and self._has_active_session_for_thread(
|
||||
channel_id=channel_id,
|
||||
thread_ts=event_thread_ts,
|
||||
user_id=user_id,
|
||||
if not is_dm and bot_uid:
|
||||
if channel_id in self._slack_free_response_channels():
|
||||
pass # Free-response channel — always process
|
||||
elif not self._slack_require_mention():
|
||||
pass # Mention requirement disabled globally for Slack
|
||||
elif not is_mentioned:
|
||||
reply_to_bot_thread = (
|
||||
is_thread_reply and event_thread_ts in self._bot_message_ts
|
||||
)
|
||||
)
|
||||
if not reply_to_bot_thread and not in_mentioned_thread and not has_session:
|
||||
return
|
||||
in_mentioned_thread = (
|
||||
event_thread_ts is not None
|
||||
and event_thread_ts in self._mentioned_threads
|
||||
)
|
||||
has_session = (
|
||||
is_thread_reply
|
||||
and self._has_active_session_for_thread(
|
||||
channel_id=channel_id,
|
||||
thread_ts=event_thread_ts,
|
||||
user_id=user_id,
|
||||
)
|
||||
)
|
||||
if not reply_to_bot_thread and not in_mentioned_thread and not has_session:
|
||||
return
|
||||
|
||||
if is_mentioned:
|
||||
# Strip the bot mention from the text
|
||||
@@ -1128,14 +1203,19 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
reply_to_message_id=thread_ts if thread_ts != ts else None,
|
||||
)
|
||||
|
||||
# Add 👀 reaction to acknowledge receipt
|
||||
await self._add_reaction(channel_id, ts, "eyes")
|
||||
# Only react when bot is directly addressed (DM or @mention).
|
||||
# In listen-all channels (require_mention=false), reacting to every
|
||||
# casual message would be noisy.
|
||||
_should_react = is_dm or is_mentioned
|
||||
|
||||
if _should_react:
|
||||
await self._add_reaction(channel_id, ts, "eyes")
|
||||
|
||||
await self.handle_message(msg_event)
|
||||
|
||||
# Replace 👀 with ✅ when done
|
||||
await self._remove_reaction(channel_id, ts, "eyes")
|
||||
await self._add_reaction(channel_id, ts, "white_check_mark")
|
||||
if _should_react:
|
||||
await self._remove_reaction(channel_id, ts, "eyes")
|
||||
await self._add_reaction(channel_id, ts, "white_check_mark")
|
||||
|
||||
# ----- Approval button support (Block Kit) -----
|
||||
|
||||
@@ -1229,6 +1309,20 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
msg_ts = message.get("ts", "")
|
||||
channel_id = body.get("channel", {}).get("id", "")
|
||||
user_name = body.get("user", {}).get("name", "unknown")
|
||||
user_id = body.get("user", {}).get("id", "")
|
||||
|
||||
# Only authorized users may click approval buttons. Button clicks
|
||||
# bypass the normal message auth flow in gateway/run.py, so we must
|
||||
# check here as well.
|
||||
allowed_csv = os.getenv("SLACK_ALLOWED_USERS", "").strip()
|
||||
if allowed_csv:
|
||||
allowed_ids = {uid.strip() for uid in allowed_csv.split(",") if uid.strip()}
|
||||
if "*" not in allowed_ids and user_id not in allowed_ids:
|
||||
logger.warning(
|
||||
"[Slack] Unauthorized approval click by %s (%s) — ignoring",
|
||||
user_name, user_id,
|
||||
)
|
||||
return
|
||||
|
||||
# Map action_id to approval choice
|
||||
choice_map = {
|
||||
@@ -1239,10 +1333,9 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
}
|
||||
choice = choice_map.get(action_id, "deny")
|
||||
|
||||
# Prevent double-clicks
|
||||
if self._approval_resolved.get(msg_ts, False):
|
||||
# Prevent double-clicks — atomic pop; first caller gets False, others get True (default)
|
||||
if self._approval_resolved.pop(msg_ts, True):
|
||||
return
|
||||
self._approval_resolved[msg_ts] = True
|
||||
|
||||
# Update the message to show the decision and remove buttons
|
||||
label_map = {
|
||||
@@ -1297,8 +1390,7 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
except Exception as exc:
|
||||
logger.error("Failed to resolve gateway approval from Slack button: %s", exc)
|
||||
|
||||
# Clean up stale approval state
|
||||
self._approval_resolved.pop(msg_ts, None)
|
||||
# (approval state already consumed by atomic pop above)
|
||||
|
||||
# ----- Thread context fetching -----
|
||||
|
||||
@@ -1309,57 +1401,104 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
"""Fetch recent thread messages to provide context when the bot is
|
||||
mentioned mid-thread for the first time.
|
||||
|
||||
Returns a formatted string with thread history, or empty string on
|
||||
failure or if the thread is empty (just the parent message).
|
||||
This method is only called when there is NO active session for the
|
||||
thread (guarded at the call site by _has_active_session_for_thread).
|
||||
That guard ensures thread messages are prepended only on the very
|
||||
first turn — after that the session history already holds them, so
|
||||
there is no duplication across subsequent turns.
|
||||
|
||||
Results are cached for _THREAD_CACHE_TTL seconds per thread to avoid
|
||||
hammering conversations.replies (Tier 3, ~50 req/min).
|
||||
|
||||
Returns a formatted string with prior thread history, or empty string
|
||||
on failure or if the thread has no prior messages.
|
||||
"""
|
||||
cache_key = f"{channel_id}:{thread_ts}"
|
||||
now = time.monotonic()
|
||||
cached = self._thread_context_cache.get(cache_key)
|
||||
if cached and (now - cached.fetched_at) < self._THREAD_CACHE_TTL:
|
||||
return cached.content
|
||||
|
||||
try:
|
||||
client = self._get_client(channel_id)
|
||||
result = await client.conversations_replies(
|
||||
channel=channel_id,
|
||||
ts=thread_ts,
|
||||
limit=limit + 1, # +1 because it includes the current message
|
||||
inclusive=True,
|
||||
)
|
||||
|
||||
# Retry with exponential backoff for Tier-3 rate limits (429).
|
||||
result = None
|
||||
for attempt in range(3):
|
||||
try:
|
||||
result = await client.conversations_replies(
|
||||
channel=channel_id,
|
||||
ts=thread_ts,
|
||||
limit=limit + 1, # +1 because it includes the current message
|
||||
inclusive=True,
|
||||
)
|
||||
break
|
||||
except Exception as exc:
|
||||
# Check for rate-limit error from slack_sdk
|
||||
err_str = str(exc).lower()
|
||||
is_rate_limit = (
|
||||
"ratelimited" in err_str
|
||||
or "429" in err_str
|
||||
or "rate_limited" in err_str
|
||||
)
|
||||
if is_rate_limit and attempt < 2:
|
||||
retry_after = 1.0 * (2 ** attempt) # 1s, 2s
|
||||
logger.warning(
|
||||
"[Slack] conversations.replies rate limited; retrying in %.1fs (attempt %d/3)",
|
||||
retry_after, attempt + 1,
|
||||
)
|
||||
await asyncio.sleep(retry_after)
|
||||
continue
|
||||
raise
|
||||
|
||||
if result is None:
|
||||
return ""
|
||||
|
||||
messages = result.get("messages", [])
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
bot_uid = self._team_bot_user_ids.get(team_id, self._bot_user_id)
|
||||
context_parts = []
|
||||
for msg in messages:
|
||||
msg_ts = msg.get("ts", "")
|
||||
# Skip the current message (the one that triggered this fetch)
|
||||
# Exclude the current triggering message — it will be delivered
|
||||
# as the user message itself, so including it here would duplicate it.
|
||||
if msg_ts == current_ts:
|
||||
continue
|
||||
# Skip bot messages from ourselves
|
||||
# Exclude our own bot messages to avoid circular context.
|
||||
if msg.get("bot_id") or msg.get("subtype") == "bot_message":
|
||||
continue
|
||||
|
||||
msg_user = msg.get("user", "unknown")
|
||||
msg_text = msg.get("text", "").strip()
|
||||
if not msg_text:
|
||||
continue
|
||||
|
||||
# Strip bot mentions from context messages
|
||||
bot_uid = self._team_bot_user_ids.get(team_id, self._bot_user_id)
|
||||
if bot_uid:
|
||||
msg_text = msg_text.replace(f"<@{bot_uid}>", "").strip()
|
||||
|
||||
# Mark the thread parent
|
||||
msg_user = msg.get("user", "unknown")
|
||||
is_parent = msg_ts == thread_ts
|
||||
prefix = "[thread parent] " if is_parent else ""
|
||||
|
||||
# Resolve user name (cached)
|
||||
name = await self._resolve_user_name(msg_user, chat_id=channel_id)
|
||||
context_parts.append(f"{prefix}{name}: {msg_text}")
|
||||
|
||||
if not context_parts:
|
||||
return ""
|
||||
content = ""
|
||||
if context_parts:
|
||||
content = (
|
||||
"[Thread context — prior messages in this thread (not yet in conversation history):]\n"
|
||||
+ "\n".join(context_parts)
|
||||
+ "\n[End of thread context]\n\n"
|
||||
)
|
||||
|
||||
return (
|
||||
"[Thread context — previous messages in this thread:]\n"
|
||||
+ "\n".join(context_parts)
|
||||
+ "\n[End of thread context]\n\n"
|
||||
self._thread_context_cache[cache_key] = _ThreadContextCache(
|
||||
content=content,
|
||||
fetched_at=now,
|
||||
message_count=len(context_parts),
|
||||
)
|
||||
return content
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("[Slack] Failed to fetch thread context: %s", e)
|
||||
return ""
|
||||
@@ -1469,6 +1608,18 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
# Slack may return an HTML sign-in/redirect page
|
||||
# instead of actual media bytes (e.g. expired token,
|
||||
# restricted file access). Detect this early so we
|
||||
# don't cache bogus data and confuse downstream tools.
|
||||
ct = response.headers.get("content-type", "")
|
||||
if "text/html" in ct:
|
||||
raise ValueError(
|
||||
"Slack returned HTML instead of media "
|
||||
f"(content-type: {ct}); "
|
||||
"check bot token scopes and file permissions"
|
||||
)
|
||||
|
||||
if audio:
|
||||
from gateway.platforms.base import cache_audio_from_bytes
|
||||
return cache_audio_from_bytes(response.content, ext)
|
||||
@@ -1515,3 +1666,30 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
continue
|
||||
raise
|
||||
raise last_exc
|
||||
|
||||
# ── Channel mention gating ─────────────────────────────────────────────
|
||||
|
||||
def _slack_require_mention(self) -> bool:
|
||||
"""Return whether channel messages require an explicit bot mention.
|
||||
|
||||
Uses explicit-false parsing (like Discord/Matrix) rather than
|
||||
truthy parsing, since the safe default is True (gating on).
|
||||
Unrecognised or empty values keep gating enabled.
|
||||
"""
|
||||
configured = self.config.extra.get("require_mention")
|
||||
if configured is not None:
|
||||
if isinstance(configured, str):
|
||||
return configured.lower() not in ("false", "0", "no", "off")
|
||||
return bool(configured)
|
||||
return os.getenv("SLACK_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no", "off")
|
||||
|
||||
def _slack_free_response_channels(self) -> set:
|
||||
"""Return channel IDs where no @mention is required."""
|
||||
raw = self.config.extra.get("free_response_channels")
|
||||
if raw is None:
|
||||
raw = os.getenv("SLACK_FREE_RESPONSE_CHANNELS", "")
|
||||
if isinstance(raw, list):
|
||||
return {str(part).strip() for part in raw if str(part).strip()}
|
||||
if isinstance(raw, str) and raw.strip():
|
||||
return {part.strip() for part in raw.split(",") if part.strip()}
|
||||
return set()
|
||||
|
||||
@@ -60,6 +60,7 @@ from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
ProcessingOutcome,
|
||||
SendResult,
|
||||
cache_image_from_bytes,
|
||||
cache_audio_from_bytes,
|
||||
@@ -121,6 +122,9 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
|
||||
# Telegram message limits
|
||||
MAX_MESSAGE_LENGTH = 4096
|
||||
# Threshold for detecting Telegram client-side message splits.
|
||||
# When a chunk is near this limit, a continuation is almost certain.
|
||||
_SPLIT_THRESHOLD = 4000
|
||||
MEDIA_GROUP_WAIT_SECONDS = 0.8
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
@@ -140,6 +144,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
# Buffer rapid text messages so Telegram client-side splits of long
|
||||
# messages are aggregated into a single MessageEvent.
|
||||
self._text_batch_delay_seconds = float(os.getenv("HERMES_TELEGRAM_TEXT_BATCH_DELAY_SECONDS", "0.6"))
|
||||
self._text_batch_split_delay_seconds = float(os.getenv("HERMES_TELEGRAM_TEXT_BATCH_SPLIT_DELAY_SECONDS", "2.0"))
|
||||
self._pending_text_batches: Dict[str, MessageEvent] = {}
|
||||
self._pending_text_batch_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._token_lock_identity: Optional[str] = None
|
||||
@@ -513,6 +518,45 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
|
||||
# Build the application
|
||||
builder = Application.builder().token(self.config.token)
|
||||
custom_base_url = self.config.extra.get("base_url")
|
||||
if custom_base_url:
|
||||
builder = builder.base_url(custom_base_url)
|
||||
builder = builder.base_file_url(
|
||||
self.config.extra.get("base_file_url", custom_base_url)
|
||||
)
|
||||
logger.info(
|
||||
"[%s] Using custom Telegram base_url: %s",
|
||||
self.name, custom_base_url,
|
||||
)
|
||||
|
||||
# PTB defaults (pool_timeout=1s) are too aggressive on flaky networks and
|
||||
# can trigger "Pool timeout: All connections in the connection pool are occupied"
|
||||
# during reconnect/bootstrap. Use safer defaults and allow env overrides.
|
||||
def _env_int(name: str, default: int) -> int:
|
||||
try:
|
||||
return int(os.getenv(name, str(default)))
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
def _env_float(name: str, default: float) -> float:
|
||||
try:
|
||||
return float(os.getenv(name, str(default)))
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
request_kwargs = {
|
||||
"connection_pool_size": _env_int("HERMES_TELEGRAM_HTTP_POOL_SIZE", 512),
|
||||
"pool_timeout": _env_float("HERMES_TELEGRAM_HTTP_POOL_TIMEOUT", 8.0),
|
||||
"connect_timeout": _env_float("HERMES_TELEGRAM_HTTP_CONNECT_TIMEOUT", 10.0),
|
||||
"read_timeout": _env_float("HERMES_TELEGRAM_HTTP_READ_TIMEOUT", 20.0),
|
||||
"write_timeout": _env_float("HERMES_TELEGRAM_HTTP_WRITE_TIMEOUT", 20.0),
|
||||
}
|
||||
|
||||
proxy_configured = any(
|
||||
(os.getenv(k) or "").strip()
|
||||
for k in ("HTTPS_PROXY", "HTTP_PROXY", "ALL_PROXY", "https_proxy", "http_proxy", "all_proxy")
|
||||
)
|
||||
disable_fallback = (os.getenv("HERMES_TELEGRAM_DISABLE_FALLBACK_IPS", "").strip().lower() in ("1", "true", "yes", "on"))
|
||||
fallback_ips = self._fallback_ips()
|
||||
if not fallback_ips:
|
||||
fallback_ips = await discover_fallback_ips()
|
||||
@@ -521,16 +565,32 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
self.name,
|
||||
", ".join(fallback_ips),
|
||||
)
|
||||
if fallback_ips:
|
||||
|
||||
if fallback_ips and not proxy_configured and not disable_fallback:
|
||||
logger.info(
|
||||
"[%s] Telegram fallback IPs active: %s",
|
||||
self.name,
|
||||
", ".join(fallback_ips),
|
||||
)
|
||||
transport = TelegramFallbackTransport(fallback_ips)
|
||||
request = HTTPXRequest(httpx_kwargs={"transport": transport})
|
||||
get_updates_request = HTTPXRequest(httpx_kwargs={"transport": transport})
|
||||
builder = builder.request(request).get_updates_request(get_updates_request)
|
||||
# Keep request/update pools separate to reduce contention during
|
||||
# polling reconnect + bot API bootstrap/delete_webhook calls.
|
||||
request = HTTPXRequest(
|
||||
**request_kwargs,
|
||||
httpx_kwargs={"transport": TelegramFallbackTransport(fallback_ips)},
|
||||
)
|
||||
get_updates_request = HTTPXRequest(
|
||||
**request_kwargs,
|
||||
httpx_kwargs={"transport": TelegramFallbackTransport(fallback_ips)},
|
||||
)
|
||||
else:
|
||||
if proxy_configured:
|
||||
logger.info("[%s] Proxy configured; skipping Telegram fallback-IP transport", self.name)
|
||||
elif disable_fallback:
|
||||
logger.info("[%s] Telegram fallback-IP transport disabled via env", self.name)
|
||||
request = HTTPXRequest(**request_kwargs)
|
||||
get_updates_request = HTTPXRequest(**request_kwargs)
|
||||
|
||||
builder = builder.request(request).get_updates_request(get_updates_request)
|
||||
self._app = builder.build()
|
||||
self._bot = self._app.bot
|
||||
|
||||
@@ -1398,6 +1458,15 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
await query.answer(text="Invalid approval data.")
|
||||
return
|
||||
|
||||
# Only authorized users may click approval buttons.
|
||||
caller_id = str(getattr(query.from_user, "id", ""))
|
||||
allowed_csv = os.getenv("TELEGRAM_ALLOWED_USERS", "").strip()
|
||||
if allowed_csv:
|
||||
allowed_ids = {uid.strip() for uid in allowed_csv.split(",") if uid.strip()}
|
||||
if "*" not in allowed_ids and caller_id not in allowed_ids:
|
||||
await query.answer(text="⛔ You are not authorized to approve commands.")
|
||||
return
|
||||
|
||||
session_key = self._approval_state.pop(approval_id, None)
|
||||
if not session_key:
|
||||
await query.answer(text="This approval has already been resolved.")
|
||||
@@ -2151,12 +2220,15 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
key = self._text_batch_key(event)
|
||||
existing = self._pending_text_batches.get(key)
|
||||
chunk_len = len(event.text or "")
|
||||
if existing is None:
|
||||
event._last_chunk_len = chunk_len # type: ignore[attr-defined]
|
||||
self._pending_text_batches[key] = event
|
||||
else:
|
||||
# Append text from the follow-up chunk
|
||||
if event.text:
|
||||
existing.text = f"{existing.text}\n{event.text}" if existing.text else event.text
|
||||
existing._last_chunk_len = chunk_len # type: ignore[attr-defined]
|
||||
# Merge any media that might be attached
|
||||
if event.media_urls:
|
||||
existing.media_urls.extend(event.media_urls)
|
||||
@@ -2171,10 +2243,22 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
)
|
||||
|
||||
async def _flush_text_batch(self, key: str) -> None:
|
||||
"""Wait for the quiet period then dispatch the aggregated text."""
|
||||
"""Wait for the quiet period then dispatch the aggregated text.
|
||||
|
||||
Uses a longer delay when the latest chunk is near Telegram's 4096-char
|
||||
split point, since a continuation chunk is almost certain.
|
||||
"""
|
||||
current_task = asyncio.current_task()
|
||||
try:
|
||||
await asyncio.sleep(self._text_batch_delay_seconds)
|
||||
# Adaptive delay: if the latest chunk is near Telegram's 4096-char
|
||||
# split point, a continuation is almost certain — wait longer.
|
||||
pending = self._pending_text_batches.get(key)
|
||||
last_len = getattr(pending, "_last_chunk_len", 0) if pending else 0
|
||||
if last_len >= self._SPLIT_THRESHOLD:
|
||||
delay = self._text_batch_split_delay_seconds
|
||||
else:
|
||||
delay = self._text_batch_delay_seconds
|
||||
await asyncio.sleep(delay)
|
||||
event = self._pending_text_batches.pop(key, None)
|
||||
if not event:
|
||||
return
|
||||
@@ -2704,7 +2788,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
if chat_id and message_id:
|
||||
await self._set_reaction(chat_id, message_id, "\U0001f440")
|
||||
|
||||
async def on_processing_complete(self, event: MessageEvent, success: bool) -> None:
|
||||
async def on_processing_complete(self, event: MessageEvent, outcome: ProcessingOutcome) -> None:
|
||||
"""Swap the in-progress reaction for a final success/failure reaction.
|
||||
|
||||
Unlike Discord (additive reactions), Telegram's set_message_reaction
|
||||
@@ -2714,5 +2798,9 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
return
|
||||
chat_id = getattr(event.source, "chat_id", None)
|
||||
message_id = getattr(event, "message_id", None)
|
||||
if chat_id and message_id:
|
||||
await self._set_reaction(chat_id, message_id, "\u2705" if success else "\u274c")
|
||||
if chat_id and message_id and outcome != ProcessingOutcome.CANCELLED:
|
||||
await self._set_reaction(
|
||||
chat_id,
|
||||
message_id,
|
||||
"\U0001f44d" if outcome == ProcessingOutcome.SUCCESS else "\U0001f44e",
|
||||
)
|
||||
|
||||
@@ -45,11 +45,9 @@ _SEED_FALLBACK_IPS: list[str] = ["149.154.167.220"]
|
||||
|
||||
|
||||
def _resolve_proxy_url() -> str | None:
|
||||
for key in ("HTTPS_PROXY", "HTTP_PROXY", "ALL_PROXY", "https_proxy", "http_proxy", "all_proxy"):
|
||||
value = (os.environ.get(key) or "").strip()
|
||||
if value:
|
||||
return value
|
||||
return None
|
||||
# Delegate to shared implementation (env vars + macOS system proxy detection)
|
||||
from gateway.platforms.base import resolve_proxy_url
|
||||
return resolve_proxy_url()
|
||||
|
||||
|
||||
class TelegramFallbackTransport(httpx.AsyncBaseTransport):
|
||||
@@ -112,7 +110,8 @@ class TelegramFallbackTransport(httpx.AsyncBaseTransport):
|
||||
logger.warning("[Telegram] Fallback IP %s failed: %s", ip, exc)
|
||||
continue
|
||||
|
||||
assert last_error is not None
|
||||
if last_error is None:
|
||||
raise RuntimeError("All Telegram fallback IPs exhausted but no error was recorded")
|
||||
raise last_error
|
||||
|
||||
async def aclose(self) -> None:
|
||||
|
||||
@@ -186,13 +186,23 @@ class WebhookAdapter(BasePlatformAdapter):
|
||||
if deliver_type == "github_comment":
|
||||
return await self._deliver_github_comment(content, delivery)
|
||||
|
||||
# Cross-platform delivery (telegram, discord, etc.)
|
||||
# Cross-platform delivery — any platform with a gateway adapter
|
||||
if self.gateway_runner and deliver_type in (
|
||||
"telegram",
|
||||
"discord",
|
||||
"slack",
|
||||
"signal",
|
||||
"sms",
|
||||
"whatsapp",
|
||||
"matrix",
|
||||
"mattermost",
|
||||
"homeassistant",
|
||||
"email",
|
||||
"dingtalk",
|
||||
"feishu",
|
||||
"wecom",
|
||||
"weixin",
|
||||
"bluebubbles",
|
||||
):
|
||||
return await self._deliver_cross_platform(
|
||||
deliver_type, content, delivery
|
||||
@@ -262,7 +272,7 @@ class WebhookAdapter(BasePlatformAdapter):
|
||||
", ".join(self._dynamic_routes.keys()) or "(none)",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("[webhook] Failed to reload dynamic routes: %s", e)
|
||||
logger.error("[webhook] Failed to reload dynamic routes: %s", e)
|
||||
|
||||
async def _handle_webhook(self, request: "web.Request") -> "web.Response":
|
||||
"""POST /webhooks/{route_name} — receive and process a webhook event."""
|
||||
|
||||
@@ -143,6 +143,9 @@ class WeComAdapter(BasePlatformAdapter):
|
||||
"""WeCom AI Bot adapter backed by a persistent WebSocket connection."""
|
||||
|
||||
MAX_MESSAGE_LENGTH = MAX_MESSAGE_LENGTH
|
||||
# Threshold for detecting WeCom client-side message splits.
|
||||
# When a chunk is near the 4000-char limit, a continuation is almost certain.
|
||||
_SPLIT_THRESHOLD = 3900
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.WECOM)
|
||||
@@ -172,6 +175,13 @@ class WeComAdapter(BasePlatformAdapter):
|
||||
self._seen_messages: Dict[str, float] = {}
|
||||
self._reply_req_ids: Dict[str, str] = {}
|
||||
|
||||
# Text batching: merge rapid successive messages (Telegram-style).
|
||||
# WeCom clients split long messages around 4000 chars.
|
||||
self._text_batch_delay_seconds = float(os.getenv("HERMES_WECOM_TEXT_BATCH_DELAY_SECONDS", "0.6"))
|
||||
self._text_batch_split_delay_seconds = float(os.getenv("HERMES_WECOM_TEXT_BATCH_SPLIT_DELAY_SECONDS", "2.0"))
|
||||
self._pending_text_batches: Dict[str, MessageEvent] = {}
|
||||
self._pending_text_batch_tasks: Dict[str, asyncio.Task] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Connection lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
@@ -519,7 +529,82 @@ class WeComAdapter(BasePlatformAdapter):
|
||||
timestamp=datetime.now(tz=timezone.utc),
|
||||
)
|
||||
|
||||
await self.handle_message(event)
|
||||
# Only batch plain text messages — commands, media, etc. dispatch
|
||||
# immediately since they won't be split by the WeCom client.
|
||||
if message_type == MessageType.TEXT and self._text_batch_delay_seconds > 0:
|
||||
self._enqueue_text_event(event)
|
||||
else:
|
||||
await self.handle_message(event)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Text message aggregation (handles WeCom client-side splits)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _text_batch_key(self, event: MessageEvent) -> str:
|
||||
"""Session-scoped key for text message batching."""
|
||||
from gateway.session import build_session_key
|
||||
return build_session_key(
|
||||
event.source,
|
||||
group_sessions_per_user=self.config.extra.get("group_sessions_per_user", True),
|
||||
thread_sessions_per_user=self.config.extra.get("thread_sessions_per_user", False),
|
||||
)
|
||||
|
||||
def _enqueue_text_event(self, event: MessageEvent) -> None:
|
||||
"""Buffer a text event and reset the flush timer.
|
||||
|
||||
When WeCom splits a long user message at 4000 chars, the chunks
|
||||
arrive within a few hundred milliseconds. This merges them into
|
||||
a single event before dispatching.
|
||||
"""
|
||||
key = self._text_batch_key(event)
|
||||
existing = self._pending_text_batches.get(key)
|
||||
chunk_len = len(event.text or "")
|
||||
if existing is None:
|
||||
event._last_chunk_len = chunk_len # type: ignore[attr-defined]
|
||||
self._pending_text_batches[key] = event
|
||||
else:
|
||||
if event.text:
|
||||
existing.text = f"{existing.text}\n{event.text}" if existing.text else event.text
|
||||
existing._last_chunk_len = chunk_len # type: ignore[attr-defined]
|
||||
# Merge any media that might be attached
|
||||
if event.media_urls:
|
||||
existing.media_urls.extend(event.media_urls)
|
||||
existing.media_types.extend(event.media_types)
|
||||
|
||||
# Cancel any pending flush and restart the timer
|
||||
prior_task = self._pending_text_batch_tasks.get(key)
|
||||
if prior_task and not prior_task.done():
|
||||
prior_task.cancel()
|
||||
self._pending_text_batch_tasks[key] = asyncio.create_task(
|
||||
self._flush_text_batch(key)
|
||||
)
|
||||
|
||||
async def _flush_text_batch(self, key: str) -> None:
|
||||
"""Wait for the quiet period then dispatch the aggregated text.
|
||||
|
||||
Uses a longer delay when the latest chunk is near WeCom's 4000-char
|
||||
split point, since a continuation chunk is almost certain.
|
||||
"""
|
||||
current_task = asyncio.current_task()
|
||||
try:
|
||||
pending = self._pending_text_batches.get(key)
|
||||
last_len = getattr(pending, "_last_chunk_len", 0) if pending else 0
|
||||
if last_len >= self._SPLIT_THRESHOLD:
|
||||
delay = self._text_batch_split_delay_seconds
|
||||
else:
|
||||
delay = self._text_batch_delay_seconds
|
||||
await asyncio.sleep(delay)
|
||||
event = self._pending_text_batches.pop(key, None)
|
||||
if not event:
|
||||
return
|
||||
logger.info(
|
||||
"[WeCom] Flushing text batch %s (%d chars)",
|
||||
key, len(event.text or ""),
|
||||
)
|
||||
await self.handle_message(event)
|
||||
finally:
|
||||
if self._pending_text_batch_tasks.get(key) is current_task:
|
||||
self._pending_text_batch_tasks.pop(key, None)
|
||||
|
||||
@staticmethod
|
||||
def _extract_text(body: Dict[str, Any]) -> Tuple[str, Optional[str]]:
|
||||
@@ -611,7 +696,11 @@ class WeComAdapter(BasePlatformAdapter):
|
||||
|
||||
if kind == "image":
|
||||
ext = self._detect_image_ext(raw)
|
||||
return cache_image_from_bytes(raw, ext), self._mime_for_ext(ext, fallback="image/jpeg")
|
||||
try:
|
||||
return cache_image_from_bytes(raw, ext), self._mime_for_ext(ext, fallback="image/jpeg")
|
||||
except ValueError as exc:
|
||||
logger.warning("[%s] Rejected non-image bytes: %s", self.name, exc)
|
||||
return None
|
||||
|
||||
filename = str(media.get("filename") or media.get("name") or "wecom_file")
|
||||
return cache_document_from_bytes(raw, filename), mimetypes.guess_type(filename)[0] or "application/octet-stream"
|
||||
@@ -637,7 +726,11 @@ class WeComAdapter(BasePlatformAdapter):
|
||||
content_type = str(headers.get("content-type") or "").split(";", 1)[0].strip() or "application/octet-stream"
|
||||
if kind == "image":
|
||||
ext = self._guess_extension(url, content_type, fallback=self._detect_image_ext(raw))
|
||||
return cache_image_from_bytes(raw, ext), content_type or self._mime_for_ext(ext, fallback="image/jpeg")
|
||||
try:
|
||||
return cache_image_from_bytes(raw, ext), content_type or self._mime_for_ext(ext, fallback="image/jpeg")
|
||||
except ValueError as exc:
|
||||
logger.warning("[%s] Rejected non-image bytes from %s: %s", self.name, url, exc)
|
||||
return None
|
||||
|
||||
filename = self._guess_filename(url, headers.get("content-disposition"), content_type)
|
||||
return cache_document_from_bytes(raw, filename), content_type
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
+288
-71
@@ -481,6 +481,7 @@ class GatewayRunner:
|
||||
self._prefill_messages = self._load_prefill_messages()
|
||||
self._ephemeral_system_prompt = self._load_ephemeral_system_prompt()
|
||||
self._reasoning_config = self._load_reasoning_config()
|
||||
self._service_tier = self._load_service_tier()
|
||||
self._show_reasoning = self._load_show_reasoning()
|
||||
self._provider_routing = self._load_provider_routing()
|
||||
self._fallback_model = self._load_fallback_model()
|
||||
@@ -514,12 +515,6 @@ class GatewayRunner:
|
||||
self._agent_cache: Dict[str, tuple] = {}
|
||||
self._agent_cache_lock = _threading.Lock()
|
||||
|
||||
# Track active fallback model/provider when primary is rate-limited.
|
||||
# Set after an agent run where fallback was activated; cleared when
|
||||
# the primary model succeeds again or the user switches via /model.
|
||||
self._effective_model: Optional[str] = None
|
||||
self._effective_provider: Optional[str] = None
|
||||
|
||||
# Per-session model overrides from /model command.
|
||||
# Key: session_key, Value: dict with model/provider/api_key/base_url/api_mode
|
||||
self._session_model_overrides: Dict[str, Dict[str, str]] = {}
|
||||
@@ -782,6 +777,7 @@ class GatewayRunner:
|
||||
|
||||
def _resolve_turn_agent_config(self, user_message: str, model: str, runtime_kwargs: dict) -> dict:
|
||||
from agent.smart_model_routing import resolve_turn_route
|
||||
from hermes_cli.models import resolve_fast_mode_overrides
|
||||
|
||||
primary = {
|
||||
"model": model,
|
||||
@@ -793,7 +789,19 @@ class GatewayRunner:
|
||||
"args": list(runtime_kwargs.get("args") or []),
|
||||
"credential_pool": runtime_kwargs.get("credential_pool"),
|
||||
}
|
||||
return resolve_turn_route(user_message, getattr(self, "_smart_model_routing", {}), primary)
|
||||
route = resolve_turn_route(user_message, getattr(self, "_smart_model_routing", {}), primary)
|
||||
|
||||
service_tier = getattr(self, "_service_tier", None)
|
||||
if not service_tier:
|
||||
route["request_overrides"] = None
|
||||
return route
|
||||
|
||||
try:
|
||||
overrides = resolve_fast_mode_overrides(route.get("model"))
|
||||
except Exception:
|
||||
overrides = None
|
||||
route["request_overrides"] = overrides
|
||||
return route
|
||||
|
||||
async def _handle_adapter_fatal_error(self, adapter: BasePlatformAdapter) -> None:
|
||||
"""React to an adapter failure after startup.
|
||||
@@ -925,8 +933,8 @@ class GatewayRunner:
|
||||
def _load_reasoning_config() -> dict | None:
|
||||
"""Load reasoning effort from config.yaml.
|
||||
|
||||
Reads agent.reasoning_effort from config.yaml. Valid: "xhigh",
|
||||
"high", "medium", "low", "minimal", "none". Returns None to use
|
||||
Reads agent.reasoning_effort from config.yaml. Valid: "none",
|
||||
"minimal", "low", "medium", "high", "xhigh". Returns None to use
|
||||
default (medium).
|
||||
"""
|
||||
from hermes_constants import parse_reasoning_effort
|
||||
@@ -945,6 +953,33 @@ class GatewayRunner:
|
||||
logger.warning("Unknown reasoning_effort '%s', using default (medium)", effort)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _load_service_tier() -> str | None:
|
||||
"""Load Priority Processing setting from config.yaml.
|
||||
|
||||
Reads agent.service_tier from config.yaml. Accepted values mirror the CLI:
|
||||
"fast"/"priority"/"on" => "priority", while "normal"/"off" disables it.
|
||||
Returns None when unset or unsupported.
|
||||
"""
|
||||
raw = ""
|
||||
try:
|
||||
import yaml as _y
|
||||
cfg_path = _hermes_home / "config.yaml"
|
||||
if cfg_path.exists():
|
||||
with open(cfg_path, encoding="utf-8") as _f:
|
||||
cfg = _y.safe_load(_f) or {}
|
||||
raw = str(cfg.get("agent", {}).get("service_tier", "") or "").strip()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
value = raw.lower()
|
||||
if not value or value in {"normal", "default", "standard", "off", "none"}:
|
||||
return None
|
||||
if value in {"fast", "priority", "on"}:
|
||||
return "priority"
|
||||
logger.warning("Unknown service_tier '%s', ignoring", raw)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _load_show_reasoning() -> bool:
|
||||
"""Load show_reasoning toggle from config.yaml display section."""
|
||||
@@ -1075,6 +1110,7 @@ class GatewayRunner:
|
||||
"MATRIX_ALLOWED_USERS", "DINGTALK_ALLOWED_USERS",
|
||||
"FEISHU_ALLOWED_USERS",
|
||||
"WECOM_ALLOWED_USERS",
|
||||
"WEIXIN_ALLOWED_USERS",
|
||||
"BLUEBUBBLES_ALLOWED_USERS",
|
||||
"GATEWAY_ALLOWED_USERS")
|
||||
)
|
||||
@@ -1087,6 +1123,7 @@ class GatewayRunner:
|
||||
"MATRIX_ALLOW_ALL_USERS", "DINGTALK_ALLOW_ALL_USERS",
|
||||
"FEISHU_ALLOW_ALL_USERS",
|
||||
"WECOM_ALLOW_ALL_USERS",
|
||||
"WEIXIN_ALLOW_ALL_USERS",
|
||||
"BLUEBUBBLES_ALLOW_ALL_USERS")
|
||||
)
|
||||
if not _any_allowlist and not _allow_all:
|
||||
@@ -1628,6 +1665,13 @@ class GatewayRunner:
|
||||
return None
|
||||
return WeComAdapter(config)
|
||||
|
||||
elif platform == Platform.WEIXIN:
|
||||
from gateway.platforms.weixin import WeixinAdapter, check_weixin_requirements
|
||||
if not check_weixin_requirements():
|
||||
logger.warning("Weixin: aiohttp/cryptography not installed")
|
||||
return None
|
||||
return WeixinAdapter(config)
|
||||
|
||||
elif platform == Platform.MATTERMOST:
|
||||
from gateway.platforms.mattermost import MattermostAdapter, check_mattermost_requirements
|
||||
if not check_mattermost_requirements():
|
||||
@@ -1703,6 +1747,7 @@ class GatewayRunner:
|
||||
Platform.DINGTALK: "DINGTALK_ALLOWED_USERS",
|
||||
Platform.FEISHU: "FEISHU_ALLOWED_USERS",
|
||||
Platform.WECOM: "WECOM_ALLOWED_USERS",
|
||||
Platform.WEIXIN: "WEIXIN_ALLOWED_USERS",
|
||||
Platform.BLUEBUBBLES: "BLUEBUBBLES_ALLOWED_USERS",
|
||||
}
|
||||
platform_allow_all_map = {
|
||||
@@ -1718,6 +1763,7 @@ class GatewayRunner:
|
||||
Platform.DINGTALK: "DINGTALK_ALLOW_ALL_USERS",
|
||||
Platform.FEISHU: "FEISHU_ALLOW_ALL_USERS",
|
||||
Platform.WECOM: "WECOM_ALLOW_ALL_USERS",
|
||||
Platform.WEIXIN: "WEIXIN_ALLOW_ALL_USERS",
|
||||
Platform.BLUEBUBBLES: "BLUEBUBBLES_ALLOW_ALL_USERS",
|
||||
}
|
||||
|
||||
@@ -1997,6 +2043,11 @@ class GatewayRunner:
|
||||
return await self._handle_approve_command(event)
|
||||
return await self._handle_deny_command(event)
|
||||
|
||||
# /background must bypass the running-agent guard — it starts a
|
||||
# parallel task and must never interrupt the active conversation.
|
||||
if _cmd_def_inner and _cmd_def_inner.name == "background":
|
||||
return await self._handle_background_command(event)
|
||||
|
||||
if event.message_type == MessageType.PHOTO:
|
||||
logger.debug("PRIORITY photo follow-up for session %s — queueing without interrupt", _quick_key[:20])
|
||||
adapter = self.adapters.get(source.platform)
|
||||
@@ -2078,6 +2129,9 @@ class GatewayRunner:
|
||||
if canonical == "reasoning":
|
||||
return await self._handle_reasoning_command(event)
|
||||
|
||||
if canonical == "fast":
|
||||
return await self._handle_fast_command(event)
|
||||
|
||||
if canonical == "verbose":
|
||||
return await self._handle_verbose_command(event)
|
||||
|
||||
@@ -2420,37 +2474,41 @@ class GatewayRunner:
|
||||
session_entry.was_auto_reset = False
|
||||
session_entry.auto_reset_reason = None
|
||||
|
||||
# Auto-load skill for DM topic bindings (e.g., Telegram Private Chat Topics)
|
||||
# Only inject on NEW sessions — for ongoing conversations the skill content
|
||||
# is already in the conversation history from the first message.
|
||||
if _is_new_session and getattr(event, "auto_skill", None):
|
||||
# Auto-load skill(s) for topic/channel bindings (Telegram DM Topics,
|
||||
# Discord channel_skill_bindings). Supports a single name or ordered list.
|
||||
# Only inject on NEW sessions — ongoing conversations already have the
|
||||
# skill content in their conversation history from the first message.
|
||||
_auto = getattr(event, "auto_skill", None)
|
||||
if _is_new_session and _auto:
|
||||
_skill_names = [_auto] if isinstance(_auto, str) else list(_auto)
|
||||
try:
|
||||
from agent.skill_commands import _load_skill_payload, _build_skill_message
|
||||
_skill_name = event.auto_skill
|
||||
_loaded = _load_skill_payload(_skill_name, task_id=_quick_key)
|
||||
if _loaded:
|
||||
_loaded_skill, _skill_dir, _display_name = _loaded
|
||||
_activation_note = (
|
||||
f'[SYSTEM: This conversation is in a topic with the "{_display_name}" skill '
|
||||
f"auto-loaded. Follow its instructions for the duration of this session.]"
|
||||
)
|
||||
_skill_msg = _build_skill_message(
|
||||
_loaded_skill, _skill_dir, _activation_note,
|
||||
user_instruction=event.text,
|
||||
)
|
||||
if _skill_msg:
|
||||
event.text = _skill_msg
|
||||
logger.info(
|
||||
"[Gateway] Auto-loaded skill '%s' for DM topic session %s",
|
||||
_skill_name, session_key,
|
||||
_combined_parts: list[str] = []
|
||||
_loaded_names: list[str] = []
|
||||
for _sname in _skill_names:
|
||||
_loaded = _load_skill_payload(_sname, task_id=_quick_key)
|
||||
if _loaded:
|
||||
_loaded_skill, _skill_dir, _display_name = _loaded
|
||||
_note = (
|
||||
f'[SYSTEM: The "{_display_name}" skill is auto-loaded. '
|
||||
f"Follow its instructions for this session.]"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"[Gateway] DM topic skill '%s' not found in available skills",
|
||||
_skill_name,
|
||||
_part = _build_skill_message(_loaded_skill, _skill_dir, _note)
|
||||
if _part:
|
||||
_combined_parts.append(_part)
|
||||
_loaded_names.append(_sname)
|
||||
else:
|
||||
logger.warning("[Gateway] Auto-skill '%s' not found", _sname)
|
||||
if _combined_parts:
|
||||
# Append the user's original text after all skill payloads
|
||||
_combined_parts.append(event.text)
|
||||
event.text = "\n\n".join(_combined_parts)
|
||||
logger.info(
|
||||
"[Gateway] Auto-loaded skill(s) %s for session %s",
|
||||
_loaded_names, session_key,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("[Gateway] Failed to auto-load topic skill '%s': %s", event.auto_skill, e)
|
||||
logger.warning("[Gateway] Failed to auto-load skill(s) %s: %s", _skill_names, e)
|
||||
|
||||
# Load conversation history from transcript
|
||||
history = self.session_store.load_transcript(session_entry.session_id)
|
||||
@@ -3546,6 +3604,7 @@ class GatewayRunner:
|
||||
current_base_url = ""
|
||||
current_api_key = ""
|
||||
user_provs = None
|
||||
custom_provs = None
|
||||
config_path = _hermes_home / "config.yaml"
|
||||
try:
|
||||
if config_path.exists():
|
||||
@@ -3557,6 +3616,7 @@ class GatewayRunner:
|
||||
current_provider = model_cfg.get("provider", current_provider)
|
||||
current_base_url = model_cfg.get("base_url", "")
|
||||
user_provs = cfg.get("providers")
|
||||
custom_provs = cfg.get("custom_providers")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -3584,6 +3644,7 @@ class GatewayRunner:
|
||||
providers = list_authenticated_providers(
|
||||
current_provider=current_provider,
|
||||
user_providers=user_provs,
|
||||
custom_providers=custom_provs,
|
||||
max_models=50,
|
||||
)
|
||||
except Exception:
|
||||
@@ -3611,6 +3672,8 @@ class GatewayRunner:
|
||||
current_api_key=_cur_api_key,
|
||||
is_global=False,
|
||||
explicit_provider=provider_slug,
|
||||
user_providers=user_provs,
|
||||
custom_providers=custom_provs,
|
||||
)
|
||||
if not result.success:
|
||||
return f"Error: {result.error_message}"
|
||||
@@ -3689,6 +3752,7 @@ class GatewayRunner:
|
||||
providers = list_authenticated_providers(
|
||||
current_provider=current_provider,
|
||||
user_providers=user_provs,
|
||||
custom_providers=custom_provs,
|
||||
max_models=5,
|
||||
)
|
||||
for p in providers:
|
||||
@@ -3718,6 +3782,8 @@ class GatewayRunner:
|
||||
current_api_key=current_api_key,
|
||||
is_global=persist_global,
|
||||
explicit_provider=explicit_provider,
|
||||
user_providers=user_provs,
|
||||
custom_providers=custom_provs,
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
@@ -3839,6 +3905,7 @@ class GatewayRunner:
|
||||
|
||||
# Resolve current provider from config
|
||||
current_provider = "openrouter"
|
||||
model_cfg = {}
|
||||
config_path = _hermes_home / 'config.yaml'
|
||||
try:
|
||||
if config_path.exists():
|
||||
@@ -4579,6 +4646,7 @@ class GatewayRunner:
|
||||
max_iterations = int(os.getenv("HERMES_MAX_ITERATIONS", "90"))
|
||||
reasoning_config = self._load_reasoning_config()
|
||||
self._reasoning_config = reasoning_config
|
||||
self._service_tier = self._load_service_tier()
|
||||
turn_route = self._resolve_turn_agent_config(prompt, model, runtime_kwargs)
|
||||
|
||||
def run_sync():
|
||||
@@ -4590,6 +4658,8 @@ class GatewayRunner:
|
||||
verbose_logging=False,
|
||||
enabled_toolsets=enabled_toolsets,
|
||||
reasoning_config=reasoning_config,
|
||||
service_tier=self._service_tier,
|
||||
request_overrides=turn_route.get("request_overrides"),
|
||||
providers_allowed=pr.get("only"),
|
||||
providers_ignored=pr.get("ignore"),
|
||||
providers_order=pr.get("order"),
|
||||
@@ -4739,6 +4809,7 @@ class GatewayRunner:
|
||||
model = _resolve_gateway_model(user_config)
|
||||
platform_key = _platform_config_key(source.platform)
|
||||
reasoning_config = self._load_reasoning_config()
|
||||
self._service_tier = self._load_service_tier()
|
||||
turn_route = self._resolve_turn_agent_config(question, model, runtime_kwargs)
|
||||
pr = self._provider_routing
|
||||
|
||||
@@ -4765,6 +4836,8 @@ class GatewayRunner:
|
||||
verbose_logging=False,
|
||||
enabled_toolsets=[],
|
||||
reasoning_config=reasoning_config,
|
||||
service_tier=self._service_tier,
|
||||
request_overrides=turn_route.get("request_overrides"),
|
||||
providers_allowed=pr.get("only"),
|
||||
providers_ignored=pr.get("ignore"),
|
||||
providers_order=pr.get("order"),
|
||||
@@ -4840,7 +4913,7 @@ class GatewayRunner:
|
||||
|
||||
Usage:
|
||||
/reasoning Show current effort level and display state
|
||||
/reasoning <level> Set reasoning effort (none, low, medium, high, xhigh)
|
||||
/reasoning <level> Set reasoning effort (none, minimal, low, medium, high, xhigh)
|
||||
/reasoning show|on Show model reasoning in responses
|
||||
/reasoning hide|off Hide model reasoning from responses
|
||||
"""
|
||||
@@ -4885,7 +4958,7 @@ class GatewayRunner:
|
||||
"🧠 **Reasoning Settings**\n\n"
|
||||
f"**Effort:** `{level}`\n"
|
||||
f"**Display:** {display_state}\n\n"
|
||||
"_Usage:_ `/reasoning <none|low|medium|high|xhigh|show|hide>`"
|
||||
"_Usage:_ `/reasoning <none|minimal|low|medium|high|xhigh|show|hide>`"
|
||||
)
|
||||
|
||||
# Display toggle
|
||||
@@ -4903,12 +4976,12 @@ class GatewayRunner:
|
||||
effort = args.strip()
|
||||
if effort == "none":
|
||||
parsed = {"enabled": False}
|
||||
elif effort in ("xhigh", "high", "medium", "low", "minimal"):
|
||||
elif effort in ("minimal", "low", "medium", "high", "xhigh"):
|
||||
parsed = {"enabled": True, "effort": effort}
|
||||
else:
|
||||
return (
|
||||
f"⚠️ Unknown argument: `{effort}`\n\n"
|
||||
"**Valid levels:** none, low, minimal, medium, high, xhigh\n"
|
||||
"**Valid levels:** none, minimal, low, medium, high, xhigh\n"
|
||||
"**Display:** show, hide"
|
||||
)
|
||||
|
||||
@@ -4918,15 +4991,82 @@ class GatewayRunner:
|
||||
else:
|
||||
return f"🧠 ✓ Reasoning effort set to `{effort}` (this session only)"
|
||||
|
||||
async def _handle_yolo_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /yolo — toggle dangerous command approval bypass."""
|
||||
current = bool(os.environ.get("HERMES_YOLO_MODE"))
|
||||
if current:
|
||||
os.environ.pop("HERMES_YOLO_MODE", None)
|
||||
return "⚠️ YOLO mode **OFF** — dangerous commands will require approval."
|
||||
async def _handle_fast_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /fast — mirror the CLI Priority Processing toggle in gateway chats."""
|
||||
import yaml
|
||||
from hermes_cli.models import model_supports_fast_mode
|
||||
|
||||
args = event.get_command_args().strip().lower()
|
||||
config_path = _hermes_home / "config.yaml"
|
||||
self._service_tier = self._load_service_tier()
|
||||
|
||||
user_config = _load_gateway_config()
|
||||
model = _resolve_gateway_model(user_config)
|
||||
if not model_supports_fast_mode(model):
|
||||
return "⚡ /fast is only available for OpenAI models that support Priority Processing."
|
||||
|
||||
def _save_config_key(key_path: str, value):
|
||||
"""Save a dot-separated key to config.yaml."""
|
||||
try:
|
||||
user_config = {}
|
||||
if config_path.exists():
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
user_config = yaml.safe_load(f) or {}
|
||||
keys = key_path.split(".")
|
||||
current = user_config
|
||||
for k in keys[:-1]:
|
||||
if k not in current or not isinstance(current[k], dict):
|
||||
current[k] = {}
|
||||
current = current[k]
|
||||
current[keys[-1]] = value
|
||||
atomic_yaml_write(config_path, user_config)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Failed to save config key %s: %s", key_path, e)
|
||||
return False
|
||||
|
||||
if not args or args == "status":
|
||||
status = "fast" if self._service_tier == "priority" else "normal"
|
||||
return (
|
||||
"⚡ Priority Processing\n\n"
|
||||
f"Current mode: `{status}`\n\n"
|
||||
"_Usage:_ `/fast <normal|fast|status>`"
|
||||
)
|
||||
|
||||
if args in {"fast", "on"}:
|
||||
self._service_tier = "priority"
|
||||
saved_value = "fast"
|
||||
label = "FAST"
|
||||
elif args in {"normal", "off"}:
|
||||
self._service_tier = None
|
||||
saved_value = "normal"
|
||||
label = "NORMAL"
|
||||
else:
|
||||
os.environ["HERMES_YOLO_MODE"] = "1"
|
||||
return "⚡ YOLO mode **ON** — all commands auto-approved. Use with caution."
|
||||
return (
|
||||
f"⚠️ Unknown argument: `{args}`\n\n"
|
||||
"**Valid options:** normal, fast, status"
|
||||
)
|
||||
|
||||
if _save_config_key("agent.service_tier", saved_value):
|
||||
return f"⚡ ✓ Priority Processing: **{label}** (saved to config)\n_(takes effect on next message)_"
|
||||
return f"⚡ ✓ Priority Processing: **{label}** (this session only)"
|
||||
|
||||
async def _handle_yolo_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /yolo — toggle dangerous command approval bypass for this session only."""
|
||||
from tools.approval import (
|
||||
disable_session_yolo,
|
||||
enable_session_yolo,
|
||||
is_session_yolo_enabled,
|
||||
)
|
||||
|
||||
session_key = self._session_key_for_source(event.source)
|
||||
current = is_session_yolo_enabled(session_key)
|
||||
if current:
|
||||
disable_session_yolo(session_key)
|
||||
return "⚠️ YOLO mode **OFF** for this session — dangerous commands will require approval."
|
||||
else:
|
||||
enable_session_yolo(session_key)
|
||||
return "⚡ YOLO mode **ON** for this session — all commands auto-approved. Use with caution."
|
||||
|
||||
async def _handle_verbose_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /verbose command — cycle tool progress display mode.
|
||||
@@ -5274,27 +5414,76 @@ class GatewayRunner:
|
||||
)
|
||||
|
||||
async def _handle_usage_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /usage command -- show token usage for the session's last agent run."""
|
||||
"""Handle /usage command -- show token usage for the current session.
|
||||
|
||||
Checks both _running_agents (mid-turn) and _agent_cache (between turns)
|
||||
so that rate limits, cost estimates, and detailed token breakdowns are
|
||||
available whenever the user asks, not only while the agent is running.
|
||||
"""
|
||||
source = event.source
|
||||
session_key = self._session_key_for_source(source)
|
||||
|
||||
# Try running agent first (mid-turn), then cached agent (between turns)
|
||||
agent = self._running_agents.get(session_key)
|
||||
if not agent or agent is _AGENT_PENDING_SENTINEL:
|
||||
_cache_lock = getattr(self, "_agent_cache_lock", None)
|
||||
_cache = getattr(self, "_agent_cache", None)
|
||||
if _cache_lock and _cache is not None:
|
||||
with _cache_lock:
|
||||
cached = _cache.get(session_key)
|
||||
if cached:
|
||||
agent = cached[0]
|
||||
|
||||
if agent and hasattr(agent, "session_total_tokens") and agent.session_api_calls > 0:
|
||||
lines = []
|
||||
|
||||
# Rate limits first (when available from provider headers)
|
||||
# Rate limits (when available from provider headers)
|
||||
rl_state = agent.get_rate_limit_state()
|
||||
if rl_state and rl_state.has_data:
|
||||
from agent.rate_limit_tracker import format_rate_limit_compact
|
||||
lines.append(f"⏱️ **Rate Limits:** {format_rate_limit_compact(rl_state)}")
|
||||
lines.append("")
|
||||
|
||||
# Session token usage
|
||||
# Session token usage — detailed breakdown matching CLI
|
||||
input_tokens = getattr(agent, "session_input_tokens", 0) or 0
|
||||
output_tokens = getattr(agent, "session_output_tokens", 0) or 0
|
||||
cache_read = getattr(agent, "session_cache_read_tokens", 0) or 0
|
||||
cache_write = getattr(agent, "session_cache_write_tokens", 0) or 0
|
||||
|
||||
lines.append("📊 **Session Token Usage**")
|
||||
lines.append(f"Prompt (input): {agent.session_prompt_tokens:,}")
|
||||
lines.append(f"Completion (output): {agent.session_completion_tokens:,}")
|
||||
lines.append(f"Model: `{agent.model}`")
|
||||
lines.append(f"Input tokens: {input_tokens:,}")
|
||||
if cache_read:
|
||||
lines.append(f"Cache read tokens: {cache_read:,}")
|
||||
if cache_write:
|
||||
lines.append(f"Cache write tokens: {cache_write:,}")
|
||||
lines.append(f"Output tokens: {output_tokens:,}")
|
||||
lines.append(f"Total: {agent.session_total_tokens:,}")
|
||||
lines.append(f"API calls: {agent.session_api_calls}")
|
||||
|
||||
# Cost estimation
|
||||
try:
|
||||
from agent.usage_pricing import CanonicalUsage, estimate_usage_cost
|
||||
cost_result = estimate_usage_cost(
|
||||
agent.model,
|
||||
CanonicalUsage(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_read_tokens=cache_read,
|
||||
cache_write_tokens=cache_write,
|
||||
),
|
||||
provider=getattr(agent, "provider", None),
|
||||
base_url=getattr(agent, "base_url", None),
|
||||
)
|
||||
if cost_result.amount_usd is not None:
|
||||
prefix = "~" if cost_result.status == "estimated" else ""
|
||||
lines.append(f"Cost: {prefix}${float(cost_result.amount_usd):.4f}")
|
||||
elif cost_result.status == "included":
|
||||
lines.append("Cost: included")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Context window and compressions
|
||||
ctx = agent.context_compressor
|
||||
if ctx.last_prompt_tokens:
|
||||
pct = min(100, ctx.last_prompt_tokens / ctx.context_length * 100) if ctx.context_length else 0
|
||||
@@ -5304,7 +5493,7 @@ class GatewayRunner:
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
# No running agent -- check session history for a rough count
|
||||
# No agent at all -- check session history for a rough count
|
||||
session_entry = self.session_store.get_or_create_session(source)
|
||||
history = self.session_store.load_transcript(session_entry.session_id)
|
||||
if history:
|
||||
@@ -5315,7 +5504,7 @@ class GatewayRunner:
|
||||
f"📊 **Session Info**\n"
|
||||
f"Messages: {len(msgs)}\n"
|
||||
f"Estimated context: ~{approx:,} tokens\n"
|
||||
f"_(Detailed usage available during active conversations)_"
|
||||
f"_(Detailed usage available after the first agent response)_"
|
||||
)
|
||||
return "No usage data available for this session."
|
||||
|
||||
@@ -5543,7 +5732,7 @@ class GatewayRunner:
|
||||
Platform.TELEGRAM, Platform.DISCORD, Platform.SLACK, Platform.WHATSAPP,
|
||||
Platform.SIGNAL, Platform.MATTERMOST, Platform.MATRIX,
|
||||
Platform.HOMEASSISTANT, Platform.EMAIL, Platform.SMS, Platform.DINGTALK,
|
||||
Platform.FEISHU, Platform.WECOM, Platform.BLUEBUBBLES, Platform.LOCAL,
|
||||
Platform.FEISHU, Platform.WECOM, Platform.WEIXIN, Platform.BLUEBUBBLES, Platform.LOCAL,
|
||||
})
|
||||
|
||||
async def _handle_update_command(self, event: MessageEvent) -> str:
|
||||
@@ -6042,16 +6231,14 @@ class GatewayRunner:
|
||||
return f"{disabled_note}\n\n{user_text}"
|
||||
return disabled_note
|
||||
|
||||
from tools.transcription_tools import transcribe_audio, get_stt_model_from_config
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
import asyncio
|
||||
|
||||
stt_model = get_stt_model_from_config()
|
||||
|
||||
enriched_parts = []
|
||||
for path in audio_paths:
|
||||
try:
|
||||
logger.debug("Transcribing user voice: %s", path)
|
||||
result = await asyncio.to_thread(transcribe_audio, path, model=stt_model)
|
||||
result = await asyncio.to_thread(transcribe_audio, path)
|
||||
if result["success"]:
|
||||
transcript = result["transcript"]
|
||||
enriched_parts.append(
|
||||
@@ -6283,6 +6470,32 @@ class GatewayRunner:
|
||||
)
|
||||
return hashlib.sha256(blob.encode()).hexdigest()[:16]
|
||||
|
||||
def _apply_session_model_override(
|
||||
self, session_key: str, model: str, runtime_kwargs: dict
|
||||
) -> tuple:
|
||||
"""Apply /model session overrides if present, returning (model, runtime_kwargs).
|
||||
|
||||
The gateway /model command stores per-session overrides in
|
||||
``_session_model_overrides``. These must take precedence over
|
||||
config.yaml defaults so the switched model is actually used for
|
||||
subsequent messages. Fields with ``None`` values are skipped so
|
||||
partial overrides don't clobber valid config defaults.
|
||||
"""
|
||||
override = self._session_model_overrides.get(session_key)
|
||||
if not override:
|
||||
return model, runtime_kwargs
|
||||
model = override.get("model", model)
|
||||
for key in ("provider", "api_key", "base_url", "api_mode"):
|
||||
val = override.get(key)
|
||||
if val is not None:
|
||||
runtime_kwargs[key] = val
|
||||
return model, runtime_kwargs
|
||||
|
||||
def _is_intentional_model_switch(self, session_key: str, agent_model: str) -> bool:
|
||||
"""Return True if *agent_model* matches an active /model session override."""
|
||||
override = self._session_model_overrides.get(session_key)
|
||||
return override is not None and override.get("model") == agent_model
|
||||
|
||||
def _evict_cached_agent(self, session_key: str) -> None:
|
||||
"""Remove a cached agent for a session (called on /new, /model, etc)."""
|
||||
_lock = getattr(self, "_agent_cache_lock", None)
|
||||
@@ -6660,9 +6873,15 @@ class GatewayRunner:
|
||||
"tools": [],
|
||||
}
|
||||
|
||||
# /model overrides take precedence over config.yaml defaults.
|
||||
model, runtime_kwargs = self._apply_session_model_override(
|
||||
session_key, model, runtime_kwargs
|
||||
)
|
||||
|
||||
pr = self._provider_routing
|
||||
reasoning_config = self._load_reasoning_config()
|
||||
self._reasoning_config = reasoning_config
|
||||
self._service_tier = self._load_service_tier()
|
||||
# Set up streaming consumer if enabled
|
||||
_stream_consumer = None
|
||||
_stream_delta_cb = None
|
||||
@@ -6725,6 +6944,8 @@ class GatewayRunner:
|
||||
ephemeral_system_prompt=combined_ephemeral or None,
|
||||
prefill_messages=self._prefill_messages or None,
|
||||
reasoning_config=reasoning_config,
|
||||
service_tier=self._service_tier,
|
||||
request_overrides=turn_route.get("request_overrides"),
|
||||
providers_allowed=pr.get("only"),
|
||||
providers_ignored=pr.get("ignore"),
|
||||
providers_order=pr.get("order"),
|
||||
@@ -6749,6 +6970,8 @@ class GatewayRunner:
|
||||
agent.stream_delta_callback = _stream_delta_cb
|
||||
agent.status_callback = _status_callback_sync
|
||||
agent.reasoning_config = reasoning_config
|
||||
agent.service_tier = self._service_tier
|
||||
agent.request_overrides = turn_route.get("request_overrides")
|
||||
|
||||
# Background review delivery — send "💾 Memory updated" etc. to user
|
||||
def _bg_review_send(message: str) -> None:
|
||||
@@ -7279,16 +7502,10 @@ class GatewayRunner:
|
||||
_agent = agent_holder[0]
|
||||
if _agent is not None and hasattr(_agent, 'model'):
|
||||
_cfg_model = _resolve_gateway_model()
|
||||
if _agent.model != _cfg_model:
|
||||
self._effective_model = _agent.model
|
||||
self._effective_provider = getattr(_agent, 'provider', None)
|
||||
if _agent.model != _cfg_model and not self._is_intentional_model_switch(session_key, _agent.model):
|
||||
# Fallback activated — evict cached agent so the next
|
||||
# message starts fresh and retries the primary model.
|
||||
self._evict_cached_agent(session_key)
|
||||
else:
|
||||
# Primary model worked — clear any stale fallback state
|
||||
self._effective_model = None
|
||||
self._effective_provider = None
|
||||
|
||||
# Check if we were interrupted OR have a queued message (/queue).
|
||||
result = result_holder[0]
|
||||
@@ -7496,7 +7713,7 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool =
|
||||
# setups (each profile using a distinct HERMES_HOME) will naturally
|
||||
# allow concurrent instances without tripping this guard.
|
||||
import time as _time
|
||||
from gateway.status import get_running_pid, remove_pid_file
|
||||
from gateway.status import get_running_pid, remove_pid_file, terminate_pid
|
||||
existing_pid = get_running_pid()
|
||||
if existing_pid is not None and existing_pid != os.getpid():
|
||||
if replace:
|
||||
@@ -7505,10 +7722,10 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool =
|
||||
existing_pid,
|
||||
)
|
||||
try:
|
||||
os.kill(existing_pid, signal.SIGTERM)
|
||||
terminate_pid(existing_pid, force=False)
|
||||
except ProcessLookupError:
|
||||
pass # Already gone
|
||||
except PermissionError:
|
||||
except (PermissionError, OSError):
|
||||
logger.error(
|
||||
"Permission denied killing PID %d. Cannot replace.",
|
||||
existing_pid,
|
||||
@@ -7528,9 +7745,9 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool =
|
||||
existing_pid,
|
||||
)
|
||||
try:
|
||||
os.kill(existing_pid, signal.SIGKILL)
|
||||
terminate_pid(existing_pid, force=True)
|
||||
_time.sleep(0.5)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
except (ProcessLookupError, PermissionError, OSError):
|
||||
pass
|
||||
remove_pid_file()
|
||||
# Also release all scoped locks left by the old process.
|
||||
|
||||
+1
-53
@@ -32,9 +32,6 @@ def _now() -> datetime:
|
||||
# PII redaction helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PHONE_RE = re.compile(r"^\+?\d[\d\-\s]{6,}$")
|
||||
|
||||
|
||||
def _hash_id(value: str) -> str:
|
||||
"""Deterministic 12-char hex hash of an identifier."""
|
||||
return hashlib.sha256(value.encode("utf-8")).hexdigest()[:12]
|
||||
@@ -58,10 +55,6 @@ def _hash_chat_id(value: str) -> str:
|
||||
return _hash_id(value)
|
||||
|
||||
|
||||
def _looks_like_phone(value: str) -> bool:
|
||||
"""Return True if *value* looks like a phone number (E.164 or similar)."""
|
||||
return bool(_PHONE_RE.match(value.strip()))
|
||||
|
||||
from .config import (
|
||||
Platform,
|
||||
GatewayConfig,
|
||||
@@ -144,15 +137,6 @@ class SessionSource:
|
||||
chat_id_alt=data.get("chat_id_alt"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def local_cli(cls) -> "SessionSource":
|
||||
"""Create a source representing the local CLI."""
|
||||
return cls(
|
||||
platform=Platform.LOCAL,
|
||||
chat_id="cli",
|
||||
chat_name="CLI terminal",
|
||||
chat_type="dm",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -510,8 +494,7 @@ class SessionStore:
|
||||
"""
|
||||
|
||||
def __init__(self, sessions_dir: Path, config: GatewayConfig,
|
||||
has_active_processes_fn=None,
|
||||
on_auto_reset=None):
|
||||
has_active_processes_fn=None):
|
||||
self.sessions_dir = sessions_dir
|
||||
self.config = config
|
||||
self._entries: Dict[str, SessionEntry] = {}
|
||||
@@ -770,41 +753,6 @@ class SessionStore:
|
||||
except Exception as e:
|
||||
print(f"[gateway] Warning: Failed to create SQLite session: {e}")
|
||||
|
||||
# Seed new DM thread sessions with parent DM session history.
|
||||
# When a bot reply creates a Slack thread and the user responds in it,
|
||||
# the thread gets a new session (keyed by thread_ts). Without seeding,
|
||||
# the thread session starts with zero context — the user's original
|
||||
# question and the bot's answer are invisible. Fix: copy the parent
|
||||
# DM session's transcript into the new thread session so context carries
|
||||
# over while still keeping threads isolated from each other.
|
||||
if (
|
||||
source.chat_type == "dm"
|
||||
and source.thread_id
|
||||
and entry.created_at == entry.updated_at # brand-new session
|
||||
and not was_auto_reset
|
||||
):
|
||||
parent_source = SessionSource(
|
||||
platform=source.platform,
|
||||
chat_id=source.chat_id,
|
||||
chat_type="dm",
|
||||
user_id=source.user_id,
|
||||
# no thread_id — this is the parent DM session
|
||||
)
|
||||
parent_key = self._generate_session_key(parent_source)
|
||||
with self._lock:
|
||||
parent_entry = self._entries.get(parent_key)
|
||||
if parent_entry and parent_entry.session_id != entry.session_id:
|
||||
try:
|
||||
parent_history = self.load_transcript(parent_entry.session_id)
|
||||
if parent_history:
|
||||
self.rewrite_transcript(entry.session_id, parent_history)
|
||||
logger.info(
|
||||
"[Session] Seeded DM thread session %s with %d messages from parent %s",
|
||||
entry.session_id, len(parent_history), parent_entry.session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("[Session] Failed to seed thread session: %s", e)
|
||||
|
||||
return entry
|
||||
|
||||
def update_session(
|
||||
|
||||
@@ -14,6 +14,8 @@ concurrently under distinct configurations).
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
@@ -23,6 +25,7 @@ from typing import Any, Optional
|
||||
_GATEWAY_KIND = "hermes-gateway"
|
||||
_RUNTIME_STATUS_FILE = "gateway_state.json"
|
||||
_LOCKS_DIRNAME = "gateway-locks"
|
||||
_IS_WINDOWS = sys.platform == "win32"
|
||||
|
||||
|
||||
def _get_pid_path() -> Path:
|
||||
@@ -49,6 +52,33 @@ def _utc_now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def terminate_pid(pid: int, *, force: bool = False) -> None:
|
||||
"""Terminate a PID with platform-appropriate force semantics.
|
||||
|
||||
POSIX uses SIGTERM/SIGKILL. Windows uses taskkill /T /F for true force-kill
|
||||
because os.kill(..., SIGTERM) is not equivalent to a tree-killing hard stop.
|
||||
"""
|
||||
if force and _IS_WINDOWS:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["taskkill", "/PID", str(pid), "/T", "/F"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
return
|
||||
|
||||
if result.returncode != 0:
|
||||
details = (result.stderr or result.stdout or "").strip()
|
||||
raise OSError(details or f"taskkill failed for PID {pid}")
|
||||
return
|
||||
|
||||
sig = signal.SIGTERM if not force else getattr(signal, "SIGKILL", signal.SIGTERM)
|
||||
os.kill(pid, sig)
|
||||
|
||||
|
||||
def _scope_hash(identity: str) -> str:
|
||||
return hashlib.sha256(identity.encode("utf-8")).hexdigest()[:16]
|
||||
|
||||
|
||||
@@ -136,7 +136,34 @@ class GatewayStreamConsumer:
|
||||
|
||||
if should_edit and self._accumulated:
|
||||
# Split overflow: if accumulated text exceeds the platform
|
||||
# limit, finalize the current message and start a new one.
|
||||
# limit, split into properly sized chunks.
|
||||
if (
|
||||
len(self._accumulated) > _safe_limit
|
||||
and self._message_id is None
|
||||
):
|
||||
# No existing message to edit (first message or after a
|
||||
# segment break). Use truncate_message — the same
|
||||
# helper the non-streaming path uses — to split with
|
||||
# proper word/code-fence boundaries and chunk
|
||||
# indicators like "(1/2)".
|
||||
chunks = self.adapter.truncate_message(
|
||||
self._accumulated, _safe_limit
|
||||
)
|
||||
for chunk in chunks:
|
||||
await self._send_new_chunk(chunk, self._message_id)
|
||||
self._accumulated = ""
|
||||
self._last_sent_text = ""
|
||||
self._last_edit_time = time.monotonic()
|
||||
if got_done:
|
||||
return
|
||||
if got_segment_break:
|
||||
self._message_id = None
|
||||
self._fallback_final_send = False
|
||||
self._fallback_prefix = ""
|
||||
continue
|
||||
|
||||
# Existing message: edit it with the first chunk, then
|
||||
# start a new message for the overflow remainder.
|
||||
while (
|
||||
len(self._accumulated) > _safe_limit
|
||||
and self._message_id is not None
|
||||
@@ -178,11 +205,20 @@ class GatewayStreamConsumer:
|
||||
await self._send_or_edit(self._accumulated)
|
||||
return
|
||||
|
||||
# Tool boundary: the should_edit block above already flushed
|
||||
# accumulated text without a cursor. Reset state so the next
|
||||
# text chunk creates a fresh message below any tool-progress
|
||||
# messages the gateway sent in between.
|
||||
if got_segment_break:
|
||||
# Tool boundary: reset message state so the next text chunk
|
||||
# creates a fresh message below any tool-progress messages.
|
||||
#
|
||||
# Exception: when _message_id is "__no_edit__" the platform
|
||||
# never returned a real message ID (e.g. Signal, webhook with
|
||||
# github_comment delivery). Resetting to None would re-enter
|
||||
# the "first send" path on every tool boundary and post one
|
||||
# platform message per tool call — that is what caused 155
|
||||
# comments under a single PR. Instead, keep all state so the
|
||||
# full continuation is delivered once via _send_fallback_final.
|
||||
# (When editing fails mid-stream due to flood control the id is
|
||||
# a real string like "msg_1", not "__no_edit__", so that case
|
||||
# still resets and creates a fresh segment as intended.)
|
||||
if got_segment_break and self._message_id != "__no_edit__":
|
||||
self._message_id = None
|
||||
self._accumulated = ""
|
||||
self._last_sent_text = ""
|
||||
@@ -226,6 +262,34 @@ class GatewayStreamConsumer:
|
||||
# Strip trailing whitespace/newlines but preserve leading content
|
||||
return cleaned.rstrip()
|
||||
|
||||
async def _send_new_chunk(self, text: str, reply_to_id: Optional[str]) -> Optional[str]:
|
||||
"""Send a new message chunk, optionally threaded to a previous message.
|
||||
|
||||
Returns the message_id so callers can thread subsequent chunks.
|
||||
"""
|
||||
text = self._clean_for_display(text)
|
||||
if not text.strip():
|
||||
return reply_to_id
|
||||
try:
|
||||
meta = dict(self.metadata) if self.metadata else {}
|
||||
result = await self.adapter.send(
|
||||
chat_id=self.chat_id,
|
||||
content=text,
|
||||
reply_to=reply_to_id,
|
||||
metadata=meta,
|
||||
)
|
||||
if result.success and result.message_id:
|
||||
self._message_id = str(result.message_id)
|
||||
self._already_sent = True
|
||||
self._last_sent_text = text
|
||||
return str(result.message_id)
|
||||
else:
|
||||
self._edit_supported = False
|
||||
return reply_to_id
|
||||
except Exception as e:
|
||||
logger.error("Stream send chunk error: %s", e)
|
||||
return reply_to_id
|
||||
|
||||
def _visible_prefix(self) -> str:
|
||||
"""Return the visible text already shown in the streamed message."""
|
||||
prefix = self._last_sent_text or ""
|
||||
|
||||
+93
-31
@@ -70,7 +70,6 @@ DEFAULT_CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex"
|
||||
DEFAULT_QWEN_BASE_URL = "https://portal.qwen.ai/v1"
|
||||
DEFAULT_GITHUB_MODELS_BASE_URL = "https://api.githubcopilot.com"
|
||||
DEFAULT_COPILOT_ACP_BASE_URL = "acp://copilot"
|
||||
DEFAULT_GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai"
|
||||
CODEX_OAUTH_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
CODEX_OAUTH_TOKEN_URL = "https://auth.openai.com/oauth/token"
|
||||
CODEX_ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120
|
||||
@@ -199,6 +198,14 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
|
||||
api_key_env_vars=("DEEPSEEK_API_KEY",),
|
||||
base_url_env_var="DEEPSEEK_BASE_URL",
|
||||
),
|
||||
"xai": ProviderConfig(
|
||||
id="xai",
|
||||
name="xAI",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://api.x.ai/v1",
|
||||
api_key_env_vars=("XAI_API_KEY",),
|
||||
base_url_env_var="XAI_BASE_URL",
|
||||
),
|
||||
"ai-gateway": ProviderConfig(
|
||||
id="ai-gateway",
|
||||
name="AI Gateway",
|
||||
@@ -705,6 +712,27 @@ def write_credential_pool(provider_id: str, entries: List[Dict[str, Any]]) -> Pa
|
||||
return _save_auth_store(auth_store)
|
||||
|
||||
|
||||
def suppress_credential_source(provider_id: str, source: str) -> None:
|
||||
"""Mark a credential source as suppressed so it won't be re-seeded."""
|
||||
with _auth_store_lock():
|
||||
auth_store = _load_auth_store()
|
||||
suppressed = auth_store.setdefault("suppressed_sources", {})
|
||||
provider_list = suppressed.setdefault(provider_id, [])
|
||||
if source not in provider_list:
|
||||
provider_list.append(source)
|
||||
_save_auth_store(auth_store)
|
||||
|
||||
|
||||
def is_source_suppressed(provider_id: str, source: str) -> bool:
|
||||
"""Check if a credential source has been suppressed by the user."""
|
||||
try:
|
||||
auth_store = _load_auth_store()
|
||||
suppressed = auth_store.get("suppressed_sources", {})
|
||||
return source in suppressed.get(provider_id, [])
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def get_provider_auth_state(provider_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Return persisted auth state for a provider, or None."""
|
||||
auth_store = _load_auth_store()
|
||||
@@ -717,6 +745,57 @@ def get_active_provider() -> Optional[str]:
|
||||
return auth_store.get("active_provider")
|
||||
|
||||
|
||||
def is_provider_explicitly_configured(provider_id: str) -> bool:
|
||||
"""Return True only if the user has explicitly configured this provider.
|
||||
|
||||
Checks:
|
||||
1. active_provider in auth.json matches
|
||||
2. model.provider in config.yaml matches
|
||||
3. Provider-specific env vars are set (e.g. ANTHROPIC_API_KEY)
|
||||
|
||||
This is used to gate auto-discovery of external credentials (e.g.
|
||||
Claude Code's ~/.claude/.credentials.json) so they are never used
|
||||
without the user's explicit choice. See PR #4210 for the same
|
||||
pattern applied to the setup wizard gate.
|
||||
"""
|
||||
normalized = (provider_id or "").strip().lower()
|
||||
|
||||
# 1. Check auth.json active_provider
|
||||
try:
|
||||
auth_store = _load_auth_store()
|
||||
active = (auth_store.get("active_provider") or "").strip().lower()
|
||||
if active and active == normalized:
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 2. Check config.yaml model.provider
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
cfg = load_config()
|
||||
model_cfg = cfg.get("model")
|
||||
if isinstance(model_cfg, dict):
|
||||
cfg_provider = (model_cfg.get("provider") or "").strip().lower()
|
||||
if cfg_provider == normalized:
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 3. Check provider-specific env vars
|
||||
# Exclude CLAUDE_CODE_OAUTH_TOKEN — it's set by Claude Code itself,
|
||||
# not by the user explicitly configuring anthropic in Hermes.
|
||||
_IMPLICIT_ENV_VARS = {"CLAUDE_CODE_OAUTH_TOKEN"}
|
||||
pconfig = PROVIDER_REGISTRY.get(normalized)
|
||||
if pconfig and pconfig.auth_type == "api_key":
|
||||
for env_var in pconfig.api_key_env_vars:
|
||||
if env_var in _IMPLICIT_ENV_VARS:
|
||||
continue
|
||||
if has_usable_secret(os.getenv(env_var, "")):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def clear_provider_auth(provider_id: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Clear auth state for a provider. Used by `hermes logout`.
|
||||
@@ -819,7 +898,7 @@ def resolve_provider(
|
||||
_PROVIDER_ALIASES = {
|
||||
"glm": "zai", "z-ai": "zai", "z.ai": "zai", "zhipu": "zai",
|
||||
"google": "gemini", "google-gemini": "gemini", "google-ai-studio": "gemini",
|
||||
"kimi": "kimi-coding", "moonshot": "kimi-coding",
|
||||
"kimi": "kimi-coding", "kimi-for-coding": "kimi-coding", "moonshot": "kimi-coding",
|
||||
"minimax-china": "minimax-cn", "minimax_cn": "minimax-cn",
|
||||
"claude": "anthropic", "claude-code": "anthropic",
|
||||
"github": "copilot", "github-copilot": "copilot",
|
||||
@@ -1442,7 +1521,15 @@ def _resolve_verify(
|
||||
if effective_insecure:
|
||||
return False
|
||||
if effective_ca:
|
||||
return str(effective_ca)
|
||||
ca_path = str(effective_ca)
|
||||
if not os.path.isfile(ca_path):
|
||||
import logging
|
||||
logging.getLogger("hermes.auth").warning(
|
||||
"CA bundle path does not exist: %s — falling back to default certificates",
|
||||
ca_path,
|
||||
)
|
||||
return True
|
||||
return ca_path
|
||||
return True
|
||||
|
||||
|
||||
@@ -2342,33 +2429,6 @@ def resolve_external_process_provider_credentials(provider_id: str) -> Dict[str,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# External credential detection
|
||||
# =============================================================================
|
||||
|
||||
def detect_external_credentials() -> List[Dict[str, Any]]:
|
||||
"""Scan for credentials from other CLI tools that Hermes can reuse.
|
||||
|
||||
Returns a list of dicts, each with:
|
||||
- provider: str -- Hermes provider id (e.g. "openai-codex")
|
||||
- path: str -- filesystem path where creds were found
|
||||
- label: str -- human-friendly description for the setup UI
|
||||
"""
|
||||
found: List[Dict[str, Any]] = []
|
||||
|
||||
# Codex CLI: ~/.codex/auth.json (importable, not shared)
|
||||
cli_tokens = _import_codex_cli_tokens()
|
||||
if cli_tokens:
|
||||
codex_path = Path.home() / ".codex" / "auth.json"
|
||||
found.append({
|
||||
"provider": "openai-codex",
|
||||
"path": str(codex_path),
|
||||
"label": f"Codex CLI credentials found ({codex_path}) — run `hermes auth` to create a separate session",
|
||||
})
|
||||
|
||||
return found
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CLI Commands — login / logout
|
||||
# =============================================================================
|
||||
@@ -2572,6 +2632,8 @@ def _prompt_model_selection(
|
||||
title=effective_title,
|
||||
)
|
||||
idx = menu.show()
|
||||
from hermes_cli.curses_ui import flush_stdin
|
||||
flush_stdin()
|
||||
if idx is None:
|
||||
return None
|
||||
print()
|
||||
@@ -2581,7 +2643,7 @@ def _prompt_model_selection(
|
||||
custom = input("Enter model name: ").strip()
|
||||
return custom if custom else None
|
||||
return None
|
||||
except (ImportError, NotImplementedError):
|
||||
except (ImportError, NotImplementedError, OSError, subprocess.SubprocessError):
|
||||
pass
|
||||
|
||||
# Fallback: numbered list
|
||||
|
||||
@@ -347,8 +347,11 @@ def auth_remove_command(args) -> None:
|
||||
print("Cleared Hermes Anthropic OAuth credentials")
|
||||
|
||||
elif removed.source == "claude_code" and provider == "anthropic":
|
||||
print("Note: Claude Code credentials live in ~/.claude/.credentials.json")
|
||||
print(" Remove them manually if you want to deauthorize Claude Code.")
|
||||
from hermes_cli.auth import suppress_credential_source
|
||||
suppress_credential_source(provider, "claude_code")
|
||||
print("Suppressed claude_code credential — it will not be re-seeded.")
|
||||
print("Note: Claude Code credentials still live in ~/.claude/.credentials.json")
|
||||
print("Run `hermes auth add anthropic` to re-enable if needed.")
|
||||
|
||||
|
||||
def auth_reset_command(args) -> None:
|
||||
|
||||
@@ -90,12 +90,6 @@ HERMES_CADUCEUS = """[#CD7F32]⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⡀⠀⣀⣀
|
||||
[#B8860B]⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠳⠈⣡⠞⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[/]
|
||||
[#B8860B]⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[/]"""
|
||||
|
||||
COMPACT_BANNER = """
|
||||
[bold #FFD700]╔══════════════════════════════════════════════════════════════╗[/]
|
||||
[bold #FFD700]║[/] [#FFBF00]⚕ NOUS HERMES[/] [dim #B8860B]- AI Agent Framework[/] [bold #FFD700]║[/]
|
||||
[bold #FFD700]║[/] [#CD7F32]Messenger of the Digital Gods[/] [dim #B8860B]Nous Research[/] [bold #FFD700]║[/]
|
||||
[bold #FFD700]╚══════════════════════════════════════════════════════════════╝[/]
|
||||
"""
|
||||
|
||||
|
||||
# =========================================================================
|
||||
|
||||
@@ -1,140 +0,0 @@
|
||||
"""Shared curses-based multi-select checklist for Hermes CLI.
|
||||
|
||||
Used by both ``hermes tools`` and ``hermes skills`` to present a
|
||||
toggleable list of items. Falls back to a numbered text UI when
|
||||
curses is unavailable (Windows without curses, piped stdin, etc.).
|
||||
"""
|
||||
|
||||
import sys
|
||||
from typing import List, Set
|
||||
|
||||
from hermes_cli.colors import Colors, color
|
||||
|
||||
|
||||
def curses_checklist(
|
||||
title: str,
|
||||
items: List[str],
|
||||
pre_selected: Set[int],
|
||||
) -> Set[int]:
|
||||
"""Multi-select checklist. Returns set of **selected** indices.
|
||||
|
||||
Args:
|
||||
title: Header text shown at the top of the checklist.
|
||||
items: Display labels for each row.
|
||||
pre_selected: Indices that start checked.
|
||||
|
||||
Returns:
|
||||
The indices the user confirmed as checked. On cancel (ESC/q),
|
||||
returns ``pre_selected`` unchanged.
|
||||
"""
|
||||
# Safety: return defaults when stdin is not a terminal.
|
||||
if not sys.stdin.isatty():
|
||||
return set(pre_selected)
|
||||
|
||||
try:
|
||||
import curses
|
||||
selected = set(pre_selected)
|
||||
result = [None]
|
||||
|
||||
def _ui(stdscr):
|
||||
curses.curs_set(0)
|
||||
if curses.has_colors():
|
||||
curses.start_color()
|
||||
curses.use_default_colors()
|
||||
curses.init_pair(1, curses.COLOR_GREEN, -1)
|
||||
curses.init_pair(2, curses.COLOR_YELLOW, -1)
|
||||
curses.init_pair(3, 8, -1) # dim gray
|
||||
cursor = 0
|
||||
scroll_offset = 0
|
||||
|
||||
while True:
|
||||
stdscr.clear()
|
||||
max_y, max_x = stdscr.getmaxyx()
|
||||
|
||||
# Header
|
||||
try:
|
||||
hattr = curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0)
|
||||
stdscr.addnstr(0, 0, title, max_x - 1, hattr)
|
||||
stdscr.addnstr(
|
||||
1, 0,
|
||||
" ↑↓ navigate SPACE toggle ENTER confirm ESC cancel",
|
||||
max_x - 1, curses.A_DIM,
|
||||
)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
# Scrollable item list
|
||||
visible_rows = max_y - 3
|
||||
if cursor < scroll_offset:
|
||||
scroll_offset = cursor
|
||||
elif cursor >= scroll_offset + visible_rows:
|
||||
scroll_offset = cursor - visible_rows + 1
|
||||
|
||||
for draw_i, i in enumerate(
|
||||
range(scroll_offset, min(len(items), scroll_offset + visible_rows))
|
||||
):
|
||||
y = draw_i + 3
|
||||
if y >= max_y - 1:
|
||||
break
|
||||
check = "✓" if i in selected else " "
|
||||
arrow = "→" if i == cursor else " "
|
||||
line = f" {arrow} [{check}] {items[i]}"
|
||||
|
||||
attr = curses.A_NORMAL
|
||||
if i == cursor:
|
||||
attr = curses.A_BOLD
|
||||
if curses.has_colors():
|
||||
attr |= curses.color_pair(1)
|
||||
try:
|
||||
stdscr.addnstr(y, 0, line, max_x - 1, attr)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
stdscr.refresh()
|
||||
key = stdscr.getch()
|
||||
|
||||
if key in (curses.KEY_UP, ord("k")):
|
||||
cursor = (cursor - 1) % len(items)
|
||||
elif key in (curses.KEY_DOWN, ord("j")):
|
||||
cursor = (cursor + 1) % len(items)
|
||||
elif key == ord(" "):
|
||||
selected.symmetric_difference_update({cursor})
|
||||
elif key in (curses.KEY_ENTER, 10, 13):
|
||||
result[0] = set(selected)
|
||||
return
|
||||
elif key in (27, ord("q")):
|
||||
result[0] = set(pre_selected)
|
||||
return
|
||||
|
||||
curses.wrapper(_ui)
|
||||
return result[0] if result[0] is not None else set(pre_selected)
|
||||
|
||||
except Exception:
|
||||
pass # fall through to numbered fallback
|
||||
|
||||
# ── Numbered text fallback ────────────────────────────────────────────
|
||||
selected = set(pre_selected)
|
||||
print(color(f"\n {title}", Colors.YELLOW))
|
||||
print(color(" Toggle by number, Enter to confirm.\n", Colors.DIM))
|
||||
|
||||
while True:
|
||||
for i, label in enumerate(items):
|
||||
check = "✓" if i in selected else " "
|
||||
print(f" {i + 1:3}. [{check}] {label}")
|
||||
print()
|
||||
|
||||
try:
|
||||
raw = input(color(" Number to toggle, 's' to save, 'q' to cancel: ", Colors.DIM)).strip()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
return set(pre_selected)
|
||||
|
||||
if raw.lower() == "s" or raw == "":
|
||||
return selected
|
||||
if raw.lower() == "q":
|
||||
return set(pre_selected)
|
||||
try:
|
||||
idx = int(raw) - 1
|
||||
if 0 <= idx < len(items):
|
||||
selected.symmetric_difference_update({idx})
|
||||
except ValueError:
|
||||
print(color(" Invalid input", Colors.DIM))
|
||||
+36
-12
@@ -16,8 +16,18 @@ from collections.abc import Callable, Mapping
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from prompt_toolkit.auto_suggest import AutoSuggest, Suggestion
|
||||
from prompt_toolkit.completion import Completer, Completion
|
||||
# prompt_toolkit is an optional CLI dependency — only needed for
|
||||
# SlashCommandCompleter and SlashCommandAutoSuggest. Gateway and test
|
||||
# environments that lack it must still be able to import this module
|
||||
# for resolve_command, gateway_help_lines, and COMMAND_REGISTRY.
|
||||
try:
|
||||
from prompt_toolkit.auto_suggest import AutoSuggest, Suggestion
|
||||
from prompt_toolkit.completion import Completer, Completion
|
||||
except ImportError: # pragma: no cover
|
||||
AutoSuggest = object # type: ignore[assignment,misc]
|
||||
Completer = object # type: ignore[assignment,misc]
|
||||
Suggestion = None # type: ignore[assignment]
|
||||
Completion = None # type: ignore[assignment]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -73,8 +83,7 @@ COMMAND_REGISTRY: list[CommandDef] = [
|
||||
args_hint="<question>"),
|
||||
CommandDef("queue", "Queue a prompt for the next turn (doesn't interrupt)", "Session",
|
||||
aliases=("q",), args_hint="<prompt>"),
|
||||
CommandDef("status", "Show session info", "Session",
|
||||
gateway_only=True),
|
||||
CommandDef("status", "Show session info", "Session"),
|
||||
CommandDef("profile", "Show active profile name and home directory", "Info"),
|
||||
CommandDef("sethome", "Set this chat as the home channel", "Session",
|
||||
gateway_only=True, aliases=("set-home",)),
|
||||
@@ -99,7 +108,10 @@ COMMAND_REGISTRY: list[CommandDef] = [
|
||||
"Configuration"),
|
||||
CommandDef("reasoning", "Manage reasoning effort and display", "Configuration",
|
||||
args_hint="[level|show|hide]",
|
||||
subcommands=("none", "low", "minimal", "medium", "high", "xhigh", "show", "hide", "on", "off")),
|
||||
subcommands=("none", "minimal", "low", "medium", "high", "xhigh", "show", "hide", "on", "off")),
|
||||
CommandDef("fast", "Toggle fast mode — OpenAI Priority Processing / Anthropic Fast Mode (Normal/Fast)", "Configuration",
|
||||
args_hint="[normal|fast|status]",
|
||||
subcommands=("normal", "fast", "status", "on", "off")),
|
||||
CommandDef("skin", "Show or change the display skin/theme", "Configuration",
|
||||
cli_only=True, args_hint="[name]"),
|
||||
CommandDef("voice", "Toggle voice mode", "Configuration",
|
||||
@@ -135,6 +147,8 @@ COMMAND_REGISTRY: list[CommandDef] = [
|
||||
cli_only=True, aliases=("gateway",)),
|
||||
CommandDef("paste", "Check clipboard for an image and attach it", "Info",
|
||||
cli_only=True),
|
||||
CommandDef("image", "Attach a local image file for your next prompt", "Info",
|
||||
cli_only=True, args_hint="<path>"),
|
||||
CommandDef("update", "Update Hermes Agent to the latest version", "Info",
|
||||
gateway_only=True),
|
||||
|
||||
@@ -169,12 +183,6 @@ def resolve_command(name: str) -> CommandDef | None:
|
||||
return _COMMAND_LOOKUP.get(name.lower().lstrip("/"))
|
||||
|
||||
|
||||
def register_plugin_command(cmd: CommandDef) -> None:
|
||||
"""Append a plugin-defined command to the registry and refresh lookups."""
|
||||
COMMAND_REGISTRY.append(cmd)
|
||||
rebuild_lookups()
|
||||
|
||||
|
||||
def rebuild_lookups() -> None:
|
||||
"""Rebuild all derived lookup dicts from the current COMMAND_REGISTRY.
|
||||
|
||||
@@ -637,8 +645,18 @@ class SlashCommandCompleter(Completer):
|
||||
def __init__(
|
||||
self,
|
||||
skill_commands_provider: Callable[[], Mapping[str, dict[str, Any]]] | None = None,
|
||||
command_filter: Callable[[str], bool] | None = None,
|
||||
) -> None:
|
||||
self._skill_commands_provider = skill_commands_provider
|
||||
self._command_filter = command_filter
|
||||
|
||||
def _command_allowed(self, slash_command: str) -> bool:
|
||||
if self._command_filter is None:
|
||||
return True
|
||||
try:
|
||||
return bool(self._command_filter(slash_command))
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
def _iter_skill_commands(self) -> Mapping[str, dict[str, Any]]:
|
||||
if self._skill_commands_provider is None:
|
||||
@@ -916,7 +934,7 @@ class SlashCommandCompleter(Completer):
|
||||
return
|
||||
|
||||
# Static subcommand completions
|
||||
if " " not in sub_text and base_cmd in SUBCOMMANDS:
|
||||
if " " not in sub_text and base_cmd in SUBCOMMANDS and self._command_allowed(base_cmd):
|
||||
for sub in SUBCOMMANDS[base_cmd]:
|
||||
if sub.startswith(sub_lower) and sub != sub_lower:
|
||||
yield Completion(
|
||||
@@ -929,6 +947,8 @@ class SlashCommandCompleter(Completer):
|
||||
word = text[1:]
|
||||
|
||||
for cmd, desc in COMMANDS.items():
|
||||
if not self._command_allowed(cmd):
|
||||
continue
|
||||
cmd_name = cmd[1:]
|
||||
if cmd_name.startswith(word):
|
||||
yield Completion(
|
||||
@@ -987,6 +1007,8 @@ class SlashCommandAutoSuggest(AutoSuggest):
|
||||
# Still typing the command name: /upd → suggest "ate"
|
||||
word = text[1:].lower()
|
||||
for cmd in COMMANDS:
|
||||
if self._completer is not None and not self._completer._command_allowed(cmd):
|
||||
continue
|
||||
cmd_name = cmd[1:] # strip leading /
|
||||
if cmd_name.startswith(word) and cmd_name != word:
|
||||
return Suggestion(cmd_name[len(word):])
|
||||
@@ -997,6 +1019,8 @@ class SlashCommandAutoSuggest(AutoSuggest):
|
||||
sub_lower = sub_text.lower()
|
||||
|
||||
# Static subcommands
|
||||
if self._completer is not None and not self._completer._command_allowed(base_cmd):
|
||||
return None
|
||||
if base_cmd in SUBCOMMANDS and SUBCOMMANDS[base_cmd]:
|
||||
if " " not in sub_text:
|
||||
for sub in SUBCOMMANDS[base_cmd]:
|
||||
|
||||
+115
-11
@@ -39,6 +39,9 @@ _EXTRA_ENV_KEYS = frozenset({
|
||||
"DINGTALK_CLIENT_ID", "DINGTALK_CLIENT_SECRET",
|
||||
"FEISHU_APP_ID", "FEISHU_APP_SECRET", "FEISHU_ENCRYPT_KEY", "FEISHU_VERIFICATION_TOKEN",
|
||||
"WECOM_BOT_ID", "WECOM_SECRET",
|
||||
"WEIXIN_ACCOUNT_ID", "WEIXIN_TOKEN", "WEIXIN_BASE_URL", "WEIXIN_CDN_BASE_URL",
|
||||
"WEIXIN_HOME_CHANNEL", "WEIXIN_HOME_CHANNEL_NAME", "WEIXIN_DM_POLICY", "WEIXIN_GROUP_POLICY",
|
||||
"WEIXIN_ALLOWED_USERS", "WEIXIN_GROUP_ALLOWED_USERS", "WEIXIN_ALLOW_ALL_USERS",
|
||||
"BLUEBUBBLES_SERVER_URL", "BLUEBUBBLES_PASSWORD",
|
||||
"TERMINAL_ENV", "TERMINAL_SSH_KEY", "TERMINAL_SSH_PORT",
|
||||
"WHATSAPP_MODE", "WHATSAPP_ENABLED",
|
||||
@@ -158,16 +161,27 @@ def get_project_root() -> Path:
|
||||
return Path(__file__).parent.parent.resolve()
|
||||
|
||||
def _secure_dir(path):
|
||||
"""Set directory to owner-only access (0700). No-op on Windows.
|
||||
"""Set directory to owner-only access (0700 by default). No-op on Windows.
|
||||
|
||||
Skipped in managed mode — the NixOS module sets group-readable
|
||||
permissions (0750) so interactive users in the hermes group can
|
||||
share state with the gateway service.
|
||||
|
||||
The mode can be overridden via the HERMES_HOME_MODE environment variable
|
||||
(e.g. HERMES_HOME_MODE=0701) for deployments where a web server (nginx,
|
||||
caddy, etc.) needs to traverse HERMES_HOME to reach a served subdirectory.
|
||||
The execute-only bit on a directory permits cd-through without exposing
|
||||
directory listings.
|
||||
"""
|
||||
if is_managed():
|
||||
return
|
||||
try:
|
||||
os.chmod(path, 0o700)
|
||||
mode_str = os.environ.get("HERMES_HOME_MODE", "").strip()
|
||||
mode = int(mode_str, 8) if mode_str else 0o700
|
||||
except ValueError:
|
||||
mode = 0o700
|
||||
try:
|
||||
os.chmod(path, mode)
|
||||
except (OSError, NotImplementedError):
|
||||
pass
|
||||
|
||||
@@ -197,14 +211,44 @@ def _ensure_default_soul_md(home: Path) -> None:
|
||||
|
||||
|
||||
def ensure_hermes_home():
|
||||
"""Ensure ~/.hermes directory structure exists with secure permissions."""
|
||||
"""Ensure ~/.hermes directory structure exists with secure permissions.
|
||||
|
||||
In managed mode (NixOS), dirs are created by the activation script with
|
||||
setgid + group-writable (2770). We skip mkdir and set umask(0o007) so
|
||||
any files created (e.g. SOUL.md) are group-writable (0660).
|
||||
"""
|
||||
home = get_hermes_home()
|
||||
home.mkdir(parents=True, exist_ok=True)
|
||||
_secure_dir(home)
|
||||
if is_managed():
|
||||
old_umask = os.umask(0o007)
|
||||
try:
|
||||
_ensure_hermes_home_managed(home)
|
||||
finally:
|
||||
os.umask(old_umask)
|
||||
else:
|
||||
home.mkdir(parents=True, exist_ok=True)
|
||||
_secure_dir(home)
|
||||
for subdir in ("cron", "sessions", "logs", "memories"):
|
||||
d = home / subdir
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
_secure_dir(d)
|
||||
_ensure_default_soul_md(home)
|
||||
|
||||
|
||||
def _ensure_hermes_home_managed(home: Path):
|
||||
"""Managed-mode variant: verify dirs exist (activation creates them), seed SOUL.md."""
|
||||
if not home.is_dir():
|
||||
raise RuntimeError(
|
||||
f"HERMES_HOME {home} does not exist. "
|
||||
"Run 'sudo nixos-rebuild switch' first."
|
||||
)
|
||||
for subdir in ("cron", "sessions", "logs", "memories"):
|
||||
d = home / subdir
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
_secure_dir(d)
|
||||
if not d.is_dir():
|
||||
raise RuntimeError(
|
||||
f"{d} does not exist. "
|
||||
"Run 'sudo nixos-rebuild switch' first."
|
||||
)
|
||||
# Inside umask(0o007) scope — SOUL.md will be created as 0660
|
||||
_ensure_default_soul_md(home)
|
||||
|
||||
|
||||
@@ -225,6 +269,7 @@ DEFAULT_CONFIG = {
|
||||
# tools or receiving API responses. Only fires when the agent has
|
||||
# been completely idle for this duration. 0 = unlimited.
|
||||
"gateway_timeout": 1800,
|
||||
"service_tier": "",
|
||||
# Tool-use enforcement: injects system prompt guidance that tells the
|
||||
# model to actually call tools instead of describing intended actions.
|
||||
# Values: "auto" (default — applies to gpt/codex models), true/false
|
||||
@@ -510,6 +555,7 @@ DEFAULT_CONFIG = {
|
||||
"discord": {
|
||||
"require_mention": True, # Require @mention to respond in server channels
|
||||
"free_response_channels": "", # Comma-separated channel IDs where bot responds without mention
|
||||
"allowed_channels": "", # If set, bot ONLY responds in these channel IDs (whitelist)
|
||||
"auto_thread": True, # Auto-create threads on @mention in channels (like Slack)
|
||||
"reactions": True, # Add 👀/✅/❌ reactions to messages during processing
|
||||
},
|
||||
@@ -569,7 +615,7 @@ DEFAULT_CONFIG = {
|
||||
},
|
||||
|
||||
# Config schema version - bump this when adding new required fields
|
||||
"_config_version": 13,
|
||||
"_config_version": 14,
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
@@ -1163,8 +1209,8 @@ OPTIONAL_ENV_VARS = {
|
||||
"advanced": True,
|
||||
},
|
||||
"API_SERVER_KEY": {
|
||||
"description": "Bearer token for API server authentication. If empty, all requests are allowed (local use only).",
|
||||
"prompt": "API server auth key (optional)",
|
||||
"description": "Bearer token for API server authentication. Required for non-loopback binding; server refuses to start without it. On loopback (127.0.0.1), all requests are allowed if empty.",
|
||||
"prompt": "API server auth key (required for network access)",
|
||||
"url": None,
|
||||
"password": True,
|
||||
"category": "messaging",
|
||||
@@ -1179,13 +1225,21 @@ OPTIONAL_ENV_VARS = {
|
||||
"advanced": True,
|
||||
},
|
||||
"API_SERVER_HOST": {
|
||||
"description": "Host/bind address for the API server (default: 127.0.0.1). Use 0.0.0.0 for network access — requires API_SERVER_KEY for security.",
|
||||
"description": "Host/bind address for the API server (default: 127.0.0.1). Use 0.0.0.0 for network access — server refuses to start without API_SERVER_KEY.",
|
||||
"prompt": "API server host",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "messaging",
|
||||
"advanced": True,
|
||||
},
|
||||
"API_SERVER_MODEL_NAME": {
|
||||
"description": "Model name advertised on /v1/models. Defaults to the profile name (or 'hermes-agent' for the default profile). Useful for multi-user setups with OpenWebUI.",
|
||||
"prompt": "API server model name",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "messaging",
|
||||
"advanced": True,
|
||||
},
|
||||
"WEBHOOK_ENABLED": {
|
||||
"description": "Enable the webhook platform adapter for receiving events from GitHub, GitLab, etc.",
|
||||
"prompt": "Enable webhooks (true/false)",
|
||||
@@ -1716,6 +1770,56 @@ def migrate_config(interactive: bool = True, quiet: bool = False) -> Dict[str, A
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ── Version 13 → 14: migrate legacy flat stt.model to provider section ──
|
||||
# Old configs (and cli-config.yaml.example) had a flat `stt.model` key
|
||||
# that was provider-agnostic. When the provider was "local" this caused
|
||||
# OpenAI model names (e.g. "whisper-1") to be fed to faster-whisper,
|
||||
# crashing with "Invalid model size". Move the value into the correct
|
||||
# provider-specific section and remove the flat key.
|
||||
if current_ver < 14:
|
||||
# Read raw config (no defaults merged) to check what the user actually
|
||||
# wrote, then apply changes to the merged config for saving.
|
||||
raw = read_raw_config()
|
||||
raw_stt = raw.get("stt", {})
|
||||
if isinstance(raw_stt, dict) and "model" in raw_stt:
|
||||
legacy_model = raw_stt["model"]
|
||||
provider = raw_stt.get("provider", "local")
|
||||
config = load_config()
|
||||
stt = config.get("stt", {})
|
||||
# Remove the legacy flat key
|
||||
stt.pop("model", None)
|
||||
# Place it in the appropriate provider section only if the
|
||||
# user didn't already set a model there
|
||||
if provider in ("local", "local_command"):
|
||||
# Don't migrate an OpenAI model name into the local section
|
||||
_local_models = {
|
||||
"tiny.en", "tiny", "base.en", "base", "small.en", "small",
|
||||
"medium.en", "medium", "large-v1", "large-v2", "large-v3",
|
||||
"large", "distil-large-v2", "distil-medium.en",
|
||||
"distil-small.en", "distil-large-v3", "distil-large-v3.5",
|
||||
"large-v3-turbo", "turbo",
|
||||
}
|
||||
if legacy_model in _local_models:
|
||||
# Check raw config — only set if user didn't already
|
||||
# have a nested local.model
|
||||
raw_local = raw_stt.get("local", {})
|
||||
if not isinstance(raw_local, dict) or "model" not in raw_local:
|
||||
local_cfg = stt.setdefault("local", {})
|
||||
local_cfg["model"] = legacy_model
|
||||
# else: drop it — it was an OpenAI model name, local section
|
||||
# already defaults to "base" via DEFAULT_CONFIG
|
||||
else:
|
||||
# Cloud provider — put it in that provider's section only
|
||||
# if user didn't already set a nested model
|
||||
raw_provider = raw_stt.get(provider, {})
|
||||
if not isinstance(raw_provider, dict) or "model" not in raw_provider:
|
||||
provider_cfg = stt.setdefault(provider, {})
|
||||
provider_cfg["model"] = legacy_model
|
||||
config["stt"] = stt
|
||||
save_config(config)
|
||||
if not quiet:
|
||||
print(f" ✓ Migrated legacy stt.model to provider-specific config")
|
||||
|
||||
if current_ver < latest_ver and not quiet:
|
||||
print(f"Config version: {current_ver} → {latest_ver}")
|
||||
|
||||
|
||||
@@ -31,13 +31,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# OAuth device code flow constants (same client ID as opencode/Copilot CLI)
|
||||
COPILOT_OAUTH_CLIENT_ID = "Ov23li8tweQw6odWQebz"
|
||||
COPILOT_DEVICE_CODE_URL = "https://github.com/login/device/code"
|
||||
COPILOT_ACCESS_TOKEN_URL = "https://github.com/login/oauth/access_token"
|
||||
|
||||
# Copilot API constants
|
||||
COPILOT_TOKEN_EXCHANGE_URL = "https://api.github.com/copilot_internal/v2/token"
|
||||
COPILOT_API_BASE_URL = "https://api.githubcopilot.com"
|
||||
|
||||
# Token type prefixes
|
||||
_CLASSIC_PAT_PREFIX = "ghp_"
|
||||
_SUPPORTED_PREFIXES = ("gho_", "github_pat_", "ghu_")
|
||||
@@ -50,11 +43,6 @@ _DEVICE_CODE_POLL_INTERVAL = 5 # seconds
|
||||
_DEVICE_CODE_POLL_SAFETY_MARGIN = 3 # seconds
|
||||
|
||||
|
||||
def is_classic_pat(token: str) -> bool:
|
||||
"""Check if a token is a classic PAT (ghp_*), which Copilot doesn't support."""
|
||||
return token.strip().startswith(_CLASSIC_PAT_PREFIX)
|
||||
|
||||
|
||||
def validate_copilot_token(token: str) -> tuple[bool, str]:
|
||||
"""Validate that a token is usable with the Copilot API.
|
||||
|
||||
@@ -285,6 +273,7 @@ def copilot_request_headers(
|
||||
headers: dict[str, str] = {
|
||||
"Editor-Version": "vscode/1.104.1",
|
||||
"User-Agent": "HermesAgent/1.0",
|
||||
"Copilot-Integration-Id": "vscode-chat",
|
||||
"Openai-Intent": "conversation-edits",
|
||||
"x-initiator": "agent" if is_agent_turn else "user",
|
||||
}
|
||||
|
||||
@@ -10,6 +10,28 @@ from typing import Callable, List, Optional, Set
|
||||
from hermes_cli.colors import Colors, color
|
||||
|
||||
|
||||
def flush_stdin() -> None:
|
||||
"""Flush any stray bytes from the stdin input buffer.
|
||||
|
||||
Must be called after ``curses.wrapper()`` (or any terminal-mode library
|
||||
like simple_term_menu) returns, **before** the next ``input()`` /
|
||||
``getpass.getpass()`` call. ``curses.endwin()`` restores the terminal
|
||||
but does NOT drain the OS input buffer — leftover escape-sequence bytes
|
||||
(from arrow keys, terminal mode-switch responses, or rapid keypresses)
|
||||
remain buffered and silently get consumed by the next ``input()`` call,
|
||||
corrupting user data (e.g. writing ``^[^[`` into .env files).
|
||||
|
||||
On non-TTY stdin (piped, redirected) or Windows, this is a no-op.
|
||||
"""
|
||||
try:
|
||||
if not sys.stdin.isatty():
|
||||
return
|
||||
import termios
|
||||
termios.tcflush(sys.stdin, termios.TCIFLUSH)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def curses_checklist(
|
||||
title: str,
|
||||
items: List[str],
|
||||
@@ -131,6 +153,7 @@ def curses_checklist(
|
||||
return
|
||||
|
||||
curses.wrapper(_draw)
|
||||
flush_stdin()
|
||||
return result_holder[0] if result_holder[0] is not None else cancel_returns
|
||||
|
||||
except Exception:
|
||||
|
||||
+52
-8
@@ -54,6 +54,32 @@ _PROVIDER_ENV_HINTS = (
|
||||
)
|
||||
|
||||
|
||||
from hermes_constants import is_termux as _is_termux
|
||||
|
||||
|
||||
def _python_install_cmd() -> str:
|
||||
return "python -m pip install" if _is_termux() else "uv pip install"
|
||||
|
||||
|
||||
def _system_package_install_cmd(pkg: str) -> str:
|
||||
if _is_termux():
|
||||
return f"pkg install {pkg}"
|
||||
if sys.platform == "darwin":
|
||||
return f"brew install {pkg}"
|
||||
return f"sudo apt install {pkg}"
|
||||
|
||||
|
||||
def _termux_browser_setup_steps(node_installed: bool) -> list[str]:
|
||||
steps: list[str] = []
|
||||
step = 1
|
||||
if not node_installed:
|
||||
steps.append(f"{step}) pkg install nodejs")
|
||||
step += 1
|
||||
steps.append(f"{step}) npm install -g agent-browser")
|
||||
steps.append(f"{step + 1}) agent-browser install")
|
||||
return steps
|
||||
|
||||
|
||||
def _has_provider_env_config(content: str) -> bool:
|
||||
"""Return True when ~/.hermes/.env contains provider auth/base URL settings."""
|
||||
return any(key in content for key in _PROVIDER_ENV_HINTS)
|
||||
@@ -200,7 +226,7 @@ def run_doctor(args):
|
||||
check_ok(name)
|
||||
except ImportError:
|
||||
check_fail(name, "(missing)")
|
||||
issues.append(f"Install {name}: uv pip install {module}")
|
||||
issues.append(f"Install {name}: {_python_install_cmd()} {module}")
|
||||
|
||||
for module, name in optional_packages:
|
||||
try:
|
||||
@@ -503,7 +529,7 @@ def run_doctor(args):
|
||||
check_ok("ripgrep (rg)", "(faster file search)")
|
||||
else:
|
||||
check_warn("ripgrep (rg) not found", "(file search uses grep fallback)")
|
||||
check_info("Install for faster search: sudo apt install ripgrep")
|
||||
check_info(f"Install for faster search: {_system_package_install_cmd('ripgrep')}")
|
||||
|
||||
# Docker (optional)
|
||||
terminal_env = os.getenv("TERMINAL_ENV", "local")
|
||||
@@ -526,7 +552,10 @@ def run_doctor(args):
|
||||
if shutil.which("docker"):
|
||||
check_ok("docker", "(optional)")
|
||||
else:
|
||||
check_warn("docker not found", "(optional)")
|
||||
if _is_termux():
|
||||
check_info("Docker backend is not available inside Termux (expected on Android)")
|
||||
else:
|
||||
check_warn("docker not found", "(optional)")
|
||||
|
||||
# SSH (if using ssh backend)
|
||||
if terminal_env == "ssh":
|
||||
@@ -574,9 +603,23 @@ def run_doctor(args):
|
||||
if agent_browser_path.exists():
|
||||
check_ok("agent-browser (Node.js)", "(browser automation)")
|
||||
else:
|
||||
check_warn("agent-browser not installed", "(run: npm install)")
|
||||
if _is_termux():
|
||||
check_info("agent-browser is not installed (expected in the tested Termux path)")
|
||||
check_info("Install it manually later with: npm install -g agent-browser && agent-browser install")
|
||||
check_info("Termux browser setup:")
|
||||
for step in _termux_browser_setup_steps(node_installed=True):
|
||||
check_info(step)
|
||||
else:
|
||||
check_warn("agent-browser not installed", "(run: npm install)")
|
||||
else:
|
||||
check_warn("Node.js not found", "(optional, needed for browser tools)")
|
||||
if _is_termux():
|
||||
check_info("Node.js not found (browser tools are optional in the tested Termux path)")
|
||||
check_info("Install Node.js on Termux with: pkg install nodejs")
|
||||
check_info("Termux browser setup:")
|
||||
for step in _termux_browser_setup_steps(node_installed=False):
|
||||
check_info(step)
|
||||
else:
|
||||
check_warn("Node.js not found", "(optional, needed for browser tools)")
|
||||
|
||||
# npm audit for all Node.js packages
|
||||
if shutil.which("npm"):
|
||||
@@ -709,7 +752,7 @@ def run_doctor(args):
|
||||
_url = (_base.rstrip("/") + "/models") if _base else _default_url
|
||||
_headers = {"Authorization": f"Bearer {_key}"}
|
||||
if "api.kimi.com" in _url.lower():
|
||||
_headers["User-Agent"] = "KimiCLI/1.0"
|
||||
_headers["User-Agent"] = "KimiCLI/1.30.0"
|
||||
_resp = httpx.get(
|
||||
_url,
|
||||
headers=_headers,
|
||||
@@ -739,8 +782,9 @@ def run_doctor(args):
|
||||
__import__("tinker_atropos")
|
||||
check_ok("tinker-atropos", "(RL training backend)")
|
||||
except ImportError:
|
||||
check_warn("tinker-atropos found but not installed", "(run: uv pip install -e ./tinker-atropos)")
|
||||
issues.append("Install tinker-atropos: uv pip install -e ./tinker-atropos")
|
||||
install_cmd = f"{_python_install_cmd()} -e ./tinker-atropos"
|
||||
check_warn("tinker-atropos found but not installed", f"(run: {install_cmd})")
|
||||
issues.append(f"Install tinker-atropos: {install_cmd}")
|
||||
else:
|
||||
check_warn("tinker-atropos requires Python 3.11+", f"(current: {py_version.major}.{py_version.minor})")
|
||||
else:
|
||||
|
||||
+1
-5
@@ -32,11 +32,6 @@ def _get_git_commit(project_root: Path) -> str:
|
||||
return "(unknown)"
|
||||
|
||||
|
||||
def _key_present(name: str) -> str:
|
||||
"""Return 'set' or 'not set' for an env var."""
|
||||
return "set" if os.getenv(name) else "not set"
|
||||
|
||||
|
||||
def _redact(value: str) -> str:
|
||||
"""Redact all but first 4 and last 4 chars."""
|
||||
if not value:
|
||||
@@ -124,6 +119,7 @@ def _configured_platforms() -> list[str]:
|
||||
"dingtalk": "DINGTALK_CLIENT_ID",
|
||||
"feishu": "FEISHU_APP_ID",
|
||||
"wecom": "WECOM_BOT_ID",
|
||||
"weixin": "WEIXIN_ACCOUNT_ID",
|
||||
}
|
||||
return [name for name, env in checks.items() if os.getenv(env)]
|
||||
|
||||
|
||||
+264
-61
@@ -14,6 +14,7 @@ from pathlib import Path
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
||||
|
||||
from gateway.status import terminate_pid
|
||||
from hermes_cli.config import get_env_value, get_hermes_home, save_env_value, is_managed, managed_error
|
||||
# display_hermes_home is imported lazily at call sites to avoid ImportError
|
||||
# when hermes_constants is cached from a pre-update version during `hermes update`.
|
||||
@@ -39,7 +40,7 @@ def _get_service_pids() -> set:
|
||||
pids: set = set()
|
||||
|
||||
# --- systemd (Linux): user and system scopes ---
|
||||
if is_linux():
|
||||
if supports_systemd_services():
|
||||
for scope_args in [["systemctl", "--user"], ["systemctl"]]:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
@@ -162,7 +163,7 @@ def kill_gateway_processes(force: bool = False, exclude_pids: set | None = None)
|
||||
"""Kill any running gateway processes. Returns count killed.
|
||||
|
||||
Args:
|
||||
force: Use SIGKILL instead of SIGTERM.
|
||||
force: Use the platform's force-kill mechanism instead of graceful terminate.
|
||||
exclude_pids: PIDs to skip (e.g. service-managed PIDs that were just
|
||||
restarted and should not be killed).
|
||||
"""
|
||||
@@ -171,10 +172,7 @@ def kill_gateway_processes(force: bool = False, exclude_pids: set | None = None)
|
||||
|
||||
for pid in pids:
|
||||
try:
|
||||
if force and not is_windows():
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
else:
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
terminate_pid(pid, force=force)
|
||||
killed += 1
|
||||
except ProcessLookupError:
|
||||
# Process already gone
|
||||
@@ -182,6 +180,8 @@ def kill_gateway_processes(force: bool = False, exclude_pids: set | None = None)
|
||||
except PermissionError:
|
||||
print(f"⚠ Permission denied to kill PID {pid}")
|
||||
|
||||
except OSError as exc:
|
||||
print(f"Failed to kill PID {pid}: {exc}")
|
||||
return killed
|
||||
|
||||
|
||||
@@ -225,6 +225,14 @@ def stop_profile_gateway() -> bool:
|
||||
def is_linux() -> bool:
|
||||
return sys.platform.startswith('linux')
|
||||
|
||||
|
||||
from hermes_constants import is_termux
|
||||
|
||||
|
||||
def supports_systemd_services() -> bool:
|
||||
return is_linux() and not is_termux()
|
||||
|
||||
|
||||
def is_macos() -> bool:
|
||||
return sys.platform == 'darwin'
|
||||
|
||||
@@ -243,18 +251,18 @@ SERVICE_DESCRIPTION = "Hermes Agent Gateway - Messaging Platform Integration"
|
||||
def _profile_suffix() -> str:
|
||||
"""Derive a service-name suffix from the current HERMES_HOME.
|
||||
|
||||
Returns ``""`` for the default ``~/.hermes``, the profile name for
|
||||
``~/.hermes/profiles/<name>``, or a short hash for any other custom
|
||||
HERMES_HOME path.
|
||||
Returns ``""`` for the default root, the profile name for
|
||||
``<root>/profiles/<name>``, or a short hash for any other path.
|
||||
Works correctly in Docker (HERMES_HOME=/opt/data) and standard deployments.
|
||||
"""
|
||||
import hashlib
|
||||
import re
|
||||
from pathlib import Path as _Path
|
||||
from hermes_constants import get_default_hermes_root
|
||||
home = get_hermes_home().resolve()
|
||||
default = (_Path.home() / ".hermes").resolve()
|
||||
default = get_default_hermes_root().resolve()
|
||||
if home == default:
|
||||
return ""
|
||||
# Detect ~/.hermes/profiles/<name> pattern → use the profile name
|
||||
# Detect <root>/profiles/<name> pattern → use the profile name
|
||||
profiles_root = (default / "profiles").resolve()
|
||||
try:
|
||||
rel = home.relative_to(profiles_root)
|
||||
@@ -279,9 +287,9 @@ def _profile_arg(hermes_home: str | None = None) -> str:
|
||||
service definition for a different user (e.g. system service).
|
||||
"""
|
||||
import re
|
||||
from pathlib import Path as _Path
|
||||
from hermes_constants import get_default_hermes_root
|
||||
home = Path(hermes_home or str(get_hermes_home())).resolve()
|
||||
default = (_Path.home() / ".hermes").resolve()
|
||||
default = get_default_hermes_root().resolve()
|
||||
if home == default:
|
||||
return ""
|
||||
profiles_root = (default / "profiles").resolve()
|
||||
@@ -308,8 +316,6 @@ def get_service_name() -> str:
|
||||
return f"{_SERVICE_BASE}-{suffix}"
|
||||
|
||||
|
||||
SERVICE_NAME = _SERVICE_BASE # backward-compat for external importers; prefer get_service_name()
|
||||
|
||||
|
||||
def get_systemd_unit_path(system: bool = False) -> Path:
|
||||
name = get_service_name()
|
||||
@@ -477,13 +483,15 @@ def install_linux_gateway_from_setup(force: bool = False) -> tuple[str | None, b
|
||||
|
||||
|
||||
def get_systemd_linger_status() -> tuple[bool | None, str]:
|
||||
"""Return whether systemd user lingering is enabled for the current user.
|
||||
"""Return systemd linger status for the current user.
|
||||
|
||||
Returns:
|
||||
(True, "") when linger is enabled.
|
||||
(False, "") when linger is disabled.
|
||||
(None, detail) when the status could not be determined.
|
||||
"""
|
||||
if is_termux():
|
||||
return None, "not supported in Termux"
|
||||
if not is_linux():
|
||||
return None, "not supported on this platform"
|
||||
|
||||
@@ -581,17 +589,6 @@ def get_python_path() -> str:
|
||||
return str(venv_python)
|
||||
return sys.executable
|
||||
|
||||
def get_hermes_cli_path() -> str:
|
||||
"""Get the path to the hermes CLI."""
|
||||
# Check if installed via pip
|
||||
import shutil
|
||||
hermes_bin = shutil.which("hermes")
|
||||
if hermes_bin:
|
||||
return hermes_bin
|
||||
|
||||
# Fallback to direct module execution
|
||||
return f"{get_python_path()} -m hermes_cli.main"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Systemd (Linux)
|
||||
@@ -608,6 +605,24 @@ def _build_user_local_paths(home: Path, path_entries: list[str]) -> list[str]:
|
||||
return [p for p in candidates if p not in path_entries and Path(p).exists()]
|
||||
|
||||
|
||||
def _remap_path_for_user(path: str, target_home_dir: str) -> str:
|
||||
"""Remap *path* from the current user's home to *target_home_dir*.
|
||||
|
||||
If *path* lives under ``Path.home()`` the corresponding prefix is swapped
|
||||
to *target_home_dir*; otherwise the path is returned unchanged.
|
||||
|
||||
/root/.hermes/hermes-agent -> /home/alice/.hermes/hermes-agent
|
||||
/opt/hermes -> /opt/hermes (kept as-is)
|
||||
"""
|
||||
current_home = Path.home().resolve()
|
||||
resolved = Path(path).resolve()
|
||||
try:
|
||||
relative = resolved.relative_to(current_home)
|
||||
return str(Path(target_home_dir) / relative)
|
||||
except ValueError:
|
||||
return str(resolved)
|
||||
|
||||
|
||||
def _hermes_home_for_target_user(target_home_dir: str) -> str:
|
||||
"""Remap the current HERMES_HOME to the equivalent under a target user's home.
|
||||
|
||||
@@ -655,6 +670,15 @@ def generate_systemd_unit(system: bool = False, run_as_user: str | None = None)
|
||||
username, group_name, home_dir = _system_service_identity(run_as_user)
|
||||
hermes_home = _hermes_home_for_target_user(home_dir)
|
||||
profile_arg = _profile_arg(hermes_home)
|
||||
# Remap all paths that may resolve under the calling user's home
|
||||
# (e.g. /root/) to the target user's home so the service can
|
||||
# actually access them.
|
||||
python_path = _remap_path_for_user(python_path, home_dir)
|
||||
working_dir = _remap_path_for_user(working_dir, home_dir)
|
||||
venv_dir = _remap_path_for_user(venv_dir, home_dir)
|
||||
venv_bin = _remap_path_for_user(venv_bin, home_dir)
|
||||
node_bin = _remap_path_for_user(node_bin, home_dir)
|
||||
path_entries = [_remap_path_for_user(p, home_dir) for p in path_entries]
|
||||
path_entries.extend(_build_user_local_paths(Path(home_dir), path_entries))
|
||||
path_entries.extend(common_bin_paths)
|
||||
sane_path = ":".join(path_entries)
|
||||
@@ -766,7 +790,7 @@ def _print_linger_enable_warning(username: str, detail: str | None = None) -> No
|
||||
|
||||
def _ensure_linger_enabled() -> None:
|
||||
"""Enable linger when possible so the user gateway survives logout."""
|
||||
if not is_linux():
|
||||
if is_termux() or not is_linux():
|
||||
return
|
||||
|
||||
import getpass
|
||||
@@ -1172,7 +1196,19 @@ def launchd_start():
|
||||
|
||||
def launchd_stop():
|
||||
label = get_launchd_label()
|
||||
subprocess.run(["launchctl", "kill", "SIGTERM", f"{_launchd_domain()}/{label}"], check=True, timeout=30)
|
||||
target = f"{_launchd_domain()}/{label}"
|
||||
# bootout unloads the service definition so KeepAlive doesn't respawn
|
||||
# the process. A plain `kill SIGTERM` only signals the process — launchd
|
||||
# immediately restarts it because KeepAlive.SuccessfulExit = false.
|
||||
# `hermes gateway start` re-bootstraps when it detects the job is unloaded.
|
||||
try:
|
||||
subprocess.run(["launchctl", "bootout", target], check=True, timeout=90)
|
||||
except subprocess.CalledProcessError as e:
|
||||
if e.returncode in (3, 113):
|
||||
pass # Already unloaded — nothing to stop.
|
||||
else:
|
||||
raise
|
||||
_wait_for_gateway_exit(timeout=10.0, force_after=5.0)
|
||||
print("✓ Service stopped")
|
||||
|
||||
def _wait_for_gateway_exit(timeout: float = 10.0, force_after: float = 5.0):
|
||||
@@ -1184,7 +1220,7 @@ def _wait_for_gateway_exit(timeout: float = 10.0, force_after: float = 5.0):
|
||||
|
||||
Args:
|
||||
timeout: Total seconds to wait before giving up.
|
||||
force_after: Seconds of graceful waiting before sending SIGKILL.
|
||||
force_after: Seconds of graceful waiting before escalating to force-kill.
|
||||
"""
|
||||
import time
|
||||
from gateway.status import get_running_pid
|
||||
@@ -1201,15 +1237,15 @@ def _wait_for_gateway_exit(timeout: float = 10.0, force_after: float = 5.0):
|
||||
if not force_sent and time.monotonic() >= force_deadline:
|
||||
# Grace period expired — force-kill the specific PID.
|
||||
try:
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
terminate_pid(pid, force=True)
|
||||
print(f"⚠ Gateway PID {pid} did not exit gracefully; sent SIGKILL")
|
||||
except (ProcessLookupError, PermissionError):
|
||||
except (ProcessLookupError, PermissionError, OSError):
|
||||
return # Already gone or we can't touch it.
|
||||
force_sent = True
|
||||
|
||||
time.sleep(0.3)
|
||||
|
||||
# Timed out even after SIGKILL.
|
||||
# Timed out even after force-kill.
|
||||
remaining_pid = get_running_pid()
|
||||
if remaining_pid is not None:
|
||||
print(f"⚠ Gateway PID {remaining_pid} still running after {timeout}s — restart may fail")
|
||||
@@ -1588,6 +1624,12 @@ _PLATFORMS = [
|
||||
"help": "Chat ID for scheduled results and notifications."},
|
||||
],
|
||||
},
|
||||
{
|
||||
"key": "weixin",
|
||||
"label": "Weixin / WeChat",
|
||||
"emoji": "💬",
|
||||
"token_var": "WEIXIN_ACCOUNT_ID",
|
||||
},
|
||||
{
|
||||
"key": "bluebubbles",
|
||||
"label": "BlueBubbles (iMessage)",
|
||||
@@ -1660,6 +1702,13 @@ def _platform_status(platform: dict) -> str:
|
||||
if val or password or homeserver:
|
||||
return "partially configured"
|
||||
return "not configured"
|
||||
if platform.get("key") == "weixin":
|
||||
token = get_env_value("WEIXIN_TOKEN")
|
||||
if val and token:
|
||||
return "configured"
|
||||
if val or token:
|
||||
return "partially configured"
|
||||
return "not configured"
|
||||
if val:
|
||||
return "configured"
|
||||
return "not configured"
|
||||
@@ -1763,7 +1812,7 @@ def _setup_standard_platform(platform: dict):
|
||||
print_warning(" Open access enabled — anyone can use your bot!")
|
||||
elif access_idx == 1:
|
||||
print_success(" DM pairing mode — users will receive a code to request access.")
|
||||
print_info(" Approve with: hermes pairing approve {platform} {code}")
|
||||
print_info(" Approve with: hermes pairing approve <platform> <code>")
|
||||
else:
|
||||
print_info(" Skipped — configure later with 'hermes gateway setup'")
|
||||
continue
|
||||
@@ -1801,7 +1850,7 @@ def _setup_whatsapp():
|
||||
|
||||
def _is_service_installed() -> bool:
|
||||
"""Check if the gateway is installed as a system service."""
|
||||
if is_linux():
|
||||
if supports_systemd_services():
|
||||
return get_systemd_unit_path(system=False).exists() or get_systemd_unit_path(system=True).exists()
|
||||
elif is_macos():
|
||||
return get_launchd_plist_path().exists()
|
||||
@@ -1810,7 +1859,7 @@ def _is_service_installed() -> bool:
|
||||
|
||||
def _is_service_running() -> bool:
|
||||
"""Check if the gateway service is currently running."""
|
||||
if is_linux():
|
||||
if supports_systemd_services():
|
||||
user_unit_exists = get_systemd_unit_path(system=False).exists()
|
||||
system_unit_exists = get_systemd_unit_path(system=True).exists()
|
||||
|
||||
@@ -1850,6 +1899,133 @@ def _is_service_running() -> bool:
|
||||
return len(find_gateway_pids()) > 0
|
||||
|
||||
|
||||
def _setup_weixin():
|
||||
"""Interactive setup for Weixin / WeChat personal accounts."""
|
||||
print()
|
||||
print(color(" ─── 💬 Weixin / WeChat Setup ───", Colors.CYAN))
|
||||
print()
|
||||
print_info(" 1. Hermes will open Tencent iLink QR login in this terminal.")
|
||||
print_info(" 2. Use WeChat to scan and confirm the QR code.")
|
||||
print_info(" 3. Hermes will store the returned account_id/token in ~/.hermes/.env.")
|
||||
print_info(" 4. This adapter supports native text, image, video, and document delivery.")
|
||||
|
||||
existing_account = get_env_value("WEIXIN_ACCOUNT_ID")
|
||||
existing_token = get_env_value("WEIXIN_TOKEN")
|
||||
if existing_account and existing_token:
|
||||
print()
|
||||
print_success("Weixin is already configured.")
|
||||
if not prompt_yes_no(" Reconfigure Weixin?", False):
|
||||
return
|
||||
|
||||
try:
|
||||
from gateway.platforms.weixin import check_weixin_requirements, qr_login
|
||||
except Exception as exc:
|
||||
print_error(f" Weixin adapter import failed: {exc}")
|
||||
print_info(" Install gateway dependencies first, then retry.")
|
||||
return
|
||||
|
||||
if not check_weixin_requirements():
|
||||
print_error(" Missing dependencies: Weixin needs aiohttp and cryptography.")
|
||||
print_info(" Install them, then rerun `hermes gateway setup`.")
|
||||
return
|
||||
|
||||
print()
|
||||
if not prompt_yes_no(" Start QR login now?", True):
|
||||
print_info(" Cancelled.")
|
||||
return
|
||||
|
||||
import asyncio
|
||||
try:
|
||||
credentials = asyncio.run(qr_login(str(get_hermes_home())))
|
||||
except KeyboardInterrupt:
|
||||
print()
|
||||
print_warning(" Weixin setup cancelled.")
|
||||
return
|
||||
except Exception as exc:
|
||||
print_error(f" QR login failed: {exc}")
|
||||
return
|
||||
|
||||
if not credentials:
|
||||
print_warning(" QR login did not complete.")
|
||||
return
|
||||
|
||||
account_id = credentials.get("account_id", "")
|
||||
token = credentials.get("token", "")
|
||||
base_url = credentials.get("base_url", "")
|
||||
user_id = credentials.get("user_id", "")
|
||||
|
||||
save_env_value("WEIXIN_ACCOUNT_ID", account_id)
|
||||
save_env_value("WEIXIN_TOKEN", token)
|
||||
if base_url:
|
||||
save_env_value("WEIXIN_BASE_URL", base_url)
|
||||
save_env_value("WEIXIN_CDN_BASE_URL", get_env_value("WEIXIN_CDN_BASE_URL") or "https://novac2c.cdn.weixin.qq.com/c2c")
|
||||
|
||||
print()
|
||||
access_choices = [
|
||||
"Use DM pairing approval (recommended)",
|
||||
"Allow all direct messages",
|
||||
"Only allow listed user IDs",
|
||||
"Disable direct messages",
|
||||
]
|
||||
access_idx = prompt_choice(" How should direct messages be authorized?", access_choices, 0)
|
||||
if access_idx == 0:
|
||||
save_env_value("WEIXIN_DM_POLICY", "pairing")
|
||||
save_env_value("WEIXIN_ALLOW_ALL_USERS", "false")
|
||||
save_env_value("WEIXIN_ALLOWED_USERS", "")
|
||||
print_success(" DM pairing enabled.")
|
||||
print_info(" Unknown DM users can request access and you approve them with `hermes pairing approve`.")
|
||||
elif access_idx == 1:
|
||||
save_env_value("WEIXIN_DM_POLICY", "open")
|
||||
save_env_value("WEIXIN_ALLOW_ALL_USERS", "true")
|
||||
save_env_value("WEIXIN_ALLOWED_USERS", "")
|
||||
print_warning(" Open DM access enabled for Weixin.")
|
||||
elif access_idx == 2:
|
||||
default_allow = user_id or ""
|
||||
allowlist = prompt(" Allowed Weixin user IDs (comma-separated)", default_allow, password=False).replace(" ", "")
|
||||
save_env_value("WEIXIN_DM_POLICY", "allowlist")
|
||||
save_env_value("WEIXIN_ALLOW_ALL_USERS", "false")
|
||||
save_env_value("WEIXIN_ALLOWED_USERS", allowlist)
|
||||
print_success(" Weixin allowlist saved.")
|
||||
else:
|
||||
save_env_value("WEIXIN_DM_POLICY", "disabled")
|
||||
save_env_value("WEIXIN_ALLOW_ALL_USERS", "false")
|
||||
save_env_value("WEIXIN_ALLOWED_USERS", "")
|
||||
print_warning(" Direct messages disabled.")
|
||||
|
||||
print()
|
||||
group_choices = [
|
||||
"Disable group chats (recommended)",
|
||||
"Allow all group chats",
|
||||
"Only allow listed group chat IDs",
|
||||
]
|
||||
group_idx = prompt_choice(" How should group chats be handled?", group_choices, 0)
|
||||
if group_idx == 0:
|
||||
save_env_value("WEIXIN_GROUP_POLICY", "disabled")
|
||||
save_env_value("WEIXIN_GROUP_ALLOWED_USERS", "")
|
||||
print_info(" Group chats disabled.")
|
||||
elif group_idx == 1:
|
||||
save_env_value("WEIXIN_GROUP_POLICY", "open")
|
||||
save_env_value("WEIXIN_GROUP_ALLOWED_USERS", "")
|
||||
print_warning(" All group chats enabled.")
|
||||
else:
|
||||
allow_groups = prompt(" Allowed group chat IDs (comma-separated)", "", password=False).replace(" ", "")
|
||||
save_env_value("WEIXIN_GROUP_POLICY", "allowlist")
|
||||
save_env_value("WEIXIN_GROUP_ALLOWED_USERS", allow_groups)
|
||||
print_success(" Group allowlist saved.")
|
||||
|
||||
if user_id:
|
||||
print()
|
||||
if prompt_yes_no(f" Use your Weixin user ID ({user_id}) as the home channel?", True):
|
||||
save_env_value("WEIXIN_HOME_CHANNEL", user_id)
|
||||
print_success(f" Home channel set to {user_id}")
|
||||
|
||||
print()
|
||||
print_success("Weixin configured!")
|
||||
print_info(f" Account ID: {account_id}")
|
||||
if user_id:
|
||||
print_info(f" User ID: {user_id}")
|
||||
|
||||
|
||||
def _setup_signal():
|
||||
"""Interactive setup for Signal messenger."""
|
||||
import shutil
|
||||
@@ -1983,7 +2159,7 @@ def gateway_setup():
|
||||
service_installed = _is_service_installed()
|
||||
service_running = _is_service_running()
|
||||
|
||||
if is_linux() and has_conflicting_systemd_units():
|
||||
if supports_systemd_services() and has_conflicting_systemd_units():
|
||||
print_systemd_scope_conflict_warning()
|
||||
print()
|
||||
|
||||
@@ -1993,7 +2169,7 @@ def gateway_setup():
|
||||
print_warning("Gateway service is installed but not running.")
|
||||
if prompt_yes_no(" Start it now?", True):
|
||||
try:
|
||||
if is_linux():
|
||||
if supports_systemd_services():
|
||||
systemd_start()
|
||||
elif is_macos():
|
||||
launchd_start()
|
||||
@@ -2025,6 +2201,8 @@ def gateway_setup():
|
||||
_setup_whatsapp()
|
||||
elif platform["key"] == "signal":
|
||||
_setup_signal()
|
||||
elif platform["key"] == "weixin":
|
||||
_setup_weixin()
|
||||
else:
|
||||
_setup_standard_platform(platform)
|
||||
|
||||
@@ -2044,7 +2222,7 @@ def gateway_setup():
|
||||
if service_running:
|
||||
if prompt_yes_no(" Restart the gateway to pick up changes?", True):
|
||||
try:
|
||||
if is_linux():
|
||||
if supports_systemd_services():
|
||||
systemd_restart()
|
||||
elif is_macos():
|
||||
launchd_restart()
|
||||
@@ -2056,7 +2234,7 @@ def gateway_setup():
|
||||
elif service_installed:
|
||||
if prompt_yes_no(" Start the gateway service?", True):
|
||||
try:
|
||||
if is_linux():
|
||||
if supports_systemd_services():
|
||||
systemd_start()
|
||||
elif is_macos():
|
||||
launchd_start()
|
||||
@@ -2064,13 +2242,13 @@ def gateway_setup():
|
||||
print_error(f" Start failed: {e}")
|
||||
else:
|
||||
print()
|
||||
if is_linux() or is_macos():
|
||||
platform_name = "systemd" if is_linux() else "launchd"
|
||||
if supports_systemd_services() or is_macos():
|
||||
platform_name = "systemd" if supports_systemd_services() else "launchd"
|
||||
if prompt_yes_no(f" Install the gateway as a {platform_name} service? (runs in background, starts on boot)", True):
|
||||
try:
|
||||
installed_scope = None
|
||||
did_install = False
|
||||
if is_linux():
|
||||
if supports_systemd_services():
|
||||
installed_scope, did_install = install_linux_gateway_from_setup(force=False)
|
||||
else:
|
||||
launchd_install(force=False)
|
||||
@@ -2078,7 +2256,7 @@ def gateway_setup():
|
||||
print()
|
||||
if did_install and prompt_yes_no(" Start the service now?", True):
|
||||
try:
|
||||
if is_linux():
|
||||
if supports_systemd_services():
|
||||
systemd_start(system=installed_scope == "system")
|
||||
else:
|
||||
launchd_start()
|
||||
@@ -2089,12 +2267,18 @@ def gateway_setup():
|
||||
print_info(" You can try manually: hermes gateway install")
|
||||
else:
|
||||
print_info(" You can install later: hermes gateway install")
|
||||
if is_linux():
|
||||
if supports_systemd_services():
|
||||
print_info(" Or as a boot-time service: sudo hermes gateway install --system")
|
||||
print_info(" Or run in foreground: hermes gateway")
|
||||
else:
|
||||
print_info(" Service install not supported on this platform.")
|
||||
print_info(" Run in foreground: hermes gateway")
|
||||
if is_termux():
|
||||
from hermes_constants import display_hermes_home as _dhh
|
||||
print_info(" Termux does not use systemd/launchd services.")
|
||||
print_info(" Run in foreground: hermes gateway")
|
||||
print_info(f" Or start it manually in the background (best effort): nohup hermes gateway >{_dhh()}/logs/gateway.log 2>&1 &")
|
||||
else:
|
||||
print_info(" Service install not supported on this platform.")
|
||||
print_info(" Run in foreground: hermes gateway")
|
||||
else:
|
||||
print()
|
||||
print_info("No platforms configured. Run 'hermes gateway setup' when ready.")
|
||||
@@ -2130,7 +2314,11 @@ def gateway_command(args):
|
||||
force = getattr(args, 'force', False)
|
||||
system = getattr(args, 'system', False)
|
||||
run_as_user = getattr(args, 'run_as_user', None)
|
||||
if is_linux():
|
||||
if is_termux():
|
||||
print("Gateway service installation is not supported on Termux.")
|
||||
print("Run manually: hermes gateway")
|
||||
sys.exit(1)
|
||||
if supports_systemd_services():
|
||||
systemd_install(force=force, system=system, run_as_user=run_as_user)
|
||||
elif is_macos():
|
||||
launchd_install(force)
|
||||
@@ -2144,7 +2332,11 @@ def gateway_command(args):
|
||||
managed_error("uninstall gateway service (managed by NixOS)")
|
||||
return
|
||||
system = getattr(args, 'system', False)
|
||||
if is_linux():
|
||||
if is_termux():
|
||||
print("Gateway service uninstall is not supported on Termux because there is no managed service to remove.")
|
||||
print("Stop manual runs with: hermes gateway stop")
|
||||
sys.exit(1)
|
||||
if supports_systemd_services():
|
||||
systemd_uninstall(system=system)
|
||||
elif is_macos():
|
||||
launchd_uninstall()
|
||||
@@ -2154,7 +2346,11 @@ def gateway_command(args):
|
||||
|
||||
elif subcmd == "start":
|
||||
system = getattr(args, 'system', False)
|
||||
if is_linux():
|
||||
if is_termux():
|
||||
print("Gateway service start is not supported on Termux because there is no system service manager.")
|
||||
print("Run manually: hermes gateway")
|
||||
sys.exit(1)
|
||||
if supports_systemd_services():
|
||||
systemd_start(system=system)
|
||||
elif is_macos():
|
||||
launchd_start()
|
||||
@@ -2169,7 +2365,7 @@ def gateway_command(args):
|
||||
if stop_all:
|
||||
# --all: kill every gateway process on the machine
|
||||
service_available = False
|
||||
if is_linux() and (get_systemd_unit_path(system=False).exists() or get_systemd_unit_path(system=True).exists()):
|
||||
if supports_systemd_services() and (get_systemd_unit_path(system=False).exists() or get_systemd_unit_path(system=True).exists()):
|
||||
try:
|
||||
systemd_stop(system=system)
|
||||
service_available = True
|
||||
@@ -2190,7 +2386,7 @@ def gateway_command(args):
|
||||
else:
|
||||
# Default: stop only the current profile's gateway
|
||||
service_available = False
|
||||
if is_linux() and (get_systemd_unit_path(system=False).exists() or get_systemd_unit_path(system=True).exists()):
|
||||
if supports_systemd_services() and (get_systemd_unit_path(system=False).exists() or get_systemd_unit_path(system=True).exists()):
|
||||
try:
|
||||
systemd_stop(system=system)
|
||||
service_available = True
|
||||
@@ -2218,7 +2414,7 @@ def gateway_command(args):
|
||||
system = getattr(args, 'system', False)
|
||||
service_configured = False
|
||||
|
||||
if is_linux() and (get_systemd_unit_path(system=False).exists() or get_systemd_unit_path(system=True).exists()):
|
||||
if supports_systemd_services() and (get_systemd_unit_path(system=False).exists() or get_systemd_unit_path(system=True).exists()):
|
||||
service_configured = True
|
||||
try:
|
||||
systemd_restart(system=system)
|
||||
@@ -2235,7 +2431,7 @@ def gateway_command(args):
|
||||
|
||||
if not service_available:
|
||||
# systemd/launchd restart failed — check if linger is the issue
|
||||
if is_linux():
|
||||
if supports_systemd_services():
|
||||
linger_ok, _detail = get_systemd_linger_status()
|
||||
if linger_ok is not True:
|
||||
import getpass
|
||||
@@ -2272,7 +2468,7 @@ def gateway_command(args):
|
||||
system = getattr(args, 'system', False)
|
||||
|
||||
# Check for service first
|
||||
if is_linux() and (get_systemd_unit_path(system=False).exists() or get_systemd_unit_path(system=True).exists()):
|
||||
if supports_systemd_services() and (get_systemd_unit_path(system=False).exists() or get_systemd_unit_path(system=True).exists()):
|
||||
systemd_status(deep, system=system)
|
||||
elif is_macos() and get_launchd_plist_path().exists():
|
||||
launchd_status(deep)
|
||||
@@ -2289,9 +2485,13 @@ def gateway_command(args):
|
||||
for line in runtime_lines:
|
||||
print(f" {line}")
|
||||
print()
|
||||
print("To install as a service:")
|
||||
print(" hermes gateway install")
|
||||
print(" sudo hermes gateway install --system")
|
||||
if is_termux():
|
||||
print("Termux note:")
|
||||
print(" Android may stop background jobs when Termux is suspended")
|
||||
else:
|
||||
print("To install as a service:")
|
||||
print(" hermes gateway install")
|
||||
print(" sudo hermes gateway install --system")
|
||||
else:
|
||||
print("✗ Gateway is not running")
|
||||
runtime_lines = _runtime_health_lines()
|
||||
@@ -2303,5 +2503,8 @@ def gateway_command(args):
|
||||
print()
|
||||
print("To start:")
|
||||
print(" hermes gateway # Run in foreground")
|
||||
print(" hermes gateway install # Install as user service")
|
||||
print(" sudo hermes gateway install --system # Install as boot-time system service")
|
||||
if is_termux():
|
||||
print(" nohup hermes gateway > ~/.hermes/logs/gateway.log 2>&1 & # Best-effort background start")
|
||||
else:
|
||||
print(" hermes gateway install # Install as user service")
|
||||
print(" sudo hermes gateway install --system # Install as boot-time system service")
|
||||
|
||||
+100
-83
@@ -97,10 +97,11 @@ def _apply_profile_override() -> None:
|
||||
consume = 1
|
||||
break
|
||||
|
||||
# 2. If no flag, check ~/.hermes/active_profile
|
||||
# 2. If no flag, check active_profile in the hermes root
|
||||
if profile_name is None:
|
||||
try:
|
||||
active_path = Path.home() / ".hermes" / "active_profile"
|
||||
from hermes_constants import get_default_hermes_root
|
||||
active_path = get_default_hermes_root() / "active_profile"
|
||||
if active_path.exists():
|
||||
name = active_path.read_text().strip()
|
||||
if name and name != "default":
|
||||
@@ -646,6 +647,7 @@ def cmd_chat(args):
|
||||
"verbose": args.verbose,
|
||||
"quiet": getattr(args, "quiet", False),
|
||||
"query": args.query,
|
||||
"image": getattr(args, "image", None),
|
||||
"resume": getattr(args, "resume", None),
|
||||
"worktree": getattr(args, "worktree", False),
|
||||
"checkpoints": getattr(args, "checkpoints", False),
|
||||
@@ -857,7 +859,6 @@ def cmd_whatsapp(args):
|
||||
|
||||
def cmd_setup(args):
|
||||
"""Interactive setup wizard."""
|
||||
_require_tty("setup")
|
||||
from hermes_cli.setup import run_setup_wizard
|
||||
run_setup_wizard(args)
|
||||
|
||||
@@ -967,10 +968,11 @@ def select_provider_and_model(args=None):
|
||||
("alibaba", "Alibaba Cloud / DashScope Coding (Qwen + multi-provider)"),
|
||||
]
|
||||
|
||||
# Add user-defined custom providers from config.yaml
|
||||
custom_providers_cfg = config.get("custom_providers") or []
|
||||
_custom_provider_map = {} # key → {name, base_url, api_key}
|
||||
if isinstance(custom_providers_cfg, list):
|
||||
def _named_custom_provider_map(cfg) -> dict[str, dict[str, str]]:
|
||||
custom_providers_cfg = cfg.get("custom_providers") or []
|
||||
custom_provider_map = {}
|
||||
if not isinstance(custom_providers_cfg, list):
|
||||
return custom_provider_map
|
||||
for entry in custom_providers_cfg:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
@@ -979,16 +981,23 @@ def select_provider_and_model(args=None):
|
||||
if not name or not base_url:
|
||||
continue
|
||||
key = "custom:" + name.lower().replace(" ", "-")
|
||||
short_url = base_url.replace("https://", "").replace("http://", "").rstrip("/")
|
||||
saved_model = entry.get("model", "")
|
||||
model_hint = f" — {saved_model}" if saved_model else ""
|
||||
top_providers.append((key, f"{name} ({short_url}){model_hint}"))
|
||||
_custom_provider_map[key] = {
|
||||
custom_provider_map[key] = {
|
||||
"name": name,
|
||||
"base_url": base_url,
|
||||
"api_key": entry.get("api_key", ""),
|
||||
"model": saved_model,
|
||||
"model": entry.get("model", ""),
|
||||
}
|
||||
return custom_provider_map
|
||||
|
||||
# Add user-defined custom providers from config.yaml
|
||||
_custom_provider_map = _named_custom_provider_map(config) # key → {name, base_url, api_key}
|
||||
for key, provider_info in _custom_provider_map.items():
|
||||
name = provider_info["name"]
|
||||
base_url = provider_info["base_url"]
|
||||
short_url = base_url.replace("https://", "").replace("http://", "").rstrip("/")
|
||||
saved_model = provider_info.get("model", "")
|
||||
model_hint = f" — {saved_model}" if saved_model else ""
|
||||
top_providers.append((key, f"{name} ({short_url}){model_hint}"))
|
||||
|
||||
top_keys = {k for k, _ in top_providers}
|
||||
extended_keys = {k for k, _ in extended_providers}
|
||||
@@ -1053,8 +1062,15 @@ def select_provider_and_model(args=None):
|
||||
_model_flow_copilot(config, current_model)
|
||||
elif selected_provider == "custom":
|
||||
_model_flow_custom(config)
|
||||
elif selected_provider.startswith("custom:") and selected_provider in _custom_provider_map:
|
||||
_model_flow_named_custom(config, _custom_provider_map[selected_provider])
|
||||
elif selected_provider.startswith("custom:"):
|
||||
provider_info = _named_custom_provider_map(load_config()).get(selected_provider)
|
||||
if provider_info is None:
|
||||
print(
|
||||
"Warning: the selected saved custom provider is no longer available. "
|
||||
"It may have been removed from config.yaml. No change."
|
||||
)
|
||||
return
|
||||
_model_flow_named_custom(config, provider_info)
|
||||
elif selected_provider == "remove-custom":
|
||||
_remove_custom_provider(config)
|
||||
elif selected_provider == "anthropic":
|
||||
@@ -1127,10 +1143,10 @@ def _model_flow_openrouter(config, current_model=""):
|
||||
print()
|
||||
|
||||
from hermes_cli.models import model_ids, get_pricing_for_provider
|
||||
openrouter_models = model_ids()
|
||||
openrouter_models = model_ids(force_refresh=True)
|
||||
|
||||
# Fetch live pricing (non-blocking — returns empty dict on failure)
|
||||
pricing = get_pricing_for_provider("openrouter")
|
||||
pricing = get_pricing_for_provider("openrouter", force_refresh=True)
|
||||
|
||||
selected = _prompt_model_selection(openrouter_models, current_model=current_model, pricing=pricing)
|
||||
if selected:
|
||||
@@ -1657,8 +1673,10 @@ def _remove_custom_provider(config):
|
||||
title="Select provider to remove:",
|
||||
)
|
||||
idx = menu.show()
|
||||
from hermes_cli.curses_ui import flush_stdin
|
||||
flush_stdin()
|
||||
print()
|
||||
except (ImportError, NotImplementedError):
|
||||
except (ImportError, NotImplementedError, OSError, subprocess.SubprocessError):
|
||||
for i, c in enumerate(choices, 1):
|
||||
print(f" {i}. {c}")
|
||||
print()
|
||||
@@ -1682,8 +1700,9 @@ def _remove_custom_provider(config):
|
||||
def _model_flow_named_custom(config, provider_info):
|
||||
"""Handle a named custom provider from config.yaml custom_providers list.
|
||||
|
||||
If the entry has a saved model name, activates it immediately.
|
||||
Otherwise probes the endpoint's /models API to let the user pick one.
|
||||
Always probes the endpoint's /models API to let the user pick a model.
|
||||
If a model was previously saved, it is pre-selected in the menu.
|
||||
Falls back to the saved model if probing fails.
|
||||
"""
|
||||
from hermes_cli.auth import _save_model_choice, deactivate_provider
|
||||
from hermes_cli.config import load_config, save_config
|
||||
@@ -1694,54 +1713,46 @@ def _model_flow_named_custom(config, provider_info):
|
||||
api_key = provider_info.get("api_key", "")
|
||||
saved_model = provider_info.get("model", "")
|
||||
|
||||
# If a model is saved, just activate immediately — no probing needed
|
||||
if saved_model:
|
||||
_save_model_choice(saved_model)
|
||||
|
||||
cfg = load_config()
|
||||
model = cfg.get("model")
|
||||
if not isinstance(model, dict):
|
||||
model = {"default": model} if model else {}
|
||||
cfg["model"] = model
|
||||
model["provider"] = "custom"
|
||||
model["base_url"] = base_url
|
||||
if api_key:
|
||||
model["api_key"] = api_key
|
||||
save_config(cfg)
|
||||
deactivate_provider()
|
||||
|
||||
print(f"✅ Switched to: {saved_model}")
|
||||
print(f" Provider: {name} ({base_url})")
|
||||
return
|
||||
|
||||
# No saved model — probe endpoint and let user pick
|
||||
print(f" Provider: {name}")
|
||||
print(f" URL: {base_url}")
|
||||
if saved_model:
|
||||
print(f" Current: {saved_model}")
|
||||
print()
|
||||
print("No model saved for this provider. Fetching available models...")
|
||||
|
||||
print("Fetching available models...")
|
||||
models = fetch_api_models(api_key, base_url, timeout=8.0)
|
||||
|
||||
if models:
|
||||
default_idx = 0
|
||||
if saved_model and saved_model in models:
|
||||
default_idx = models.index(saved_model)
|
||||
|
||||
print(f"Found {len(models)} model(s):\n")
|
||||
try:
|
||||
from simple_term_menu import TerminalMenu
|
||||
menu_items = [f" {m}" for m in models] + [" Cancel"]
|
||||
menu_items = [
|
||||
f" {m} (current)" if m == saved_model else f" {m}"
|
||||
for m in models
|
||||
] + [" Cancel"]
|
||||
menu = TerminalMenu(
|
||||
menu_items, cursor_index=0,
|
||||
menu_items, cursor_index=default_idx,
|
||||
menu_cursor="-> ", menu_cursor_style=("fg_green", "bold"),
|
||||
menu_highlight_style=("fg_green",),
|
||||
cycle_cursor=True, clear_screen=False,
|
||||
title=f"Select model from {name}:",
|
||||
)
|
||||
idx = menu.show()
|
||||
from hermes_cli.curses_ui import flush_stdin
|
||||
flush_stdin()
|
||||
print()
|
||||
if idx is None or idx >= len(models):
|
||||
print("Cancelled.")
|
||||
return
|
||||
model_name = models[idx]
|
||||
except (ImportError, NotImplementedError):
|
||||
except (ImportError, NotImplementedError, OSError, subprocess.SubprocessError):
|
||||
for i, m in enumerate(models, 1):
|
||||
print(f" {i}. {m}")
|
||||
suffix = " (current)" if m == saved_model else ""
|
||||
print(f" {i}. {m}{suffix}")
|
||||
print(f" {len(models) + 1}. Cancel")
|
||||
print()
|
||||
try:
|
||||
@@ -1757,6 +1768,13 @@ def _model_flow_named_custom(config, provider_info):
|
||||
except (ValueError, KeyboardInterrupt, EOFError):
|
||||
print("\nCancelled.")
|
||||
return
|
||||
elif saved_model:
|
||||
print("Could not fetch models from endpoint.")
|
||||
try:
|
||||
model_name = input(f"Model name [{saved_model}]: ").strip() or saved_model
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print("\nCancelled.")
|
||||
return
|
||||
else:
|
||||
print("Could not fetch models from endpoint. Enter model name manually.")
|
||||
try:
|
||||
@@ -1811,7 +1829,10 @@ def _set_reasoning_effort(config, effort: str) -> None:
|
||||
|
||||
def _prompt_reasoning_effort_selection(efforts, current_effort=""):
|
||||
"""Prompt for a reasoning effort. Returns effort, 'none', or None to keep current."""
|
||||
ordered = list(dict.fromkeys(str(effort).strip().lower() for effort in efforts if str(effort).strip()))
|
||||
deduped = list(dict.fromkeys(str(effort).strip().lower() for effort in efforts if str(effort).strip()))
|
||||
canonical_order = ("minimal", "low", "medium", "high", "xhigh")
|
||||
ordered = [effort for effort in canonical_order if effort in deduped]
|
||||
ordered.extend(effort for effort in deduped if effort not in canonical_order)
|
||||
if not ordered:
|
||||
return None
|
||||
|
||||
@@ -1849,6 +1870,8 @@ def _prompt_reasoning_effort_selection(efforts, current_effort=""):
|
||||
title="Select reasoning effort:",
|
||||
)
|
||||
idx = menu.show()
|
||||
from hermes_cli.curses_ui import flush_stdin
|
||||
flush_stdin()
|
||||
if idx is None:
|
||||
return None
|
||||
print()
|
||||
@@ -1857,7 +1880,7 @@ def _prompt_reasoning_effort_selection(efforts, current_effort=""):
|
||||
if idx == len(ordered):
|
||||
return "none"
|
||||
return None
|
||||
except (ImportError, NotImplementedError):
|
||||
except (ImportError, NotImplementedError, OSError, subprocess.SubprocessError):
|
||||
pass
|
||||
|
||||
print("Select reasoning effort:")
|
||||
@@ -3018,33 +3041,19 @@ def _restore_stashed_changes(
|
||||
print("\nYour stashed changes are preserved — nothing is lost.")
|
||||
print(f" Stash ref: {stash_ref}")
|
||||
|
||||
# Ask before resetting (if interactive)
|
||||
do_reset = True
|
||||
if prompt_user:
|
||||
print("\nReset working tree to clean state so Hermes can run?")
|
||||
print(" (You can re-apply your changes later with: git stash apply)")
|
||||
print("[Y/n] ", end="", flush=True)
|
||||
response = input().strip().lower()
|
||||
if response not in ("", "y", "yes"):
|
||||
do_reset = False
|
||||
|
||||
if do_reset:
|
||||
subprocess.run(
|
||||
git_cmd + ["reset", "--hard", "HEAD"],
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
)
|
||||
print("Working tree reset to clean state.")
|
||||
else:
|
||||
print("Working tree left as-is (may have conflict markers).")
|
||||
print("Resolve conflicts manually, then run: git stash drop")
|
||||
|
||||
print(f"Restore your changes with: git stash apply {stash_ref}")
|
||||
# In non-interactive mode (gateway /update), don't abort — the code
|
||||
# update itself succeeded, only the stash restore had conflicts.
|
||||
# Aborting would report the entire update as failed.
|
||||
if prompt_user:
|
||||
sys.exit(1)
|
||||
# Always reset to clean state — leaving conflict markers in source
|
||||
# files makes hermes completely unrunnable (SyntaxError on import).
|
||||
# The user's changes are safe in the stash for manual recovery.
|
||||
subprocess.run(
|
||||
git_cmd + ["reset", "--hard", "HEAD"],
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
)
|
||||
print("Working tree reset to clean state.")
|
||||
print(f"Restore your changes later with: git stash apply {stash_ref}")
|
||||
# Don't sys.exit — the code update itself succeeded, only the stash
|
||||
# restore had conflicts. Let cmd_update continue with pip install,
|
||||
# skill sync, and gateway restart.
|
||||
return False
|
||||
|
||||
stash_selector = _resolve_stash_selector(git_cmd, cwd, stash_ref)
|
||||
@@ -3305,10 +3314,11 @@ def _invalidate_update_cache():
|
||||
``hermes update``, every profile is now current.
|
||||
"""
|
||||
homes = []
|
||||
# Default profile home
|
||||
default_home = Path.home() / ".hermes"
|
||||
# Default profile home (Docker-aware — uses /opt/data in Docker)
|
||||
from hermes_constants import get_default_hermes_root
|
||||
default_home = get_default_hermes_root()
|
||||
homes.append(default_home)
|
||||
# Named profiles under ~/.hermes/profiles/
|
||||
# Named profiles under <root>/profiles/
|
||||
profiles_root = default_home / "profiles"
|
||||
if profiles_root.is_dir():
|
||||
for entry in profiles_root.iterdir():
|
||||
@@ -3760,7 +3770,7 @@ def cmd_update(args):
|
||||
# running gateway needs restarting to pick up the new code.
|
||||
try:
|
||||
from hermes_cli.gateway import (
|
||||
is_macos, is_linux, _ensure_user_systemd_env,
|
||||
is_macos, supports_systemd_services, _ensure_user_systemd_env,
|
||||
find_gateway_pids,
|
||||
_get_service_pids,
|
||||
)
|
||||
@@ -3771,7 +3781,7 @@ def cmd_update(args):
|
||||
|
||||
# --- Systemd services (Linux) ---
|
||||
# Discover all hermes-gateway* units (default + profiles)
|
||||
if is_linux():
|
||||
if supports_systemd_services():
|
||||
try:
|
||||
_ensure_user_systemd_env()
|
||||
except Exception:
|
||||
@@ -4045,7 +4055,10 @@ def cmd_profile(args):
|
||||
print(f" {name} chat Start chatting")
|
||||
print(f" {name} gateway start Start the messaging gateway")
|
||||
if clone or clone_all:
|
||||
profile_dir_display = f"~/.hermes/profiles/{name}"
|
||||
try:
|
||||
profile_dir_display = "~/" + str(profile_dir.relative_to(Path.home()))
|
||||
except ValueError:
|
||||
profile_dir_display = str(profile_dir)
|
||||
print(f"\n Edit {profile_dir_display}/.env for different API keys")
|
||||
print(f" Edit {profile_dir_display}/SOUL.md for different personality")
|
||||
print()
|
||||
@@ -4288,6 +4301,10 @@ For more help on a command:
|
||||
"-q", "--query",
|
||||
help="Single query (non-interactive mode)"
|
||||
)
|
||||
chat_parser.add_argument(
|
||||
"--image",
|
||||
help="Optional local image path to attach to a single query"
|
||||
)
|
||||
chat_parser.add_argument(
|
||||
"-m", "--model",
|
||||
help="Model to use (e.g., anthropic/claude-sonnet-4)"
|
||||
@@ -4478,12 +4495,12 @@ For more help on a command:
|
||||
"setup",
|
||||
help="Interactive setup wizard",
|
||||
description="Configure Hermes Agent with an interactive wizard. "
|
||||
"Run a specific section: hermes setup model|terminal|gateway|tools|agent"
|
||||
"Run a specific section: hermes setup model|tts|terminal|gateway|tools|agent"
|
||||
)
|
||||
setup_parser.add_argument(
|
||||
"section",
|
||||
nargs="?",
|
||||
choices=["model", "terminal", "gateway", "tools", "agent"],
|
||||
choices=["model", "tts", "terminal", "gateway", "tools", "agent"],
|
||||
default=None,
|
||||
help="Run a specific setup section instead of the full wizard"
|
||||
)
|
||||
|
||||
@@ -76,17 +76,22 @@ _STRIP_VENDOR_ONLY_PROVIDERS: frozenset[str] = frozenset({
|
||||
"copilot-acp",
|
||||
})
|
||||
|
||||
# Providers whose own naming is authoritative -- pass through unchanged.
|
||||
_PASSTHROUGH_PROVIDERS: frozenset[str] = frozenset({
|
||||
# Providers whose native naming is authoritative -- pass through unchanged.
|
||||
_AUTHORITATIVE_NATIVE_PROVIDERS: frozenset[str] = frozenset({
|
||||
"gemini",
|
||||
"huggingface",
|
||||
"openai-codex",
|
||||
})
|
||||
|
||||
# Direct providers that accept bare native names but should repair a matching
|
||||
# provider/ prefix when users copy the aggregator form into config.yaml.
|
||||
_MATCHING_PREFIX_STRIP_PROVIDERS: frozenset[str] = frozenset({
|
||||
"zai",
|
||||
"kimi-coding",
|
||||
"minimax",
|
||||
"minimax-cn",
|
||||
"alibaba",
|
||||
"qwen-oauth",
|
||||
"huggingface",
|
||||
"openai-codex",
|
||||
"custom",
|
||||
})
|
||||
|
||||
@@ -168,6 +173,40 @@ def _dots_to_hyphens(model_name: str) -> str:
|
||||
return model_name.replace(".", "-")
|
||||
|
||||
|
||||
def _normalize_provider_alias(provider_name: str) -> str:
|
||||
"""Resolve provider aliases to Hermes' canonical ids."""
|
||||
raw = (provider_name or "").strip().lower()
|
||||
if not raw:
|
||||
return raw
|
||||
try:
|
||||
from hermes_cli.models import normalize_provider
|
||||
|
||||
return normalize_provider(raw)
|
||||
except Exception:
|
||||
return raw
|
||||
|
||||
|
||||
def _strip_matching_provider_prefix(model_name: str, target_provider: str) -> str:
|
||||
"""Strip ``provider/`` only when the prefix matches the target provider.
|
||||
|
||||
This prevents arbitrary slash-bearing model IDs from being mangled on
|
||||
native providers while still repairing manual config values like
|
||||
``zai/glm-5.1`` for the ``zai`` provider.
|
||||
"""
|
||||
if "/" not in model_name:
|
||||
return model_name
|
||||
|
||||
prefix, remainder = model_name.split("/", 1)
|
||||
if not prefix.strip() or not remainder.strip():
|
||||
return model_name
|
||||
|
||||
normalized_prefix = _normalize_provider_alias(prefix)
|
||||
normalized_target = _normalize_provider_alias(target_provider)
|
||||
if normalized_prefix and normalized_prefix == normalized_target:
|
||||
return remainder.strip()
|
||||
return model_name
|
||||
|
||||
|
||||
def detect_vendor(model_name: str) -> Optional[str]:
|
||||
"""Detect the vendor slug from a bare model name.
|
||||
|
||||
@@ -305,24 +344,37 @@ def normalize_model_for_provider(model_input: str, target_provider: str) -> str:
|
||||
if not name:
|
||||
return name
|
||||
|
||||
provider = (target_provider or "").strip().lower()
|
||||
provider = _normalize_provider_alias(target_provider)
|
||||
|
||||
# --- Aggregators: need vendor/model format ---
|
||||
if provider in _AGGREGATOR_PROVIDERS:
|
||||
return _prepend_vendor(name)
|
||||
|
||||
# --- Anthropic / OpenCode: strip vendor, dots -> hyphens ---
|
||||
# --- Anthropic / OpenCode: strip matching provider prefix, dots -> hyphens ---
|
||||
if provider in _DOT_TO_HYPHEN_PROVIDERS:
|
||||
bare = _strip_vendor_prefix(name)
|
||||
bare = _strip_matching_provider_prefix(name, provider)
|
||||
if "/" in bare:
|
||||
return bare
|
||||
return _dots_to_hyphens(bare)
|
||||
|
||||
# --- Copilot: strip vendor, keep dots ---
|
||||
# --- Copilot: strip matching provider prefix, keep dots ---
|
||||
if provider in _STRIP_VENDOR_ONLY_PROVIDERS:
|
||||
return _strip_vendor_prefix(name)
|
||||
return _strip_matching_provider_prefix(name, provider)
|
||||
|
||||
# --- DeepSeek: map to one of two canonical names ---
|
||||
if provider == "deepseek":
|
||||
return _normalize_for_deepseek(name)
|
||||
bare = _strip_matching_provider_prefix(name, provider)
|
||||
if "/" in bare:
|
||||
return bare
|
||||
return _normalize_for_deepseek(bare)
|
||||
|
||||
# --- Direct providers: repair matching provider prefixes only ---
|
||||
if provider in _MATCHING_PREFIX_STRIP_PROVIDERS:
|
||||
return _strip_matching_provider_prefix(name, provider)
|
||||
|
||||
# --- Authoritative native providers: preserve user-facing slugs as-is ---
|
||||
if provider in _AUTHORITATIVE_NATIVE_PROVIDERS:
|
||||
return name
|
||||
|
||||
# --- Custom & all others: pass through as-is ---
|
||||
return name
|
||||
@@ -332,31 +384,3 @@ def normalize_model_for_provider(model_input: str, target_provider: str) -> str:
|
||||
# Batch / convenience helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def model_display_name(model_id: str) -> str:
|
||||
"""Return a short, human-readable display name for a model id.
|
||||
|
||||
Strips the vendor prefix (if any) for a cleaner display in menus
|
||||
and status bars, while preserving dots for readability.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> model_display_name("anthropic/claude-sonnet-4.6")
|
||||
'claude-sonnet-4.6'
|
||||
>>> model_display_name("claude-sonnet-4-6")
|
||||
'claude-sonnet-4-6'
|
||||
"""
|
||||
return _strip_vendor_prefix((model_id or "").strip())
|
||||
|
||||
|
||||
def is_aggregator_provider(provider: str) -> bool:
|
||||
"""Check if a provider is an aggregator that needs vendor/model format."""
|
||||
return (provider or "").strip().lower() in _AGGREGATOR_PROVIDERS
|
||||
|
||||
|
||||
def vendor_for_model(model_name: str) -> str:
|
||||
"""Return the vendor slug for a model, or ``""`` if unknown.
|
||||
|
||||
Convenience wrapper around :func:`detect_vendor` that never returns
|
||||
``None``.
|
||||
"""
|
||||
return detect_vendor(model_name) or ""
|
||||
|
||||
+92
-80
@@ -25,6 +25,7 @@ from dataclasses import dataclass
|
||||
from typing import List, NamedTuple, Optional
|
||||
|
||||
from hermes_cli.providers import (
|
||||
custom_provider_slug,
|
||||
determine_api_mode,
|
||||
get_label,
|
||||
is_aggregator,
|
||||
@@ -336,6 +337,7 @@ def resolve_alias(
|
||||
def get_authenticated_provider_slugs(
|
||||
current_provider: str = "",
|
||||
user_providers: dict = None,
|
||||
custom_providers: list | None = None,
|
||||
) -> list[str]:
|
||||
"""Return slugs of providers that have credentials.
|
||||
|
||||
@@ -346,6 +348,7 @@ def get_authenticated_provider_slugs(
|
||||
providers = list_authenticated_providers(
|
||||
current_provider=current_provider,
|
||||
user_providers=user_providers,
|
||||
custom_providers=custom_providers,
|
||||
max_models=0,
|
||||
)
|
||||
return [p["slug"] for p in providers]
|
||||
@@ -383,6 +386,7 @@ def switch_model(
|
||||
is_global: bool = False,
|
||||
explicit_provider: str = "",
|
||||
user_providers: dict = None,
|
||||
custom_providers: list | None = None,
|
||||
) -> ModelSwitchResult:
|
||||
"""Core model-switching pipeline shared between CLI and gateway.
|
||||
|
||||
@@ -416,6 +420,7 @@ def switch_model(
|
||||
is_global: Whether to persist the switch.
|
||||
explicit_provider: From --provider flag (empty = no explicit provider).
|
||||
user_providers: The ``providers:`` dict from config.yaml (for user endpoints).
|
||||
custom_providers: The ``custom_providers:`` list from config.yaml.
|
||||
|
||||
Returns:
|
||||
ModelSwitchResult with all information the caller needs.
|
||||
@@ -436,7 +441,11 @@ def switch_model(
|
||||
# =================================================================
|
||||
if explicit_provider:
|
||||
# Resolve the provider
|
||||
pdef = resolve_provider_full(explicit_provider, user_providers)
|
||||
pdef = resolve_provider_full(
|
||||
explicit_provider,
|
||||
user_providers,
|
||||
custom_providers,
|
||||
)
|
||||
if pdef is None:
|
||||
_switch_err = (
|
||||
f"Unknown provider '{explicit_provider}'. "
|
||||
@@ -516,6 +525,7 @@ def switch_model(
|
||||
authed = get_authenticated_provider_slugs(
|
||||
current_provider=current_provider,
|
||||
user_providers=user_providers,
|
||||
custom_providers=custom_providers,
|
||||
)
|
||||
fallback_result = _resolve_alias_fallback(raw_input, authed)
|
||||
if fallback_result is not None:
|
||||
@@ -590,6 +600,14 @@ def switch_model(
|
||||
|
||||
provider_changed = target_provider != current_provider
|
||||
provider_label = get_label(target_provider)
|
||||
if target_provider.startswith("custom:"):
|
||||
custom_pdef = resolve_provider_full(
|
||||
target_provider,
|
||||
user_providers,
|
||||
custom_providers,
|
||||
)
|
||||
if custom_pdef is not None:
|
||||
provider_label = custom_pdef.name
|
||||
|
||||
# --- Resolve credentials ---
|
||||
api_key = current_api_key
|
||||
@@ -708,6 +726,7 @@ def switch_model(
|
||||
def list_authenticated_providers(
|
||||
current_provider: str = "",
|
||||
user_providers: dict = None,
|
||||
custom_providers: list | None = None,
|
||||
max_models: int = 8,
|
||||
) -> List[dict]:
|
||||
"""Detect which providers have credentials and list their curated models.
|
||||
@@ -790,42 +809,69 @@ def list_authenticated_providers(
|
||||
})
|
||||
seen_slugs.add(slug)
|
||||
|
||||
# --- 2. Check Hermes-only providers (nous, openai-codex, copilot) ---
|
||||
# --- 2. Check Hermes-only providers (nous, openai-codex, copilot, opencode-go) ---
|
||||
from hermes_cli.providers import HERMES_OVERLAYS
|
||||
from hermes_cli.auth import PROVIDER_REGISTRY as _auth_registry
|
||||
|
||||
# Build reverse mapping: models.dev ID → Hermes provider ID.
|
||||
# HERMES_OVERLAYS keys may be models.dev IDs (e.g. "github-copilot")
|
||||
# while _PROVIDER_MODELS and config.yaml use Hermes IDs ("copilot").
|
||||
_mdev_to_hermes = {v: k for k, v in PROVIDER_TO_MODELS_DEV.items()}
|
||||
|
||||
for pid, overlay in HERMES_OVERLAYS.items():
|
||||
if pid in seen_slugs:
|
||||
continue
|
||||
|
||||
# Resolve Hermes slug — e.g. "github-copilot" → "copilot"
|
||||
hermes_slug = _mdev_to_hermes.get(pid, pid)
|
||||
if hermes_slug in seen_slugs:
|
||||
continue
|
||||
|
||||
# Check if credentials exist
|
||||
has_creds = False
|
||||
if overlay.extra_env_vars:
|
||||
has_creds = any(os.environ.get(ev) for ev in overlay.extra_env_vars)
|
||||
if overlay.auth_type in ("oauth_device_code", "oauth_external", "external_process"):
|
||||
# Also check api_key_env_vars from PROVIDER_REGISTRY for api_key auth_type
|
||||
if not has_creds and overlay.auth_type == "api_key":
|
||||
for _key in (pid, hermes_slug):
|
||||
pcfg = _auth_registry.get(_key)
|
||||
if pcfg and pcfg.api_key_env_vars:
|
||||
if any(os.environ.get(ev) for ev in pcfg.api_key_env_vars):
|
||||
has_creds = True
|
||||
break
|
||||
if not has_creds and overlay.auth_type in ("oauth_device_code", "oauth_external", "external_process"):
|
||||
# These use auth stores, not env vars — check for auth.json entries
|
||||
try:
|
||||
from hermes_cli.auth import _load_auth_store
|
||||
store = _load_auth_store()
|
||||
if store and (pid in store.get("providers", {}) or pid in store.get("credential_pool", {})):
|
||||
providers_store = store.get("providers", {})
|
||||
pool_store = store.get("credential_pool", {})
|
||||
if store and (
|
||||
pid in providers_store or hermes_slug in providers_store
|
||||
or pid in pool_store or hermes_slug in pool_store
|
||||
):
|
||||
has_creds = True
|
||||
except Exception as exc:
|
||||
logger.debug("Auth store check failed for %s: %s", pid, exc)
|
||||
if not has_creds:
|
||||
continue
|
||||
|
||||
# Use curated list
|
||||
model_ids = curated.get(pid, [])
|
||||
# Use curated list — look up by Hermes slug, fall back to overlay key
|
||||
model_ids = curated.get(hermes_slug, []) or curated.get(pid, [])
|
||||
total = len(model_ids)
|
||||
top = model_ids[:max_models]
|
||||
|
||||
results.append({
|
||||
"slug": pid,
|
||||
"name": get_label(pid),
|
||||
"is_current": pid == current_provider,
|
||||
"slug": hermes_slug,
|
||||
"name": get_label(hermes_slug),
|
||||
"is_current": hermes_slug == current_provider or pid == current_provider,
|
||||
"is_user_defined": False,
|
||||
"models": top,
|
||||
"total_models": total,
|
||||
"source": "hermes",
|
||||
})
|
||||
seen_slugs.add(pid)
|
||||
seen_slugs.add(hermes_slug)
|
||||
|
||||
# --- 3. User-defined endpoints from config ---
|
||||
if user_providers and isinstance(user_providers, dict):
|
||||
@@ -853,80 +899,46 @@ def list_authenticated_providers(
|
||||
"api_url": api_url,
|
||||
})
|
||||
|
||||
# --- 4. Saved custom providers from config ---
|
||||
if custom_providers and isinstance(custom_providers, list):
|
||||
for entry in custom_providers:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
|
||||
display_name = (entry.get("name") or "").strip()
|
||||
api_url = (
|
||||
entry.get("base_url", "")
|
||||
or entry.get("url", "")
|
||||
or entry.get("api", "")
|
||||
or ""
|
||||
).strip()
|
||||
if not display_name or not api_url:
|
||||
continue
|
||||
|
||||
slug = custom_provider_slug(display_name)
|
||||
if slug in seen_slugs:
|
||||
continue
|
||||
|
||||
models_list = []
|
||||
default_model = (entry.get("model") or "").strip()
|
||||
if default_model:
|
||||
models_list.append(default_model)
|
||||
|
||||
results.append({
|
||||
"slug": slug,
|
||||
"name": display_name,
|
||||
"is_current": slug == current_provider,
|
||||
"is_user_defined": True,
|
||||
"models": models_list,
|
||||
"total_models": len(models_list),
|
||||
"source": "user-config",
|
||||
"api_url": api_url,
|
||||
})
|
||||
seen_slugs.add(slug)
|
||||
|
||||
# Sort: current provider first, then by model count descending
|
||||
results.sort(key=lambda r: (not r["is_current"], -r["total_models"]))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fuzzy suggestions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def suggest_models(raw_input: str, limit: int = 3) -> List[str]:
|
||||
"""Return fuzzy model suggestions for a (possibly misspelled) input."""
|
||||
query = raw_input.strip()
|
||||
if not query:
|
||||
return []
|
||||
|
||||
results = search_models_dev(query, limit=limit)
|
||||
suggestions: list[str] = []
|
||||
for r in results:
|
||||
mid = r.get("model_id", "")
|
||||
if mid:
|
||||
suggestions.append(mid)
|
||||
|
||||
return suggestions[:limit]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Custom provider switch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def switch_to_custom_provider() -> CustomAutoResult:
|
||||
"""Handle bare '/model --provider custom' — resolve endpoint and auto-detect model."""
|
||||
from hermes_cli.runtime_provider import (
|
||||
resolve_runtime_provider,
|
||||
_auto_detect_local_model,
|
||||
)
|
||||
|
||||
try:
|
||||
runtime = resolve_runtime_provider(requested="custom")
|
||||
except Exception as e:
|
||||
return CustomAutoResult(
|
||||
success=False,
|
||||
error_message=f"Could not resolve custom endpoint: {e}",
|
||||
)
|
||||
|
||||
cust_base = runtime.get("base_url", "")
|
||||
cust_key = runtime.get("api_key", "")
|
||||
|
||||
if not cust_base or "openrouter.ai" in cust_base:
|
||||
return CustomAutoResult(
|
||||
success=False,
|
||||
error_message=(
|
||||
"No custom endpoint configured. "
|
||||
"Set model.base_url in config.yaml, or set OPENAI_BASE_URL "
|
||||
"in .env, or run: hermes setup -> Custom OpenAI-compatible endpoint"
|
||||
),
|
||||
)
|
||||
|
||||
detected_model = _auto_detect_local_model(cust_base)
|
||||
if not detected_model:
|
||||
return CustomAutoResult(
|
||||
success=False,
|
||||
base_url=cust_base,
|
||||
api_key=cust_key,
|
||||
error_message=(
|
||||
f"Custom endpoint at {cust_base} is reachable but no single "
|
||||
f"model was auto-detected. Specify the model explicitly: "
|
||||
f"/model <model-name> --provider custom"
|
||||
),
|
||||
)
|
||||
|
||||
return CustomAutoResult(
|
||||
success=True,
|
||||
model=detected_model,
|
||||
base_url=cust_base,
|
||||
api_key=cust_key,
|
||||
)
|
||||
|
||||
+174
-48
@@ -20,22 +20,20 @@ COPILOT_EDITOR_VERSION = "vscode/1.104.1"
|
||||
COPILOT_REASONING_EFFORTS_GPT5 = ["minimal", "low", "medium", "high"]
|
||||
COPILOT_REASONING_EFFORTS_O_SERIES = ["low", "medium", "high"]
|
||||
|
||||
# Backward-compatible aliases for the earlier GitHub Models-backed Copilot work.
|
||||
GITHUB_MODELS_BASE_URL = COPILOT_BASE_URL
|
||||
GITHUB_MODELS_CATALOG_URL = COPILOT_MODELS_URL
|
||||
|
||||
# Fallback OpenRouter snapshot used when the live catalog is unavailable.
|
||||
# (model_id, display description shown in menus)
|
||||
OPENROUTER_MODELS: list[tuple[str, str]] = [
|
||||
("anthropic/claude-opus-4.6", "recommended"),
|
||||
("anthropic/claude-sonnet-4.6", ""),
|
||||
("qwen/qwen3.6-plus:free", "free"),
|
||||
("qwen/qwen3.6-plus", ""),
|
||||
("anthropic/claude-sonnet-4.5", ""),
|
||||
("anthropic/claude-haiku-4.5", ""),
|
||||
("openai/gpt-5.4", ""),
|
||||
("openai/gpt-5.4-mini", ""),
|
||||
("xiaomi/mimo-v2-pro", ""),
|
||||
("openai/gpt-5.3-codex", ""),
|
||||
("google/gemini-3-pro-preview", ""),
|
||||
("google/gemini-3-pro-image-preview", ""),
|
||||
("google/gemini-3-flash-preview", ""),
|
||||
("google/gemini-3.1-pro-preview", ""),
|
||||
("google/gemini-3.1-flash-lite-preview", ""),
|
||||
@@ -47,7 +45,7 @@ OPENROUTER_MODELS: list[tuple[str, str]] = [
|
||||
("z-ai/glm-5.1", ""),
|
||||
("z-ai/glm-5-turbo", ""),
|
||||
("moonshotai/kimi-k2.5", ""),
|
||||
("x-ai/grok-4.20-beta", ""),
|
||||
("x-ai/grok-4.20", ""),
|
||||
("nvidia/nemotron-3-super-120b-a12b", ""),
|
||||
("nvidia/nemotron-3-super-120b-a12b:free", "free"),
|
||||
("arcee-ai/trinity-large-preview:free", "free"),
|
||||
@@ -56,6 +54,8 @@ OPENROUTER_MODELS: list[tuple[str, str]] = [
|
||||
("openai/gpt-5.4-nano", ""),
|
||||
]
|
||||
|
||||
_openrouter_catalog_cache: list[tuple[str, str]] | None = None
|
||||
|
||||
_PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"nous": [
|
||||
"anthropic/claude-opus-4.6",
|
||||
@@ -129,6 +129,19 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"glm-4.5",
|
||||
"glm-4.5-flash",
|
||||
],
|
||||
"xai": [
|
||||
"grok-4.20-0309-reasoning",
|
||||
"grok-4.20-0309-non-reasoning",
|
||||
"grok-4.20-multi-agent-0309",
|
||||
"grok-4-1-fast-reasoning",
|
||||
"grok-4-1-fast-non-reasoning",
|
||||
"grok-4-fast-reasoning",
|
||||
"grok-4-fast-non-reasoning",
|
||||
"grok-4-0709",
|
||||
"grok-code-fast-1",
|
||||
"grok-3",
|
||||
"grok-3-mini",
|
||||
],
|
||||
"kimi-coding": [
|
||||
"kimi-for-coding",
|
||||
"kimi-k2.5",
|
||||
@@ -416,12 +429,6 @@ _FREE_TIER_CACHE_TTL: int = 180 # seconds (3 minutes)
|
||||
_free_tier_cache: tuple[bool, float] | None = None # (result, timestamp)
|
||||
|
||||
|
||||
def clear_nous_free_tier_cache() -> None:
|
||||
"""Invalidate the cached free-tier result (e.g. after login/logout)."""
|
||||
global _free_tier_cache
|
||||
_free_tier_cache = None
|
||||
|
||||
|
||||
def check_nous_free_tier() -> bool:
|
||||
"""Check if the current Nous Portal user is on a free (unpaid) tier.
|
||||
|
||||
@@ -530,19 +537,84 @@ _PROVIDER_ALIASES = {
|
||||
}
|
||||
|
||||
|
||||
def model_ids() -> list[str]:
|
||||
def _openrouter_model_is_free(pricing: Any) -> bool:
|
||||
"""Return True when both prompt and completion pricing are zero."""
|
||||
if not isinstance(pricing, dict):
|
||||
return False
|
||||
try:
|
||||
return float(pricing.get("prompt", "0")) == 0 and float(pricing.get("completion", "0")) == 0
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
|
||||
def fetch_openrouter_models(
|
||||
timeout: float = 8.0,
|
||||
*,
|
||||
force_refresh: bool = False,
|
||||
) -> list[tuple[str, str]]:
|
||||
"""Return the curated OpenRouter picker list, refreshed from the live catalog when possible."""
|
||||
global _openrouter_catalog_cache
|
||||
|
||||
if _openrouter_catalog_cache is not None and not force_refresh:
|
||||
return list(_openrouter_catalog_cache)
|
||||
|
||||
fallback = list(OPENROUTER_MODELS)
|
||||
preferred_ids = [mid for mid, _ in fallback]
|
||||
|
||||
try:
|
||||
req = urllib.request.Request(
|
||||
"https://openrouter.ai/api/v1/models",
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||
payload = json.loads(resp.read().decode())
|
||||
except Exception:
|
||||
return list(_openrouter_catalog_cache or fallback)
|
||||
|
||||
live_items = payload.get("data", [])
|
||||
if not isinstance(live_items, list):
|
||||
return list(_openrouter_catalog_cache or fallback)
|
||||
|
||||
live_by_id: dict[str, dict[str, Any]] = {}
|
||||
for item in live_items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
mid = str(item.get("id") or "").strip()
|
||||
if not mid:
|
||||
continue
|
||||
live_by_id[mid] = item
|
||||
|
||||
curated: list[tuple[str, str]] = []
|
||||
for preferred_id in preferred_ids:
|
||||
live_item = live_by_id.get(preferred_id)
|
||||
if live_item is None:
|
||||
continue
|
||||
desc = "free" if _openrouter_model_is_free(live_item.get("pricing")) else ""
|
||||
curated.append((preferred_id, desc))
|
||||
|
||||
if not curated:
|
||||
return list(_openrouter_catalog_cache or fallback)
|
||||
|
||||
first_id, _ = curated[0]
|
||||
curated[0] = (first_id, "recommended")
|
||||
_openrouter_catalog_cache = curated
|
||||
return list(curated)
|
||||
|
||||
|
||||
def model_ids(*, force_refresh: bool = False) -> list[str]:
|
||||
"""Return just the OpenRouter model-id strings."""
|
||||
return [mid for mid, _ in OPENROUTER_MODELS]
|
||||
return [mid for mid, _ in fetch_openrouter_models(force_refresh=force_refresh)]
|
||||
|
||||
|
||||
def menu_labels() -> list[str]:
|
||||
def menu_labels(*, force_refresh: bool = False) -> list[str]:
|
||||
"""Return display labels like 'anthropic/claude-opus-4.6 (recommended)'."""
|
||||
labels = []
|
||||
for mid, desc in OPENROUTER_MODELS:
|
||||
for mid, desc in fetch_openrouter_models(force_refresh=force_refresh):
|
||||
labels.append(f"{mid} ({desc})" if desc else mid)
|
||||
return labels
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pricing helpers — fetch live pricing from OpenRouter-compatible /v1/models
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -575,31 +647,6 @@ def _format_price_per_mtok(per_token_str: str) -> str:
|
||||
return f"${per_m:.2f}"
|
||||
|
||||
|
||||
def format_pricing_label(pricing: dict[str, str] | None) -> str:
|
||||
"""Build a compact pricing label like 'in $3 · out $15 · cache $0.30/Mtok'.
|
||||
|
||||
Returns empty string when pricing is unavailable.
|
||||
"""
|
||||
if not pricing:
|
||||
return ""
|
||||
prompt_price = pricing.get("prompt", "")
|
||||
completion_price = pricing.get("completion", "")
|
||||
if not prompt_price and not completion_price:
|
||||
return ""
|
||||
inp = _format_price_per_mtok(prompt_price)
|
||||
out = _format_price_per_mtok(completion_price)
|
||||
if inp == "free" and out == "free":
|
||||
return "free"
|
||||
cache_read = pricing.get("input_cache_read", "")
|
||||
cache_str = _format_price_per_mtok(cache_read) if cache_read else ""
|
||||
if inp == out and not cache_str:
|
||||
return f"{inp}/Mtok"
|
||||
parts = [f"in {inp}", f"out {out}"]
|
||||
if cache_str and cache_str != "?" and cache_str != inp:
|
||||
parts.append(f"cache {cache_str}")
|
||||
return " · ".join(parts) + "/Mtok"
|
||||
|
||||
|
||||
def format_model_pricing_table(
|
||||
models: list[tuple[str, str]],
|
||||
pricing_map: dict[str, dict[str, str]],
|
||||
@@ -727,13 +774,14 @@ def _resolve_nous_pricing_credentials() -> tuple[str, str]:
|
||||
return ("", "")
|
||||
|
||||
|
||||
def get_pricing_for_provider(provider: str) -> dict[str, dict[str, str]]:
|
||||
def get_pricing_for_provider(provider: str, *, force_refresh: bool = False) -> dict[str, dict[str, str]]:
|
||||
"""Return live pricing for providers that support it (openrouter, nous)."""
|
||||
normalized = normalize_provider(provider)
|
||||
if normalized == "openrouter":
|
||||
return fetch_models_with_pricing(
|
||||
api_key=_resolve_openrouter_api_key(),
|
||||
base_url="https://openrouter.ai/api",
|
||||
force_refresh=force_refresh,
|
||||
)
|
||||
if normalized == "nous":
|
||||
api_key, base_url = _resolve_nous_pricing_credentials()
|
||||
@@ -746,6 +794,7 @@ def get_pricing_for_provider(provider: str) -> dict[str, dict[str, str]]:
|
||||
return fetch_models_with_pricing(
|
||||
api_key=api_key,
|
||||
base_url=stripped,
|
||||
force_refresh=force_refresh,
|
||||
)
|
||||
return {}
|
||||
|
||||
@@ -854,7 +903,11 @@ def _get_custom_base_url() -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def curated_models_for_provider(provider: Optional[str]) -> list[tuple[str, str]]:
|
||||
def curated_models_for_provider(
|
||||
provider: Optional[str],
|
||||
*,
|
||||
force_refresh: bool = False,
|
||||
) -> list[tuple[str, str]]:
|
||||
"""Return ``(model_id, description)`` tuples for a provider's model list.
|
||||
|
||||
Tries to fetch the live model list from the provider's API first,
|
||||
@@ -863,7 +916,7 @@ def curated_models_for_provider(provider: Optional[str]) -> list[tuple[str, str]
|
||||
"""
|
||||
normalized = normalize_provider(provider)
|
||||
if normalized == "openrouter":
|
||||
return list(OPENROUTER_MODELS)
|
||||
return fetch_openrouter_models(force_refresh=force_refresh)
|
||||
|
||||
# Try live API first (Codex, Nous, etc. all support /models)
|
||||
live = provider_model_ids(normalized)
|
||||
@@ -982,12 +1035,12 @@ def _find_openrouter_slug(model_name: str) -> Optional[str]:
|
||||
return None
|
||||
|
||||
# Exact match (already has provider/ prefix)
|
||||
for mid, _ in OPENROUTER_MODELS:
|
||||
for mid in model_ids():
|
||||
if name_lower == mid.lower():
|
||||
return mid
|
||||
|
||||
# Try matching just the model part (after the /)
|
||||
for mid, _ in OPENROUTER_MODELS:
|
||||
for mid in model_ids():
|
||||
if "/" in mid:
|
||||
_, model_part = mid.split("/", 1)
|
||||
if name_lower == model_part.lower():
|
||||
@@ -1017,6 +1070,79 @@ def provider_label(provider: Optional[str]) -> str:
|
||||
return _PROVIDER_LABELS.get(normalized, original or "OpenRouter")
|
||||
|
||||
|
||||
# Models that support OpenAI Priority Processing (service_tier="priority").
|
||||
# See https://openai.com/api-priority-processing/ for the canonical list.
|
||||
# Only the bare model slug is stored (no vendor prefix).
|
||||
_PRIORITY_PROCESSING_MODELS: frozenset[str] = frozenset({
|
||||
"gpt-5.4",
|
||||
"gpt-5.4-mini",
|
||||
"gpt-5.2",
|
||||
"gpt-5.1",
|
||||
"gpt-5",
|
||||
"gpt-5-mini",
|
||||
"gpt-4.1",
|
||||
"gpt-4.1-mini",
|
||||
"gpt-4.1-nano",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"o3",
|
||||
"o4-mini",
|
||||
})
|
||||
|
||||
# Models that support Anthropic Fast Mode (speed="fast").
|
||||
# See https://platform.claude.com/docs/en/build-with-claude/fast-mode
|
||||
# Currently only Claude Opus 4.6. Both hyphen and dot variants are stored
|
||||
# to handle native Anthropic (claude-opus-4-6) and OpenRouter (claude-opus-4.6).
|
||||
_ANTHROPIC_FAST_MODE_MODELS: frozenset[str] = frozenset({
|
||||
"claude-opus-4-6",
|
||||
"claude-opus-4.6",
|
||||
})
|
||||
|
||||
|
||||
def _strip_vendor_prefix(model_id: str) -> str:
|
||||
"""Strip vendor/ prefix from a model ID (e.g. 'anthropic/claude-opus-4-6' -> 'claude-opus-4-6')."""
|
||||
raw = str(model_id or "").strip().lower()
|
||||
if "/" in raw:
|
||||
raw = raw.split("/", 1)[1]
|
||||
return raw
|
||||
|
||||
|
||||
def model_supports_fast_mode(model_id: Optional[str]) -> bool:
|
||||
"""Return whether Hermes should expose the /fast toggle for this model."""
|
||||
raw = _strip_vendor_prefix(str(model_id or ""))
|
||||
if raw in _PRIORITY_PROCESSING_MODELS:
|
||||
return True
|
||||
# Anthropic fast mode — strip date suffixes (e.g. claude-opus-4-6-20260401)
|
||||
# and OpenRouter variant tags (:fast, :beta) for matching.
|
||||
base = raw.split(":")[0]
|
||||
return base in _ANTHROPIC_FAST_MODE_MODELS
|
||||
|
||||
|
||||
def _is_anthropic_fast_model(model_id: Optional[str]) -> bool:
|
||||
"""Return True if the model supports Anthropic's fast mode (speed='fast')."""
|
||||
raw = _strip_vendor_prefix(str(model_id or ""))
|
||||
base = raw.split(":")[0]
|
||||
return base in _ANTHROPIC_FAST_MODE_MODELS
|
||||
|
||||
|
||||
def resolve_fast_mode_overrides(model_id: Optional[str]) -> dict[str, Any] | None:
|
||||
"""Return request_overrides for fast/priority mode, or None if unsupported.
|
||||
|
||||
Returns provider-appropriate overrides:
|
||||
- OpenAI models: ``{"service_tier": "priority"}`` (Priority Processing)
|
||||
- Anthropic models: ``{"speed": "fast"}`` (Anthropic Fast Mode beta)
|
||||
|
||||
The overrides are injected into the API request kwargs by
|
||||
``_build_api_kwargs`` in run_agent.py — each API path handles its own
|
||||
keys (service_tier for OpenAI/Codex, speed for Anthropic Messages).
|
||||
"""
|
||||
if not model_supports_fast_mode(model_id):
|
||||
return None
|
||||
if _is_anthropic_fast_model(model_id):
|
||||
return {"speed": "fast"}
|
||||
return {"service_tier": "priority"}
|
||||
|
||||
|
||||
def _resolve_copilot_catalog_api_key() -> str:
|
||||
"""Best-effort GitHub token for fetching the Copilot model catalog."""
|
||||
try:
|
||||
@@ -1028,7 +1154,7 @@ def _resolve_copilot_catalog_api_key() -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def provider_model_ids(provider: Optional[str]) -> list[str]:
|
||||
def provider_model_ids(provider: Optional[str], *, force_refresh: bool = False) -> list[str]:
|
||||
"""Return the best known model catalog for a provider.
|
||||
|
||||
Tries live API endpoints for providers that support them (Codex, Nous),
|
||||
@@ -1036,7 +1162,7 @@ def provider_model_ids(provider: Optional[str]) -> list[str]:
|
||||
"""
|
||||
normalized = normalize_provider(provider)
|
||||
if normalized == "openrouter":
|
||||
return model_ids()
|
||||
return model_ids(force_refresh=force_refresh)
|
||||
if normalized == "openai-codex":
|
||||
from hermes_cli.codex_models import get_codex_model_ids
|
||||
|
||||
|
||||
+21
-6
@@ -42,6 +42,11 @@ _PROFILE_DIRS = [
|
||||
"plans",
|
||||
"workspace",
|
||||
"cron",
|
||||
# Per-profile HOME for subprocesses: isolates system tool configs (git,
|
||||
# ssh, gh, npm …) so credentials don't bleed between profiles. In Docker
|
||||
# this also ensures tool configs land inside the persistent volume.
|
||||
# See hermes_constants.get_subprocess_home() and issue #4426.
|
||||
"home",
|
||||
]
|
||||
|
||||
# Files copied during --clone (if they exist in the source)
|
||||
@@ -115,16 +120,26 @@ _HERMES_SUBCOMMANDS = frozenset({
|
||||
def _get_profiles_root() -> Path:
|
||||
"""Return the directory where named profiles are stored.
|
||||
|
||||
Always ``~/.hermes/profiles/`` — anchored to the user's home,
|
||||
NOT to the current HERMES_HOME (which may itself be a profile).
|
||||
This ensures ``coder profile list`` can see all profiles.
|
||||
Anchored to the hermes root, NOT to the current HERMES_HOME
|
||||
(which may itself be a profile). This ensures ``coder profile list``
|
||||
can see all profiles.
|
||||
|
||||
In Docker/custom deployments where HERMES_HOME points outside
|
||||
``~/.hermes``, profiles live under ``HERMES_HOME/profiles/`` so
|
||||
they persist on the mounted volume.
|
||||
"""
|
||||
return Path.home() / ".hermes" / "profiles"
|
||||
return _get_default_hermes_home() / "profiles"
|
||||
|
||||
|
||||
def _get_default_hermes_home() -> Path:
|
||||
"""Return the default (pre-profile) HERMES_HOME path."""
|
||||
return Path.home() / ".hermes"
|
||||
"""Return the default (pre-profile) HERMES_HOME path.
|
||||
|
||||
In standard deployments this is ``~/.hermes``.
|
||||
In Docker/custom deployments where HERMES_HOME is outside ``~/.hermes``
|
||||
(e.g. ``/opt/data``), returns HERMES_HOME directly.
|
||||
"""
|
||||
from hermes_constants import get_default_hermes_root
|
||||
return get_default_hermes_root()
|
||||
|
||||
|
||||
def _get_active_profile_path() -> Path:
|
||||
|
||||
+70
-40
@@ -127,6 +127,11 @@ HERMES_OVERLAYS: Dict[str, HermesOverlay] = {
|
||||
is_aggregator=True,
|
||||
base_url_env_var="HF_BASE_URL",
|
||||
),
|
||||
"xai": HermesOverlay(
|
||||
transport="openai_chat",
|
||||
base_url_override="https://api.x.ai/v1",
|
||||
base_url_env_var="XAI_BASE_URL",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -148,10 +153,6 @@ class ProviderDef:
|
||||
doc: str = ""
|
||||
source: str = "" # "models.dev", "hermes", "user-config"
|
||||
|
||||
@property
|
||||
def is_user_defined(self) -> bool:
|
||||
return self.source == "user-config"
|
||||
|
||||
|
||||
# -- Aliases ------------------------------------------------------------------
|
||||
# Maps human-friendly / legacy names to canonical provider IDs.
|
||||
@@ -167,6 +168,10 @@ ALIASES: Dict[str, str] = {
|
||||
"z.ai": "zai",
|
||||
"zhipu": "zai",
|
||||
|
||||
# xai
|
||||
"x-ai": "xai",
|
||||
"x.ai": "xai",
|
||||
|
||||
# kimi-for-coding (models.dev ID)
|
||||
"kimi": "kimi-for-coding",
|
||||
"kimi-coding": "kimi-for-coding",
|
||||
@@ -262,12 +267,6 @@ def normalize_provider(name: str) -> str:
|
||||
return ALIASES.get(key, key)
|
||||
|
||||
|
||||
def get_overlay(provider_id: str) -> Optional[HermesOverlay]:
|
||||
"""Get Hermes overlay for a provider, if one exists."""
|
||||
canonical = normalize_provider(provider_id)
|
||||
return HERMES_OVERLAYS.get(canonical)
|
||||
|
||||
|
||||
def get_provider(name: str) -> Optional[ProviderDef]:
|
||||
"""Look up a provider by id or alias, merging all data sources.
|
||||
|
||||
@@ -350,36 +349,6 @@ def get_label(provider_id: str) -> str:
|
||||
return canonical
|
||||
|
||||
|
||||
# For direct import compat, expose as module-level dict
|
||||
# Built on demand by get_label() calls
|
||||
LABELS: Dict[str, str] = {
|
||||
# Static entries for backward compat — get_label() is the proper API
|
||||
"openrouter": "OpenRouter",
|
||||
"nous": "Nous Portal",
|
||||
"openai-codex": "OpenAI Codex",
|
||||
"copilot-acp": "GitHub Copilot ACP",
|
||||
"github-copilot": "GitHub Copilot",
|
||||
"anthropic": "Anthropic",
|
||||
"zai": "Z.AI / GLM",
|
||||
"kimi-for-coding": "Kimi / Moonshot",
|
||||
"minimax": "MiniMax",
|
||||
"minimax-cn": "MiniMax (China)",
|
||||
"deepseek": "DeepSeek",
|
||||
"alibaba": "Alibaba Cloud (DashScope)",
|
||||
"vercel": "Vercel AI Gateway",
|
||||
"opencode": "OpenCode Zen",
|
||||
"opencode-go": "OpenCode Go",
|
||||
"kilo": "Kilo Gateway",
|
||||
"huggingface": "Hugging Face",
|
||||
"local": "Local endpoint",
|
||||
"custom": "Custom endpoint",
|
||||
# Legacy Hermes IDs (point to same providers)
|
||||
"ai-gateway": "Vercel AI Gateway",
|
||||
"kilocode": "Kilo Gateway",
|
||||
"copilot": "GitHub Copilot",
|
||||
"kimi-coding": "Kimi / Moonshot",
|
||||
"opencode-zen": "OpenCode Zen",
|
||||
}
|
||||
|
||||
|
||||
def is_aggregator(provider: str) -> bool:
|
||||
@@ -452,9 +421,64 @@ def resolve_user_provider(name: str, user_config: Dict[str, Any]) -> Optional[Pr
|
||||
)
|
||||
|
||||
|
||||
def custom_provider_slug(display_name: str) -> str:
|
||||
"""Build a canonical slug for a custom_providers entry.
|
||||
|
||||
Matches the convention used by runtime_provider and credential_pool
|
||||
(``custom:<normalized-name>``). Centralised here so all call-sites
|
||||
produce identical slugs.
|
||||
"""
|
||||
return "custom:" + display_name.strip().lower().replace(" ", "-")
|
||||
|
||||
|
||||
def resolve_custom_provider(
|
||||
name: str,
|
||||
custom_providers: Optional[List[Dict[str, Any]]],
|
||||
) -> Optional[ProviderDef]:
|
||||
"""Resolve a provider from the user's config.yaml ``custom_providers`` list."""
|
||||
if not custom_providers or not isinstance(custom_providers, list):
|
||||
return None
|
||||
|
||||
requested = (name or "").strip().lower()
|
||||
if not requested:
|
||||
return None
|
||||
|
||||
for entry in custom_providers:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
|
||||
display_name = (entry.get("name") or "").strip()
|
||||
api_url = (
|
||||
entry.get("base_url", "")
|
||||
or entry.get("url", "")
|
||||
or entry.get("api", "")
|
||||
or ""
|
||||
).strip()
|
||||
if not display_name or not api_url:
|
||||
continue
|
||||
|
||||
slug = custom_provider_slug(display_name)
|
||||
if requested not in {display_name.lower(), slug}:
|
||||
continue
|
||||
|
||||
return ProviderDef(
|
||||
id=slug,
|
||||
name=display_name,
|
||||
transport="openai_chat",
|
||||
api_key_env_vars=(),
|
||||
base_url=api_url,
|
||||
is_aggregator=False,
|
||||
auth_type="api_key",
|
||||
source="user-config",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def resolve_provider_full(
|
||||
name: str,
|
||||
user_providers: Optional[Dict[str, Any]] = None,
|
||||
custom_providers: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> Optional[ProviderDef]:
|
||||
"""Full resolution chain: built-in → models.dev → user config.
|
||||
|
||||
@@ -463,6 +487,7 @@ def resolve_provider_full(
|
||||
Args:
|
||||
name: Provider name or alias.
|
||||
user_providers: The ``providers:`` dict from config.yaml (optional).
|
||||
custom_providers: The ``custom_providers:`` list from config.yaml (optional).
|
||||
|
||||
Returns:
|
||||
ProviderDef if found, else None.
|
||||
@@ -485,6 +510,11 @@ def resolve_provider_full(
|
||||
if user_pdef is not None:
|
||||
return user_pdef
|
||||
|
||||
# 2b. Saved custom providers from config
|
||||
custom_pdef = resolve_custom_provider(name, custom_providers)
|
||||
if custom_pdef is not None:
|
||||
return custom_pdef
|
||||
|
||||
# 3. Try models.dev directly (for providers not in our ALIASES)
|
||||
try:
|
||||
from agent.models_dev import get_provider_info as _mdev_provider
|
||||
|
||||
@@ -16,6 +16,7 @@ from hermes_cli.auth import (
|
||||
DEFAULT_CODEX_BASE_URL,
|
||||
DEFAULT_QWEN_BASE_URL,
|
||||
PROVIDER_REGISTRY,
|
||||
_agent_key_is_usable,
|
||||
format_auth_error,
|
||||
resolve_provider,
|
||||
resolve_nous_runtime_credentials,
|
||||
@@ -644,6 +645,21 @@ def resolve_runtime_provider(
|
||||
getattr(entry, "runtime_api_key", None)
|
||||
or getattr(entry, "access_token", "")
|
||||
)
|
||||
# For Nous, the pool entry's runtime_api_key is the agent_key — a
|
||||
# short-lived inference credential (~30 min TTL). The pool doesn't
|
||||
# refresh it during selection (that would trigger network calls in
|
||||
# non-runtime contexts like `hermes auth list`). If the key is
|
||||
# expired, clear pool_api_key so we fall through to
|
||||
# resolve_nous_runtime_credentials() which handles refresh + mint.
|
||||
if provider == "nous" and entry is not None and pool_api_key:
|
||||
min_ttl = max(60, int(os.getenv("HERMES_NOUS_MIN_KEY_TTL_SECONDS", "1800")))
|
||||
nous_state = {
|
||||
"agent_key": getattr(entry, "agent_key", None),
|
||||
"agent_key_expires_at": getattr(entry, "agent_key_expires_at", None),
|
||||
}
|
||||
if not _agent_key_is_usable(nous_state, min_ttl):
|
||||
logger.debug("Nous pool entry agent_key expired/missing, falling through to runtime resolution")
|
||||
pool_api_key = ""
|
||||
if entry is not None and pool_api_key:
|
||||
return _resolve_runtime_from_pool_entry(
|
||||
provider=provider,
|
||||
|
||||
+24
-157
@@ -16,6 +16,7 @@ import logging
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import copy
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
@@ -172,150 +173,10 @@ def _setup_copilot_reasoning_selection(
|
||||
_set_reasoning_effort(config, "none")
|
||||
|
||||
|
||||
def _setup_provider_model_selection(config, provider_id, current_model, prompt_choice, prompt_fn):
|
||||
"""Model selection for API-key providers with live /models detection.
|
||||
|
||||
Tries the provider's /models endpoint first. Falls back to a
|
||||
hardcoded default list with a warning if the endpoint is unreachable.
|
||||
Always offers a 'Custom model' escape hatch.
|
||||
"""
|
||||
from hermes_cli.auth import PROVIDER_REGISTRY, resolve_api_key_provider_credentials
|
||||
from hermes_cli.config import get_env_value
|
||||
from hermes_cli.models import (
|
||||
copilot_model_api_mode,
|
||||
fetch_api_models,
|
||||
fetch_github_model_catalog,
|
||||
normalize_copilot_model_id,
|
||||
normalize_opencode_model_id,
|
||||
opencode_model_api_mode,
|
||||
)
|
||||
|
||||
pconfig = PROVIDER_REGISTRY[provider_id]
|
||||
is_copilot_catalog_provider = provider_id in {"copilot", "copilot-acp"}
|
||||
|
||||
# Resolve API key and base URL for the probe
|
||||
if is_copilot_catalog_provider:
|
||||
api_key = ""
|
||||
if provider_id == "copilot":
|
||||
creds = resolve_api_key_provider_credentials(provider_id)
|
||||
api_key = creds.get("api_key", "")
|
||||
base_url = creds.get("base_url", "") or pconfig.inference_base_url
|
||||
else:
|
||||
try:
|
||||
creds = resolve_api_key_provider_credentials("copilot")
|
||||
api_key = creds.get("api_key", "")
|
||||
except Exception:
|
||||
pass
|
||||
base_url = pconfig.inference_base_url
|
||||
catalog = fetch_github_model_catalog(api_key)
|
||||
current_model = normalize_copilot_model_id(
|
||||
current_model,
|
||||
catalog=catalog,
|
||||
api_key=api_key,
|
||||
) or current_model
|
||||
else:
|
||||
api_key = ""
|
||||
for ev in pconfig.api_key_env_vars:
|
||||
api_key = get_env_value(ev) or os.getenv(ev, "")
|
||||
if api_key:
|
||||
break
|
||||
base_url_env = pconfig.base_url_env_var or ""
|
||||
base_url = (get_env_value(base_url_env) if base_url_env else "") or pconfig.inference_base_url
|
||||
catalog = None
|
||||
|
||||
# Try live /models endpoint
|
||||
if is_copilot_catalog_provider and catalog:
|
||||
live_models = [item.get("id", "") for item in catalog if item.get("id")]
|
||||
else:
|
||||
live_models = fetch_api_models(api_key, base_url)
|
||||
|
||||
if live_models:
|
||||
provider_models = live_models
|
||||
print_info(f"Found {len(live_models)} model(s) from {pconfig.name} API")
|
||||
else:
|
||||
fallback_provider_id = "copilot" if provider_id == "copilot-acp" else provider_id
|
||||
provider_models = _DEFAULT_PROVIDER_MODELS.get(fallback_provider_id, [])
|
||||
if provider_models:
|
||||
print_warning(
|
||||
f"Could not auto-detect models from {pconfig.name} API — showing defaults.\n"
|
||||
f" Use \"Custom model\" if the model you expect isn't listed."
|
||||
)
|
||||
|
||||
if provider_id in {"opencode-zen", "opencode-go"}:
|
||||
provider_models = [normalize_opencode_model_id(provider_id, mid) for mid in provider_models]
|
||||
current_model = normalize_opencode_model_id(provider_id, current_model)
|
||||
provider_models = list(dict.fromkeys(mid for mid in provider_models if mid))
|
||||
|
||||
model_choices = list(provider_models)
|
||||
model_choices.append("Custom model")
|
||||
model_choices.append(f"Keep current ({current_model})")
|
||||
|
||||
keep_idx = len(model_choices) - 1
|
||||
model_idx = prompt_choice("Select default model:", model_choices, keep_idx)
|
||||
|
||||
selected_model = current_model
|
||||
|
||||
if model_idx < len(provider_models):
|
||||
selected_model = provider_models[model_idx]
|
||||
if is_copilot_catalog_provider:
|
||||
selected_model = normalize_copilot_model_id(
|
||||
selected_model,
|
||||
catalog=catalog,
|
||||
api_key=api_key,
|
||||
) or selected_model
|
||||
elif provider_id in {"opencode-zen", "opencode-go"}:
|
||||
selected_model = normalize_opencode_model_id(provider_id, selected_model)
|
||||
_set_default_model(config, selected_model)
|
||||
elif model_idx == len(provider_models):
|
||||
custom = prompt_fn("Enter model name")
|
||||
if custom:
|
||||
if is_copilot_catalog_provider:
|
||||
selected_model = normalize_copilot_model_id(
|
||||
custom,
|
||||
catalog=catalog,
|
||||
api_key=api_key,
|
||||
) or custom
|
||||
elif provider_id in {"opencode-zen", "opencode-go"}:
|
||||
selected_model = normalize_opencode_model_id(provider_id, custom)
|
||||
else:
|
||||
selected_model = custom
|
||||
_set_default_model(config, selected_model)
|
||||
else:
|
||||
# "Keep current" selected — validate it's compatible with the new
|
||||
# provider. OpenRouter-formatted names (containing "/") won't work
|
||||
# on direct-API providers and would silently break the gateway.
|
||||
if "/" in (current_model or "") and provider_models:
|
||||
print_warning(
|
||||
f"Current model \"{current_model}\" looks like an OpenRouter model "
|
||||
f"and won't work with {pconfig.name}. "
|
||||
f"Switching to {provider_models[0]}."
|
||||
)
|
||||
selected_model = provider_models[0]
|
||||
_set_default_model(config, provider_models[0])
|
||||
|
||||
if provider_id == "copilot" and selected_model:
|
||||
model_cfg = _model_config_dict(config)
|
||||
model_cfg["api_mode"] = copilot_model_api_mode(
|
||||
selected_model,
|
||||
catalog=catalog,
|
||||
api_key=api_key,
|
||||
)
|
||||
config["model"] = model_cfg
|
||||
_setup_copilot_reasoning_selection(
|
||||
config,
|
||||
selected_model,
|
||||
prompt_choice,
|
||||
catalog=catalog,
|
||||
api_key=api_key,
|
||||
)
|
||||
elif provider_id in {"opencode-zen", "opencode-go"} and selected_model:
|
||||
model_cfg = _model_config_dict(config)
|
||||
model_cfg["api_mode"] = opencode_model_api_mode(provider_id, selected_model)
|
||||
config["model"] = model_cfg
|
||||
|
||||
|
||||
# Import config helpers
|
||||
from hermes_cli.config import (
|
||||
DEFAULT_CONFIG,
|
||||
get_hermes_home,
|
||||
get_config_path,
|
||||
get_env_path,
|
||||
@@ -477,6 +338,8 @@ def _curses_prompt_choice(question: str, choices: list, default: int = 0) -> int
|
||||
return
|
||||
|
||||
curses.wrapper(_curses_menu)
|
||||
from hermes_cli.curses_ui import flush_stdin
|
||||
flush_stdin()
|
||||
return result_holder[0]
|
||||
except Exception:
|
||||
return -1
|
||||
@@ -921,8 +784,10 @@ def setup_model_provider(config: dict, *, quick: bool = False):
|
||||
# changes with stale values (#4172).
|
||||
_refreshed = load_config()
|
||||
config["model"] = _refreshed.get("model", config.get("model"))
|
||||
if _refreshed.get("custom_providers"):
|
||||
if "custom_providers" in _refreshed:
|
||||
config["custom_providers"] = _refreshed["custom_providers"]
|
||||
else:
|
||||
config.pop("custom_providers", None)
|
||||
|
||||
# Derive the selected provider for downstream steps (vision setup).
|
||||
selected_provider = None
|
||||
@@ -1006,8 +871,6 @@ def setup_model_provider(config: dict, *, quick: bool = False):
|
||||
strategy_value = ["fill_first", "round_robin", "random"][strategy_idx]
|
||||
_set_credential_pool_strategy(config, selected_provider, strategy_value)
|
||||
print_success(f"Saved {selected_provider} rotation strategy: {strategy_value}")
|
||||
else:
|
||||
_set_credential_pool_strategy(config, selected_provider, "fill_first")
|
||||
except Exception as exc:
|
||||
logger.debug("Could not configure same-provider fallback in setup: %s", exc)
|
||||
|
||||
@@ -2167,6 +2030,12 @@ def _setup_whatsapp():
|
||||
print_info("or personal self-chat) and pair via QR code.")
|
||||
|
||||
|
||||
def _setup_weixin():
|
||||
"""Configure Weixin (personal WeChat) via iLink Bot API QR login."""
|
||||
from hermes_cli.gateway import _setup_weixin as _gateway_setup_weixin
|
||||
_gateway_setup_weixin()
|
||||
|
||||
|
||||
def _setup_bluebubbles():
|
||||
"""Configure BlueBubbles iMessage gateway."""
|
||||
print_header("BlueBubbles (iMessage)")
|
||||
@@ -2286,6 +2155,7 @@ _GATEWAY_PLATFORMS = [
|
||||
("Matrix", "MATRIX_ACCESS_TOKEN", _setup_matrix),
|
||||
("Mattermost", "MATTERMOST_TOKEN", _setup_mattermost),
|
||||
("WhatsApp", "WHATSAPP_ENABLED", _setup_whatsapp),
|
||||
("Weixin (WeChat)", "WEIXIN_ACCOUNT_ID", _setup_weixin),
|
||||
("BlueBubbles (iMessage)", "BLUEBUBBLES_SERVER_URL", _setup_bluebubbles),
|
||||
("Webhooks (GitHub, GitLab, etc.)", "WEBHOOK_ENABLED", _setup_webhooks),
|
||||
]
|
||||
@@ -2844,6 +2714,7 @@ def run_setup_wizard(args):
|
||||
Supports full, quick, and section-specific setup:
|
||||
hermes setup — full or quick (auto-detected)
|
||||
hermes setup model — just model/provider
|
||||
hermes setup tts — just text-to-speech
|
||||
hermes setup terminal — just terminal backend
|
||||
hermes setup gateway — just messaging platforms
|
||||
hermes setup tools — just tool configuration
|
||||
@@ -2855,6 +2726,11 @@ def run_setup_wizard(args):
|
||||
return
|
||||
ensure_hermes_home()
|
||||
|
||||
reset_requested = bool(getattr(args, "reset", False))
|
||||
if reset_requested:
|
||||
save_config(copy.deepcopy(DEFAULT_CONFIG))
|
||||
print_success("Configuration reset to defaults.")
|
||||
|
||||
config = load_config()
|
||||
hermes_home = get_hermes_home()
|
||||
|
||||
@@ -2955,18 +2831,13 @@ def run_setup_wizard(args):
|
||||
menu_choices = [
|
||||
"Quick Setup - configure missing items only",
|
||||
"Full Setup - reconfigure everything",
|
||||
"---",
|
||||
"Model & Provider",
|
||||
"Terminal Backend",
|
||||
"Messaging Platforms (Gateway)",
|
||||
"Tools",
|
||||
"Agent Settings",
|
||||
"---",
|
||||
"Exit",
|
||||
]
|
||||
|
||||
# Separator indices (not selectable, but prompt_choice doesn't filter them,
|
||||
# so we handle them below)
|
||||
choice = prompt_choice("What would you like to do?", menu_choices, 0)
|
||||
|
||||
if choice == 0:
|
||||
@@ -2976,18 +2847,14 @@ def run_setup_wizard(args):
|
||||
elif choice == 1:
|
||||
# Full setup — fall through to run all sections
|
||||
pass
|
||||
elif choice in (2, 8):
|
||||
# Separator — treat as exit
|
||||
elif choice == 7:
|
||||
print_info("Exiting. Run 'hermes setup' again when ready.")
|
||||
return
|
||||
elif choice == 9:
|
||||
print_info("Exiting. Run 'hermes setup' again when ready.")
|
||||
return
|
||||
elif 3 <= choice <= 7:
|
||||
elif 2 <= choice <= 6:
|
||||
# Individual section — map by key, not by position.
|
||||
# SETUP_SECTIONS includes TTS but the returning-user menu skips it,
|
||||
# so positional indexing (choice - 3) would dispatch the wrong section.
|
||||
section_key = RETURNING_USER_MENU_SECTION_KEYS[choice - 3]
|
||||
# so positional indexing (choice - 2) would dispatch the wrong section.
|
||||
section_key = RETURNING_USER_MENU_SECTION_KEYS[choice - 2]
|
||||
section = next((s for s in SETUP_SECTIONS if s[0] == section_key), None)
|
||||
if section:
|
||||
_, label, func = section
|
||||
|
||||
@@ -31,6 +31,7 @@ PLATFORMS = {
|
||||
"dingtalk": "💬 DingTalk",
|
||||
"feishu": "🪽 Feishu",
|
||||
"wecom": "💬 WeCom",
|
||||
"weixin": "💬 Weixin",
|
||||
"webhook": "🔗 Webhook",
|
||||
}
|
||||
|
||||
|
||||
+27
-24
@@ -151,7 +151,8 @@ def do_search(query: str, source: str = "all", limit: int = 10,
|
||||
|
||||
auth = GitHubAuth()
|
||||
sources = create_source_router(auth)
|
||||
results = unified_search(query, sources, source_filter=source, limit=limit)
|
||||
with c.status("[bold]Searching registries..."):
|
||||
results = unified_search(query, sources, source_filter=source, limit=limit)
|
||||
|
||||
if not results:
|
||||
c.print("[dim]No skills found matching your query.[/]\n")
|
||||
@@ -187,7 +188,7 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
|
||||
Official skills are always shown first, regardless of source filter.
|
||||
"""
|
||||
from tools.skills_hub import (
|
||||
GitHubAuth, create_source_router,
|
||||
GitHubAuth, create_source_router, parallel_search_sources,
|
||||
)
|
||||
|
||||
# Clamp page_size to safe range
|
||||
@@ -198,27 +199,23 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
|
||||
auth = GitHubAuth()
|
||||
sources = create_source_router(auth)
|
||||
|
||||
# Collect results from all (or filtered) sources
|
||||
# Use empty query to get everything; per-source limits prevent overload
|
||||
# Collect results from all (or filtered) sources in parallel.
|
||||
# Per-source limits are generous — parallelism + 30s timeout cap prevents hangs.
|
||||
_TRUST_RANK = {"builtin": 3, "trusted": 2, "community": 1}
|
||||
_PER_SOURCE_LIMIT = {"official": 100, "skills-sh": 100, "well-known": 25, "github": 100, "clawhub": 50,
|
||||
"claude-marketplace": 50, "lobehub": 50}
|
||||
_PER_SOURCE_LIMIT = {
|
||||
"official": 200, "skills-sh": 200, "well-known": 50,
|
||||
"github": 200, "clawhub": 500, "claude-marketplace": 100,
|
||||
"lobehub": 500,
|
||||
}
|
||||
|
||||
all_results: list = []
|
||||
source_counts: dict = {}
|
||||
|
||||
for src in sources:
|
||||
sid = src.source_id()
|
||||
if source != "all" and sid != source and sid != "official":
|
||||
# Always include official source for the "first" placement
|
||||
continue
|
||||
try:
|
||||
limit = _PER_SOURCE_LIMIT.get(sid, 50)
|
||||
results = src.search("", limit=limit)
|
||||
source_counts[sid] = len(results)
|
||||
all_results.extend(results)
|
||||
except Exception:
|
||||
continue
|
||||
with c.status("[bold]Fetching skills from registries..."):
|
||||
all_results, source_counts, timed_out = parallel_search_sources(
|
||||
sources,
|
||||
query="",
|
||||
per_source_limits=_PER_SOURCE_LIMIT,
|
||||
source_filter=source,
|
||||
overall_timeout=30,
|
||||
)
|
||||
|
||||
if not all_results:
|
||||
c.print("[dim]No skills found in the Skills Hub.[/]\n")
|
||||
@@ -252,8 +249,11 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
|
||||
|
||||
# Build header
|
||||
source_label = f"— {source}" if source != "all" else "— all sources"
|
||||
loaded_label = f"{total} skills loaded"
|
||||
if timed_out:
|
||||
loaded_label += f", {len(timed_out)} source(s) still loading"
|
||||
c.print(f"\n[bold]Skills Hub — Browse {source_label}[/]"
|
||||
f" [dim]({total} skills, page {page}/{total_pages})[/]")
|
||||
f" [dim]({loaded_label}, page {page}/{total_pages})[/]")
|
||||
if official_count > 0 and page == 1:
|
||||
c.print(f"[bright_cyan]★ {official_count} official optional skill(s) from Nous Research[/]")
|
||||
c.print()
|
||||
@@ -300,8 +300,11 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
|
||||
parts = [f"{sid}: {ct}" for sid, ct in sorted(source_counts.items())]
|
||||
c.print(f" [dim]Sources: {', '.join(parts)}[/]")
|
||||
|
||||
c.print("[dim]Use: hermes skills inspect <identifier> to preview, "
|
||||
"hermes skills install <identifier> to install[/]\n")
|
||||
if timed_out:
|
||||
c.print(f" [yellow]⚡ Slow sources skipped: {', '.join(timed_out)} "
|
||||
f"— run again for cached results[/]")
|
||||
|
||||
c.print("[dim]Tip: 'hermes skills search <query>' searches deeper across all registries[/]\n")
|
||||
|
||||
|
||||
def do_install(identifier: str, category: str = "", force: bool = False,
|
||||
|
||||
+24
-2
@@ -79,6 +79,9 @@ def _effective_provider_label() -> str:
|
||||
return provider_label(effective)
|
||||
|
||||
|
||||
from hermes_constants import is_termux as _is_termux
|
||||
|
||||
|
||||
def show_status(args):
|
||||
"""Show status of all Hermes Agent components."""
|
||||
show_all = getattr(args, 'all', False)
|
||||
@@ -302,6 +305,7 @@ def show_status(args):
|
||||
"DingTalk": ("DINGTALK_CLIENT_ID", None),
|
||||
"Feishu": ("FEISHU_APP_ID", "FEISHU_HOME_CHANNEL"),
|
||||
"WeCom": ("WECOM_BOT_ID", "WECOM_HOME_CHANNEL"),
|
||||
"Weixin": ("WEIXIN_ACCOUNT_ID", "WEIXIN_HOME_CHANNEL"),
|
||||
"BlueBubbles": ("BLUEBUBBLES_SERVER_URL", "BLUEBUBBLES_HOME_CHANNEL"),
|
||||
}
|
||||
|
||||
@@ -325,7 +329,25 @@ def show_status(args):
|
||||
print()
|
||||
print(color("◆ Gateway Service", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
if sys.platform.startswith('linux'):
|
||||
if _is_termux():
|
||||
try:
|
||||
from hermes_cli.gateway import find_gateway_pids
|
||||
gateway_pids = find_gateway_pids()
|
||||
except Exception:
|
||||
gateway_pids = []
|
||||
is_running = bool(gateway_pids)
|
||||
print(f" Status: {check_mark(is_running)} {'running' if is_running else 'stopped'}")
|
||||
print(" Manager: Termux / manual process")
|
||||
if gateway_pids:
|
||||
rendered = ", ".join(str(pid) for pid in gateway_pids[:3])
|
||||
if len(gateway_pids) > 3:
|
||||
rendered += ", ..."
|
||||
print(f" PID(s): {rendered}")
|
||||
else:
|
||||
print(" Start with: hermes gateway")
|
||||
print(" Note: Android may stop background jobs when Termux is suspended")
|
||||
|
||||
elif sys.platform.startswith('linux'):
|
||||
try:
|
||||
from hermes_cli.gateway import get_service_name
|
||||
_gw_svc = get_service_name()
|
||||
@@ -339,7 +361,7 @@ def show_status(args):
|
||||
timeout=5
|
||||
)
|
||||
is_active = result.stdout.strip() == "active"
|
||||
except subprocess.TimeoutExpired:
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||
is_active = False
|
||||
print(f" Status: {check_mark(is_active)} {'running' if is_active else 'stopped'}")
|
||||
print(" Manager: systemd (user)")
|
||||
|
||||
@@ -133,6 +133,7 @@ PLATFORMS = {
|
||||
"dingtalk": {"label": "💬 DingTalk", "default_toolset": "hermes-dingtalk"},
|
||||
"feishu": {"label": "🪽 Feishu", "default_toolset": "hermes-feishu"},
|
||||
"wecom": {"label": "💬 WeCom", "default_toolset": "hermes-wecom"},
|
||||
"weixin": {"label": "💬 Weixin", "default_toolset": "hermes-weixin"},
|
||||
"api_server": {"label": "🌐 API Server", "default_toolset": "hermes-api-server"},
|
||||
"mattermost": {"label": "💬 Mattermost", "default_toolset": "hermes-mattermost"},
|
||||
"webhook": {"label": "🔗 Webhook", "default_toolset": "hermes-webhook"},
|
||||
@@ -720,6 +721,8 @@ def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
|
||||
return
|
||||
|
||||
curses.wrapper(_curses_menu)
|
||||
from hermes_cli.curses_ui import flush_stdin
|
||||
flush_stdin()
|
||||
return result_holder[0]
|
||||
|
||||
except Exception:
|
||||
|
||||
@@ -6,6 +6,8 @@ Provides options for:
|
||||
- Keep data: Remove code but keep ~/.hermes/ (configs, sessions, logs)
|
||||
"""
|
||||
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
@@ -122,6 +124,10 @@ def uninstall_gateway_service():
|
||||
|
||||
if platform.system() != "Linux":
|
||||
return False
|
||||
|
||||
prefix = os.getenv("PREFIX", "")
|
||||
if os.getenv("TERMUX_VERSION") or "com.termux/files/usr" in prefix:
|
||||
return False
|
||||
|
||||
try:
|
||||
from hermes_cli.gateway import get_service_name
|
||||
|
||||
+77
-6
@@ -17,6 +17,45 @@ def get_hermes_home() -> Path:
|
||||
return Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
|
||||
|
||||
def get_default_hermes_root() -> Path:
|
||||
"""Return the root Hermes directory for profile-level operations.
|
||||
|
||||
In standard deployments this is ``~/.hermes``.
|
||||
|
||||
In Docker or custom deployments where ``HERMES_HOME`` points outside
|
||||
``~/.hermes`` (e.g. ``/opt/data``), returns ``HERMES_HOME`` directly
|
||||
— that IS the root.
|
||||
|
||||
In profile mode where ``HERMES_HOME`` is ``<root>/profiles/<name>``,
|
||||
returns ``<root>`` so that ``profile list`` can see all profiles.
|
||||
Works both for standard (``~/.hermes/profiles/coder``) and Docker
|
||||
(``/opt/data/profiles/coder``) layouts.
|
||||
|
||||
Import-safe — no dependencies beyond stdlib.
|
||||
"""
|
||||
native_home = Path.home() / ".hermes"
|
||||
env_home = os.environ.get("HERMES_HOME", "")
|
||||
if not env_home:
|
||||
return native_home
|
||||
env_path = Path(env_home)
|
||||
try:
|
||||
env_path.resolve().relative_to(native_home.resolve())
|
||||
# HERMES_HOME is under ~/.hermes (normal or profile mode)
|
||||
return native_home
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Docker / custom deployment.
|
||||
# Check if this is a profile path: <root>/profiles/<name>
|
||||
# If the immediate parent dir is named "profiles", the root is
|
||||
# the grandparent — this covers Docker profiles correctly.
|
||||
if env_path.parent.name == "profiles":
|
||||
return env_path.parent.parent
|
||||
|
||||
# Not a profile path — HERMES_HOME itself is the root
|
||||
return env_path
|
||||
|
||||
|
||||
def get_optional_skills_dir(default: Path | None = None) -> Path:
|
||||
"""Return the optional-skills directory, honoring package-manager wrappers.
|
||||
|
||||
@@ -72,13 +111,39 @@ def display_hermes_home() -> str:
|
||||
return str(home)
|
||||
|
||||
|
||||
VALID_REASONING_EFFORTS = ("xhigh", "high", "medium", "low", "minimal")
|
||||
def get_subprocess_home() -> str | None:
|
||||
"""Return a per-profile HOME directory for subprocesses, or None.
|
||||
|
||||
When ``{HERMES_HOME}/home/`` exists on disk, subprocesses should use it
|
||||
as ``HOME`` so system tools (git, ssh, gh, npm …) write their configs
|
||||
inside the Hermes data directory instead of the OS-level ``/root`` or
|
||||
``~/``. This provides:
|
||||
|
||||
* **Docker persistence** — tool configs land inside the persistent volume.
|
||||
* **Profile isolation** — each profile gets its own git identity, SSH
|
||||
keys, gh tokens, etc.
|
||||
|
||||
The Python process's own ``os.environ["HOME"]`` and ``Path.home()`` are
|
||||
**never** modified — only subprocess environments should inject this value.
|
||||
Activation is directory-based: if the ``home/`` subdirectory doesn't
|
||||
exist, returns ``None`` and behavior is unchanged.
|
||||
"""
|
||||
hermes_home = os.getenv("HERMES_HOME")
|
||||
if not hermes_home:
|
||||
return None
|
||||
profile_home = os.path.join(hermes_home, "home")
|
||||
if os.path.isdir(profile_home):
|
||||
return profile_home
|
||||
return None
|
||||
|
||||
|
||||
VALID_REASONING_EFFORTS = ("minimal", "low", "medium", "high", "xhigh")
|
||||
|
||||
|
||||
def parse_reasoning_effort(effort: str) -> dict | None:
|
||||
"""Parse a reasoning effort level into a config dict.
|
||||
|
||||
Valid levels: "xhigh", "high", "medium", "low", "minimal", "none".
|
||||
Valid levels: "none", "minimal", "low", "medium", "high", "xhigh".
|
||||
Returns None when the input is empty or unrecognized (caller uses default).
|
||||
Returns {"enabled": False} for "none".
|
||||
Returns {"enabled": True, "effort": <level>} for valid effort levels.
|
||||
@@ -93,13 +158,19 @@ def parse_reasoning_effort(effort: str) -> dict | None:
|
||||
return None
|
||||
|
||||
|
||||
def is_termux() -> bool:
|
||||
"""Return True when running inside a Termux (Android) environment.
|
||||
|
||||
Checks ``TERMUX_VERSION`` (set by Termux) or the Termux-specific
|
||||
``PREFIX`` path. Import-safe — no heavy deps.
|
||||
"""
|
||||
prefix = os.getenv("PREFIX", "")
|
||||
return bool(os.getenv("TERMUX_VERSION") or "com.termux/files/usr" in prefix)
|
||||
|
||||
|
||||
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
|
||||
OPENROUTER_MODELS_URL = f"{OPENROUTER_BASE_URL}/models"
|
||||
OPENROUTER_CHAT_URL = f"{OPENROUTER_BASE_URL}/chat/completions"
|
||||
|
||||
AI_GATEWAY_BASE_URL = "https://ai-gateway.vercel.sh/v1"
|
||||
AI_GATEWAY_MODELS_URL = f"{AI_GATEWAY_BASE_URL}/models"
|
||||
AI_GATEWAY_CHAT_URL = f"{AI_GATEWAY_BASE_URL}/chat/completions"
|
||||
|
||||
NOUS_API_BASE_URL = "https://inference-api.nousresearch.com/v1"
|
||||
NOUS_API_CHAT_URL = f"{NOUS_API_BASE_URL}/chat/completions"
|
||||
|
||||
+34
-1
@@ -13,6 +13,7 @@ secrets are never written to disk.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
@@ -177,6 +178,38 @@ def setup_verbose_logging() -> None:
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _ManagedRotatingFileHandler(RotatingFileHandler):
|
||||
"""RotatingFileHandler that ensures group-writable perms in managed mode.
|
||||
|
||||
In managed mode (NixOS), the stateDir uses setgid (2770) so new files
|
||||
inherit the hermes group. However, both _open() (initial creation) and
|
||||
doRollover() create files via open(), which uses the process umask —
|
||||
typically 0022, producing 0644. This subclass applies chmod 0660 after
|
||||
both operations so the gateway and interactive users can share log files.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
from hermes_cli.config import is_managed
|
||||
self._managed = is_managed()
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _chmod_if_managed(self):
|
||||
if self._managed:
|
||||
try:
|
||||
os.chmod(self.baseFilename, 0o660)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _open(self):
|
||||
stream = super()._open()
|
||||
self._chmod_if_managed()
|
||||
return stream
|
||||
|
||||
def doRollover(self):
|
||||
super().doRollover()
|
||||
self._chmod_if_managed()
|
||||
|
||||
|
||||
def _add_rotating_handler(
|
||||
logger: logging.Logger,
|
||||
path: Path,
|
||||
@@ -198,7 +231,7 @@ def _add_rotating_handler(
|
||||
return # already attached
|
||||
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
handler = RotatingFileHandler(
|
||||
handler = _ManagedRotatingFileHandler(
|
||||
str(path), maxBytes=max_bytes, backupCount=backup_count,
|
||||
)
|
||||
handler.setLevel(level)
|
||||
|
||||
+8
-70
@@ -520,72 +520,6 @@ class SessionDB:
|
||||
)
|
||||
self._execute_write(_do)
|
||||
|
||||
def set_token_counts(
|
||||
self,
|
||||
session_id: str,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
model: str = None,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_write_tokens: int = 0,
|
||||
reasoning_tokens: int = 0,
|
||||
estimated_cost_usd: Optional[float] = None,
|
||||
actual_cost_usd: Optional[float] = None,
|
||||
cost_status: Optional[str] = None,
|
||||
cost_source: Optional[str] = None,
|
||||
pricing_version: Optional[str] = None,
|
||||
billing_provider: Optional[str] = None,
|
||||
billing_base_url: Optional[str] = None,
|
||||
billing_mode: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Set token counters to absolute values (not increment).
|
||||
|
||||
Use this when the caller provides cumulative totals from a completed
|
||||
conversation run (e.g. the gateway, where the cached agent's
|
||||
session_prompt_tokens already reflects the running total).
|
||||
"""
|
||||
def _do(conn):
|
||||
conn.execute(
|
||||
"""UPDATE sessions SET
|
||||
input_tokens = ?,
|
||||
output_tokens = ?,
|
||||
cache_read_tokens = ?,
|
||||
cache_write_tokens = ?,
|
||||
reasoning_tokens = ?,
|
||||
estimated_cost_usd = ?,
|
||||
actual_cost_usd = CASE
|
||||
WHEN ? IS NULL THEN actual_cost_usd
|
||||
ELSE ?
|
||||
END,
|
||||
cost_status = COALESCE(?, cost_status),
|
||||
cost_source = COALESCE(?, cost_source),
|
||||
pricing_version = COALESCE(?, pricing_version),
|
||||
billing_provider = COALESCE(billing_provider, ?),
|
||||
billing_base_url = COALESCE(billing_base_url, ?),
|
||||
billing_mode = COALESCE(billing_mode, ?),
|
||||
model = COALESCE(model, ?)
|
||||
WHERE id = ?""",
|
||||
(
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens,
|
||||
reasoning_tokens,
|
||||
estimated_cost_usd,
|
||||
actual_cost_usd,
|
||||
actual_cost_usd,
|
||||
cost_status,
|
||||
cost_source,
|
||||
pricing_version,
|
||||
billing_provider,
|
||||
billing_base_url,
|
||||
billing_mode,
|
||||
model,
|
||||
session_id,
|
||||
),
|
||||
)
|
||||
self._execute_write(_do)
|
||||
|
||||
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a session by ID."""
|
||||
with self._lock:
|
||||
@@ -944,7 +878,8 @@ class SessionDB:
|
||||
try:
|
||||
msg["tool_calls"] = json.loads(msg["tool_calls"])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
logger.warning("Failed to deserialize tool_calls in get_messages, falling back to []")
|
||||
msg["tool_calls"] = []
|
||||
result.append(msg)
|
||||
return result
|
||||
|
||||
@@ -972,7 +907,8 @@ class SessionDB:
|
||||
try:
|
||||
msg["tool_calls"] = json.loads(row["tool_calls"])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
logger.warning("Failed to deserialize tool_calls in conversation replay, falling back to []")
|
||||
msg["tool_calls"] = []
|
||||
# Restore reasoning fields on assistant messages so providers
|
||||
# that replay reasoning (OpenRouter, OpenAI, Nous) receive
|
||||
# coherent multi-turn reasoning context.
|
||||
@@ -983,12 +919,14 @@ class SessionDB:
|
||||
try:
|
||||
msg["reasoning_details"] = json.loads(row["reasoning_details"])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
logger.warning("Failed to deserialize reasoning_details, falling back to None")
|
||||
msg["reasoning_details"] = None
|
||||
if row["codex_reasoning_items"]:
|
||||
try:
|
||||
msg["codex_reasoning_items"] = json.loads(row["codex_reasoning_items"])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
logger.warning("Failed to deserialize codex_reasoning_items, falling back to None")
|
||||
msg["codex_reasoning_items"] = None
|
||||
messages.append(msg)
|
||||
return messages
|
||||
|
||||
|
||||
@@ -89,13 +89,6 @@ def get_timezone() -> Optional[ZoneInfo]:
|
||||
return _cached_tz
|
||||
|
||||
|
||||
def get_timezone_name() -> str:
|
||||
"""Return the IANA name of the configured timezone, or empty string."""
|
||||
if not _cache_resolved:
|
||||
get_timezone() # populates cache
|
||||
return _cached_tz_name or ""
|
||||
|
||||
|
||||
def now() -> datetime:
|
||||
"""
|
||||
Return the current time as a timezone-aware datetime.
|
||||
@@ -110,9 +103,3 @@ def now() -> datetime:
|
||||
return datetime.now().astimezone()
|
||||
|
||||
|
||||
def reset_cache() -> None:
|
||||
"""Clear the cached timezone. Used by tests and after config changes."""
|
||||
global _cached_tz, _cached_tz_name, _cache_resolved
|
||||
_cached_tz = None
|
||||
_cached_tz_name = None
|
||||
_cache_resolved = False
|
||||
|
||||
+26
-4
@@ -560,10 +560,14 @@
|
||||
# ── Directories ───────────────────────────────────────────────────
|
||||
{
|
||||
systemd.tmpfiles.rules = [
|
||||
"d ${cfg.stateDir} 0750 ${cfg.user} ${cfg.group} - -"
|
||||
"d ${cfg.stateDir}/.hermes 0750 ${cfg.user} ${cfg.group} - -"
|
||||
"d ${cfg.stateDir} 2770 ${cfg.user} ${cfg.group} - -"
|
||||
"d ${cfg.stateDir}/.hermes 2770 ${cfg.user} ${cfg.group} - -"
|
||||
"d ${cfg.stateDir}/.hermes/cron 2770 ${cfg.user} ${cfg.group} - -"
|
||||
"d ${cfg.stateDir}/.hermes/sessions 2770 ${cfg.user} ${cfg.group} - -"
|
||||
"d ${cfg.stateDir}/.hermes/logs 2770 ${cfg.user} ${cfg.group} - -"
|
||||
"d ${cfg.stateDir}/.hermes/memories 2770 ${cfg.user} ${cfg.group} - -"
|
||||
"d ${cfg.stateDir}/home 0750 ${cfg.user} ${cfg.group} - -"
|
||||
"d ${cfg.workingDirectory} 0750 ${cfg.user} ${cfg.group} - -"
|
||||
"d ${cfg.workingDirectory} 2770 ${cfg.user} ${cfg.group} - -"
|
||||
];
|
||||
}
|
||||
|
||||
@@ -575,7 +579,21 @@
|
||||
mkdir -p ${cfg.stateDir}/home
|
||||
mkdir -p ${cfg.workingDirectory}
|
||||
chown ${cfg.user}:${cfg.group} ${cfg.stateDir} ${cfg.stateDir}/.hermes ${cfg.stateDir}/home ${cfg.workingDirectory}
|
||||
chmod 0750 ${cfg.stateDir} ${cfg.stateDir}/.hermes ${cfg.stateDir}/home ${cfg.workingDirectory}
|
||||
chmod 2770 ${cfg.stateDir} ${cfg.stateDir}/.hermes ${cfg.workingDirectory}
|
||||
chmod 0750 ${cfg.stateDir}/home
|
||||
|
||||
# Create subdirs, set setgid + group-writable, migrate existing files.
|
||||
# Nix-managed files (config.yaml, .env, .managed) stay 0640/0644.
|
||||
find ${cfg.stateDir}/.hermes -maxdepth 1 \
|
||||
\( -name "*.db" -o -name "*.db-wal" -o -name "*.db-shm" -o -name "SOUL.md" \) \
|
||||
-exec chmod g+rw {} + 2>/dev/null || true
|
||||
for _subdir in cron sessions logs memories; do
|
||||
mkdir -p "${cfg.stateDir}/.hermes/$_subdir"
|
||||
chown ${cfg.user}:${cfg.group} "${cfg.stateDir}/.hermes/$_subdir"
|
||||
chmod 2770 "${cfg.stateDir}/.hermes/$_subdir"
|
||||
find "${cfg.stateDir}/.hermes/$_subdir" -type f \
|
||||
-exec chmod g+rw {} + 2>/dev/null || true
|
||||
done
|
||||
|
||||
# Merge Nix settings into existing config.yaml.
|
||||
# Preserves user-added keys (skills, streaming, etc.); Nix keys win.
|
||||
@@ -662,6 +680,10 @@ HERMES_NIX_ENV_EOF
|
||||
Restart = cfg.restart;
|
||||
RestartSec = cfg.restartSec;
|
||||
|
||||
# Shared-state: files created by the gateway should be group-writable
|
||||
# so interactive users in the hermes group can read/write them.
|
||||
UMask = "0007";
|
||||
|
||||
# Hardening
|
||||
NoNewPrivileges = true;
|
||||
ProtectSystem = "strict";
|
||||
|
||||
@@ -63,6 +63,17 @@ homeassistant = ["aiohttp>=3.9.0,<4"]
|
||||
sms = ["aiohttp>=3.9.0,<4"]
|
||||
acp = ["agent-client-protocol>=0.9.0,<1.0"]
|
||||
mistral = ["mistralai>=2.3.0,<3"]
|
||||
termux = [
|
||||
# Tested Android / Termux path: keeps the core CLI feature-rich while
|
||||
# avoiding extras that currently depend on non-Android wheels (notably
|
||||
# faster-whisper -> ctranslate2 via the voice extra).
|
||||
"hermes-agent[cron]",
|
||||
"hermes-agent[cli]",
|
||||
"hermes-agent[pty]",
|
||||
"hermes-agent[mcp]",
|
||||
"hermes-agent[honcho]",
|
||||
"hermes-agent[acp]",
|
||||
]
|
||||
dingtalk = ["dingtalk-stream>=0.1.0,<1"]
|
||||
feishu = ["lark-oapi>=1.5.3,<2"]
|
||||
rl = [
|
||||
|
||||
+369
-78
@@ -359,8 +359,9 @@ def _sanitize_surrogates(text: str) -> str:
|
||||
def _sanitize_messages_surrogates(messages: list) -> bool:
|
||||
"""Sanitize surrogate characters from all string content in a messages list.
|
||||
|
||||
Walks message dicts in-place. Returns True if any surrogates were found
|
||||
and replaced, False otherwise.
|
||||
Walks message dicts in-place. Returns True if any surrogates were found
|
||||
and replaced, False otherwise. Covers content/text, name, and tool call
|
||||
metadata/arguments so retries don't fail on a non-content field.
|
||||
"""
|
||||
found = False
|
||||
for msg in messages:
|
||||
@@ -377,6 +378,88 @@ def _sanitize_messages_surrogates(messages: list) -> bool:
|
||||
if isinstance(text, str) and _SURROGATE_RE.search(text):
|
||||
part["text"] = _SURROGATE_RE.sub('\ufffd', text)
|
||||
found = True
|
||||
name = msg.get("name")
|
||||
if isinstance(name, str) and _SURROGATE_RE.search(name):
|
||||
msg["name"] = _SURROGATE_RE.sub('\ufffd', name)
|
||||
found = True
|
||||
tool_calls = msg.get("tool_calls")
|
||||
if isinstance(tool_calls, list):
|
||||
for tc in tool_calls:
|
||||
if not isinstance(tc, dict):
|
||||
continue
|
||||
tc_id = tc.get("id")
|
||||
if isinstance(tc_id, str) and _SURROGATE_RE.search(tc_id):
|
||||
tc["id"] = _SURROGATE_RE.sub('\ufffd', tc_id)
|
||||
found = True
|
||||
fn = tc.get("function")
|
||||
if isinstance(fn, dict):
|
||||
fn_name = fn.get("name")
|
||||
if isinstance(fn_name, str) and _SURROGATE_RE.search(fn_name):
|
||||
fn["name"] = _SURROGATE_RE.sub('\ufffd', fn_name)
|
||||
found = True
|
||||
fn_args = fn.get("arguments")
|
||||
if isinstance(fn_args, str) and _SURROGATE_RE.search(fn_args):
|
||||
fn["arguments"] = _SURROGATE_RE.sub('\ufffd', fn_args)
|
||||
found = True
|
||||
return found
|
||||
|
||||
|
||||
def _strip_non_ascii(text: str) -> str:
|
||||
"""Remove non-ASCII characters, replacing with closest ASCII equivalent or removing.
|
||||
|
||||
Used as a last resort when the system encoding is ASCII and can't handle
|
||||
any non-ASCII characters (e.g. LANG=C on Chromebooks).
|
||||
"""
|
||||
return text.encode('ascii', errors='ignore').decode('ascii')
|
||||
|
||||
|
||||
def _sanitize_messages_non_ascii(messages: list) -> bool:
|
||||
"""Strip non-ASCII characters from all string content in a messages list.
|
||||
|
||||
This is a last-resort recovery for systems with ASCII-only encoding
|
||||
(LANG=C, Chromebooks, minimal containers). Returns True if any
|
||||
non-ASCII content was found and sanitized.
|
||||
"""
|
||||
found = False
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
# Sanitize content (string)
|
||||
content = msg.get("content")
|
||||
if isinstance(content, str):
|
||||
sanitized = _strip_non_ascii(content)
|
||||
if sanitized != content:
|
||||
msg["content"] = sanitized
|
||||
found = True
|
||||
elif isinstance(content, list):
|
||||
for part in content:
|
||||
if isinstance(part, dict):
|
||||
text = part.get("text")
|
||||
if isinstance(text, str):
|
||||
sanitized = _strip_non_ascii(text)
|
||||
if sanitized != text:
|
||||
part["text"] = sanitized
|
||||
found = True
|
||||
# Sanitize name field (can contain non-ASCII in tool results)
|
||||
name = msg.get("name")
|
||||
if isinstance(name, str):
|
||||
sanitized = _strip_non_ascii(name)
|
||||
if sanitized != name:
|
||||
msg["name"] = sanitized
|
||||
found = True
|
||||
# Sanitize tool_calls
|
||||
tool_calls = msg.get("tool_calls")
|
||||
if isinstance(tool_calls, list):
|
||||
for tc in tool_calls:
|
||||
if isinstance(tc, dict):
|
||||
fn = tc.get("function", {})
|
||||
if isinstance(fn, dict):
|
||||
fn_args = fn.get("arguments")
|
||||
if isinstance(fn_args, str):
|
||||
sanitized = _strip_non_ascii(fn_args)
|
||||
if sanitized != fn_args:
|
||||
fn["arguments"] = sanitized
|
||||
found = True
|
||||
return found
|
||||
|
||||
|
||||
@@ -500,6 +583,8 @@ class AIAgent:
|
||||
status_callback: callable = None,
|
||||
max_tokens: int = None,
|
||||
reasoning_config: Dict[str, Any] = None,
|
||||
service_tier: str = None,
|
||||
request_overrides: Dict[str, Any] = None,
|
||||
prefill_messages: List[Dict[str, Any]] = None,
|
||||
platform: str = None,
|
||||
user_id: str = None,
|
||||
@@ -604,6 +689,17 @@ class AIAgent:
|
||||
else:
|
||||
self.api_mode = "chat_completions"
|
||||
|
||||
try:
|
||||
from hermes_cli.model_normalize import (
|
||||
_AGGREGATOR_PROVIDERS,
|
||||
normalize_model_for_provider,
|
||||
)
|
||||
|
||||
if self.provider not in _AGGREGATOR_PROVIDERS:
|
||||
self.model = normalize_model_for_provider(self.model, self.provider)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Direct OpenAI sessions use the Responses API path. GPT-5.x tool
|
||||
# calls with reasoning are rejected on /v1/chat/completions, and
|
||||
# Hermes is a tool-using client by default.
|
||||
@@ -622,9 +718,9 @@ class AIAgent:
|
||||
self.tool_progress_callback = tool_progress_callback
|
||||
self.tool_start_callback = tool_start_callback
|
||||
self.tool_complete_callback = tool_complete_callback
|
||||
self.suppress_status_output = False
|
||||
self.thinking_callback = thinking_callback
|
||||
self.reasoning_callback = reasoning_callback
|
||||
self._reasoning_deltas_fired = False # Set by _fire_reasoning_delta, reset per API call
|
||||
self.clarify_callback = clarify_callback
|
||||
self.step_callback = step_callback
|
||||
self.stream_delta_callback = stream_delta_callback
|
||||
@@ -661,6 +757,8 @@ class AIAgent:
|
||||
# Model response configuration
|
||||
self.max_tokens = max_tokens # None = use model default
|
||||
self.reasoning_config = reasoning_config # None = use default (medium for OpenRouter)
|
||||
self.service_tier = service_tier
|
||||
self.request_overrides = dict(request_overrides or {})
|
||||
self.prefill_messages = prefill_messages or [] # Prefilled conversation turns
|
||||
|
||||
# Anthropic prompt caching: auto-enabled for Claude models via OpenRouter.
|
||||
@@ -789,7 +887,7 @@ class AIAgent:
|
||||
client_kwargs["default_headers"] = copilot_default_headers()
|
||||
elif "api.kimi.com" in effective_base.lower():
|
||||
client_kwargs["default_headers"] = {
|
||||
"User-Agent": "KimiCLI/1.3",
|
||||
"User-Agent": "KimiCLI/1.30.0",
|
||||
}
|
||||
elif "portal.qwen.ai" in effective_base.lower():
|
||||
client_kwargs["default_headers"] = _qwen_portal_headers()
|
||||
@@ -849,6 +947,7 @@ class AIAgent:
|
||||
client_kwargs["default_headers"] = headers
|
||||
|
||||
self.api_key = client_kwargs.get("api_key", "")
|
||||
self.base_url = client_kwargs.get("base_url", self.base_url)
|
||||
try:
|
||||
self.client = self._create_openai_client(client_kwargs, reason="agent_init", shared=True)
|
||||
if not self.quiet_mode:
|
||||
@@ -1145,6 +1244,9 @@ class AIAgent:
|
||||
except (TypeError, ValueError):
|
||||
_config_context_length = None
|
||||
|
||||
# Store for reuse in switch_model (so config override persists across model switches)
|
||||
self._config_context_length = _config_context_length
|
||||
|
||||
# Check custom_providers per-model context_length
|
||||
if _config_context_length is None:
|
||||
_custom_providers = _agent_cfg.get("custom_providers")
|
||||
@@ -1299,7 +1401,6 @@ class AIAgent:
|
||||
if hasattr(self, "context_compressor") and self.context_compressor:
|
||||
self.context_compressor.last_prompt_tokens = 0
|
||||
self.context_compressor.last_completion_tokens = 0
|
||||
self.context_compressor.last_total_tokens = 0
|
||||
self.context_compressor.compression_count = 0
|
||||
self.context_compressor._context_probed = False
|
||||
self.context_compressor._context_probe_persistable = False
|
||||
@@ -1383,6 +1484,7 @@ class AIAgent:
|
||||
base_url=self.base_url,
|
||||
api_key=self.api_key,
|
||||
provider=self.provider,
|
||||
config_context_length=getattr(self, "_config_context_length", None),
|
||||
)
|
||||
self.context_compressor.model = self.model
|
||||
self.context_compressor.base_url = self.base_url
|
||||
@@ -1460,7 +1562,14 @@ class AIAgent:
|
||||
After the main response has been delivered and the remaining tool
|
||||
calls are post-response housekeeping (``_mute_post_response``),
|
||||
all non-forced output is suppressed.
|
||||
|
||||
``suppress_status_output`` is a stricter CLI automation mode used by
|
||||
parseable single-query flows such as ``hermes chat -q``. In that mode,
|
||||
all status/diagnostic prints routed through ``_vprint`` are suppressed
|
||||
so stdout stays machine-readable.
|
||||
"""
|
||||
if getattr(self, "suppress_status_output", False):
|
||||
return
|
||||
if not force and getattr(self, "_mute_post_response", False):
|
||||
return
|
||||
if not force and self._has_stream_consumers() and not self._executing_tools:
|
||||
@@ -1486,6 +1595,17 @@ class AIAgent:
|
||||
except (AttributeError, ValueError, OSError):
|
||||
return False
|
||||
|
||||
def _should_emit_quiet_tool_messages(self) -> bool:
|
||||
"""Return True when quiet-mode tool summaries should print directly.
|
||||
|
||||
When the caller provides ``tool_progress_callback`` (for example the CLI
|
||||
TUI or a gateway progress renderer), that callback owns progress display.
|
||||
Emitting quiet-mode summary lines here duplicates progress and leaks tool
|
||||
previews into flows that are expected to stay silent, such as
|
||||
``hermes chat -q``.
|
||||
"""
|
||||
return self.quiet_mode and not self.tool_progress_callback
|
||||
|
||||
def _emit_status(self, message: str) -> None:
|
||||
"""Emit a lifecycle status message to both CLI and gateway channels.
|
||||
|
||||
@@ -2901,7 +3021,7 @@ class AIAgent:
|
||||
|
||||
@staticmethod
|
||||
def _cap_delegate_task_calls(tool_calls: list) -> list:
|
||||
"""Truncate excess delegate_task calls to MAX_CONCURRENT_CHILDREN.
|
||||
"""Truncate excess delegate_task calls to max_concurrent_children.
|
||||
|
||||
The delegate_tool caps the task list inside a single call, but the
|
||||
model can emit multiple separate delegate_task tool_calls in one
|
||||
@@ -2909,23 +3029,24 @@ class AIAgent:
|
||||
|
||||
Returns the original list if no truncation was needed.
|
||||
"""
|
||||
from tools.delegate_tool import MAX_CONCURRENT_CHILDREN
|
||||
from tools.delegate_tool import _get_max_concurrent_children
|
||||
max_children = _get_max_concurrent_children()
|
||||
delegate_count = sum(1 for tc in tool_calls if tc.function.name == "delegate_task")
|
||||
if delegate_count <= MAX_CONCURRENT_CHILDREN:
|
||||
if delegate_count <= max_children:
|
||||
return tool_calls
|
||||
kept_delegates = 0
|
||||
truncated = []
|
||||
for tc in tool_calls:
|
||||
if tc.function.name == "delegate_task":
|
||||
if kept_delegates < MAX_CONCURRENT_CHILDREN:
|
||||
if kept_delegates < max_children:
|
||||
truncated.append(tc)
|
||||
kept_delegates += 1
|
||||
else:
|
||||
truncated.append(tc)
|
||||
logger.warning(
|
||||
"Truncated %d excess delegate_task call(s) to enforce "
|
||||
"MAX_CONCURRENT_CHILDREN=%d limit",
|
||||
delegate_count - MAX_CONCURRENT_CHILDREN, MAX_CONCURRENT_CHILDREN,
|
||||
"max_concurrent_children=%d limit",
|
||||
delegate_count - max_children, max_children,
|
||||
)
|
||||
return truncated
|
||||
|
||||
@@ -3324,7 +3445,7 @@ class AIAgent:
|
||||
allowed_keys = {
|
||||
"model", "instructions", "input", "tools", "store",
|
||||
"reasoning", "include", "max_output_tokens", "temperature",
|
||||
"tool_choice", "parallel_tool_calls", "prompt_cache_key",
|
||||
"tool_choice", "parallel_tool_calls", "prompt_cache_key", "service_tier",
|
||||
}
|
||||
normalized: Dict[str, Any] = {
|
||||
"model": model,
|
||||
@@ -3342,6 +3463,9 @@ class AIAgent:
|
||||
include = api_kwargs.get("include")
|
||||
if isinstance(include, list):
|
||||
normalized["include"] = include
|
||||
service_tier = api_kwargs.get("service_tier")
|
||||
if isinstance(service_tier, str) and service_tier.strip():
|
||||
normalized["service_tier"] = service_tier.strip()
|
||||
|
||||
# Pass through max_output_tokens and temperature
|
||||
max_output_tokens = api_kwargs.get("max_output_tokens")
|
||||
@@ -3849,7 +3973,6 @@ class AIAgent:
|
||||
max_stream_retries = 1
|
||||
has_tool_calls = False
|
||||
first_delta_fired = False
|
||||
self._reasoning_deltas_fired = False
|
||||
# Accumulate streamed text so we can recover if get_final_response()
|
||||
# returns empty output (e.g. chatgpt.com backend-api sends
|
||||
# response.incomplete instead of response.completed).
|
||||
@@ -4155,7 +4278,7 @@ class AIAgent:
|
||||
|
||||
self._client_kwargs["default_headers"] = copilot_default_headers()
|
||||
elif "api.kimi.com" in normalized:
|
||||
self._client_kwargs["default_headers"] = {"User-Agent": "KimiCLI/1.3"}
|
||||
self._client_kwargs["default_headers"] = {"User-Agent": "KimiCLI/1.30.0"}
|
||||
elif "portal.qwen.ai" in normalized:
|
||||
self._client_kwargs["default_headers"] = _qwen_portal_headers()
|
||||
else:
|
||||
@@ -4193,49 +4316,80 @@ class AIAgent:
|
||||
*,
|
||||
status_code: Optional[int],
|
||||
has_retried_429: bool,
|
||||
classified_reason: Optional[FailoverReason] = None,
|
||||
error_context: Optional[Dict[str, Any]] = None,
|
||||
) -> tuple[bool, bool]:
|
||||
"""Attempt credential recovery via pool rotation.
|
||||
|
||||
Returns (recovered, has_retried_429).
|
||||
On 429: first occurrence retries same credential (sets flag True).
|
||||
second consecutive 429 rotates to next credential (resets flag).
|
||||
On 402: immediately rotates (billing exhaustion won't resolve with retry).
|
||||
On 401: attempts token refresh before rotating.
|
||||
On rate limits: first occurrence retries same credential (sets flag True).
|
||||
second consecutive failure rotates to next credential.
|
||||
On billing exhaustion: immediately rotates.
|
||||
On auth failures: attempts token refresh before rotating.
|
||||
|
||||
`classified_reason` lets the recovery path honor the structured error
|
||||
classifier instead of relying only on raw HTTP codes. This matters for
|
||||
providers that surface billing/rate-limit/auth conditions under a
|
||||
different status code, such as Anthropic returning HTTP 400 for
|
||||
"out of extra usage".
|
||||
"""
|
||||
pool = self._credential_pool
|
||||
if pool is None or status_code is None:
|
||||
if pool is None:
|
||||
return False, has_retried_429
|
||||
|
||||
if status_code == 402:
|
||||
next_entry = pool.mark_exhausted_and_rotate(status_code=402, error_context=error_context)
|
||||
effective_reason = classified_reason
|
||||
if effective_reason is None:
|
||||
if status_code == 402:
|
||||
effective_reason = FailoverReason.billing
|
||||
elif status_code == 429:
|
||||
effective_reason = FailoverReason.rate_limit
|
||||
elif status_code == 401:
|
||||
effective_reason = FailoverReason.auth
|
||||
|
||||
if effective_reason == FailoverReason.billing:
|
||||
rotate_status = status_code if status_code is not None else 402
|
||||
next_entry = pool.mark_exhausted_and_rotate(status_code=rotate_status, error_context=error_context)
|
||||
if next_entry is not None:
|
||||
logger.info(f"Credential 402 (billing) — rotated to pool entry {getattr(next_entry, 'id', '?')}")
|
||||
logger.info(
|
||||
"Credential %s (billing) — rotated to pool entry %s",
|
||||
rotate_status,
|
||||
getattr(next_entry, "id", "?"),
|
||||
)
|
||||
self._swap_credential(next_entry)
|
||||
return True, False
|
||||
return False, has_retried_429
|
||||
|
||||
if status_code == 429:
|
||||
if effective_reason == FailoverReason.rate_limit:
|
||||
if not has_retried_429:
|
||||
return False, True
|
||||
next_entry = pool.mark_exhausted_and_rotate(status_code=429, error_context=error_context)
|
||||
rotate_status = status_code if status_code is not None else 429
|
||||
next_entry = pool.mark_exhausted_and_rotate(status_code=rotate_status, error_context=error_context)
|
||||
if next_entry is not None:
|
||||
logger.info(f"Credential 429 (rate limit) — rotated to pool entry {getattr(next_entry, 'id', '?')}")
|
||||
logger.info(
|
||||
"Credential %s (rate limit) — rotated to pool entry %s",
|
||||
rotate_status,
|
||||
getattr(next_entry, "id", "?"),
|
||||
)
|
||||
self._swap_credential(next_entry)
|
||||
return True, False
|
||||
return False, True
|
||||
|
||||
if status_code == 401:
|
||||
if effective_reason == FailoverReason.auth:
|
||||
refreshed = pool.try_refresh_current()
|
||||
if refreshed is not None:
|
||||
logger.info(f"Credential 401 — refreshed pool entry {getattr(refreshed, 'id', '?')}")
|
||||
logger.info(f"Credential auth failure — refreshed pool entry {getattr(refreshed, 'id', '?')}")
|
||||
self._swap_credential(refreshed)
|
||||
return True, has_retried_429
|
||||
# Refresh failed — rotate to next credential instead of giving up.
|
||||
# The failed entry is already marked exhausted by try_refresh_current().
|
||||
next_entry = pool.mark_exhausted_and_rotate(status_code=401, error_context=error_context)
|
||||
rotate_status = status_code if status_code is not None else 401
|
||||
next_entry = pool.mark_exhausted_and_rotate(status_code=rotate_status, error_context=error_context)
|
||||
if next_entry is not None:
|
||||
logger.info(f"Credential 401 (refresh failed) — rotated to pool entry {getattr(next_entry, 'id', '?')}")
|
||||
logger.info(
|
||||
"Credential %s (auth refresh failed) — rotated to pool entry %s",
|
||||
rotate_status,
|
||||
getattr(next_entry, "id", "?"),
|
||||
)
|
||||
self._swap_credential(next_entry)
|
||||
return True, False
|
||||
|
||||
@@ -4327,7 +4481,6 @@ class AIAgent:
|
||||
|
||||
def _fire_reasoning_delta(self, text: str) -> None:
|
||||
"""Fire reasoning callback if registered."""
|
||||
self._reasoning_deltas_fired = True
|
||||
cb = self.reasoning_callback
|
||||
if cb is not None:
|
||||
try:
|
||||
@@ -4407,7 +4560,17 @@ class AIAgent:
|
||||
"""Stream a chat completions response."""
|
||||
import httpx as _httpx
|
||||
_base_timeout = float(os.getenv("HERMES_API_TIMEOUT", 1800.0))
|
||||
_stream_read_timeout = float(os.getenv("HERMES_STREAM_READ_TIMEOUT", 60.0))
|
||||
_stream_read_timeout = float(os.getenv("HERMES_STREAM_READ_TIMEOUT", 120.0))
|
||||
# Local providers (Ollama, llama.cpp, vLLM) can take minutes for
|
||||
# prefill on large contexts before producing the first token.
|
||||
# Auto-increase the httpx read timeout unless the user explicitly
|
||||
# overrode HERMES_STREAM_READ_TIMEOUT.
|
||||
if _stream_read_timeout == 120.0 and self.base_url and is_local_endpoint(self.base_url):
|
||||
_stream_read_timeout = _base_timeout
|
||||
logger.debug(
|
||||
"Local provider detected (%s) — stream read timeout raised to %.0fs",
|
||||
self.base_url, _stream_read_timeout,
|
||||
)
|
||||
stream_kwargs = {
|
||||
**api_kwargs,
|
||||
"stream": True,
|
||||
@@ -4447,10 +4610,6 @@ class AIAgent:
|
||||
role = "assistant"
|
||||
reasoning_parts: list = []
|
||||
usage_obj = None
|
||||
# Reset per-call reasoning tracking so _build_assistant_message
|
||||
# knows whether reasoning was already displayed during streaming.
|
||||
self._reasoning_deltas_fired = False
|
||||
|
||||
_first_chunk_seen = False
|
||||
for chunk in stream:
|
||||
last_chunk_time["t"] = time.time()
|
||||
@@ -4565,20 +4724,31 @@ class AIAgent:
|
||||
# Build mock response matching non-streaming shape
|
||||
full_content = "".join(content_parts) or None
|
||||
mock_tool_calls = None
|
||||
has_truncated_tool_args = False
|
||||
if tool_calls_acc:
|
||||
mock_tool_calls = []
|
||||
for idx in sorted(tool_calls_acc):
|
||||
tc = tool_calls_acc[idx]
|
||||
arguments = tc["function"]["arguments"]
|
||||
if arguments and arguments.strip():
|
||||
try:
|
||||
json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
has_truncated_tool_args = True
|
||||
mock_tool_calls.append(SimpleNamespace(
|
||||
id=tc["id"],
|
||||
type=tc["type"],
|
||||
extra_content=tc.get("extra_content"),
|
||||
function=SimpleNamespace(
|
||||
name=tc["function"]["name"],
|
||||
arguments=tc["function"]["arguments"],
|
||||
arguments=arguments,
|
||||
),
|
||||
))
|
||||
|
||||
effective_finish_reason = finish_reason or "stop"
|
||||
if has_truncated_tool_args:
|
||||
effective_finish_reason = "length"
|
||||
|
||||
full_reasoning = "".join(reasoning_parts) or None
|
||||
mock_message = SimpleNamespace(
|
||||
role=role,
|
||||
@@ -4589,7 +4759,7 @@ class AIAgent:
|
||||
mock_choice = SimpleNamespace(
|
||||
index=0,
|
||||
message=mock_message,
|
||||
finish_reason=finish_reason or "stop",
|
||||
finish_reason=effective_finish_reason,
|
||||
)
|
||||
return SimpleNamespace(
|
||||
id="stream-" + str(uuid.uuid4()),
|
||||
@@ -4607,13 +4777,20 @@ class AIAgent:
|
||||
works unchanged.
|
||||
"""
|
||||
has_tool_use = False
|
||||
self._reasoning_deltas_fired = False
|
||||
|
||||
# Reset stale-stream timer for this attempt
|
||||
last_chunk_time["t"] = time.time()
|
||||
# Use the Anthropic SDK's streaming context manager
|
||||
with self._anthropic_client.messages.stream(**api_kwargs) as stream:
|
||||
for event in stream:
|
||||
# Update stale-stream timer on every event so the
|
||||
# outer poll loop knows data is flowing. Without
|
||||
# this, the detector kills healthy long-running
|
||||
# Opus streams after 180 s even when events are
|
||||
# actively arriving (the chat_completions path
|
||||
# already does this at the top of its chunk loop).
|
||||
last_chunk_time["t"] = time.time()
|
||||
|
||||
if self._interrupt_requested:
|
||||
break
|
||||
|
||||
@@ -4637,6 +4814,7 @@ class AIAgent:
|
||||
if text and not has_tool_use:
|
||||
_fire_first_delta()
|
||||
self._fire_stream_delta(text)
|
||||
deltas_were_sent["yes"] = True
|
||||
elif delta_type == "thinking_delta":
|
||||
thinking_text = getattr(delta, "thinking", "")
|
||||
if thinking_text:
|
||||
@@ -4927,7 +5105,7 @@ class AIAgent:
|
||||
# when no explicit key is in the fallback config.
|
||||
if fb_base_url_hint and "ollama.com" in fb_base_url_hint.lower() and not fb_api_key_hint:
|
||||
fb_api_key_hint = os.getenv("OLLAMA_API_KEY") or None
|
||||
fb_client, _ = resolve_provider_client(
|
||||
fb_client, _resolved_fb_model = resolve_provider_client(
|
||||
fb_provider, model=fb_model, raw_codex=True,
|
||||
explicit_base_url=fb_base_url_hint,
|
||||
explicit_api_key=fb_api_key_hint)
|
||||
@@ -4936,6 +5114,12 @@ class AIAgent:
|
||||
"Fallback to %s failed: provider not configured",
|
||||
fb_provider)
|
||||
return self._try_activate_fallback() # try next in chain
|
||||
try:
|
||||
from hermes_cli.model_normalize import normalize_model_for_provider
|
||||
|
||||
fb_model = normalize_model_for_provider(fb_model, fb_provider)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Determine api_mode from provider / base URL
|
||||
fb_api_mode = "chat_completions"
|
||||
@@ -5096,6 +5280,7 @@ class AIAgent:
|
||||
_TRANSIENT_TRANSPORT_ERRORS = frozenset({
|
||||
"ReadTimeout", "ConnectTimeout", "PoolTimeout",
|
||||
"ConnectError", "RemoteProtocolError",
|
||||
"APIConnectionError", "APITimeoutError",
|
||||
})
|
||||
|
||||
def _try_recover_primary_transport(
|
||||
@@ -5419,6 +5604,7 @@ class AIAgent:
|
||||
preserve_dots=self._anthropic_preserve_dots(),
|
||||
context_length=ctx_len,
|
||||
base_url=getattr(self, "_anthropic_base_url", None),
|
||||
fast_mode=(self.request_overrides or {}).get("speed") == "fast",
|
||||
)
|
||||
|
||||
if self.api_mode == "codex_responses":
|
||||
@@ -5434,6 +5620,10 @@ class AIAgent:
|
||||
"models.github.ai" in self.base_url.lower()
|
||||
or "api.githubcopilot.com" in self.base_url.lower()
|
||||
)
|
||||
is_codex_backend = (
|
||||
self.provider == "openai-codex"
|
||||
or "chatgpt.com/backend-api/codex" in self.base_url.lower()
|
||||
)
|
||||
|
||||
# Resolve reasoning effort: config > default (medium)
|
||||
reasoning_effort = "medium"
|
||||
@@ -5471,7 +5661,10 @@ class AIAgent:
|
||||
elif not is_github_responses:
|
||||
kwargs["include"] = []
|
||||
|
||||
if self.max_tokens is not None:
|
||||
if self.request_overrides:
|
||||
kwargs.update(self.request_overrides)
|
||||
|
||||
if self.max_tokens is not None and not is_codex_backend:
|
||||
kwargs["max_output_tokens"] = self.max_tokens
|
||||
|
||||
return kwargs
|
||||
@@ -5566,20 +5759,20 @@ class AIAgent:
|
||||
if self.max_tokens is not None:
|
||||
if not self._is_qwen_portal():
|
||||
api_kwargs.update(self._max_tokens_param(self.max_tokens))
|
||||
elif self._is_openrouter_url() and "claude" in (self.model or "").lower():
|
||||
# OpenRouter translates requests to Anthropic's Messages API,
|
||||
# which requires max_tokens as a mandatory field. When we omit
|
||||
# it, OpenRouter picks a default that can be too low — the model
|
||||
# spends its output budget on thinking and has almost nothing
|
||||
# left for the actual response (especially large tool calls like
|
||||
# write_file). Sending the model's real output limit ensures
|
||||
# full capacity. Other providers handle the default fine.
|
||||
elif (self._is_openrouter_url() or "nousresearch" in self._base_url_lower) and "claude" in (self.model or "").lower():
|
||||
# OpenRouter and Nous Portal translate requests to Anthropic's
|
||||
# Messages API, which requires max_tokens as a mandatory field.
|
||||
# When we omit it, the proxy picks a default that can be too
|
||||
# low — the model spends its output budget on thinking and has
|
||||
# almost nothing left for the actual response (especially large
|
||||
# tool calls like write_file). Sending the model's real output
|
||||
# limit ensures full capacity.
|
||||
try:
|
||||
from agent.anthropic_adapter import _get_anthropic_max_output
|
||||
_model_output_limit = _get_anthropic_max_output(self.model)
|
||||
api_kwargs["max_tokens"] = _model_output_limit
|
||||
except Exception:
|
||||
pass # fail open — let OpenRouter pick its default
|
||||
pass # fail open — let the proxy pick its default
|
||||
|
||||
extra_body = {}
|
||||
|
||||
@@ -5642,6 +5835,11 @@ class AIAgent:
|
||||
if "x.ai" in self._base_url_lower and hasattr(self, "session_id") and self.session_id:
|
||||
api_kwargs["extra_headers"] = {"x-grok-conv-id": self.session_id}
|
||||
|
||||
# Priority Processing / generic request overrides (e.g. service_tier).
|
||||
# Applied last so overrides win over any defaults set above.
|
||||
if self.request_overrides:
|
||||
api_kwargs.update(self.request_overrides)
|
||||
|
||||
return api_kwargs
|
||||
|
||||
def _supports_reasoning_extra_body(self) -> bool:
|
||||
@@ -6347,7 +6545,7 @@ class AIAgent:
|
||||
|
||||
# Start spinner for CLI mode (skip when TUI handles tool progress)
|
||||
spinner = None
|
||||
if self.quiet_mode and not self.tool_progress_callback and self._should_start_quiet_spinner():
|
||||
if self._should_emit_quiet_tool_messages() and self._should_start_quiet_spinner():
|
||||
face = random.choice(KawaiiSpinner.KAWAII_WAITING)
|
||||
spinner = KawaiiSpinner(f"{face} ⚡ running {num_tools} tools concurrently", spinner_type='dots', print_fn=self._print_fn)
|
||||
spinner.start()
|
||||
@@ -6397,7 +6595,7 @@ class AIAgent:
|
||||
logging.debug(f"Tool result ({len(function_result)} chars): {function_result}")
|
||||
|
||||
# Print cute message per tool
|
||||
if self.quiet_mode:
|
||||
if self._should_emit_quiet_tool_messages():
|
||||
cute_msg = _get_cute_tool_message_impl(name, args, tool_duration, result=function_result)
|
||||
self._safe_print(f" {cute_msg}")
|
||||
elif not self.quiet_mode:
|
||||
@@ -6554,7 +6752,7 @@ class AIAgent:
|
||||
store=self._todo_store,
|
||||
)
|
||||
tool_duration = time.time() - tool_start_time
|
||||
if self.quiet_mode:
|
||||
if self._should_emit_quiet_tool_messages():
|
||||
self._vprint(f" {_get_cute_tool_message_impl('todo', function_args, tool_duration, result=function_result)}")
|
||||
elif function_name == "session_search":
|
||||
if not self._session_db:
|
||||
@@ -6569,7 +6767,7 @@ class AIAgent:
|
||||
current_session_id=self.session_id,
|
||||
)
|
||||
tool_duration = time.time() - tool_start_time
|
||||
if self.quiet_mode:
|
||||
if self._should_emit_quiet_tool_messages():
|
||||
self._vprint(f" {_get_cute_tool_message_impl('session_search', function_args, tool_duration, result=function_result)}")
|
||||
elif function_name == "memory":
|
||||
target = function_args.get("target", "memory")
|
||||
@@ -6582,7 +6780,7 @@ class AIAgent:
|
||||
store=self._memory_store,
|
||||
)
|
||||
tool_duration = time.time() - tool_start_time
|
||||
if self.quiet_mode:
|
||||
if self._should_emit_quiet_tool_messages():
|
||||
self._vprint(f" {_get_cute_tool_message_impl('memory', function_args, tool_duration, result=function_result)}")
|
||||
elif function_name == "clarify":
|
||||
from tools.clarify_tool import clarify_tool as _clarify_tool
|
||||
@@ -6592,7 +6790,7 @@ class AIAgent:
|
||||
callback=self.clarify_callback,
|
||||
)
|
||||
tool_duration = time.time() - tool_start_time
|
||||
if self.quiet_mode:
|
||||
if self._should_emit_quiet_tool_messages():
|
||||
self._vprint(f" {_get_cute_tool_message_impl('clarify', function_args, tool_duration, result=function_result)}")
|
||||
elif function_name == "delegate_task":
|
||||
from tools.delegate_tool import delegate_task as _delegate_task
|
||||
@@ -6603,7 +6801,7 @@ class AIAgent:
|
||||
goal_preview = (function_args.get("goal") or "")[:30]
|
||||
spinner_label = f"🔀 {goal_preview}" if goal_preview else "🔀 delegating"
|
||||
spinner = None
|
||||
if self.quiet_mode and not self.tool_progress_callback and self._should_start_quiet_spinner():
|
||||
if self._should_emit_quiet_tool_messages() and self._should_start_quiet_spinner():
|
||||
face = random.choice(KawaiiSpinner.KAWAII_WAITING)
|
||||
spinner = KawaiiSpinner(f"{face} {spinner_label}", spinner_type='dots', print_fn=self._print_fn)
|
||||
spinner.start()
|
||||
@@ -6625,13 +6823,13 @@ class AIAgent:
|
||||
cute_msg = _get_cute_tool_message_impl('delegate_task', function_args, tool_duration, result=_delegate_result)
|
||||
if spinner:
|
||||
spinner.stop(cute_msg)
|
||||
elif self.quiet_mode:
|
||||
elif self._should_emit_quiet_tool_messages():
|
||||
self._vprint(f" {cute_msg}")
|
||||
elif self._memory_manager and self._memory_manager.has_tool(function_name):
|
||||
# Memory provider tools (hindsight_retain, honcho_search, etc.)
|
||||
# These are not in the tool registry — route through MemoryManager.
|
||||
spinner = None
|
||||
if self.quiet_mode and not self.tool_progress_callback:
|
||||
if self._should_emit_quiet_tool_messages() and self._should_start_quiet_spinner():
|
||||
face = random.choice(KawaiiSpinner.KAWAII_WAITING)
|
||||
emoji = _get_tool_emoji(function_name)
|
||||
preview = _build_tool_preview(function_name, function_args) or function_name
|
||||
@@ -6649,11 +6847,11 @@ class AIAgent:
|
||||
cute_msg = _get_cute_tool_message_impl(function_name, function_args, tool_duration, result=_mem_result)
|
||||
if spinner:
|
||||
spinner.stop(cute_msg)
|
||||
elif self.quiet_mode:
|
||||
elif self._should_emit_quiet_tool_messages():
|
||||
self._vprint(f" {cute_msg}")
|
||||
elif self.quiet_mode:
|
||||
spinner = None
|
||||
if not self.tool_progress_callback:
|
||||
if self._should_emit_quiet_tool_messages() and self._should_start_quiet_spinner():
|
||||
face = random.choice(KawaiiSpinner.KAWAII_WAITING)
|
||||
emoji = _get_tool_emoji(function_name)
|
||||
preview = _build_tool_preview(function_name, function_args) or function_name
|
||||
@@ -6676,7 +6874,7 @@ class AIAgent:
|
||||
cute_msg = _get_cute_tool_message_impl(function_name, function_args, tool_duration, result=_spinner_result)
|
||||
if spinner:
|
||||
spinner.stop(cute_msg)
|
||||
else:
|
||||
elif self._should_emit_quiet_tool_messages():
|
||||
self._vprint(f" {cute_msg}")
|
||||
else:
|
||||
try:
|
||||
@@ -7070,7 +7268,7 @@ class AIAgent:
|
||||
self._thinking_prefill_retries = 0
|
||||
self._last_content_with_tools = None
|
||||
self._mute_post_response = False
|
||||
self._surrogate_sanitized = False
|
||||
self._unicode_sanitization_passes = 0
|
||||
|
||||
# Pre-turn connection health check: detect and clean up dead TCP
|
||||
# connections left over from provider outages or dropped streams.
|
||||
@@ -7300,6 +7498,7 @@ class AIAgent:
|
||||
interrupted = False
|
||||
codex_ack_continuations = 0
|
||||
length_continue_retries = 0
|
||||
truncated_tool_call_retries = 0
|
||||
truncated_response_prefix = ""
|
||||
compression_attempts = 0
|
||||
_turn_exit_reason = "unknown" # Diagnostic: why the loop ended
|
||||
@@ -7768,9 +7967,11 @@ class AIAgent:
|
||||
# retries are pointless. Detect this early and give a
|
||||
# targeted error instead of wasting 3 API calls.
|
||||
_trunc_content = None
|
||||
_trunc_has_tool_calls = False
|
||||
if self.api_mode == "chat_completions":
|
||||
_trunc_msg = response.choices[0].message if (hasattr(response, "choices") and response.choices) else None
|
||||
_trunc_content = getattr(_trunc_msg, "content", None) if _trunc_msg else None
|
||||
_trunc_has_tool_calls = bool(getattr(_trunc_msg, "tool_calls", None)) if _trunc_msg else False
|
||||
elif self.api_mode == "anthropic_messages":
|
||||
# Anthropic response.content is a list of blocks
|
||||
_text_parts = []
|
||||
@@ -7780,9 +7981,11 @@ class AIAgent:
|
||||
_trunc_content = "\n".join(_text_parts) if _text_parts else None
|
||||
|
||||
_thinking_exhausted = (
|
||||
_trunc_content is not None
|
||||
and not self._has_content_after_think_block(_trunc_content)
|
||||
) or _trunc_content is None
|
||||
not _trunc_has_tool_calls and (
|
||||
(_trunc_content is not None and not self._has_content_after_think_block(_trunc_content))
|
||||
or _trunc_content is None
|
||||
)
|
||||
)
|
||||
|
||||
if _thinking_exhausted:
|
||||
_exhaust_error = (
|
||||
@@ -7858,6 +8061,34 @@ class AIAgent:
|
||||
"error": "Response remained truncated after 3 continuation attempts",
|
||||
}
|
||||
|
||||
if self.api_mode == "chat_completions":
|
||||
assistant_message = response.choices[0].message
|
||||
if assistant_message.tool_calls:
|
||||
if truncated_tool_call_retries < 1:
|
||||
truncated_tool_call_retries += 1
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ Truncated tool call detected — retrying API call...",
|
||||
force=True,
|
||||
)
|
||||
# Don't append the broken response to messages;
|
||||
# just re-run the same API call from the current
|
||||
# message state, giving the model another chance.
|
||||
continue
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ Truncated tool call response detected again — refusing to execute incomplete tool arguments.",
|
||||
force=True,
|
||||
)
|
||||
self._cleanup_task_resources(effective_task_id)
|
||||
self._persist_session(messages, conversation_history)
|
||||
return {
|
||||
"final_response": None,
|
||||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": False,
|
||||
"partial": True,
|
||||
"error": "Response truncated due to output length limit",
|
||||
}
|
||||
|
||||
# If we have prior messages, roll back to last complete state
|
||||
if len(messages) > 1:
|
||||
self._vprint(f"{self.log_prefix} ⏪ Rolling back to last complete assistant turn")
|
||||
@@ -8022,22 +8253,40 @@ class AIAgent:
|
||||
self.thinking_callback("")
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Surrogate character recovery. UnicodeEncodeError happens
|
||||
# when the messages contain lone surrogates (U+D800..U+DFFF)
|
||||
# that are invalid UTF-8. Common source: clipboard paste
|
||||
# from Google Docs or similar rich-text editors. We sanitize
|
||||
# the entire messages list in-place and retry once.
|
||||
# UnicodeEncodeError recovery. Two common causes:
|
||||
# 1. Lone surrogates (U+D800..U+DFFF) from clipboard paste
|
||||
# (Google Docs, rich-text editors) — sanitize and retry.
|
||||
# 2. ASCII codec on systems with LANG=C or non-UTF-8 locale
|
||||
# (e.g. Chromebooks) — any non-ASCII character fails.
|
||||
# Detect via the error message mentioning 'ascii' codec.
|
||||
# We sanitize messages in-place and may retry twice:
|
||||
# first to strip surrogates, then once more for pure
|
||||
# ASCII-only locale sanitization if needed.
|
||||
# -----------------------------------------------------------
|
||||
if isinstance(api_error, UnicodeEncodeError) and not getattr(self, '_surrogate_sanitized', False):
|
||||
self._surrogate_sanitized = True
|
||||
if _sanitize_messages_surrogates(messages):
|
||||
if isinstance(api_error, UnicodeEncodeError) and getattr(self, '_unicode_sanitization_passes', 0) < 2:
|
||||
_err_str = str(api_error).lower()
|
||||
_is_ascii_codec = "'ascii'" in _err_str or "ascii" in _err_str
|
||||
_surrogates_found = _sanitize_messages_surrogates(messages)
|
||||
if _surrogates_found:
|
||||
self._unicode_sanitization_passes += 1
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ Stripped invalid surrogate characters from messages. Retrying...",
|
||||
force=True,
|
||||
)
|
||||
continue
|
||||
# Surrogates weren't in messages — might be in system
|
||||
# prompt or prefill. Fall through to normal error path.
|
||||
if _is_ascii_codec:
|
||||
# ASCII codec: the system encoding can't handle
|
||||
# non-ASCII characters at all. Sanitize all
|
||||
# non-ASCII content from messages and retry.
|
||||
if _sanitize_messages_non_ascii(messages):
|
||||
self._unicode_sanitization_passes += 1
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ System encoding is ASCII — stripped non-ASCII characters from messages. Retrying...",
|
||||
force=True,
|
||||
)
|
||||
continue
|
||||
# Nothing to sanitize in messages — might be in system
|
||||
# prompt or prefill. Fall through to normal error path.
|
||||
|
||||
status_code = getattr(api_error, "status_code", None)
|
||||
error_context = self._extract_api_error_context(api_error)
|
||||
@@ -8063,6 +8312,7 @@ class AIAgent:
|
||||
recovered_with_pool, has_retried_429 = self._recover_with_credential_pool(
|
||||
status_code=status_code,
|
||||
has_retried_429=has_retried_429,
|
||||
classified_reason=classified.reason,
|
||||
error_context=error_context,
|
||||
)
|
||||
if recovered_with_pool:
|
||||
@@ -8170,7 +8420,33 @@ class AIAgent:
|
||||
if _err_body_str:
|
||||
self._vprint(f"{self.log_prefix} 📋 Details: {_err_body_str}", force=True)
|
||||
self._vprint(f"{self.log_prefix} ⏱️ Elapsed: {elapsed_time:.2f}s Context: {len(api_messages)} msgs, ~{approx_tokens:,} tokens")
|
||||
|
||||
|
||||
# Actionable hint for OpenRouter "no tool endpoints" error.
|
||||
# This fires regardless of whether fallback succeeds — the
|
||||
# user needs to know WHY their model failed so they can fix
|
||||
# their provider routing, not just silently fall back.
|
||||
if (
|
||||
self._is_openrouter_url()
|
||||
and "support tool use" in error_msg
|
||||
):
|
||||
self._vprint(
|
||||
f"{self.log_prefix} 💡 No OpenRouter providers for {_model} support tool calling with your current settings.",
|
||||
force=True,
|
||||
)
|
||||
if self.providers_allowed:
|
||||
self._vprint(
|
||||
f"{self.log_prefix} Your provider_routing.only restriction is filtering out tool-capable providers.",
|
||||
force=True,
|
||||
)
|
||||
self._vprint(
|
||||
f"{self.log_prefix} Try removing the restriction or adding providers that support tools for this model.",
|
||||
force=True,
|
||||
)
|
||||
self._vprint(
|
||||
f"{self.log_prefix} Check which providers support tools: https://openrouter.ai/models/{_model}",
|
||||
force=True,
|
||||
)
|
||||
|
||||
# Check for interrupt before deciding to retry
|
||||
if self._interrupt_requested:
|
||||
self._vprint(f"{self.log_prefix}⚡ Interrupt detected during error handling, aborting retries.", force=True)
|
||||
@@ -8226,6 +8502,10 @@ class AIAgent:
|
||||
approx_tokens=approx_tokens,
|
||||
task_id=effective_task_id,
|
||||
)
|
||||
# Compression created a new session — clear history
|
||||
# so _flush_messages_to_session_db writes compressed
|
||||
# messages to the new session, not skipping them.
|
||||
conversation_history = None
|
||||
if len(messages) < original_len or old_ctx > _reduced_ctx:
|
||||
self._emit_status(
|
||||
f"🗜️ Context reduced to {_reduced_ctx:,} tokens "
|
||||
@@ -8283,6 +8563,10 @@ class AIAgent:
|
||||
messages, system_message, approx_tokens=approx_tokens,
|
||||
task_id=effective_task_id,
|
||||
)
|
||||
# Compression created a new session — clear history
|
||||
# so _flush_messages_to_session_db writes compressed
|
||||
# messages to the new session, not skipping them.
|
||||
conversation_history = None
|
||||
|
||||
if len(messages) < original_len:
|
||||
self._emit_status(f"🗜️ Compressed {original_len} → {len(messages)} messages, retrying...")
|
||||
@@ -8401,6 +8685,10 @@ class AIAgent:
|
||||
messages, system_message, approx_tokens=approx_tokens,
|
||||
task_id=effective_task_id,
|
||||
)
|
||||
# Compression created a new session — clear history
|
||||
# so _flush_messages_to_session_db writes compressed
|
||||
# messages to the new session, not skipping them.
|
||||
conversation_history = None
|
||||
|
||||
if len(messages) < original_len or new_ctx and new_ctx < old_ctx:
|
||||
if len(messages) < original_len:
|
||||
@@ -9008,6 +9296,11 @@ class AIAgent:
|
||||
|
||||
self._execute_tool_calls(assistant_message, messages, effective_task_id, api_call_count)
|
||||
|
||||
# Reset per-turn retry counters after successful tool
|
||||
# execution so a single truncation doesn't poison the
|
||||
# entire conversation.
|
||||
truncated_tool_call_retries = 0
|
||||
|
||||
# Signal that a paragraph break is needed before the next
|
||||
# streamed text. We don't emit it immediately because
|
||||
# multiple consecutive tool iterations would stack up
|
||||
@@ -9194,7 +9487,6 @@ class AIAgent:
|
||||
# Reset retry counter/signature on successful content
|
||||
if hasattr(self, '_empty_content_retries'):
|
||||
self._empty_content_retries = 0
|
||||
self._last_empty_content_signature = None
|
||||
self._thinking_prefill_retries = 0
|
||||
|
||||
if (
|
||||
@@ -9266,7 +9558,6 @@ class AIAgent:
|
||||
# If an assistant message with tool_calls was already appended,
|
||||
# the API expects a role="tool" result for every tool_call_id.
|
||||
# Fill in error results for any that weren't answered yet.
|
||||
pending_handled = False
|
||||
for idx in range(len(messages) - 1, -1, -1):
|
||||
msg = messages[idx]
|
||||
if not isinstance(msg, dict):
|
||||
|
||||
+250
-35
@@ -2,8 +2,8 @@
|
||||
# ============================================================================
|
||||
# Hermes Agent Installer
|
||||
# ============================================================================
|
||||
# Installation script for Linux and macOS.
|
||||
# Uses uv for fast Python provisioning and package management.
|
||||
# Installation script for Linux, macOS, and Android/Termux.
|
||||
# Uses uv for desktop/server installs and Python's stdlib venv + pip on Termux.
|
||||
#
|
||||
# Usage:
|
||||
# curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash
|
||||
@@ -117,6 +117,36 @@ log_error() {
|
||||
echo -e "${RED}✗${NC} $1"
|
||||
}
|
||||
|
||||
is_termux() {
|
||||
[ -n "${TERMUX_VERSION:-}" ] || [[ "${PREFIX:-}" == *"com.termux/files/usr"* ]]
|
||||
}
|
||||
|
||||
get_command_link_dir() {
|
||||
if is_termux && [ -n "${PREFIX:-}" ]; then
|
||||
echo "$PREFIX/bin"
|
||||
else
|
||||
echo "$HOME/.local/bin"
|
||||
fi
|
||||
}
|
||||
|
||||
get_command_link_display_dir() {
|
||||
if is_termux && [ -n "${PREFIX:-}" ]; then
|
||||
echo '$PREFIX/bin'
|
||||
else
|
||||
echo '~/.local/bin'
|
||||
fi
|
||||
}
|
||||
|
||||
get_hermes_command_path() {
|
||||
local link_dir
|
||||
link_dir="$(get_command_link_dir)"
|
||||
if [ -x "$link_dir/hermes" ]; then
|
||||
echo "$link_dir/hermes"
|
||||
else
|
||||
echo "hermes"
|
||||
fi
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# System detection
|
||||
# ============================================================================
|
||||
@@ -124,12 +154,17 @@ log_error() {
|
||||
detect_os() {
|
||||
case "$(uname -s)" in
|
||||
Linux*)
|
||||
OS="linux"
|
||||
if [ -f /etc/os-release ]; then
|
||||
. /etc/os-release
|
||||
DISTRO="$ID"
|
||||
if is_termux; then
|
||||
OS="android"
|
||||
DISTRO="termux"
|
||||
else
|
||||
DISTRO="unknown"
|
||||
OS="linux"
|
||||
if [ -f /etc/os-release ]; then
|
||||
. /etc/os-release
|
||||
DISTRO="$ID"
|
||||
else
|
||||
DISTRO="unknown"
|
||||
fi
|
||||
fi
|
||||
;;
|
||||
Darwin*)
|
||||
@@ -158,6 +193,12 @@ detect_os() {
|
||||
# ============================================================================
|
||||
|
||||
install_uv() {
|
||||
if [ "$DISTRO" = "termux" ]; then
|
||||
log_info "Termux detected — using Python's stdlib venv + pip instead of uv"
|
||||
UV_CMD=""
|
||||
return 0
|
||||
fi
|
||||
|
||||
log_info "Checking for uv package manager..."
|
||||
|
||||
# Check common locations for uv
|
||||
@@ -209,6 +250,25 @@ install_uv() {
|
||||
}
|
||||
|
||||
check_python() {
|
||||
if [ "$DISTRO" = "termux" ]; then
|
||||
log_info "Checking Termux Python..."
|
||||
if command -v python >/dev/null 2>&1; then
|
||||
PYTHON_PATH="$(command -v python)"
|
||||
if "$PYTHON_PATH" -c 'import sys; raise SystemExit(0 if sys.version_info >= (3, 11) else 1)' 2>/dev/null; then
|
||||
PYTHON_FOUND_VERSION=$($PYTHON_PATH --version 2>/dev/null)
|
||||
log_success "Python found: $PYTHON_FOUND_VERSION"
|
||||
return 0
|
||||
fi
|
||||
fi
|
||||
|
||||
log_info "Installing Python via pkg..."
|
||||
pkg install -y python >/dev/null
|
||||
PYTHON_PATH="$(command -v python)"
|
||||
PYTHON_FOUND_VERSION=$($PYTHON_PATH --version 2>/dev/null)
|
||||
log_success "Python installed: $PYTHON_FOUND_VERSION"
|
||||
return 0
|
||||
fi
|
||||
|
||||
log_info "Checking Python $PYTHON_VERSION..."
|
||||
|
||||
# Let uv handle Python — it can download and manage Python versions
|
||||
@@ -243,6 +303,17 @@ check_git() {
|
||||
fi
|
||||
|
||||
log_error "Git not found"
|
||||
|
||||
if [ "$DISTRO" = "termux" ]; then
|
||||
log_info "Installing Git via pkg..."
|
||||
pkg install -y git >/dev/null
|
||||
if command -v git >/dev/null 2>&1; then
|
||||
GIT_VERSION=$(git --version | awk '{print $3}')
|
||||
log_success "Git $GIT_VERSION installed"
|
||||
return 0
|
||||
fi
|
||||
fi
|
||||
|
||||
log_info "Please install Git:"
|
||||
|
||||
case "$OS" in
|
||||
@@ -262,6 +333,9 @@ check_git() {
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
android)
|
||||
log_info " pkg install git"
|
||||
;;
|
||||
macos)
|
||||
log_info " xcode-select --install"
|
||||
log_info " Or: brew install git"
|
||||
@@ -290,11 +364,29 @@ check_node() {
|
||||
return 0
|
||||
fi
|
||||
|
||||
log_info "Node.js not found — installing Node.js $NODE_VERSION LTS..."
|
||||
if [ "$DISTRO" = "termux" ]; then
|
||||
log_info "Node.js not found — installing Node.js via pkg..."
|
||||
else
|
||||
log_info "Node.js not found — installing Node.js $NODE_VERSION LTS..."
|
||||
fi
|
||||
install_node
|
||||
}
|
||||
|
||||
install_node() {
|
||||
if [ "$DISTRO" = "termux" ]; then
|
||||
log_info "Installing Node.js via pkg..."
|
||||
if pkg install -y nodejs >/dev/null; then
|
||||
local installed_ver
|
||||
installed_ver=$(node --version 2>/dev/null)
|
||||
log_success "Node.js $installed_ver installed via pkg"
|
||||
HAS_NODE=true
|
||||
else
|
||||
log_warn "Failed to install Node.js via pkg"
|
||||
HAS_NODE=false
|
||||
fi
|
||||
return 0
|
||||
fi
|
||||
|
||||
local arch=$(uname -m)
|
||||
local node_arch
|
||||
case "$arch" in
|
||||
@@ -413,6 +505,30 @@ install_system_packages() {
|
||||
need_ffmpeg=true
|
||||
fi
|
||||
|
||||
# Termux always needs the Android build toolchain for the tested pip path,
|
||||
# even when ripgrep/ffmpeg are already present.
|
||||
if [ "$DISTRO" = "termux" ]; then
|
||||
local termux_pkgs=(clang rust make pkg-config libffi openssl)
|
||||
if [ "$need_ripgrep" = true ]; then
|
||||
termux_pkgs+=("ripgrep")
|
||||
fi
|
||||
if [ "$need_ffmpeg" = true ]; then
|
||||
termux_pkgs+=("ffmpeg")
|
||||
fi
|
||||
|
||||
log_info "Installing Termux packages: ${termux_pkgs[*]}"
|
||||
if pkg install -y "${termux_pkgs[@]}" >/dev/null; then
|
||||
[ "$need_ripgrep" = true ] && HAS_RIPGREP=true && log_success "ripgrep installed"
|
||||
[ "$need_ffmpeg" = true ] && HAS_FFMPEG=true && log_success "ffmpeg installed"
|
||||
log_success "Termux build dependencies installed"
|
||||
return 0
|
||||
fi
|
||||
|
||||
log_warn "Could not auto-install all Termux packages"
|
||||
log_info "Install manually: pkg install ${termux_pkgs[*]}"
|
||||
return 0
|
||||
fi
|
||||
|
||||
# Nothing to install — done
|
||||
if [ "$need_ripgrep" = false ] && [ "$need_ffmpeg" = false ]; then
|
||||
return 0
|
||||
@@ -550,6 +666,9 @@ show_manual_install_hint() {
|
||||
*) log_info " Use your package manager or visit the project homepage" ;;
|
||||
esac
|
||||
;;
|
||||
android)
|
||||
log_info " pkg install $pkg"
|
||||
;;
|
||||
macos) log_info " brew install $pkg" ;;
|
||||
esac
|
||||
}
|
||||
@@ -646,6 +765,19 @@ setup_venv() {
|
||||
return 0
|
||||
fi
|
||||
|
||||
if [ "$DISTRO" = "termux" ]; then
|
||||
log_info "Creating virtual environment with Termux Python..."
|
||||
|
||||
if [ -d "venv" ]; then
|
||||
log_info "Virtual environment already exists, recreating..."
|
||||
rm -rf venv
|
||||
fi
|
||||
|
||||
"$PYTHON_PATH" -m venv venv
|
||||
log_success "Virtual environment ready ($(./venv/bin/python --version 2>/dev/null))"
|
||||
return 0
|
||||
fi
|
||||
|
||||
log_info "Creating virtual environment with Python $PYTHON_VERSION..."
|
||||
|
||||
if [ -d "venv" ]; then
|
||||
@@ -662,6 +794,46 @@ setup_venv() {
|
||||
install_deps() {
|
||||
log_info "Installing dependencies..."
|
||||
|
||||
if [ "$DISTRO" = "termux" ]; then
|
||||
if [ "$USE_VENV" = true ]; then
|
||||
export VIRTUAL_ENV="$INSTALL_DIR/venv"
|
||||
PIP_PYTHON="$INSTALL_DIR/venv/bin/python"
|
||||
else
|
||||
PIP_PYTHON="$PYTHON_PATH"
|
||||
fi
|
||||
|
||||
if [ -z "${ANDROID_API_LEVEL:-}" ]; then
|
||||
ANDROID_API_LEVEL="$(getprop ro.build.version.sdk 2>/dev/null || true)"
|
||||
if [ -z "$ANDROID_API_LEVEL" ]; then
|
||||
ANDROID_API_LEVEL=24
|
||||
fi
|
||||
export ANDROID_API_LEVEL
|
||||
log_info "Using ANDROID_API_LEVEL=$ANDROID_API_LEVEL for Android wheel builds"
|
||||
fi
|
||||
|
||||
"$PIP_PYTHON" -m pip install --upgrade pip setuptools wheel >/dev/null
|
||||
if ! "$PIP_PYTHON" -m pip install -e '.[termux]' -c constraints-termux.txt; then
|
||||
log_warn "Termux feature install (.[termux]) failed, trying base install..."
|
||||
if ! "$PIP_PYTHON" -m pip install -e '.' -c constraints-termux.txt; then
|
||||
log_error "Package installation failed on Termux."
|
||||
log_info "Ensure these packages are installed: pkg install clang rust make pkg-config libffi openssl"
|
||||
log_info "Then re-run: cd $INSTALL_DIR && python -m pip install -e '.[termux]' -c constraints-termux.txt"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
log_success "Main package installed"
|
||||
log_info "Termux note: browser/WhatsApp tooling is not installed by default; see the Termux guide for optional follow-up steps."
|
||||
|
||||
if [ -d "tinker-atropos" ] && [ -f "tinker-atropos/pyproject.toml" ]; then
|
||||
log_info "tinker-atropos submodule found — skipping install (optional, for RL training)"
|
||||
log_info " To install later: $PIP_PYTHON -m pip install -e \"./tinker-atropos\""
|
||||
fi
|
||||
|
||||
log_success "All dependencies installed"
|
||||
return 0
|
||||
fi
|
||||
|
||||
if [ "$USE_VENV" = true ]; then
|
||||
# Tell uv to install into our venv (no need to activate)
|
||||
export VIRTUAL_ENV="$INSTALL_DIR/venv"
|
||||
@@ -743,19 +915,35 @@ setup_path() {
|
||||
if [ ! -x "$HERMES_BIN" ]; then
|
||||
log_warn "hermes entry point not found at $HERMES_BIN"
|
||||
log_info "This usually means the pip install didn't complete successfully."
|
||||
log_info "Try: cd $INSTALL_DIR && uv pip install -e '.[all]'"
|
||||
if [ "$DISTRO" = "termux" ]; then
|
||||
log_info "Try: cd $INSTALL_DIR && python -m pip install -e '.[termux]' -c constraints-termux.txt"
|
||||
else
|
||||
log_info "Try: cd $INSTALL_DIR && uv pip install -e '.[all]'"
|
||||
fi
|
||||
return 0
|
||||
fi
|
||||
|
||||
# Create symlink in ~/.local/bin (standard user binary location, usually on PATH)
|
||||
mkdir -p "$HOME/.local/bin"
|
||||
ln -sf "$HERMES_BIN" "$HOME/.local/bin/hermes"
|
||||
log_success "Symlinked hermes → ~/.local/bin/hermes"
|
||||
local command_link_dir
|
||||
local command_link_display_dir
|
||||
command_link_dir="$(get_command_link_dir)"
|
||||
command_link_display_dir="$(get_command_link_display_dir)"
|
||||
|
||||
# Create a user-facing shim for the hermes command.
|
||||
mkdir -p "$command_link_dir"
|
||||
ln -sf "$HERMES_BIN" "$command_link_dir/hermes"
|
||||
log_success "Symlinked hermes → $command_link_display_dir/hermes"
|
||||
|
||||
if [ "$DISTRO" = "termux" ]; then
|
||||
export PATH="$command_link_dir:$PATH"
|
||||
log_info "$command_link_display_dir is the native Termux command path"
|
||||
log_success "hermes command ready"
|
||||
return 0
|
||||
fi
|
||||
|
||||
# Check if ~/.local/bin is on PATH; if not, add it to shell config.
|
||||
# Detect the user's actual login shell (not the shell running this script,
|
||||
# which is always bash when piped from curl).
|
||||
if ! echo "$PATH" | tr ':' '\n' | grep -q "^$HOME/.local/bin$"; then
|
||||
if ! echo "$PATH" | tr ':' '\n' | grep -q "^$command_link_dir$"; then
|
||||
SHELL_CONFIGS=()
|
||||
LOGIN_SHELL="$(basename "${SHELL:-/bin/bash}")"
|
||||
case "$LOGIN_SHELL" in
|
||||
@@ -801,7 +989,7 @@ setup_path() {
|
||||
fi
|
||||
|
||||
# Export for current session so hermes works immediately
|
||||
export PATH="$HOME/.local/bin:$PATH"
|
||||
export PATH="$command_link_dir:$PATH"
|
||||
|
||||
log_success "hermes command ready"
|
||||
}
|
||||
@@ -878,6 +1066,13 @@ install_node_deps() {
|
||||
return 0
|
||||
fi
|
||||
|
||||
if [ "$DISTRO" = "termux" ]; then
|
||||
log_info "Skipping automatic Node/browser dependency setup on Termux"
|
||||
log_info "Browser automation and WhatsApp bridge are not part of the tested Termux install path yet."
|
||||
log_info "If you want to experiment manually later, run: cd $INSTALL_DIR && npm install"
|
||||
return 0
|
||||
fi
|
||||
|
||||
if [ -f "$INSTALL_DIR/package.json" ]; then
|
||||
log_info "Installing Node.js dependencies (browser tools)..."
|
||||
cd "$INSTALL_DIR"
|
||||
@@ -992,8 +1187,7 @@ maybe_start_gateway() {
|
||||
read -p "Pair WhatsApp now? [Y/n] " -n 1 -r
|
||||
echo
|
||||
if [[ $REPLY =~ ^[Yy]$ ]] || [[ -z $REPLY ]]; then
|
||||
HERMES_CMD="$HOME/.local/bin/hermes"
|
||||
[ ! -x "$HERMES_CMD" ] && HERMES_CMD="hermes"
|
||||
HERMES_CMD="$(get_hermes_command_path)"
|
||||
$HERMES_CMD whatsapp || true
|
||||
fi
|
||||
else
|
||||
@@ -1007,16 +1201,17 @@ maybe_start_gateway() {
|
||||
fi
|
||||
|
||||
echo ""
|
||||
read -p "Would you like to install the gateway as a background service? [Y/n] " -n 1 -r < /dev/tty
|
||||
if [ "$DISTRO" = "termux" ]; then
|
||||
read -p "Would you like to start the gateway in the background? [Y/n] " -n 1 -r < /dev/tty
|
||||
else
|
||||
read -p "Would you like to install the gateway as a background service? [Y/n] " -n 1 -r < /dev/tty
|
||||
fi
|
||||
echo
|
||||
|
||||
if [[ $REPLY =~ ^[Yy]$ ]] || [[ -z $REPLY ]]; then
|
||||
HERMES_CMD="$HOME/.local/bin/hermes"
|
||||
if [ ! -x "$HERMES_CMD" ]; then
|
||||
HERMES_CMD="hermes"
|
||||
fi
|
||||
HERMES_CMD="$(get_hermes_command_path)"
|
||||
|
||||
if command -v systemctl &> /dev/null; then
|
||||
if [ "$DISTRO" != "termux" ] && command -v systemctl &> /dev/null; then
|
||||
log_info "Installing systemd service..."
|
||||
if $HERMES_CMD gateway install 2>/dev/null; then
|
||||
log_success "Gateway service installed"
|
||||
@@ -1029,12 +1224,19 @@ maybe_start_gateway() {
|
||||
log_warn "Systemd install failed. You can start manually: hermes gateway"
|
||||
fi
|
||||
else
|
||||
log_info "systemd not available — starting gateway in background..."
|
||||
if [ "$DISTRO" = "termux" ]; then
|
||||
log_info "Termux detected — starting gateway in best-effort background mode..."
|
||||
else
|
||||
log_info "systemd not available — starting gateway in background..."
|
||||
fi
|
||||
nohup $HERMES_CMD gateway > "$HERMES_HOME/logs/gateway.log" 2>&1 &
|
||||
GATEWAY_PID=$!
|
||||
log_success "Gateway started (PID $GATEWAY_PID). Logs: ~/.hermes/logs/gateway.log"
|
||||
log_info "To stop: kill $GATEWAY_PID"
|
||||
log_info "To restart later: hermes gateway"
|
||||
if [ "$DISTRO" = "termux" ]; then
|
||||
log_warn "Android may stop background processes when Termux is suspended or the system reclaims resources."
|
||||
fi
|
||||
fi
|
||||
else
|
||||
log_info "Skipped. Start the gateway later with: hermes gateway"
|
||||
@@ -1073,24 +1275,33 @@ print_success() {
|
||||
|
||||
echo -e "${CYAN}─────────────────────────────────────────────────────────${NC}"
|
||||
echo ""
|
||||
echo -e "${YELLOW}⚡ Reload your shell to use 'hermes' command:${NC}"
|
||||
echo ""
|
||||
LOGIN_SHELL="$(basename "${SHELL:-/bin/bash}")"
|
||||
if [ "$LOGIN_SHELL" = "zsh" ]; then
|
||||
echo " source ~/.zshrc"
|
||||
elif [ "$LOGIN_SHELL" = "bash" ]; then
|
||||
echo " source ~/.bashrc"
|
||||
if [ "$DISTRO" = "termux" ]; then
|
||||
echo -e "${YELLOW}⚡ 'hermes' was linked into $(get_command_link_display_dir), which is already on PATH in Termux.${NC}"
|
||||
echo ""
|
||||
else
|
||||
echo " source ~/.bashrc # or ~/.zshrc"
|
||||
echo -e "${YELLOW}⚡ Reload your shell to use 'hermes' command:${NC}"
|
||||
echo ""
|
||||
LOGIN_SHELL="$(basename "${SHELL:-/bin/bash}")"
|
||||
if [ "$LOGIN_SHELL" = "zsh" ]; then
|
||||
echo " source ~/.zshrc"
|
||||
elif [ "$LOGIN_SHELL" = "bash" ]; then
|
||||
echo " source ~/.bashrc"
|
||||
else
|
||||
echo " source ~/.bashrc # or ~/.zshrc"
|
||||
fi
|
||||
echo ""
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Show Node.js warning if auto-install failed
|
||||
if [ "$HAS_NODE" = false ]; then
|
||||
echo -e "${YELLOW}"
|
||||
echo "Note: Node.js could not be installed automatically."
|
||||
echo "Browser tools need Node.js. Install manually:"
|
||||
echo " https://nodejs.org/en/download/"
|
||||
if [ "$DISTRO" = "termux" ]; then
|
||||
echo " pkg install nodejs"
|
||||
else
|
||||
echo " https://nodejs.org/en/download/"
|
||||
fi
|
||||
echo -e "${NC}"
|
||||
fi
|
||||
|
||||
@@ -1099,7 +1310,11 @@ print_success() {
|
||||
echo -e "${YELLOW}"
|
||||
echo "Note: ripgrep (rg) was not found. File search will use"
|
||||
echo "grep as a fallback. For faster search in large codebases,"
|
||||
echo "install ripgrep: sudo apt install ripgrep (or brew install ripgrep)"
|
||||
if [ "$DISTRO" = "termux" ]; then
|
||||
echo "install ripgrep: pkg install ripgrep"
|
||||
else
|
||||
echo "install ripgrep: sudo apt install ripgrep (or brew install ripgrep)"
|
||||
fi
|
||||
echo -e "${NC}"
|
||||
fi
|
||||
}
|
||||
|
||||
+212
-120
@@ -3,17 +3,17 @@
|
||||
# Hermes Agent Setup Script
|
||||
# ============================================================================
|
||||
# Quick setup for developers who cloned the repo manually.
|
||||
# Uses uv for fast Python provisioning and package management.
|
||||
# Uses uv for desktop/server setup and Python's stdlib venv + pip on Termux.
|
||||
#
|
||||
# Usage:
|
||||
# ./setup-hermes.sh
|
||||
#
|
||||
# This script:
|
||||
# 1. Installs uv if not present
|
||||
# 2. Creates a virtual environment with Python 3.11 via uv
|
||||
# 3. Installs all dependencies (main package + submodules)
|
||||
# 1. Detects desktop/server vs Android/Termux setup path
|
||||
# 2. Creates a Python 3.11 virtual environment
|
||||
# 3. Installs the appropriate dependency set for the platform
|
||||
# 4. Creates .env from template (if not exists)
|
||||
# 5. Symlinks the 'hermes' CLI command into ~/.local/bin
|
||||
# 5. Symlinks the 'hermes' CLI command into a user-facing bin dir
|
||||
# 6. Runs the setup wizard (optional)
|
||||
# ============================================================================
|
||||
|
||||
@@ -31,6 +31,26 @@ cd "$SCRIPT_DIR"
|
||||
|
||||
PYTHON_VERSION="3.11"
|
||||
|
||||
is_termux() {
|
||||
[ -n "${TERMUX_VERSION:-}" ] || [[ "${PREFIX:-}" == *"com.termux/files/usr"* ]]
|
||||
}
|
||||
|
||||
get_command_link_dir() {
|
||||
if is_termux && [ -n "${PREFIX:-}" ]; then
|
||||
echo "$PREFIX/bin"
|
||||
else
|
||||
echo "$HOME/.local/bin"
|
||||
fi
|
||||
}
|
||||
|
||||
get_command_link_display_dir() {
|
||||
if is_termux && [ -n "${PREFIX:-}" ]; then
|
||||
echo '$PREFIX/bin'
|
||||
else
|
||||
echo '~/.local/bin'
|
||||
fi
|
||||
}
|
||||
|
||||
echo ""
|
||||
echo -e "${CYAN}⚕ Hermes Agent Setup${NC}"
|
||||
echo ""
|
||||
@@ -42,36 +62,40 @@ echo ""
|
||||
echo -e "${CYAN}→${NC} Checking for uv..."
|
||||
|
||||
UV_CMD=""
|
||||
if command -v uv &> /dev/null; then
|
||||
UV_CMD="uv"
|
||||
elif [ -x "$HOME/.local/bin/uv" ]; then
|
||||
UV_CMD="$HOME/.local/bin/uv"
|
||||
elif [ -x "$HOME/.cargo/bin/uv" ]; then
|
||||
UV_CMD="$HOME/.cargo/bin/uv"
|
||||
fi
|
||||
|
||||
if [ -n "$UV_CMD" ]; then
|
||||
UV_VERSION=$($UV_CMD --version 2>/dev/null)
|
||||
echo -e "${GREEN}✓${NC} uv found ($UV_VERSION)"
|
||||
if is_termux; then
|
||||
echo -e "${CYAN}→${NC} Termux detected — using Python's stdlib venv + pip instead of uv"
|
||||
else
|
||||
echo -e "${CYAN}→${NC} Installing uv..."
|
||||
if curl -LsSf https://astral.sh/uv/install.sh | sh 2>/dev/null; then
|
||||
if [ -x "$HOME/.local/bin/uv" ]; then
|
||||
UV_CMD="$HOME/.local/bin/uv"
|
||||
elif [ -x "$HOME/.cargo/bin/uv" ]; then
|
||||
UV_CMD="$HOME/.cargo/bin/uv"
|
||||
fi
|
||||
|
||||
if [ -n "$UV_CMD" ]; then
|
||||
UV_VERSION=$($UV_CMD --version 2>/dev/null)
|
||||
echo -e "${GREEN}✓${NC} uv installed ($UV_VERSION)"
|
||||
if command -v uv &> /dev/null; then
|
||||
UV_CMD="uv"
|
||||
elif [ -x "$HOME/.local/bin/uv" ]; then
|
||||
UV_CMD="$HOME/.local/bin/uv"
|
||||
elif [ -x "$HOME/.cargo/bin/uv" ]; then
|
||||
UV_CMD="$HOME/.cargo/bin/uv"
|
||||
fi
|
||||
|
||||
if [ -n "$UV_CMD" ]; then
|
||||
UV_VERSION=$($UV_CMD --version 2>/dev/null)
|
||||
echo -e "${GREEN}✓${NC} uv found ($UV_VERSION)"
|
||||
else
|
||||
echo -e "${CYAN}→${NC} Installing uv..."
|
||||
if curl -LsSf https://astral.sh/uv/install.sh | sh 2>/dev/null; then
|
||||
if [ -x "$HOME/.local/bin/uv" ]; then
|
||||
UV_CMD="$HOME/.local/bin/uv"
|
||||
elif [ -x "$HOME/.cargo/bin/uv" ]; then
|
||||
UV_CMD="$HOME/.cargo/bin/uv"
|
||||
fi
|
||||
|
||||
if [ -n "$UV_CMD" ]; then
|
||||
UV_VERSION=$($UV_CMD --version 2>/dev/null)
|
||||
echo -e "${GREEN}✓${NC} uv installed ($UV_VERSION)"
|
||||
else
|
||||
echo -e "${RED}✗${NC} uv installed but not found. Add ~/.local/bin to PATH and retry."
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo -e "${RED}✗${NC} uv installed but not found. Add ~/.local/bin to PATH and retry."
|
||||
echo -e "${RED}✗${NC} Failed to install uv. Visit https://docs.astral.sh/uv/"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo -e "${RED}✗${NC} Failed to install uv. Visit https://docs.astral.sh/uv/"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
@@ -81,16 +105,34 @@ fi
|
||||
|
||||
echo -e "${CYAN}→${NC} Checking Python $PYTHON_VERSION..."
|
||||
|
||||
if $UV_CMD python find "$PYTHON_VERSION" &> /dev/null; then
|
||||
PYTHON_PATH=$($UV_CMD python find "$PYTHON_VERSION")
|
||||
PYTHON_FOUND_VERSION=$($PYTHON_PATH --version 2>/dev/null)
|
||||
echo -e "${GREEN}✓${NC} $PYTHON_FOUND_VERSION found"
|
||||
if is_termux; then
|
||||
if command -v python >/dev/null 2>&1; then
|
||||
PYTHON_PATH="$(command -v python)"
|
||||
if "$PYTHON_PATH" -c 'import sys; raise SystemExit(0 if sys.version_info >= (3, 11) else 1)' 2>/dev/null; then
|
||||
PYTHON_FOUND_VERSION=$($PYTHON_PATH --version 2>/dev/null)
|
||||
echo -e "${GREEN}✓${NC} $PYTHON_FOUND_VERSION found"
|
||||
else
|
||||
echo -e "${RED}✗${NC} Termux Python must be 3.11+"
|
||||
echo " Run: pkg install python"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo -e "${RED}✗${NC} Python not found in Termux"
|
||||
echo " Run: pkg install python"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo -e "${CYAN}→${NC} Python $PYTHON_VERSION not found, installing via uv..."
|
||||
$UV_CMD python install "$PYTHON_VERSION"
|
||||
PYTHON_PATH=$($UV_CMD python find "$PYTHON_VERSION")
|
||||
PYTHON_FOUND_VERSION=$($PYTHON_PATH --version 2>/dev/null)
|
||||
echo -e "${GREEN}✓${NC} $PYTHON_FOUND_VERSION installed"
|
||||
if $UV_CMD python find "$PYTHON_VERSION" &> /dev/null; then
|
||||
PYTHON_PATH=$($UV_CMD python find "$PYTHON_VERSION")
|
||||
PYTHON_FOUND_VERSION=$($PYTHON_PATH --version 2>/dev/null)
|
||||
echo -e "${GREEN}✓${NC} $PYTHON_FOUND_VERSION found"
|
||||
else
|
||||
echo -e "${CYAN}→${NC} Python $PYTHON_VERSION not found, installing via uv..."
|
||||
$UV_CMD python install "$PYTHON_VERSION"
|
||||
PYTHON_PATH=$($UV_CMD python find "$PYTHON_VERSION")
|
||||
PYTHON_FOUND_VERSION=$($PYTHON_PATH --version 2>/dev/null)
|
||||
echo -e "${GREEN}✓${NC} $PYTHON_FOUND_VERSION installed"
|
||||
fi
|
||||
fi
|
||||
|
||||
# ============================================================================
|
||||
@@ -104,11 +146,16 @@ if [ -d "venv" ]; then
|
||||
rm -rf venv
|
||||
fi
|
||||
|
||||
$UV_CMD venv venv --python "$PYTHON_VERSION"
|
||||
echo -e "${GREEN}✓${NC} venv created (Python $PYTHON_VERSION)"
|
||||
if is_termux; then
|
||||
"$PYTHON_PATH" -m venv venv
|
||||
echo -e "${GREEN}✓${NC} venv created with stdlib venv"
|
||||
else
|
||||
$UV_CMD venv venv --python "$PYTHON_VERSION"
|
||||
echo -e "${GREEN}✓${NC} venv created (Python $PYTHON_VERSION)"
|
||||
fi
|
||||
|
||||
# Tell uv to install into this venv (no activation needed for uv)
|
||||
export VIRTUAL_ENV="$SCRIPT_DIR/venv"
|
||||
SETUP_PYTHON="$SCRIPT_DIR/venv/bin/python"
|
||||
|
||||
# ============================================================================
|
||||
# Dependencies
|
||||
@@ -116,19 +163,34 @@ export VIRTUAL_ENV="$SCRIPT_DIR/venv"
|
||||
|
||||
echo -e "${CYAN}→${NC} Installing dependencies..."
|
||||
|
||||
# Prefer uv sync with lockfile (hash-verified installs) when available,
|
||||
# fall back to pip install for compatibility or when lockfile is stale.
|
||||
if [ -f "uv.lock" ]; then
|
||||
echo -e "${CYAN}→${NC} Using uv.lock for hash-verified installation..."
|
||||
UV_PROJECT_ENVIRONMENT="$SCRIPT_DIR/venv" $UV_CMD sync --all-extras --locked 2>/dev/null && \
|
||||
echo -e "${GREEN}✓${NC} Dependencies installed (lockfile verified)" || {
|
||||
echo -e "${YELLOW}⚠${NC} Lockfile install failed (may be outdated), falling back to pip install..."
|
||||
if is_termux; then
|
||||
export ANDROID_API_LEVEL="$(getprop ro.build.version.sdk 2>/dev/null || printf '%s' "${ANDROID_API_LEVEL:-}")"
|
||||
echo -e "${CYAN}→${NC} Termux detected — installing the tested Android bundle"
|
||||
"$SETUP_PYTHON" -m pip install --upgrade pip setuptools wheel
|
||||
if [ -f "constraints-termux.txt" ]; then
|
||||
"$SETUP_PYTHON" -m pip install -e ".[termux]" -c constraints-termux.txt || {
|
||||
echo -e "${YELLOW}⚠${NC} Termux bundle install failed, falling back to base install..."
|
||||
"$SETUP_PYTHON" -m pip install -e "." -c constraints-termux.txt
|
||||
}
|
||||
else
|
||||
"$SETUP_PYTHON" -m pip install -e ".[termux]" || "$SETUP_PYTHON" -m pip install -e "."
|
||||
fi
|
||||
echo -e "${GREEN}✓${NC} Dependencies installed"
|
||||
else
|
||||
# Prefer uv sync with lockfile (hash-verified installs) when available,
|
||||
# fall back to pip install for compatibility or when lockfile is stale.
|
||||
if [ -f "uv.lock" ]; then
|
||||
echo -e "${CYAN}→${NC} Using uv.lock for hash-verified installation..."
|
||||
UV_PROJECT_ENVIRONMENT="$SCRIPT_DIR/venv" $UV_CMD sync --all-extras --locked 2>/dev/null && \
|
||||
echo -e "${GREEN}✓${NC} Dependencies installed (lockfile verified)" || {
|
||||
echo -e "${YELLOW}⚠${NC} Lockfile install failed (may be outdated), falling back to pip install..."
|
||||
$UV_CMD pip install -e ".[all]" || $UV_CMD pip install -e "."
|
||||
echo -e "${GREEN}✓${NC} Dependencies installed"
|
||||
}
|
||||
else
|
||||
$UV_CMD pip install -e ".[all]" || $UV_CMD pip install -e "."
|
||||
echo -e "${GREEN}✓${NC} Dependencies installed"
|
||||
}
|
||||
else
|
||||
$UV_CMD pip install -e ".[all]" || $UV_CMD pip install -e "."
|
||||
echo -e "${GREEN}✓${NC} Dependencies installed"
|
||||
fi
|
||||
fi
|
||||
|
||||
# ============================================================================
|
||||
@@ -138,7 +200,9 @@ fi
|
||||
echo -e "${CYAN}→${NC} Installing optional submodules..."
|
||||
|
||||
# tinker-atropos (RL training backend)
|
||||
if [ -d "tinker-atropos" ] && [ -f "tinker-atropos/pyproject.toml" ]; then
|
||||
if is_termux; then
|
||||
echo -e "${CYAN}→${NC} Skipping tinker-atropos on Termux (not part of the tested Android path)"
|
||||
elif [ -d "tinker-atropos" ] && [ -f "tinker-atropos/pyproject.toml" ]; then
|
||||
$UV_CMD pip install -e "./tinker-atropos" && \
|
||||
echo -e "${GREEN}✓${NC} tinker-atropos installed" || \
|
||||
echo -e "${YELLOW}⚠${NC} tinker-atropos install failed (RL tools may not work)"
|
||||
@@ -160,34 +224,42 @@ else
|
||||
echo
|
||||
if [[ $REPLY =~ ^[Yy]$ ]] || [[ -z $REPLY ]]; then
|
||||
INSTALLED=false
|
||||
|
||||
# Check if sudo is available
|
||||
if command -v sudo &> /dev/null && sudo -n true 2>/dev/null; then
|
||||
if command -v apt &> /dev/null; then
|
||||
sudo apt install -y ripgrep && INSTALLED=true
|
||||
elif command -v dnf &> /dev/null; then
|
||||
sudo dnf install -y ripgrep && INSTALLED=true
|
||||
|
||||
if is_termux; then
|
||||
pkg install -y ripgrep && INSTALLED=true
|
||||
else
|
||||
# Check if sudo is available
|
||||
if command -v sudo &> /dev/null && sudo -n true 2>/dev/null; then
|
||||
if command -v apt &> /dev/null; then
|
||||
sudo apt install -y ripgrep && INSTALLED=true
|
||||
elif command -v dnf &> /dev/null; then
|
||||
sudo dnf install -y ripgrep && INSTALLED=true
|
||||
fi
|
||||
fi
|
||||
|
||||
# Try brew (no sudo needed)
|
||||
if [ "$INSTALLED" = false ] && command -v brew &> /dev/null; then
|
||||
brew install ripgrep && INSTALLED=true
|
||||
fi
|
||||
|
||||
# Try cargo (no sudo needed)
|
||||
if [ "$INSTALLED" = false ] && command -v cargo &> /dev/null; then
|
||||
echo -e "${CYAN}→${NC} Trying cargo install (no sudo required)..."
|
||||
cargo install ripgrep && INSTALLED=true
|
||||
fi
|
||||
fi
|
||||
|
||||
# Try brew (no sudo needed)
|
||||
if [ "$INSTALLED" = false ] && command -v brew &> /dev/null; then
|
||||
brew install ripgrep && INSTALLED=true
|
||||
fi
|
||||
|
||||
# Try cargo (no sudo needed)
|
||||
if [ "$INSTALLED" = false ] && command -v cargo &> /dev/null; then
|
||||
echo -e "${CYAN}→${NC} Trying cargo install (no sudo required)..."
|
||||
cargo install ripgrep && INSTALLED=true
|
||||
fi
|
||||
|
||||
|
||||
if [ "$INSTALLED" = true ]; then
|
||||
echo -e "${GREEN}✓${NC} ripgrep installed"
|
||||
else
|
||||
echo -e "${YELLOW}⚠${NC} Auto-install failed. Install options:"
|
||||
echo " sudo apt install ripgrep # Debian/Ubuntu"
|
||||
echo " brew install ripgrep # macOS"
|
||||
echo " cargo install ripgrep # With Rust (no sudo)"
|
||||
if is_termux; then
|
||||
echo " pkg install ripgrep # Termux / Android"
|
||||
else
|
||||
echo " sudo apt install ripgrep # Debian/Ubuntu"
|
||||
echo " brew install ripgrep # macOS"
|
||||
echo " cargo install ripgrep # With Rust (no sudo)"
|
||||
fi
|
||||
echo " https://github.com/BurntSushi/ripgrep#installation"
|
||||
fi
|
||||
fi
|
||||
@@ -207,49 +279,56 @@ else
|
||||
fi
|
||||
|
||||
# ============================================================================
|
||||
# PATH setup — symlink hermes into ~/.local/bin
|
||||
# PATH setup — symlink hermes into a user-facing bin dir
|
||||
# ============================================================================
|
||||
|
||||
echo -e "${CYAN}→${NC} Setting up hermes command..."
|
||||
|
||||
HERMES_BIN="$SCRIPT_DIR/venv/bin/hermes"
|
||||
mkdir -p "$HOME/.local/bin"
|
||||
ln -sf "$HERMES_BIN" "$HOME/.local/bin/hermes"
|
||||
echo -e "${GREEN}✓${NC} Symlinked hermes → ~/.local/bin/hermes"
|
||||
COMMAND_LINK_DIR="$(get_command_link_dir)"
|
||||
COMMAND_LINK_DISPLAY_DIR="$(get_command_link_display_dir)"
|
||||
mkdir -p "$COMMAND_LINK_DIR"
|
||||
ln -sf "$HERMES_BIN" "$COMMAND_LINK_DIR/hermes"
|
||||
echo -e "${GREEN}✓${NC} Symlinked hermes → $COMMAND_LINK_DISPLAY_DIR/hermes"
|
||||
|
||||
# Determine the appropriate shell config file
|
||||
SHELL_CONFIG=""
|
||||
if [[ "$SHELL" == *"zsh"* ]]; then
|
||||
SHELL_CONFIG="$HOME/.zshrc"
|
||||
elif [[ "$SHELL" == *"bash"* ]]; then
|
||||
SHELL_CONFIG="$HOME/.bashrc"
|
||||
[ ! -f "$SHELL_CONFIG" ] && SHELL_CONFIG="$HOME/.bash_profile"
|
||||
if is_termux; then
|
||||
export PATH="$COMMAND_LINK_DIR:$PATH"
|
||||
echo -e "${GREEN}✓${NC} $COMMAND_LINK_DISPLAY_DIR is already on PATH in Termux"
|
||||
else
|
||||
# Fallback to checking existing files
|
||||
if [ -f "$HOME/.zshrc" ]; then
|
||||
# Determine the appropriate shell config file
|
||||
SHELL_CONFIG=""
|
||||
if [[ "$SHELL" == *"zsh"* ]]; then
|
||||
SHELL_CONFIG="$HOME/.zshrc"
|
||||
elif [ -f "$HOME/.bashrc" ]; then
|
||||
elif [[ "$SHELL" == *"bash"* ]]; then
|
||||
SHELL_CONFIG="$HOME/.bashrc"
|
||||
elif [ -f "$HOME/.bash_profile" ]; then
|
||||
SHELL_CONFIG="$HOME/.bash_profile"
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ -n "$SHELL_CONFIG" ]; then
|
||||
# Touch the file just in case it doesn't exist yet but was selected
|
||||
touch "$SHELL_CONFIG" 2>/dev/null || true
|
||||
|
||||
if ! echo "$PATH" | tr ':' '\n' | grep -q "^$HOME/.local/bin$"; then
|
||||
if ! grep -q '\.local/bin' "$SHELL_CONFIG" 2>/dev/null; then
|
||||
echo "" >> "$SHELL_CONFIG"
|
||||
echo "# Hermes Agent — ensure ~/.local/bin is on PATH" >> "$SHELL_CONFIG"
|
||||
echo 'export PATH="$HOME/.local/bin:$PATH"' >> "$SHELL_CONFIG"
|
||||
echo -e "${GREEN}✓${NC} Added ~/.local/bin to PATH in $SHELL_CONFIG"
|
||||
else
|
||||
echo -e "${GREEN}✓${NC} ~/.local/bin already in $SHELL_CONFIG"
|
||||
fi
|
||||
[ ! -f "$SHELL_CONFIG" ] && SHELL_CONFIG="$HOME/.bash_profile"
|
||||
else
|
||||
echo -e "${GREEN}✓${NC} ~/.local/bin already on PATH"
|
||||
# Fallback to checking existing files
|
||||
if [ -f "$HOME/.zshrc" ]; then
|
||||
SHELL_CONFIG="$HOME/.zshrc"
|
||||
elif [ -f "$HOME/.bashrc" ]; then
|
||||
SHELL_CONFIG="$HOME/.bashrc"
|
||||
elif [ -f "$HOME/.bash_profile" ]; then
|
||||
SHELL_CONFIG="$HOME/.bash_profile"
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ -n "$SHELL_CONFIG" ]; then
|
||||
# Touch the file just in case it doesn't exist yet but was selected
|
||||
touch "$SHELL_CONFIG" 2>/dev/null || true
|
||||
|
||||
if ! echo "$PATH" | tr ':' '\n' | grep -q "^$HOME/.local/bin$"; then
|
||||
if ! grep -q '\.local/bin' "$SHELL_CONFIG" 2>/dev/null; then
|
||||
echo "" >> "$SHELL_CONFIG"
|
||||
echo "# Hermes Agent — ensure ~/.local/bin is on PATH" >> "$SHELL_CONFIG"
|
||||
echo 'export PATH="$HOME/.local/bin:$PATH"' >> "$SHELL_CONFIG"
|
||||
echo -e "${GREEN}✓${NC} Added ~/.local/bin to PATH in $SHELL_CONFIG"
|
||||
else
|
||||
echo -e "${GREEN}✓${NC} ~/.local/bin already in $SHELL_CONFIG"
|
||||
fi
|
||||
else
|
||||
echo -e "${GREEN}✓${NC} ~/.local/bin already on PATH"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
@@ -281,18 +360,31 @@ echo -e "${GREEN}✓ Setup complete!${NC}"
|
||||
echo ""
|
||||
echo "Next steps:"
|
||||
echo ""
|
||||
echo " 1. Reload your shell:"
|
||||
echo " source $SHELL_CONFIG"
|
||||
echo ""
|
||||
echo " 2. Run the setup wizard to configure API keys:"
|
||||
echo " hermes setup"
|
||||
echo ""
|
||||
echo " 3. Start chatting:"
|
||||
echo " hermes"
|
||||
echo ""
|
||||
if is_termux; then
|
||||
echo " 1. Run the setup wizard to configure API keys:"
|
||||
echo " hermes setup"
|
||||
echo ""
|
||||
echo " 2. Start chatting:"
|
||||
echo " hermes"
|
||||
echo ""
|
||||
else
|
||||
echo " 1. Reload your shell:"
|
||||
echo " source $SHELL_CONFIG"
|
||||
echo ""
|
||||
echo " 2. Run the setup wizard to configure API keys:"
|
||||
echo " hermes setup"
|
||||
echo ""
|
||||
echo " 3. Start chatting:"
|
||||
echo " hermes"
|
||||
echo ""
|
||||
fi
|
||||
echo "Other commands:"
|
||||
echo " hermes status # Check configuration"
|
||||
echo " hermes gateway install # Install gateway service (messaging + cron)"
|
||||
if is_termux; then
|
||||
echo " hermes gateway # Run gateway in foreground"
|
||||
else
|
||||
echo " hermes gateway install # Install gateway service (messaging + cron)"
|
||||
fi
|
||||
echo " hermes cron list # View scheduled jobs"
|
||||
echo " hermes doctor # Diagnose issues"
|
||||
echo ""
|
||||
|
||||
@@ -250,7 +250,7 @@ Type these during an interactive chat session.
|
||||
/model [name] Show or change model
|
||||
/provider Show provider info
|
||||
/personality [name] Set personality
|
||||
/reasoning [level] Set reasoning (none|low|medium|high|xhigh|show|hide)
|
||||
/reasoning [level] Set reasoning (none|minimal|low|medium|high|xhigh|show|hide)
|
||||
/verbose Cycle: off → new → all → verbose
|
||||
/voice [on|off|tts] Voice mode
|
||||
/yolo Toggle approval bypass
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
---
|
||||
name: google-workspace
|
||||
description: Gmail, Calendar, Drive, Contacts, Sheets, and Docs integration via Python. Uses OAuth2 with automatic token refresh. No external binaries needed — runs entirely with Google's Python client libraries in the Hermes venv.
|
||||
version: 1.0.0
|
||||
description: Gmail, Calendar, Drive, Contacts, Sheets, and Docs integration via gws CLI (googleworkspace/cli). Uses OAuth2 with automatic token refresh via bridge script. Requires gws binary.
|
||||
version: 2.0.0
|
||||
author: Nous Research
|
||||
license: MIT
|
||||
required_credential_files:
|
||||
@@ -11,14 +11,25 @@ required_credential_files:
|
||||
description: Google OAuth2 client credentials (downloaded from Google Cloud Console)
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Google, Gmail, Calendar, Drive, Sheets, Docs, Contacts, Email, OAuth]
|
||||
tags: [Google, Gmail, Calendar, Drive, Sheets, Docs, Contacts, Email, OAuth, gws]
|
||||
homepage: https://github.com/NousResearch/hermes-agent
|
||||
related_skills: [himalaya]
|
||||
---
|
||||
|
||||
# Google Workspace
|
||||
|
||||
Gmail, Calendar, Drive, Contacts, Sheets, and Docs — all through Python scripts in this skill. No external binaries to install.
|
||||
Gmail, Calendar, Drive, Contacts, Sheets, and Docs — powered by `gws` (Google's official Rust CLI). The skill provides a backward-compatible Python wrapper that handles OAuth token refresh and delegates to `gws`.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
google_api.py → gws_bridge.py → gws CLI
|
||||
(argparse compat) (token refresh) (Google APIs)
|
||||
```
|
||||
|
||||
- `setup.py` handles OAuth2 (headless-compatible, works on CLI/Telegram/Discord)
|
||||
- `gws_bridge.py` refreshes the Hermes token and injects it into `gws` via `GOOGLE_WORKSPACE_CLI_TOKEN`
|
||||
- `google_api.py` provides the same CLI interface as v1 but delegates to `gws`
|
||||
|
||||
## References
|
||||
|
||||
@@ -27,7 +38,22 @@ Gmail, Calendar, Drive, Contacts, Sheets, and Docs — all through Python script
|
||||
## Scripts
|
||||
|
||||
- `scripts/setup.py` — OAuth2 setup (run once to authorize)
|
||||
- `scripts/google_api.py` — API wrapper CLI (agent uses this for all operations)
|
||||
- `scripts/gws_bridge.py` — Token refresh bridge to gws CLI
|
||||
- `scripts/google_api.py` — Backward-compatible API wrapper (delegates to gws)
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Install `gws`:
|
||||
|
||||
```bash
|
||||
cargo install google-workspace-cli
|
||||
# or via npm (recommended, downloads prebuilt binary):
|
||||
npm install -g @googleworkspace/cli
|
||||
# or via Homebrew:
|
||||
brew install googleworkspace-cli
|
||||
```
|
||||
|
||||
Verify: `gws --version`
|
||||
|
||||
## First-Time Setup
|
||||
|
||||
@@ -56,42 +82,29 @@ If it prints `AUTHENTICATED`, skip to Usage — setup is already done.
|
||||
|
||||
### Step 1: Triage — ask the user what they need
|
||||
|
||||
Before starting OAuth setup, ask the user TWO questions:
|
||||
|
||||
**Question 1: "What Google services do you need? Just email, or also
|
||||
Calendar/Drive/Sheets/Docs?"**
|
||||
|
||||
- **Email only** → They don't need this skill at all. Use the `himalaya` skill
|
||||
instead — it works with a Gmail App Password (Settings → Security → App
|
||||
Passwords) and takes 2 minutes to set up. No Google Cloud project needed.
|
||||
Load the himalaya skill and follow its setup instructions.
|
||||
- **Email only** → Use the `himalaya` skill instead — simpler setup.
|
||||
- **Calendar, Drive, Sheets, Docs (or email + these)** → Continue below.
|
||||
|
||||
- **Calendar, Drive, Sheets, Docs (or email + these)** → Continue with this
|
||||
skill's OAuth setup below.
|
||||
**Partial scopes**: Users can authorize only a subset of services. The setup
|
||||
script accepts partial scopes and warns about missing ones.
|
||||
|
||||
**Question 2: "Does your Google account use Advanced Protection (hardware
|
||||
security keys required to sign in)? If you're not sure, you probably don't
|
||||
— it's something you would have explicitly enrolled in."**
|
||||
**Question 2: "Does your Google account use Advanced Protection?"**
|
||||
|
||||
- **No / Not sure** → Normal setup. Continue below.
|
||||
- **Yes** → Their Workspace admin must add the OAuth client ID to the org's
|
||||
allowed apps list before Step 4 will work. Let them know upfront.
|
||||
- **No / Not sure** → Normal setup.
|
||||
- **Yes** → Workspace admin must add the OAuth client ID to allowed apps first.
|
||||
|
||||
### Step 2: Create OAuth credentials (one-time, ~5 minutes)
|
||||
|
||||
Tell the user:
|
||||
|
||||
> You need a Google Cloud OAuth client. This is a one-time setup:
|
||||
>
|
||||
> 1. Go to https://console.cloud.google.com/apis/credentials
|
||||
> 2. Create a project (or use an existing one)
|
||||
> 3. Click "Enable APIs" and enable: Gmail API, Google Calendar API,
|
||||
> Google Drive API, Google Sheets API, Google Docs API, People API
|
||||
> 4. Go to Credentials → Create Credentials → OAuth 2.0 Client ID
|
||||
> 5. Application type: "Desktop app" → Create
|
||||
> 6. Click "Download JSON" and tell me the file path
|
||||
|
||||
Once they provide the path:
|
||||
> 3. Enable the APIs you need (Gmail, Calendar, Drive, Sheets, Docs, People)
|
||||
> 4. Credentials → Create Credentials → OAuth 2.0 Client ID → Desktop app
|
||||
> 5. Download JSON and tell me the file path
|
||||
|
||||
```bash
|
||||
$GSETUP --client-secret /path/to/client_secret.json
|
||||
@@ -103,20 +116,10 @@ $GSETUP --client-secret /path/to/client_secret.json
|
||||
$GSETUP --auth-url
|
||||
```
|
||||
|
||||
This prints a URL. **Send the URL to the user** and tell them:
|
||||
|
||||
> Open this link in your browser, sign in with your Google account, and
|
||||
> authorize access. After authorizing, you'll be redirected to a page that
|
||||
> may show an error — that's expected. Copy the ENTIRE URL from your
|
||||
> browser's address bar and paste it back to me.
|
||||
Send the URL to the user. After authorizing, they paste back the redirect URL or code.
|
||||
|
||||
### Step 4: Exchange the code
|
||||
|
||||
The user will paste back either a URL like `http://localhost:1/?code=4/0A...&scope=...`
|
||||
or just the code string. Either works. The `--auth-url` step stores a temporary
|
||||
pending OAuth session locally so `--auth-code` can complete the PKCE exchange
|
||||
later, even on headless systems:
|
||||
|
||||
```bash
|
||||
$GSETUP --auth-code "THE_URL_OR_CODE_THE_USER_PASTED"
|
||||
```
|
||||
@@ -127,18 +130,11 @@ $GSETUP --auth-code "THE_URL_OR_CODE_THE_USER_PASTED"
|
||||
$GSETUP --check
|
||||
```
|
||||
|
||||
Should print `AUTHENTICATED`. Setup is complete — token refreshes automatically from now on.
|
||||
|
||||
### Notes
|
||||
|
||||
- Token is stored at `google_token.json` under the active profile's `HERMES_HOME` and auto-refreshes.
|
||||
- Pending OAuth session state/verifier are stored temporarily at `google_oauth_pending.json` under the active profile's `HERMES_HOME` until exchange completes.
|
||||
- Hermes now refuses to overwrite a full Google Workspace token with a narrower re-auth token missing Gmail scopes, so one profile's partial consent cannot silently break email actions later.
|
||||
- To revoke: `$GSETUP --revoke`
|
||||
Should print `AUTHENTICATED`. Token refreshes automatically from now on.
|
||||
|
||||
## Usage
|
||||
|
||||
All commands go through the API script. Set `GAPI` as a shorthand:
|
||||
All commands go through the API script:
|
||||
|
||||
```bash
|
||||
HERMES_HOME="${HERMES_HOME:-$HOME/.hermes}"
|
||||
@@ -153,40 +149,21 @@ GAPI="$PYTHON_BIN $GWORKSPACE_SKILL_DIR/scripts/google_api.py"
|
||||
### Gmail
|
||||
|
||||
```bash
|
||||
# Search (returns JSON array with id, from, subject, date, snippet)
|
||||
$GAPI gmail search "is:unread" --max 10
|
||||
$GAPI gmail search "from:boss@company.com newer_than:1d"
|
||||
$GAPI gmail search "has:attachment filename:pdf newer_than:7d"
|
||||
|
||||
# Read full message (returns JSON with body text)
|
||||
$GAPI gmail get MESSAGE_ID
|
||||
|
||||
# Send
|
||||
$GAPI gmail send --to user@example.com --subject "Hello" --body "Message text"
|
||||
$GAPI gmail send --to user@example.com --subject "Report" --body "<h1>Q4</h1><p>Details...</p>" --html
|
||||
|
||||
# Reply (automatically threads and sets In-Reply-To)
|
||||
$GAPI gmail send --to user@example.com --subject "Report" --body "<h1>Q4</h1>" --html
|
||||
$GAPI gmail reply MESSAGE_ID --body "Thanks, that works for me."
|
||||
|
||||
# Labels
|
||||
$GAPI gmail labels
|
||||
$GAPI gmail modify MESSAGE_ID --add-labels LABEL_ID
|
||||
$GAPI gmail modify MESSAGE_ID --remove-labels UNREAD
|
||||
```
|
||||
|
||||
### Calendar
|
||||
|
||||
```bash
|
||||
# List events (defaults to next 7 days)
|
||||
$GAPI calendar list
|
||||
$GAPI calendar list --start 2026-03-01T00:00:00Z --end 2026-03-07T23:59:59Z
|
||||
|
||||
# Create event (ISO 8601 with timezone required)
|
||||
$GAPI calendar create --summary "Team Standup" --start 2026-03-01T10:00:00-06:00 --end 2026-03-01T10:30:00-06:00
|
||||
$GAPI calendar create --summary "Lunch" --start 2026-03-01T12:00:00Z --end 2026-03-01T13:00:00Z --location "Cafe"
|
||||
$GAPI calendar create --summary "Review" --start 2026-03-01T14:00:00Z --end 2026-03-01T15:00:00Z --attendees "alice@co.com,bob@co.com"
|
||||
|
||||
# Delete event
|
||||
$GAPI calendar create --summary "Standup" --start 2026-03-01T10:00:00+01:00 --end 2026-03-01T10:30:00+01:00
|
||||
$GAPI calendar create --summary "Review" --start ... --end ... --attendees "alice@co.com,bob@co.com"
|
||||
$GAPI calendar delete EVENT_ID
|
||||
```
|
||||
|
||||
@@ -206,13 +183,8 @@ $GAPI contacts list --max 20
|
||||
### Sheets
|
||||
|
||||
```bash
|
||||
# Read
|
||||
$GAPI sheets get SHEET_ID "Sheet1!A1:D10"
|
||||
|
||||
# Write
|
||||
$GAPI sheets update SHEET_ID "Sheet1!A1:B2" --values '[["Name","Score"],["Alice","95"]]'
|
||||
|
||||
# Append rows
|
||||
$GAPI sheets append SHEET_ID "Sheet1!A:C" --values '[["new","row","data"]]'
|
||||
```
|
||||
|
||||
@@ -222,37 +194,52 @@ $GAPI sheets append SHEET_ID "Sheet1!A:C" --values '[["new","row","data"]]'
|
||||
$GAPI docs get DOC_ID
|
||||
```
|
||||
|
||||
### Direct gws access (advanced)
|
||||
|
||||
For operations not covered by the wrapper, use `gws_bridge.py` directly:
|
||||
|
||||
```bash
|
||||
GBRIDGE="$PYTHON_BIN $GWORKSPACE_SKILL_DIR/scripts/gws_bridge.py"
|
||||
$GBRIDGE calendar +agenda --today --format table
|
||||
$GBRIDGE gmail +triage --labels --format json
|
||||
$GBRIDGE drive +upload ./report.pdf
|
||||
$GBRIDGE sheets +read --spreadsheet SHEET_ID --range "Sheet1!A1:D10"
|
||||
```
|
||||
|
||||
## Output Format
|
||||
|
||||
All commands return JSON. Parse with `jq` or read directly. Key fields:
|
||||
All commands return JSON via `gws --format json`. Key output shapes:
|
||||
|
||||
- **Gmail search**: `[{id, threadId, from, to, subject, date, snippet, labels}]`
|
||||
- **Gmail get**: `{id, threadId, from, to, subject, date, labels, body}`
|
||||
- **Gmail send/reply**: `{status: "sent", id, threadId}`
|
||||
- **Calendar list**: `[{id, summary, start, end, location, description, htmlLink}]`
|
||||
- **Calendar create**: `{status: "created", id, summary, htmlLink}`
|
||||
- **Drive search**: `[{id, name, mimeType, modifiedTime, webViewLink}]`
|
||||
- **Contacts list**: `[{name, emails: [...], phones: [...]}]`
|
||||
- **Sheets get**: `[[cell, cell, ...], ...]`
|
||||
- **Gmail search/triage**: Array of message summaries (sender, subject, date, snippet)
|
||||
- **Gmail get/read**: Message object with headers and body text
|
||||
- **Gmail send/reply**: Confirmation with message ID
|
||||
- **Calendar list/agenda**: Array of event objects (summary, start, end, location)
|
||||
- **Calendar create**: Confirmation with event ID and htmlLink
|
||||
- **Drive search**: Array of file objects (id, name, mimeType, webViewLink)
|
||||
- **Sheets get/read**: 2D array of cell values
|
||||
- **Docs get**: Full document JSON (use `body.content` for text extraction)
|
||||
- **Contacts list**: Array of person objects with names, emails, phones
|
||||
|
||||
Parse output with `jq` or read JSON directly.
|
||||
|
||||
## Rules
|
||||
|
||||
1. **Never send email or create/delete events without confirming with the user first.** Show the draft content and ask for approval.
|
||||
2. **Check auth before first use** — run `setup.py --check`. If it fails, guide the user through setup.
|
||||
3. **Use the Gmail search syntax reference** for complex queries — load it with `skill_view("google-workspace", file_path="references/gmail-search-syntax.md")`.
|
||||
4. **Calendar times must include timezone** — always use ISO 8601 with offset (e.g., `2026-03-01T10:00:00-06:00`) or UTC (`Z`).
|
||||
5. **Respect rate limits** — avoid rapid-fire sequential API calls. Batch reads when possible.
|
||||
1. **Never send email or create/delete events without confirming with the user first.**
|
||||
2. **Check auth before first use** — run `setup.py --check`.
|
||||
3. **Use the Gmail search syntax reference** for complex queries.
|
||||
4. **Calendar times must include timezone** — ISO 8601 with offset or UTC.
|
||||
5. **Respect rate limits** — avoid rapid-fire sequential API calls.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
| Problem | Fix |
|
||||
|---------|-----|
|
||||
| `NOT_AUTHENTICATED` | Run setup Steps 2-5 above |
|
||||
| `REFRESH_FAILED` | Token revoked or expired — redo Steps 3-5 |
|
||||
| `HttpError 403: Insufficient Permission` | Missing API scope — `$GSETUP --revoke` then redo Steps 3-5 |
|
||||
| `HttpError 403: Access Not Configured` | API not enabled — user needs to enable it in Google Cloud Console |
|
||||
| `ModuleNotFoundError` | Run `$GSETUP --install-deps` |
|
||||
| Advanced Protection blocks auth | Workspace admin must allowlist the OAuth client ID |
|
||||
| `NOT_AUTHENTICATED` | Run setup Steps 2-5 |
|
||||
| `REFRESH_FAILED` | Token revoked — redo Steps 3-5 |
|
||||
| `gws: command not found` | Install: `npm install -g @googleworkspace/cli` |
|
||||
| `HttpError 403` | Missing scope — `$GSETUP --revoke` then redo Steps 3-5 |
|
||||
| `HttpError 403: Access Not Configured` | Enable API in Google Cloud Console |
|
||||
| Advanced Protection blocks auth | Admin must allowlist the OAuth client ID |
|
||||
|
||||
## Revoking Access
|
||||
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Google Workspace API CLI for Hermes Agent.
|
||||
|
||||
A thin CLI wrapper around Google's Python client libraries.
|
||||
Authenticates using the token stored by setup.py.
|
||||
Thin wrapper that delegates to gws (googleworkspace/cli) via gws_bridge.py.
|
||||
Maintains the same CLI interface for backward compatibility with Hermes skills.
|
||||
|
||||
Usage:
|
||||
python google_api.py gmail search "is:unread" [--max 10]
|
||||
python google_api.py gmail get MESSAGE_ID
|
||||
python google_api.py gmail send --to user@example.com --subject "Hi" --body "Hello"
|
||||
python google_api.py gmail reply MESSAGE_ID --body "Thanks"
|
||||
python google_api.py calendar list [--from DATE] [--to DATE] [--calendar primary]
|
||||
python google_api.py calendar list [--start DATE] [--end DATE] [--calendar primary]
|
||||
python google_api.py calendar create --summary "Meeting" --start DATETIME --end DATETIME
|
||||
python google_api.py calendar delete EVENT_ID
|
||||
python google_api.py drive search "budget report" [--max 10]
|
||||
python google_api.py contacts list [--max 20]
|
||||
python google_api.py sheets get SHEET_ID RANGE
|
||||
@@ -20,386 +21,193 @@ Usage:
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from email.mime.text import MIMEText
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
from hermes_constants import display_hermes_home, get_hermes_home
|
||||
except ModuleNotFoundError:
|
||||
HERMES_AGENT_ROOT = Path(__file__).resolve().parents[4]
|
||||
if HERMES_AGENT_ROOT.exists():
|
||||
sys.path.insert(0, str(HERMES_AGENT_ROOT))
|
||||
from hermes_constants import display_hermes_home, get_hermes_home
|
||||
|
||||
HERMES_HOME = get_hermes_home()
|
||||
TOKEN_PATH = HERMES_HOME / "google_token.json"
|
||||
|
||||
SCOPES = [
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
"https://www.googleapis.com/auth/gmail.send",
|
||||
"https://www.googleapis.com/auth/gmail.modify",
|
||||
"https://www.googleapis.com/auth/calendar",
|
||||
"https://www.googleapis.com/auth/drive.readonly",
|
||||
"https://www.googleapis.com/auth/contacts.readonly",
|
||||
"https://www.googleapis.com/auth/spreadsheets",
|
||||
"https://www.googleapis.com/auth/documents.readonly",
|
||||
]
|
||||
BRIDGE = Path(__file__).parent / "gws_bridge.py"
|
||||
PYTHON = sys.executable
|
||||
|
||||
|
||||
def _missing_scopes() -> list[str]:
|
||||
try:
|
||||
payload = json.loads(TOKEN_PATH.read_text())
|
||||
except Exception:
|
||||
return []
|
||||
raw = payload.get("scopes") or payload.get("scope")
|
||||
if not raw:
|
||||
return []
|
||||
granted = {s.strip() for s in (raw.split() if isinstance(raw, str) else raw) if s.strip()}
|
||||
return sorted(scope for scope in SCOPES if scope not in granted)
|
||||
def gws(*args: str) -> None:
|
||||
"""Call gws via the bridge and exit with its return code."""
|
||||
result = subprocess.run(
|
||||
[PYTHON, str(BRIDGE)] + list(args),
|
||||
env={**os.environ, "HERMES_HOME": os.environ.get("HERMES_HOME", str(Path.home() / ".hermes"))},
|
||||
)
|
||||
sys.exit(result.returncode)
|
||||
|
||||
|
||||
def get_credentials():
|
||||
"""Load and refresh credentials from token file."""
|
||||
if not TOKEN_PATH.exists():
|
||||
print("Not authenticated. Run the setup script first:", file=sys.stderr)
|
||||
print(f" python {Path(__file__).parent / 'setup.py'}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from google.auth.transport.requests import Request
|
||||
|
||||
creds = Credentials.from_authorized_user_file(str(TOKEN_PATH), SCOPES)
|
||||
if creds.expired and creds.refresh_token:
|
||||
creds.refresh(Request())
|
||||
TOKEN_PATH.write_text(creds.to_json())
|
||||
if not creds.valid:
|
||||
print("Token is invalid. Re-run setup.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
missing_scopes = _missing_scopes()
|
||||
if missing_scopes:
|
||||
print(
|
||||
"Token is valid but missing Google Workspace scopes required by this skill.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
for scope in missing_scopes:
|
||||
print(f" - {scope}", file=sys.stderr)
|
||||
print(
|
||||
f"Re-run setup.py from the active Hermes profile ({display_hermes_home()}) to restore full access.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
return creds
|
||||
|
||||
|
||||
def build_service(api, version):
|
||||
from googleapiclient.discovery import build
|
||||
return build(api, version, credentials=get_credentials())
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Gmail
|
||||
# =========================================================================
|
||||
# -- Gmail --
|
||||
|
||||
def gmail_search(args):
|
||||
service = build_service("gmail", "v1")
|
||||
results = service.users().messages().list(
|
||||
userId="me", q=args.query, maxResults=args.max
|
||||
).execute()
|
||||
messages = results.get("messages", [])
|
||||
if not messages:
|
||||
print("No messages found.")
|
||||
return
|
||||
|
||||
output = []
|
||||
for msg_meta in messages:
|
||||
msg = service.users().messages().get(
|
||||
userId="me", id=msg_meta["id"], format="metadata",
|
||||
metadataHeaders=["From", "To", "Subject", "Date"],
|
||||
).execute()
|
||||
headers = {h["name"]: h["value"] for h in msg.get("payload", {}).get("headers", [])}
|
||||
output.append({
|
||||
"id": msg["id"],
|
||||
"threadId": msg["threadId"],
|
||||
"from": headers.get("From", ""),
|
||||
"to": headers.get("To", ""),
|
||||
"subject": headers.get("Subject", ""),
|
||||
"date": headers.get("Date", ""),
|
||||
"snippet": msg.get("snippet", ""),
|
||||
"labels": msg.get("labelIds", []),
|
||||
})
|
||||
print(json.dumps(output, indent=2, ensure_ascii=False))
|
||||
|
||||
cmd = ["gmail", "+triage", "--query", args.query, "--max", str(args.max), "--format", "json"]
|
||||
gws(*cmd)
|
||||
|
||||
def gmail_get(args):
|
||||
service = build_service("gmail", "v1")
|
||||
msg = service.users().messages().get(
|
||||
userId="me", id=args.message_id, format="full"
|
||||
).execute()
|
||||
|
||||
headers = {h["name"]: h["value"] for h in msg.get("payload", {}).get("headers", [])}
|
||||
|
||||
# Extract body text
|
||||
body = ""
|
||||
payload = msg.get("payload", {})
|
||||
if payload.get("body", {}).get("data"):
|
||||
body = base64.urlsafe_b64decode(payload["body"]["data"]).decode("utf-8", errors="replace")
|
||||
elif payload.get("parts"):
|
||||
for part in payload["parts"]:
|
||||
if part.get("mimeType") == "text/plain" and part.get("body", {}).get("data"):
|
||||
body = base64.urlsafe_b64decode(part["body"]["data"]).decode("utf-8", errors="replace")
|
||||
break
|
||||
if not body:
|
||||
for part in payload["parts"]:
|
||||
if part.get("mimeType") == "text/html" and part.get("body", {}).get("data"):
|
||||
body = base64.urlsafe_b64decode(part["body"]["data"]).decode("utf-8", errors="replace")
|
||||
break
|
||||
|
||||
result = {
|
||||
"id": msg["id"],
|
||||
"threadId": msg["threadId"],
|
||||
"from": headers.get("From", ""),
|
||||
"to": headers.get("To", ""),
|
||||
"subject": headers.get("Subject", ""),
|
||||
"date": headers.get("Date", ""),
|
||||
"labels": msg.get("labelIds", []),
|
||||
"body": body,
|
||||
}
|
||||
print(json.dumps(result, indent=2, ensure_ascii=False))
|
||||
|
||||
gws("gmail", "+read", "--id", args.message_id, "--headers", "--format", "json")
|
||||
|
||||
def gmail_send(args):
|
||||
service = build_service("gmail", "v1")
|
||||
message = MIMEText(args.body, "html" if args.html else "plain")
|
||||
message["to"] = args.to
|
||||
message["subject"] = args.subject
|
||||
cmd = ["gmail", "+send", "--to", args.to, "--subject", args.subject, "--body", args.body, "--format", "json"]
|
||||
if args.cc:
|
||||
message["cc"] = args.cc
|
||||
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
body = {"raw": raw}
|
||||
|
||||
if args.thread_id:
|
||||
body["threadId"] = args.thread_id
|
||||
|
||||
result = service.users().messages().send(userId="me", body=body).execute()
|
||||
print(json.dumps({"status": "sent", "id": result["id"], "threadId": result.get("threadId", "")}, indent=2))
|
||||
|
||||
cmd += ["--cc", args.cc]
|
||||
if args.html:
|
||||
cmd.append("--html")
|
||||
gws(*cmd)
|
||||
|
||||
def gmail_reply(args):
|
||||
service = build_service("gmail", "v1")
|
||||
# Fetch original to get thread ID and headers
|
||||
original = service.users().messages().get(
|
||||
userId="me", id=args.message_id, format="metadata",
|
||||
metadataHeaders=["From", "Subject", "Message-ID"],
|
||||
).execute()
|
||||
headers = {h["name"]: h["value"] for h in original.get("payload", {}).get("headers", [])}
|
||||
|
||||
subject = headers.get("Subject", "")
|
||||
if not subject.startswith("Re:"):
|
||||
subject = f"Re: {subject}"
|
||||
|
||||
message = MIMEText(args.body)
|
||||
message["to"] = headers.get("From", "")
|
||||
message["subject"] = subject
|
||||
if headers.get("Message-ID"):
|
||||
message["In-Reply-To"] = headers["Message-ID"]
|
||||
message["References"] = headers["Message-ID"]
|
||||
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
body = {"raw": raw, "threadId": original["threadId"]}
|
||||
|
||||
result = service.users().messages().send(userId="me", body=body).execute()
|
||||
print(json.dumps({"status": "sent", "id": result["id"], "threadId": result.get("threadId", "")}, indent=2))
|
||||
|
||||
gws("gmail", "+reply", "--message-id", args.message_id, "--body", args.body, "--format", "json")
|
||||
|
||||
def gmail_labels(args):
|
||||
service = build_service("gmail", "v1")
|
||||
results = service.users().labels().list(userId="me").execute()
|
||||
labels = [{"id": l["id"], "name": l["name"], "type": l.get("type", "")} for l in results.get("labels", [])]
|
||||
print(json.dumps(labels, indent=2))
|
||||
|
||||
gws("gmail", "users", "labels", "list", "--params", json.dumps({"userId": "me"}), "--format", "json")
|
||||
|
||||
def gmail_modify(args):
|
||||
service = build_service("gmail", "v1")
|
||||
body = {}
|
||||
if args.add_labels:
|
||||
body["addLabelIds"] = args.add_labels.split(",")
|
||||
if args.remove_labels:
|
||||
body["removeLabelIds"] = args.remove_labels.split(",")
|
||||
result = service.users().messages().modify(userId="me", id=args.message_id, body=body).execute()
|
||||
print(json.dumps({"id": result["id"], "labels": result.get("labelIds", [])}, indent=2))
|
||||
gws(
|
||||
"gmail", "users", "messages", "modify",
|
||||
"--params", json.dumps({"userId": "me", "id": args.message_id}),
|
||||
"--json", json.dumps(body),
|
||||
"--format", "json",
|
||||
)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Calendar
|
||||
# =========================================================================
|
||||
# -- Calendar --
|
||||
|
||||
def calendar_list(args):
|
||||
service = build_service("calendar", "v3")
|
||||
now = datetime.now(timezone.utc)
|
||||
time_min = args.start or now.isoformat()
|
||||
time_max = args.end or (now + timedelta(days=7)).isoformat()
|
||||
|
||||
# Ensure timezone info
|
||||
for val in [time_min, time_max]:
|
||||
if "T" in val and "Z" not in val and "+" not in val and "-" not in val[11:]:
|
||||
val += "Z"
|
||||
|
||||
results = service.events().list(
|
||||
calendarId=args.calendar, timeMin=time_min, timeMax=time_max,
|
||||
maxResults=args.max, singleEvents=True, orderBy="startTime",
|
||||
).execute()
|
||||
|
||||
events = []
|
||||
for e in results.get("items", []):
|
||||
events.append({
|
||||
"id": e["id"],
|
||||
"summary": e.get("summary", "(no title)"),
|
||||
"start": e.get("start", {}).get("dateTime", e.get("start", {}).get("date", "")),
|
||||
"end": e.get("end", {}).get("dateTime", e.get("end", {}).get("date", "")),
|
||||
"location": e.get("location", ""),
|
||||
"description": e.get("description", ""),
|
||||
"status": e.get("status", ""),
|
||||
"htmlLink": e.get("htmlLink", ""),
|
||||
})
|
||||
print(json.dumps(events, indent=2, ensure_ascii=False))
|
||||
|
||||
if args.start or args.end:
|
||||
# Specific date range — use raw Calendar API for precise timeMin/timeMax
|
||||
from datetime import datetime, timedelta, timezone as tz
|
||||
now = datetime.now(tz.utc)
|
||||
time_min = args.start or now.isoformat()
|
||||
time_max = args.end or (now + timedelta(days=7)).isoformat()
|
||||
gws(
|
||||
"calendar", "events", "list",
|
||||
"--params", json.dumps({
|
||||
"calendarId": args.calendar,
|
||||
"timeMin": time_min,
|
||||
"timeMax": time_max,
|
||||
"maxResults": args.max,
|
||||
"singleEvents": True,
|
||||
"orderBy": "startTime",
|
||||
}),
|
||||
"--format", "json",
|
||||
)
|
||||
else:
|
||||
# No date range — use +agenda helper (defaults to 7 days)
|
||||
cmd = ["calendar", "+agenda", "--days", "7", "--format", "json"]
|
||||
if args.calendar != "primary":
|
||||
cmd += ["--calendar", args.calendar]
|
||||
gws(*cmd)
|
||||
|
||||
def calendar_create(args):
|
||||
service = build_service("calendar", "v3")
|
||||
event = {
|
||||
"summary": args.summary,
|
||||
"start": {"dateTime": args.start},
|
||||
"end": {"dateTime": args.end},
|
||||
}
|
||||
cmd = [
|
||||
"calendar", "+insert",
|
||||
"--summary", args.summary,
|
||||
"--start", args.start,
|
||||
"--end", args.end,
|
||||
"--format", "json",
|
||||
]
|
||||
if args.location:
|
||||
event["location"] = args.location
|
||||
cmd += ["--location", args.location]
|
||||
if args.description:
|
||||
event["description"] = args.description
|
||||
cmd += ["--description", args.description]
|
||||
if args.attendees:
|
||||
event["attendees"] = [{"email": e.strip()} for e in args.attendees.split(",")]
|
||||
|
||||
result = service.events().insert(calendarId=args.calendar, body=event).execute()
|
||||
print(json.dumps({
|
||||
"status": "created",
|
||||
"id": result["id"],
|
||||
"summary": result.get("summary", ""),
|
||||
"htmlLink": result.get("htmlLink", ""),
|
||||
}, indent=2))
|
||||
|
||||
for email in args.attendees.split(","):
|
||||
cmd += ["--attendee", email.strip()]
|
||||
if args.calendar != "primary":
|
||||
cmd += ["--calendar", args.calendar]
|
||||
gws(*cmd)
|
||||
|
||||
def calendar_delete(args):
|
||||
service = build_service("calendar", "v3")
|
||||
service.events().delete(calendarId=args.calendar, eventId=args.event_id).execute()
|
||||
print(json.dumps({"status": "deleted", "eventId": args.event_id}))
|
||||
gws(
|
||||
"calendar", "events", "delete",
|
||||
"--params", json.dumps({"calendarId": args.calendar, "eventId": args.event_id}),
|
||||
"--format", "json",
|
||||
)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Drive
|
||||
# =========================================================================
|
||||
# -- Drive --
|
||||
|
||||
def drive_search(args):
|
||||
service = build_service("drive", "v3")
|
||||
query = f"fullText contains '{args.query}'" if not args.raw_query else args.query
|
||||
results = service.files().list(
|
||||
q=query, pageSize=args.max, fields="files(id, name, mimeType, modifiedTime, webViewLink)",
|
||||
).execute()
|
||||
files = results.get("files", [])
|
||||
print(json.dumps(files, indent=2, ensure_ascii=False))
|
||||
query = args.query if args.raw_query else f"fullText contains '{args.query}'"
|
||||
gws(
|
||||
"drive", "files", "list",
|
||||
"--params", json.dumps({
|
||||
"q": query,
|
||||
"pageSize": args.max,
|
||||
"fields": "files(id,name,mimeType,modifiedTime,webViewLink)",
|
||||
}),
|
||||
"--format", "json",
|
||||
)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Contacts
|
||||
# =========================================================================
|
||||
# -- Contacts --
|
||||
|
||||
def contacts_list(args):
|
||||
service = build_service("people", "v1")
|
||||
results = service.people().connections().list(
|
||||
resourceName="people/me",
|
||||
pageSize=args.max,
|
||||
personFields="names,emailAddresses,phoneNumbers",
|
||||
).execute()
|
||||
contacts = []
|
||||
for person in results.get("connections", []):
|
||||
names = person.get("names", [{}])
|
||||
emails = person.get("emailAddresses", [])
|
||||
phones = person.get("phoneNumbers", [])
|
||||
contacts.append({
|
||||
"name": names[0].get("displayName", "") if names else "",
|
||||
"emails": [e.get("value", "") for e in emails],
|
||||
"phones": [p.get("value", "") for p in phones],
|
||||
})
|
||||
print(json.dumps(contacts, indent=2, ensure_ascii=False))
|
||||
gws(
|
||||
"people", "people", "connections", "list",
|
||||
"--params", json.dumps({
|
||||
"resourceName": "people/me",
|
||||
"pageSize": args.max,
|
||||
"personFields": "names,emailAddresses,phoneNumbers",
|
||||
}),
|
||||
"--format", "json",
|
||||
)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Sheets
|
||||
# =========================================================================
|
||||
# -- Sheets --
|
||||
|
||||
def sheets_get(args):
|
||||
service = build_service("sheets", "v4")
|
||||
result = service.spreadsheets().values().get(
|
||||
spreadsheetId=args.sheet_id, range=args.range,
|
||||
).execute()
|
||||
print(json.dumps(result.get("values", []), indent=2, ensure_ascii=False))
|
||||
|
||||
gws(
|
||||
"sheets", "+read",
|
||||
"--spreadsheet", args.sheet_id,
|
||||
"--range", args.range,
|
||||
"--format", "json",
|
||||
)
|
||||
|
||||
def sheets_update(args):
|
||||
service = build_service("sheets", "v4")
|
||||
values = json.loads(args.values)
|
||||
body = {"values": values}
|
||||
result = service.spreadsheets().values().update(
|
||||
spreadsheetId=args.sheet_id, range=args.range,
|
||||
valueInputOption="USER_ENTERED", body=body,
|
||||
).execute()
|
||||
print(json.dumps({"updatedCells": result.get("updatedCells", 0), "updatedRange": result.get("updatedRange", "")}, indent=2))
|
||||
|
||||
gws(
|
||||
"sheets", "spreadsheets", "values", "update",
|
||||
"--params", json.dumps({
|
||||
"spreadsheetId": args.sheet_id,
|
||||
"range": args.range,
|
||||
"valueInputOption": "USER_ENTERED",
|
||||
}),
|
||||
"--json", json.dumps({"values": values}),
|
||||
"--format", "json",
|
||||
)
|
||||
|
||||
def sheets_append(args):
|
||||
service = build_service("sheets", "v4")
|
||||
values = json.loads(args.values)
|
||||
body = {"values": values}
|
||||
result = service.spreadsheets().values().append(
|
||||
spreadsheetId=args.sheet_id, range=args.range,
|
||||
valueInputOption="USER_ENTERED", insertDataOption="INSERT_ROWS", body=body,
|
||||
).execute()
|
||||
print(json.dumps({"updatedCells": result.get("updates", {}).get("updatedCells", 0)}, indent=2))
|
||||
gws(
|
||||
"sheets", "+append",
|
||||
"--spreadsheet", args.sheet_id,
|
||||
"--json-values", json.dumps(values),
|
||||
"--format", "json",
|
||||
)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Docs
|
||||
# =========================================================================
|
||||
# -- Docs --
|
||||
|
||||
def docs_get(args):
|
||||
service = build_service("docs", "v1")
|
||||
doc = service.documents().get(documentId=args.doc_id).execute()
|
||||
# Extract plain text from the document structure
|
||||
text_parts = []
|
||||
for element in doc.get("body", {}).get("content", []):
|
||||
paragraph = element.get("paragraph", {})
|
||||
for pe in paragraph.get("elements", []):
|
||||
text_run = pe.get("textRun", {})
|
||||
if text_run.get("content"):
|
||||
text_parts.append(text_run["content"])
|
||||
result = {
|
||||
"title": doc.get("title", ""),
|
||||
"documentId": doc.get("documentId", ""),
|
||||
"body": "".join(text_parts),
|
||||
}
|
||||
print(json.dumps(result, indent=2, ensure_ascii=False))
|
||||
gws(
|
||||
"docs", "documents", "get",
|
||||
"--params", json.dumps({"documentId": args.doc_id}),
|
||||
"--format", "json",
|
||||
)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# CLI parser
|
||||
# =========================================================================
|
||||
# -- CLI parser (backward-compatible interface) --
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Google Workspace API for Hermes Agent")
|
||||
parser = argparse.ArgumentParser(description="Google Workspace API for Hermes Agent (gws backend)")
|
||||
sub = parser.add_subparsers(dest="service", required=True)
|
||||
|
||||
# --- Gmail ---
|
||||
@@ -421,7 +229,7 @@ def main():
|
||||
p.add_argument("--body", required=True)
|
||||
p.add_argument("--cc", default="")
|
||||
p.add_argument("--html", action="store_true", help="Send body as HTML")
|
||||
p.add_argument("--thread-id", default="", help="Thread ID for threading")
|
||||
p.add_argument("--thread-id", default="", help="Thread ID (unused with gws, kept for compat)")
|
||||
p.set_defaults(func=gmail_send)
|
||||
|
||||
p = gmail_sub.add_parser("reply")
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Bridge between Hermes OAuth token and gws CLI.
|
||||
|
||||
Refreshes the token if expired, then executes gws with the valid access token.
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def get_hermes_home() -> Path:
|
||||
return Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
|
||||
|
||||
|
||||
def get_token_path() -> Path:
|
||||
return get_hermes_home() / "google_token.json"
|
||||
|
||||
|
||||
def refresh_token(token_data: dict) -> dict:
|
||||
"""Refresh the access token using the refresh token."""
|
||||
import urllib.error
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
|
||||
params = urllib.parse.urlencode({
|
||||
"client_id": token_data["client_id"],
|
||||
"client_secret": token_data["client_secret"],
|
||||
"refresh_token": token_data["refresh_token"],
|
||||
"grant_type": "refresh_token",
|
||||
}).encode()
|
||||
|
||||
req = urllib.request.Request(token_data["token_uri"], data=params)
|
||||
try:
|
||||
with urllib.request.urlopen(req) as resp:
|
||||
result = json.loads(resp.read())
|
||||
except urllib.error.HTTPError as e:
|
||||
body = e.read().decode("utf-8", errors="replace")
|
||||
print(f"ERROR: Token refresh failed (HTTP {e.code}): {body}", file=sys.stderr)
|
||||
print("Re-run setup.py to re-authenticate.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
token_data["token"] = result["access_token"]
|
||||
token_data["expiry"] = datetime.fromtimestamp(
|
||||
datetime.now(timezone.utc).timestamp() + result["expires_in"],
|
||||
tz=timezone.utc,
|
||||
).isoformat()
|
||||
|
||||
get_token_path().write_text(json.dumps(token_data, indent=2))
|
||||
return token_data
|
||||
|
||||
|
||||
def get_valid_token() -> str:
|
||||
"""Return a valid access token, refreshing if needed."""
|
||||
token_path = get_token_path()
|
||||
if not token_path.exists():
|
||||
print("ERROR: No Google token found. Run setup.py --auth-url first.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
token_data = json.loads(token_path.read_text())
|
||||
|
||||
expiry = token_data.get("expiry", "")
|
||||
if expiry:
|
||||
exp_dt = datetime.fromisoformat(expiry.replace("Z", "+00:00"))
|
||||
now = datetime.now(timezone.utc)
|
||||
if now >= exp_dt:
|
||||
token_data = refresh_token(token_data)
|
||||
|
||||
return token_data["token"]
|
||||
|
||||
|
||||
def main():
|
||||
"""Refresh token if needed, then exec gws with remaining args."""
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: gws_bridge.py <gws args...>", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
access_token = get_valid_token()
|
||||
env = os.environ.copy()
|
||||
env["GOOGLE_WORKSPACE_CLI_TOKEN"] = access_token
|
||||
|
||||
result = subprocess.run(["gws"] + sys.argv[1:], env=env)
|
||||
sys.exit(result.returncode)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -23,6 +23,7 @@ Agent workflow:
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
@@ -128,7 +129,11 @@ def check_auth():
|
||||
from google.auth.transport.requests import Request
|
||||
|
||||
try:
|
||||
creds = Credentials.from_authorized_user_file(str(TOKEN_PATH), SCOPES)
|
||||
# Don't pass scopes — user may have authorized only a subset.
|
||||
# Passing scopes forces google-auth to validate them on refresh,
|
||||
# which fails with invalid_scope if the token has fewer scopes
|
||||
# than requested.
|
||||
creds = Credentials.from_authorized_user_file(str(TOKEN_PATH))
|
||||
except Exception as e:
|
||||
print(f"TOKEN_CORRUPT: {e}")
|
||||
return False
|
||||
@@ -137,8 +142,9 @@ def check_auth():
|
||||
if creds.valid:
|
||||
missing_scopes = _missing_scopes_from_payload(payload)
|
||||
if missing_scopes:
|
||||
print(f"AUTH_SCOPE_MISMATCH: {_format_missing_scopes(missing_scopes)}")
|
||||
return False
|
||||
print(f"AUTHENTICATED (partial): Token valid but missing {len(missing_scopes)} scopes:")
|
||||
for s in missing_scopes:
|
||||
print(f" - {s}")
|
||||
print(f"AUTHENTICATED: Token valid at {TOKEN_PATH}")
|
||||
return True
|
||||
|
||||
@@ -148,8 +154,9 @@ def check_auth():
|
||||
TOKEN_PATH.write_text(creds.to_json())
|
||||
missing_scopes = _missing_scopes_from_payload(_load_token_payload(TOKEN_PATH))
|
||||
if missing_scopes:
|
||||
print(f"AUTH_SCOPE_MISMATCH: {_format_missing_scopes(missing_scopes)}")
|
||||
return False
|
||||
print(f"AUTHENTICATED (partial): Token refreshed but missing {len(missing_scopes)} scopes:")
|
||||
for s in missing_scopes:
|
||||
print(f" - {s}")
|
||||
print(f"AUTHENTICATED: Token refreshed at {TOKEN_PATH}")
|
||||
return True
|
||||
except Exception as e:
|
||||
@@ -272,16 +279,33 @@ def exchange_auth_code(code: str):
|
||||
|
||||
_ensure_deps()
|
||||
from google_auth_oauthlib.flow import Flow
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
# Extract granted scopes from the callback URL if present
|
||||
if returned_state and "scope" in parse_qs(urlparse(code).query if isinstance(code, str) and code.startswith("http") else {}):
|
||||
granted_scopes = parse_qs(urlparse(code).query)["scope"][0].split()
|
||||
else:
|
||||
# Try to extract from code_or_url parameter
|
||||
if isinstance(code, str) and code.startswith("http"):
|
||||
params = parse_qs(urlparse(code).query)
|
||||
if "scope" in params:
|
||||
granted_scopes = params["scope"][0].split()
|
||||
else:
|
||||
granted_scopes = SCOPES
|
||||
else:
|
||||
granted_scopes = SCOPES
|
||||
|
||||
flow = Flow.from_client_secrets_file(
|
||||
str(CLIENT_SECRET_PATH),
|
||||
scopes=SCOPES,
|
||||
scopes=granted_scopes,
|
||||
redirect_uri=pending_auth.get("redirect_uri", REDIRECT_URI),
|
||||
state=pending_auth["state"],
|
||||
code_verifier=pending_auth["code_verifier"],
|
||||
)
|
||||
|
||||
try:
|
||||
# Accept partial scopes — user may deselect some permissions in the consent screen
|
||||
os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1"
|
||||
flow.fetch_token(code=code)
|
||||
except Exception as e:
|
||||
print(f"ERROR: Token exchange failed: {e}")
|
||||
@@ -290,11 +314,21 @@ def exchange_auth_code(code: str):
|
||||
|
||||
creds = flow.credentials
|
||||
token_payload = json.loads(creds.to_json())
|
||||
|
||||
# Store only the scopes actually granted by the user, not what was requested.
|
||||
# creds.to_json() writes the requested scopes, which causes refresh to fail
|
||||
# with invalid_scope if the user only authorized a subset.
|
||||
actually_granted = list(creds.granted_scopes or []) if hasattr(creds, "granted_scopes") and creds.granted_scopes else []
|
||||
if actually_granted:
|
||||
token_payload["scopes"] = actually_granted
|
||||
elif granted_scopes != SCOPES:
|
||||
# granted_scopes was extracted from the callback URL
|
||||
token_payload["scopes"] = granted_scopes
|
||||
|
||||
missing_scopes = _missing_scopes_from_payload(token_payload)
|
||||
if missing_scopes:
|
||||
print(f"ERROR: Refusing to save incomplete Google Workspace token. {_format_missing_scopes(missing_scopes)}")
|
||||
print(f"Existing token at {TOKEN_PATH} was left unchanged.")
|
||||
sys.exit(1)
|
||||
print(f"WARNING: Token missing some Google Workspace scopes: {', '.join(missing_scopes)}")
|
||||
print("Some services may not be available.")
|
||||
|
||||
TOKEN_PATH.write_text(json.dumps(token_payload, indent=2))
|
||||
PENDING_AUTH_PATH.unlink(missing_ok=True)
|
||||
|
||||
@@ -68,9 +68,22 @@ class TestInitialize:
|
||||
resp = await agent.initialize(protocol_version=1)
|
||||
caps = resp.agent_capabilities
|
||||
assert isinstance(caps, AgentCapabilities)
|
||||
assert caps.load_session is True
|
||||
assert caps.session_capabilities is not None
|
||||
assert caps.session_capabilities.fork is not None
|
||||
assert caps.session_capabilities.list is not None
|
||||
assert caps.session_capabilities.resume is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_capabilities_wire_format(self, agent):
|
||||
"""Verify the JSON wire format uses correct aliases so ACP clients see the right keys."""
|
||||
resp = await agent.initialize(protocol_version=1)
|
||||
payload = resp.agent_capabilities.model_dump(by_alias=True, exclude_none=True)
|
||||
assert payload["loadSession"] is True
|
||||
session_caps = payload["sessionCapabilities"]
|
||||
assert "fork" in session_caps
|
||||
assert "list" in session_caps
|
||||
assert "resume" in session_caps
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -410,6 +423,37 @@ class TestPrompt:
|
||||
update = last_call[1].get("update") or last_call[0][1]
|
||||
assert update.session_update == "agent_message_chunk"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_populates_usage_from_top_level_run_conversation_fields(self, agent):
|
||||
"""ACP should map top-level token fields into PromptResponse.usage."""
|
||||
new_resp = await agent.new_session(cwd=".")
|
||||
state = agent.session_manager.get_session(new_resp.session_id)
|
||||
|
||||
state.agent.run_conversation = MagicMock(return_value={
|
||||
"final_response": "usage attached",
|
||||
"messages": [],
|
||||
"prompt_tokens": 123,
|
||||
"completion_tokens": 45,
|
||||
"total_tokens": 168,
|
||||
"reasoning_tokens": 7,
|
||||
"cache_read_tokens": 11,
|
||||
})
|
||||
|
||||
mock_conn = MagicMock(spec=acp.Client)
|
||||
mock_conn.session_update = AsyncMock()
|
||||
agent._conn = mock_conn
|
||||
|
||||
prompt = [TextContentBlock(type="text", text="show usage")]
|
||||
resp = await agent.prompt(prompt=prompt, session_id=new_resp.session_id)
|
||||
|
||||
assert isinstance(resp, PromptResponse)
|
||||
assert resp.usage is not None
|
||||
assert resp.usage.input_tokens == 123
|
||||
assert resp.usage.output_tokens == 45
|
||||
assert resp.usage.total_tokens == 168
|
||||
assert resp.usage.thought_tokens == 7
|
||||
assert resp.usage.cached_read_tokens == 11
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_cancelled_returns_cancelled_stop_reason(self, agent):
|
||||
"""If cancel is called during prompt, stop_reason should be 'cancelled'."""
|
||||
|
||||
@@ -17,7 +17,6 @@ from agent.anthropic_adapter import (
|
||||
build_anthropic_kwargs,
|
||||
convert_messages_to_anthropic,
|
||||
convert_tools_to_anthropic,
|
||||
get_anthropic_token_source,
|
||||
is_claude_code_token_valid,
|
||||
normalize_anthropic_response,
|
||||
normalize_model_name,
|
||||
@@ -81,6 +80,9 @@ class TestBuildAnthropicClient:
|
||||
build_anthropic_client("sk-ant-api03-x", base_url="https://custom.api.com")
|
||||
kwargs = mock_sdk.Anthropic.call_args[1]
|
||||
assert kwargs["base_url"] == "https://custom.api.com"
|
||||
assert kwargs["default_headers"] == {
|
||||
"anthropic-beta": "interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14"
|
||||
}
|
||||
|
||||
def test_minimax_anthropic_endpoint_uses_bearer_auth_for_regular_api_keys(self):
|
||||
with patch("agent.anthropic_adapter._anthropic_sdk") as mock_sdk:
|
||||
@@ -92,7 +94,20 @@ class TestBuildAnthropicClient:
|
||||
assert kwargs["auth_token"] == "minimax-secret-123"
|
||||
assert "api_key" not in kwargs
|
||||
assert kwargs["default_headers"] == {
|
||||
"anthropic-beta": "interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14"
|
||||
"anthropic-beta": "interleaved-thinking-2025-05-14"
|
||||
}
|
||||
|
||||
def test_minimax_cn_anthropic_endpoint_omits_tool_streaming_beta(self):
|
||||
with patch("agent.anthropic_adapter._anthropic_sdk") as mock_sdk:
|
||||
build_anthropic_client(
|
||||
"minimax-cn-secret-123",
|
||||
base_url="https://api.minimaxi.com/anthropic",
|
||||
)
|
||||
kwargs = mock_sdk.Anthropic.call_args[1]
|
||||
assert kwargs["auth_token"] == "minimax-cn-secret-123"
|
||||
assert "api_key" not in kwargs
|
||||
assert kwargs["default_headers"] == {
|
||||
"anthropic-beta": "interleaved-thinking-2025-05-14"
|
||||
}
|
||||
|
||||
|
||||
@@ -165,15 +180,6 @@ class TestResolveAnthropicToken:
|
||||
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
|
||||
assert resolve_anthropic_token() == "sk-ant-oat01-mytoken"
|
||||
|
||||
def test_reports_claude_json_primary_key_source(self, monkeypatch, tmp_path):
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
|
||||
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
|
||||
(tmp_path / ".claude.json").write_text(json.dumps({"primaryApiKey": "sk-ant-api03-primary"}))
|
||||
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
|
||||
|
||||
assert get_anthropic_token_source("sk-ant-api03-primary") == "claude_json_primary_api_key"
|
||||
|
||||
def test_does_not_resolve_primary_api_key_as_native_anthropic_token(self, monkeypatch, tmp_path):
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
|
||||
|
||||
@@ -9,7 +9,6 @@ import pytest
|
||||
|
||||
from agent.auxiliary_client import (
|
||||
get_text_auxiliary_client,
|
||||
get_vision_auxiliary_client,
|
||||
get_available_vision_backends,
|
||||
resolve_vision_provider_client,
|
||||
resolve_provider_client,
|
||||
@@ -20,7 +19,6 @@ from agent.auxiliary_client import (
|
||||
_get_provider_chain,
|
||||
_is_payment_error,
|
||||
_try_payment_fallback,
|
||||
_resolve_forced_provider,
|
||||
_resolve_auto,
|
||||
)
|
||||
|
||||
@@ -664,15 +662,6 @@ class TestGetTextAuxiliaryClient:
|
||||
class TestVisionClientFallback:
|
||||
"""Vision client auto mode resolves known-good multimodal backends."""
|
||||
|
||||
def test_vision_returns_none_without_any_credentials(self):
|
||||
with (
|
||||
patch("agent.auxiliary_client._read_nous_auth", return_value=None),
|
||||
patch("agent.auxiliary_client._try_anthropic", return_value=(None, None)),
|
||||
):
|
||||
client, model = get_vision_auxiliary_client()
|
||||
assert client is None
|
||||
assert model is None
|
||||
|
||||
def test_vision_auto_includes_active_provider_when_configured(self, monkeypatch):
|
||||
"""Active provider appears in available backends when credentials exist."""
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "***")
|
||||
@@ -754,21 +743,6 @@ class TestAuxiliaryPoolAwareness:
|
||||
assert call_kwargs["base_url"] == "https://api.githubcopilot.com"
|
||||
assert call_kwargs["default_headers"]["Editor-Version"]
|
||||
|
||||
def test_vision_auto_uses_active_provider_as_fallback(self, monkeypatch):
|
||||
"""When no OpenRouter/Nous available, vision auto falls back to active provider."""
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "***")
|
||||
with (
|
||||
patch("agent.auxiliary_client._read_nous_auth", return_value=None),
|
||||
patch("agent.auxiliary_client._read_main_provider", return_value="anthropic"),
|
||||
patch("agent.auxiliary_client._read_main_model", return_value="claude-sonnet-4"),
|
||||
patch("agent.anthropic_adapter.build_anthropic_client", return_value=MagicMock()),
|
||||
patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="***"),
|
||||
):
|
||||
client, model = get_vision_auxiliary_client()
|
||||
|
||||
assert client is not None
|
||||
assert client.__class__.__name__ == "AnthropicAuxiliaryClient"
|
||||
|
||||
def test_vision_auto_prefers_active_provider_over_openrouter(self, monkeypatch):
|
||||
"""Active provider is tried before OpenRouter in vision auto."""
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
@@ -800,43 +774,6 @@ class TestAuxiliaryPoolAwareness:
|
||||
assert client is not None
|
||||
assert provider == "custom:local"
|
||||
|
||||
def test_vision_direct_endpoint_override(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
monkeypatch.setenv("AUXILIARY_VISION_BASE_URL", "http://localhost:4567/v1")
|
||||
monkeypatch.setenv("AUXILIARY_VISION_API_KEY", "vision-key")
|
||||
monkeypatch.setenv("AUXILIARY_VISION_MODEL", "vision-model")
|
||||
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = get_vision_auxiliary_client()
|
||||
assert model == "vision-model"
|
||||
assert mock_openai.call_args.kwargs["base_url"] == "http://localhost:4567/v1"
|
||||
assert mock_openai.call_args.kwargs["api_key"] == "vision-key"
|
||||
|
||||
def test_vision_direct_endpoint_without_key_uses_placeholder(self, monkeypatch):
|
||||
"""Vision endpoint without API key should use 'no-key-required' placeholder."""
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
monkeypatch.setenv("AUXILIARY_VISION_BASE_URL", "http://localhost:4567/v1")
|
||||
monkeypatch.setenv("AUXILIARY_VISION_MODEL", "vision-model")
|
||||
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = get_vision_auxiliary_client()
|
||||
assert client is not None
|
||||
assert model == "vision-model"
|
||||
assert mock_openai.call_args.kwargs["api_key"] == "no-key-required"
|
||||
|
||||
def test_vision_uses_openrouter_when_available(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = get_vision_auxiliary_client()
|
||||
assert model == "google/gemini-3-flash-preview"
|
||||
assert client is not None
|
||||
|
||||
def test_vision_uses_nous_when_available(self, monkeypatch):
|
||||
with patch("agent.auxiliary_client._read_nous_auth") as mock_nous, \
|
||||
patch("agent.auxiliary_client.OpenAI"):
|
||||
mock_nous.return_value = {"access_token": "nous-tok"}
|
||||
client, model = get_vision_auxiliary_client()
|
||||
assert model == "google/gemini-3-flash-preview"
|
||||
assert client is not None
|
||||
|
||||
def test_vision_config_google_provider_uses_gemini_credentials(self, monkeypatch):
|
||||
config = {
|
||||
"auxiliary": {
|
||||
@@ -862,53 +799,6 @@ class TestAuxiliaryPoolAwareness:
|
||||
assert mock_openai.call_args.kwargs["api_key"] == "gemini-key"
|
||||
assert mock_openai.call_args.kwargs["base_url"] == "https://generativelanguage.googleapis.com/v1beta/openai"
|
||||
|
||||
def test_vision_forced_main_uses_custom_endpoint(self, monkeypatch):
|
||||
"""When explicitly forced to 'main', vision CAN use custom endpoint."""
|
||||
config = {
|
||||
"model": {
|
||||
"provider": "custom",
|
||||
"base_url": "http://localhost:1234/v1",
|
||||
"default": "my-local-model",
|
||||
}
|
||||
}
|
||||
monkeypatch.setenv("AUXILIARY_VISION_PROVIDER", "main")
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "local-key")
|
||||
monkeypatch.setattr("hermes_cli.config.load_config", lambda: config)
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.load_config", lambda: config)
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = get_vision_auxiliary_client()
|
||||
assert client is not None
|
||||
assert model == "my-local-model"
|
||||
|
||||
def test_vision_forced_main_returns_none_without_creds(self, monkeypatch):
|
||||
"""Forced main with no credentials still returns None."""
|
||||
monkeypatch.setenv("AUXILIARY_VISION_PROVIDER", "main")
|
||||
monkeypatch.delenv("OPENAI_BASE_URL", raising=False)
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
# Clear client cache to avoid stale entries from previous tests
|
||||
from agent.auxiliary_client import _client_cache
|
||||
_client_cache.clear()
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client._read_main_provider", return_value=""), \
|
||||
patch("agent.auxiliary_client._read_main_model", return_value=""), \
|
||||
patch("agent.auxiliary_client._select_pool_entry", return_value=(False, None)), \
|
||||
patch("agent.auxiliary_client._resolve_custom_runtime", return_value=(None, None)), \
|
||||
patch("agent.auxiliary_client._read_codex_access_token", return_value=None), \
|
||||
patch("agent.auxiliary_client._resolve_api_key_provider", return_value=(None, None)):
|
||||
client, model = get_vision_auxiliary_client()
|
||||
assert client is None
|
||||
assert model is None
|
||||
|
||||
def test_vision_forced_codex(self, monkeypatch, codex_auth_dir):
|
||||
"""When forced to 'codex', vision uses Codex OAuth."""
|
||||
monkeypatch.setenv("AUXILIARY_VISION_PROVIDER", "codex")
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client.OpenAI"):
|
||||
client, model = get_vision_auxiliary_client()
|
||||
from agent.auxiliary_client import CodexAuxiliaryClient
|
||||
assert isinstance(client, CodexAuxiliaryClient)
|
||||
assert model == "gpt-5.2-codex"
|
||||
|
||||
|
||||
class TestGetAuxiliaryProvider:
|
||||
@@ -948,122 +838,6 @@ class TestGetAuxiliaryProvider:
|
||||
assert _get_auxiliary_provider("web_extract") == "main"
|
||||
|
||||
|
||||
class TestResolveForcedProvider:
|
||||
"""Tests for _resolve_forced_provider with explicit provider selection."""
|
||||
|
||||
def test_forced_openrouter(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = _resolve_forced_provider("openrouter")
|
||||
assert model == "google/gemini-3-flash-preview"
|
||||
assert client is not None
|
||||
|
||||
def test_forced_openrouter_no_key(self, monkeypatch):
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None):
|
||||
client, model = _resolve_forced_provider("openrouter")
|
||||
assert client is None
|
||||
assert model is None
|
||||
|
||||
def test_forced_nous(self, monkeypatch):
|
||||
with patch("agent.auxiliary_client._read_nous_auth") as mock_nous, \
|
||||
patch("agent.auxiliary_client.OpenAI"):
|
||||
mock_nous.return_value = {"access_token": "nous-tok"}
|
||||
client, model = _resolve_forced_provider("nous")
|
||||
assert model == "google/gemini-3-flash-preview"
|
||||
assert client is not None
|
||||
|
||||
def test_forced_nous_not_configured(self, monkeypatch):
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None):
|
||||
client, model = _resolve_forced_provider("nous")
|
||||
assert client is None
|
||||
assert model is None
|
||||
|
||||
def test_forced_main_uses_custom(self, monkeypatch):
|
||||
config = {
|
||||
"model": {
|
||||
"provider": "custom",
|
||||
"base_url": "http://local:8080/v1",
|
||||
"default": "my-local-model",
|
||||
}
|
||||
}
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "local-key")
|
||||
monkeypatch.setattr("hermes_cli.config.load_config", lambda: config)
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.load_config", lambda: config)
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = _resolve_forced_provider("main")
|
||||
assert model == "my-local-model"
|
||||
|
||||
def test_forced_main_uses_config_saved_custom_endpoint(self, monkeypatch):
|
||||
config = {
|
||||
"model": {
|
||||
"provider": "custom",
|
||||
"base_url": "http://local:8080/v1",
|
||||
"default": "my-local-model",
|
||||
}
|
||||
}
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "local-key")
|
||||
monkeypatch.setattr("hermes_cli.config.load_config", lambda: config)
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.load_config", lambda: config)
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client._read_codex_access_token", return_value=None), \
|
||||
patch("agent.auxiliary_client._resolve_api_key_provider", return_value=(None, None)), \
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = _resolve_forced_provider("main")
|
||||
assert client is not None
|
||||
assert model == "my-local-model"
|
||||
call_kwargs = mock_openai.call_args
|
||||
assert call_kwargs.kwargs["base_url"] == "http://local:8080/v1"
|
||||
|
||||
def test_forced_main_skips_openrouter_nous(self, monkeypatch):
|
||||
"""Even if OpenRouter key is set, 'main' skips it."""
|
||||
config = {
|
||||
"model": {
|
||||
"provider": "custom",
|
||||
"base_url": "http://local:8080/v1",
|
||||
"default": "my-local-model",
|
||||
}
|
||||
}
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "local-key")
|
||||
monkeypatch.setattr("hermes_cli.config.load_config", lambda: config)
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.load_config", lambda: config)
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = _resolve_forced_provider("main")
|
||||
# Should use custom endpoint, not OpenRouter
|
||||
assert model == "my-local-model"
|
||||
|
||||
def test_forced_main_falls_to_codex(self, codex_auth_dir, monkeypatch):
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client.OpenAI"):
|
||||
client, model = _resolve_forced_provider("main")
|
||||
from agent.auxiliary_client import CodexAuxiliaryClient
|
||||
assert isinstance(client, CodexAuxiliaryClient)
|
||||
assert model == "gpt-5.2-codex"
|
||||
|
||||
def test_forced_codex(self, codex_auth_dir, monkeypatch):
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client.OpenAI"):
|
||||
client, model = _resolve_forced_provider("codex")
|
||||
from agent.auxiliary_client import CodexAuxiliaryClient
|
||||
assert isinstance(client, CodexAuxiliaryClient)
|
||||
assert model == "gpt-5.2-codex"
|
||||
|
||||
def test_forced_codex_no_token(self, monkeypatch):
|
||||
with patch("agent.auxiliary_client._read_codex_access_token", return_value=None):
|
||||
client, model = _resolve_forced_provider("codex")
|
||||
assert client is None
|
||||
assert model is None
|
||||
|
||||
def test_forced_unknown_returns_none(self, monkeypatch):
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client._read_codex_access_token", return_value=None):
|
||||
client, model = _resolve_forced_provider("invalid-provider")
|
||||
assert client is None
|
||||
assert model is None
|
||||
|
||||
|
||||
class TestTaskSpecificOverrides:
|
||||
"""Integration tests for per-task provider routing via get_text_auxiliary_client(task=...)."""
|
||||
|
||||
@@ -1337,3 +1111,45 @@ class TestCallLlmPaymentFallback:
|
||||
task="compression",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Gate: _resolve_api_key_provider must skip anthropic when not configured
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_resolve_api_key_provider_skips_unconfigured_anthropic(monkeypatch):
|
||||
"""_resolve_api_key_provider must not try anthropic when user never configured it."""
|
||||
from collections import OrderedDict
|
||||
from hermes_cli.auth import ProviderConfig
|
||||
|
||||
# Build a minimal registry with only "anthropic" so the loop is guaranteed
|
||||
# to reach it without being short-circuited by earlier providers.
|
||||
fake_registry = OrderedDict({
|
||||
"anthropic": ProviderConfig(
|
||||
id="anthropic",
|
||||
name="Anthropic",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://api.anthropic.com",
|
||||
api_key_env_vars=("ANTHROPIC_API_KEY",),
|
||||
),
|
||||
})
|
||||
|
||||
called = []
|
||||
|
||||
def mock_try_anthropic():
|
||||
called.append("anthropic")
|
||||
return None, None
|
||||
|
||||
monkeypatch.setattr("agent.auxiliary_client._try_anthropic", mock_try_anthropic)
|
||||
monkeypatch.setattr("hermes_cli.auth.PROVIDER_REGISTRY", fake_registry)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth.is_provider_explicitly_configured",
|
||||
lambda pid: False,
|
||||
)
|
||||
|
||||
from agent.auxiliary_client import _resolve_api_key_provider
|
||||
_resolve_api_key_provider()
|
||||
|
||||
assert "anthropic" not in called, \
|
||||
"_try_anthropic() should not be called when anthropic is not explicitly configured"
|
||||
|
||||
@@ -12,6 +12,17 @@ def _isolate(tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
for env_var in (
|
||||
"AUXILIARY_VISION_PROVIDER",
|
||||
"AUXILIARY_VISION_MODEL",
|
||||
"AUXILIARY_VISION_BASE_URL",
|
||||
"AUXILIARY_VISION_API_KEY",
|
||||
"CONTEXT_VISION_PROVIDER",
|
||||
"CONTEXT_VISION_MODEL",
|
||||
"CONTEXT_VISION_BASE_URL",
|
||||
"CONTEXT_VISION_API_KEY",
|
||||
):
|
||||
monkeypatch.delenv(env_var, raising=False)
|
||||
# Write a minimal config so load_config doesn't fail
|
||||
(hermes_home / "config.yaml").write_text("model:\n default: test-model\n")
|
||||
|
||||
@@ -149,3 +160,83 @@ class TestResolveProviderClientNamedCustom:
|
||||
# "coffee" doesn't exist in custom_providers
|
||||
client, model = resolve_provider_client("coffee", "test")
|
||||
assert client is None
|
||||
|
||||
|
||||
class TestResolveProviderClientModelNormalization:
|
||||
"""Direct-provider auxiliary routing should normalize models like main runtime."""
|
||||
|
||||
def test_matching_native_prefix_is_stripped_for_main_provider(self, tmp_path):
|
||||
_write_config(tmp_path, {
|
||||
"model": {"default": "zai/glm-5.1", "provider": "zai"},
|
||||
})
|
||||
with (
|
||||
patch("hermes_cli.auth.resolve_api_key_provider_credentials", return_value={
|
||||
"api_key": "glm-key",
|
||||
"base_url": "https://api.z.ai/api/paas/v4",
|
||||
}),
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai,
|
||||
):
|
||||
mock_openai.return_value = MagicMock()
|
||||
from agent.auxiliary_client import resolve_provider_client
|
||||
|
||||
client, model = resolve_provider_client("main", "zai/glm-5.1")
|
||||
|
||||
assert client is not None
|
||||
assert model == "glm-5.1"
|
||||
|
||||
def test_non_matching_prefix_is_preserved_for_direct_provider(self, tmp_path):
|
||||
_write_config(tmp_path, {
|
||||
"model": {"default": "zai/glm-5.1", "provider": "zai"},
|
||||
})
|
||||
with (
|
||||
patch("hermes_cli.auth.resolve_api_key_provider_credentials", return_value={
|
||||
"api_key": "glm-key",
|
||||
"base_url": "https://api.z.ai/api/paas/v4",
|
||||
}),
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai,
|
||||
):
|
||||
mock_openai.return_value = MagicMock()
|
||||
from agent.auxiliary_client import resolve_provider_client
|
||||
|
||||
client, model = resolve_provider_client("zai", "google/gemini-2.5-pro")
|
||||
|
||||
assert client is not None
|
||||
assert model == "google/gemini-2.5-pro"
|
||||
|
||||
def test_aggregator_vendor_slug_is_preserved(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
mock_openai.return_value = MagicMock()
|
||||
from agent.auxiliary_client import resolve_provider_client
|
||||
|
||||
client, model = resolve_provider_client(
|
||||
"openrouter", "anthropic/claude-sonnet-4.6"
|
||||
)
|
||||
|
||||
assert client is not None
|
||||
assert model == "anthropic/claude-sonnet-4.6"
|
||||
|
||||
|
||||
class TestResolveVisionProviderClientModelNormalization:
|
||||
"""Vision auto-routing should reuse the same provider-specific normalization."""
|
||||
|
||||
def test_vision_auto_strips_matching_main_provider_prefix(self, tmp_path):
|
||||
_write_config(tmp_path, {
|
||||
"model": {"default": "zai/glm-5.1", "provider": "zai"},
|
||||
})
|
||||
with (
|
||||
patch("agent.auxiliary_client._read_nous_auth", return_value=None),
|
||||
patch("hermes_cli.auth.resolve_api_key_provider_credentials", return_value={
|
||||
"api_key": "glm-key",
|
||||
"base_url": "https://api.z.ai/api/paas/v4",
|
||||
}),
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai,
|
||||
):
|
||||
mock_openai.return_value = MagicMock()
|
||||
from agent.auxiliary_client import resolve_vision_provider_client
|
||||
|
||||
provider, client, model = resolve_vision_provider_client()
|
||||
|
||||
assert provider == "zai"
|
||||
assert client is not None
|
||||
assert model == "glm-5.1"
|
||||
|
||||
@@ -38,16 +38,6 @@ class TestShouldCompress:
|
||||
assert compressor.should_compress(prompt_tokens=50000) is False
|
||||
|
||||
|
||||
class TestShouldCompressPreflight:
|
||||
def test_short_messages(self, compressor):
|
||||
msgs = [{"role": "user", "content": "short"}]
|
||||
assert compressor.should_compress_preflight(msgs) is False
|
||||
|
||||
def test_long_messages(self, compressor):
|
||||
# Each message ~100k chars / 4 = 25k tokens, need >85k threshold
|
||||
msgs = [{"role": "user", "content": "x" * 400000}]
|
||||
assert compressor.should_compress_preflight(msgs) is True
|
||||
|
||||
|
||||
class TestUpdateFromResponse:
|
||||
def test_updates_fields(self, compressor):
|
||||
@@ -58,27 +48,12 @@ class TestUpdateFromResponse:
|
||||
})
|
||||
assert compressor.last_prompt_tokens == 5000
|
||||
assert compressor.last_completion_tokens == 1000
|
||||
assert compressor.last_total_tokens == 6000
|
||||
|
||||
def test_missing_fields_default_zero(self, compressor):
|
||||
compressor.update_from_response({})
|
||||
assert compressor.last_prompt_tokens == 0
|
||||
|
||||
|
||||
class TestGetStatus:
|
||||
def test_returns_expected_keys(self, compressor):
|
||||
status = compressor.get_status()
|
||||
assert "last_prompt_tokens" in status
|
||||
assert "threshold_tokens" in status
|
||||
assert "context_length" in status
|
||||
assert "usage_percent" in status
|
||||
assert "compression_count" in status
|
||||
|
||||
def test_usage_percent_calculation(self, compressor):
|
||||
compressor.last_prompt_tokens = 50000
|
||||
status = compressor.get_status()
|
||||
assert status["usage_percent"] == 50.0
|
||||
|
||||
|
||||
class TestCompress:
|
||||
def _make_messages(self, n):
|
||||
|
||||
@@ -83,6 +83,24 @@ def test_parse_references_strips_trailing_punctuation():
|
||||
assert refs[1].target == "https://example.com/docs"
|
||||
|
||||
|
||||
def test_parse_quoted_references_with_spaces_and_preserve_unquoted_ranges():
|
||||
from agent.context_references import parse_context_references
|
||||
|
||||
refs = parse_context_references(
|
||||
'review @file:"C:\\Users\\Simba\\My Project\\main.py":7-9 '
|
||||
'and @folder:"docs and specs" plus @file:src/main.py:1-2'
|
||||
)
|
||||
|
||||
assert [ref.kind for ref in refs] == ["file", "folder", "file"]
|
||||
assert refs[0].target == r"C:\Users\Simba\My Project\main.py"
|
||||
assert refs[0].line_start == 7
|
||||
assert refs[0].line_end == 9
|
||||
assert refs[1].target == "docs and specs"
|
||||
assert refs[2].target == "src/main.py"
|
||||
assert refs[2].line_start == 1
|
||||
assert refs[2].line_end == 2
|
||||
|
||||
|
||||
def test_expand_file_range_and_folder_listing(sample_repo: Path):
|
||||
from agent.context_references import preprocess_context_references
|
||||
|
||||
@@ -106,6 +124,30 @@ def test_expand_file_range_and_folder_listing(sample_repo: Path):
|
||||
assert not result.warnings
|
||||
|
||||
|
||||
def test_expand_quoted_file_reference_with_spaces(tmp_path: Path):
|
||||
from agent.context_references import preprocess_context_references
|
||||
|
||||
workspace = tmp_path / "repo"
|
||||
folder = workspace / "docs and specs"
|
||||
folder.mkdir(parents=True)
|
||||
file_path = folder / "release notes.txt"
|
||||
file_path.write_text("line 1\nline 2\nline 3\n", encoding="utf-8")
|
||||
|
||||
result = preprocess_context_references(
|
||||
'Review @file:"docs and specs/release notes.txt":2-3',
|
||||
cwd=workspace,
|
||||
context_length=100_000,
|
||||
)
|
||||
|
||||
assert result.expanded
|
||||
assert result.message.startswith("Review")
|
||||
assert "line 1" not in result.message
|
||||
assert "line 2" in result.message
|
||||
assert "line 3" in result.message
|
||||
assert "release notes.txt" in result.message
|
||||
assert not result.warnings
|
||||
|
||||
|
||||
def test_expand_git_diff_staged_and_log(sample_repo: Path):
|
||||
from agent.context_references import preprocess_context_references
|
||||
|
||||
|
||||
@@ -567,6 +567,7 @@ def test_singleton_seed_does_not_clobber_manual_oauth_entry(tmp_path, monkeypatc
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
|
||||
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
|
||||
monkeypatch.setattr("hermes_cli.auth.is_provider_explicitly_configured", lambda pid: True)
|
||||
_write_auth_store(
|
||||
tmp_path,
|
||||
{
|
||||
@@ -702,53 +703,6 @@ def test_least_used_strategy_selects_lowest_count(tmp_path, monkeypatch):
|
||||
assert entry.access_token == "sk-or-light"
|
||||
|
||||
|
||||
def test_mark_used_increments_request_count(tmp_path, monkeypatch):
|
||||
"""mark_used should increment the request_count of the current entry."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
monkeypatch.setattr(
|
||||
"agent.credential_pool.get_pool_strategy",
|
||||
lambda _provider: "fill_first",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"agent.credential_pool._seed_from_singletons",
|
||||
lambda provider, entries: (False, set()),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"agent.credential_pool._seed_from_env",
|
||||
lambda provider, entries: (False, set()),
|
||||
)
|
||||
_write_auth_store(
|
||||
tmp_path,
|
||||
{
|
||||
"version": 1,
|
||||
"credential_pool": {
|
||||
"openrouter": [
|
||||
{
|
||||
"id": "key-a",
|
||||
"label": "test",
|
||||
"auth_type": "api_key",
|
||||
"priority": 0,
|
||||
"source": "manual",
|
||||
"access_token": "sk-or-test",
|
||||
"request_count": 5,
|
||||
},
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
from agent.credential_pool import load_pool
|
||||
|
||||
pool = load_pool("openrouter")
|
||||
entry = pool.select()
|
||||
assert entry is not None
|
||||
assert entry.request_count == 5
|
||||
pool.mark_used()
|
||||
updated = pool.current()
|
||||
assert updated is not None
|
||||
assert updated.request_count == 6
|
||||
|
||||
|
||||
def test_thread_safety_concurrent_select(tmp_path, monkeypatch):
|
||||
"""Concurrent select() calls should not corrupt pool state."""
|
||||
import threading as _threading
|
||||
@@ -798,7 +752,6 @@ def test_thread_safety_concurrent_select(tmp_path, monkeypatch):
|
||||
entry = pool.select()
|
||||
if entry:
|
||||
results.append(entry.id)
|
||||
pool.mark_used(entry.id)
|
||||
except Exception as exc:
|
||||
errors.append(exc)
|
||||
|
||||
@@ -1056,8 +1009,8 @@ def test_acquire_lease_prefers_unleased_entry(tmp_path, monkeypatch):
|
||||
|
||||
assert first == "cred-1"
|
||||
assert second == "cred-2"
|
||||
assert pool.active_lease_count("cred-1") == 1
|
||||
assert pool.active_lease_count("cred-2") == 1
|
||||
assert pool._active_leases.get("cred-1", 0) == 1
|
||||
assert pool._active_leases.get("cred-2", 0) == 1
|
||||
|
||||
|
||||
|
||||
@@ -1087,7 +1040,34 @@ def test_release_lease_decrements_counter(tmp_path, monkeypatch):
|
||||
pool = load_pool("openrouter")
|
||||
leased = pool.acquire_lease()
|
||||
assert leased == "cred-1"
|
||||
assert pool.active_lease_count("cred-1") == 1
|
||||
assert pool._active_leases.get("cred-1", 0) == 1
|
||||
|
||||
pool.release_lease("cred-1")
|
||||
assert pool.active_lease_count("cred-1") == 0
|
||||
assert pool._active_leases.get("cred-1", 0) == 0
|
||||
|
||||
|
||||
def test_load_pool_does_not_seed_claude_code_when_anthropic_not_configured(tmp_path, monkeypatch):
|
||||
"""Claude Code credentials must not be auto-seeded when the user never selected anthropic."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
_write_auth_store(tmp_path, {"version": 1, "credential_pool": {}})
|
||||
|
||||
# Claude Code credentials exist on disk
|
||||
monkeypatch.setattr(
|
||||
"agent.anthropic_adapter.read_claude_code_credentials",
|
||||
lambda: {"accessToken": "sk-ant...oken", "refreshToken": "rt", "expiresAt": 9999999999999},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"agent.anthropic_adapter.read_hermes_oauth_credentials",
|
||||
lambda: None,
|
||||
)
|
||||
# User configured kimi-coding, NOT anthropic
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth.is_provider_explicitly_configured",
|
||||
lambda pid: pid == "kimi-coding",
|
||||
)
|
||||
|
||||
from agent.credential_pool import load_pool
|
||||
pool = load_pool("anthropic")
|
||||
|
||||
# Should NOT have seeded the claude_code entry
|
||||
assert pool.entries() == []
|
||||
|
||||
@@ -75,28 +75,6 @@ class TestClassifiedError:
|
||||
e3 = ClassifiedError(reason=FailoverReason.billing)
|
||||
assert e3.is_auth is False
|
||||
|
||||
def test_is_transient_property(self):
|
||||
transient_reasons = [
|
||||
FailoverReason.rate_limit,
|
||||
FailoverReason.overloaded,
|
||||
FailoverReason.server_error,
|
||||
FailoverReason.timeout,
|
||||
FailoverReason.unknown,
|
||||
]
|
||||
for reason in transient_reasons:
|
||||
e = ClassifiedError(reason=reason)
|
||||
assert e.is_transient is True, f"{reason} should be transient"
|
||||
|
||||
non_transient = [
|
||||
FailoverReason.auth,
|
||||
FailoverReason.billing,
|
||||
FailoverReason.model_not_found,
|
||||
FailoverReason.format_error,
|
||||
]
|
||||
for reason in non_transient:
|
||||
e = ClassifiedError(reason=reason)
|
||||
assert e.is_transient is False, f"{reason} should NOT be transient"
|
||||
|
||||
def test_defaults(self):
|
||||
e = ClassifiedError(reason=FailoverReason.unknown)
|
||||
assert e.retryable is True
|
||||
@@ -271,6 +249,22 @@ class TestClassifyApiError:
|
||||
assert result.reason == FailoverReason.rate_limit
|
||||
assert result.should_fallback is True
|
||||
|
||||
def test_alibaba_rate_increased_too_quickly(self):
|
||||
"""Alibaba/DashScope returns a unique throttling message.
|
||||
|
||||
Port from anomalyco/opencode#21355.
|
||||
"""
|
||||
msg = (
|
||||
"Upstream error from Alibaba: Request rate increased too quickly. "
|
||||
"To ensure system stability, please adjust your client logic to "
|
||||
"scale requests more smoothly over time."
|
||||
)
|
||||
e = MockAPIError(msg, status_code=400)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.rate_limit
|
||||
assert result.retryable is True
|
||||
assert result.should_rotate_credential is True
|
||||
|
||||
# ── Server errors ──
|
||||
|
||||
def test_500_server_error(self):
|
||||
@@ -480,6 +474,39 @@ class TestClassifyApiError:
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.context_overflow
|
||||
|
||||
# ── Message-only usage limit disambiguation (no status code) ──
|
||||
|
||||
def test_message_usage_limit_transient_is_rate_limit(self):
|
||||
"""'usage limit' + 'try again' with no status code → rate_limit, not billing."""
|
||||
e = Exception("usage limit exceeded, try again in 5 minutes")
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.rate_limit
|
||||
assert result.retryable is True
|
||||
assert result.should_rotate_credential is True
|
||||
assert result.should_fallback is True
|
||||
|
||||
def test_message_usage_limit_no_retry_signal_is_billing(self):
|
||||
"""'usage limit' with no transient signal and no status code → billing."""
|
||||
e = Exception("usage limit reached")
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.billing
|
||||
assert result.retryable is False
|
||||
assert result.should_rotate_credential is True
|
||||
|
||||
def test_message_quota_with_reset_window_is_rate_limit(self):
|
||||
"""'quota' + 'resets at' with no status code → rate_limit."""
|
||||
e = Exception("quota exceeded, resets at midnight UTC")
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.rate_limit
|
||||
assert result.retryable is True
|
||||
|
||||
def test_message_limit_exceeded_with_wait_is_rate_limit(self):
|
||||
"""'limit exceeded' + 'wait' with no status code → rate_limit."""
|
||||
e = Exception("key limit exceeded, please wait before retrying")
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.rate_limit
|
||||
assert result.retryable is True
|
||||
|
||||
# ── Unknown / fallback ──
|
||||
|
||||
def test_generic_exception_is_unknown(self):
|
||||
|
||||
@@ -7,7 +7,6 @@ from pathlib import Path
|
||||
from hermes_state import SessionDB
|
||||
from agent.insights import (
|
||||
InsightsEngine,
|
||||
_get_pricing,
|
||||
_estimate_cost,
|
||||
_format_duration,
|
||||
_bar_chart,
|
||||
@@ -118,45 +117,6 @@ def populated_db(db):
|
||||
return db
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Pricing helpers
|
||||
# =========================================================================
|
||||
|
||||
class TestPricing:
|
||||
def test_provider_prefix_stripped(self):
|
||||
pricing = _get_pricing("anthropic/claude-sonnet-4-20250514")
|
||||
assert pricing["input"] == 3.00
|
||||
assert pricing["output"] == 15.00
|
||||
|
||||
def test_unknown_models_do_not_use_heuristics(self):
|
||||
pricing = _get_pricing("some-new-opus-model")
|
||||
assert pricing == _DEFAULT_PRICING
|
||||
pricing = _get_pricing("anthropic/claude-haiku-future")
|
||||
assert pricing == _DEFAULT_PRICING
|
||||
|
||||
def test_unknown_model_returns_zero_cost(self):
|
||||
"""Unknown/custom models should NOT have fabricated costs."""
|
||||
pricing = _get_pricing("totally-unknown-model-xyz")
|
||||
assert pricing == _DEFAULT_PRICING
|
||||
assert pricing["input"] == 0.0
|
||||
assert pricing["output"] == 0.0
|
||||
|
||||
def test_custom_endpoint_model_zero_cost(self):
|
||||
"""Self-hosted models should return zero cost."""
|
||||
for model in ["FP16_Hermes_4.5", "Hermes_4.5_1T_epoch2", "my-local-llama"]:
|
||||
pricing = _get_pricing(model)
|
||||
assert pricing["input"] == 0.0, f"{model} should have zero cost"
|
||||
assert pricing["output"] == 0.0, f"{model} should have zero cost"
|
||||
|
||||
def test_none_model(self):
|
||||
pricing = _get_pricing(None)
|
||||
assert pricing == _DEFAULT_PRICING
|
||||
|
||||
def test_empty_model(self):
|
||||
pricing = _get_pricing("")
|
||||
assert pricing == _DEFAULT_PRICING
|
||||
|
||||
|
||||
class TestHasKnownPricing:
|
||||
def test_known_commercial_model(self):
|
||||
assert _has_known_pricing("gpt-4o", provider="openai") is True
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
"""Tests for local provider stream read timeout auto-detection.
|
||||
|
||||
When a local LLM provider is detected (Ollama, llama.cpp, vLLM, etc.),
|
||||
the httpx stream read timeout should be automatically increased from the
|
||||
default 60s to HERMES_API_TIMEOUT (1800s) to avoid premature connection
|
||||
kills during long prefill phases.
|
||||
"""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
from agent.model_metadata import is_local_endpoint
|
||||
|
||||
|
||||
class TestLocalStreamReadTimeout:
|
||||
"""Verify stream read timeout auto-detection logic."""
|
||||
|
||||
@pytest.mark.parametrize("base_url", [
|
||||
"http://localhost:11434",
|
||||
"http://127.0.0.1:8080",
|
||||
"http://0.0.0.0:5000",
|
||||
"http://192.168.1.100:8000",
|
||||
"http://10.0.0.5:1234",
|
||||
])
|
||||
def test_local_endpoint_bumps_read_timeout(self, base_url):
|
||||
"""Local endpoint + default timeout -> bumps to base_timeout."""
|
||||
with patch.dict(os.environ, {}, clear=False):
|
||||
os.environ.pop("HERMES_STREAM_READ_TIMEOUT", None)
|
||||
_base_timeout = float(os.getenv("HERMES_API_TIMEOUT", 1800.0))
|
||||
_stream_read_timeout = float(os.getenv("HERMES_STREAM_READ_TIMEOUT", 120.0))
|
||||
if _stream_read_timeout == 120.0 and base_url and is_local_endpoint(base_url):
|
||||
_stream_read_timeout = _base_timeout
|
||||
assert _stream_read_timeout == 1800.0
|
||||
|
||||
def test_user_override_respected_for_local(self):
|
||||
"""User sets HERMES_STREAM_READ_TIMEOUT -> keep their value even for local."""
|
||||
with patch.dict(os.environ, {"HERMES_STREAM_READ_TIMEOUT": "300"}, clear=False):
|
||||
_base_timeout = float(os.getenv("HERMES_API_TIMEOUT", 1800.0))
|
||||
_stream_read_timeout = float(os.getenv("HERMES_STREAM_READ_TIMEOUT", 120.0))
|
||||
base_url = "http://localhost:11434"
|
||||
if _stream_read_timeout == 120.0 and base_url and is_local_endpoint(base_url):
|
||||
_stream_read_timeout = _base_timeout
|
||||
assert _stream_read_timeout == 300.0
|
||||
|
||||
@pytest.mark.parametrize("base_url", [
|
||||
"https://api.openai.com",
|
||||
"https://openrouter.ai/api",
|
||||
"https://api.anthropic.com",
|
||||
])
|
||||
def test_remote_endpoint_keeps_default(self, base_url):
|
||||
"""Remote endpoint -> keep 120s default."""
|
||||
with patch.dict(os.environ, {}, clear=False):
|
||||
os.environ.pop("HERMES_STREAM_READ_TIMEOUT", None)
|
||||
_base_timeout = float(os.getenv("HERMES_API_TIMEOUT", 1800.0))
|
||||
_stream_read_timeout = float(os.getenv("HERMES_STREAM_READ_TIMEOUT", 120.0))
|
||||
if _stream_read_timeout == 120.0 and base_url and is_local_endpoint(base_url):
|
||||
_stream_read_timeout = _base_timeout
|
||||
assert _stream_read_timeout == 120.0
|
||||
|
||||
def test_empty_base_url_keeps_default(self):
|
||||
"""No base_url set -> keep 120s default."""
|
||||
with patch.dict(os.environ, {}, clear=False):
|
||||
os.environ.pop("HERMES_STREAM_READ_TIMEOUT", None)
|
||||
_base_timeout = float(os.getenv("HERMES_API_TIMEOUT", 1800.0))
|
||||
_stream_read_timeout = float(os.getenv("HERMES_STREAM_READ_TIMEOUT", 120.0))
|
||||
base_url = ""
|
||||
if _stream_read_timeout == 120.0 and base_url and is_local_endpoint(base_url):
|
||||
_stream_read_timeout = _base_timeout
|
||||
assert _stream_read_timeout == 120.0
|
||||
@@ -1,299 +0,0 @@
|
||||
"""End-to-end test: a SQLite-backed memory plugin exercising the full interface.
|
||||
|
||||
This proves a real plugin can register as a MemoryProvider and get wired
|
||||
into the agent loop via MemoryManager. Uses SQLite + FTS5 (stdlib, no
|
||||
external deps, no API keys).
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from agent.memory_provider import MemoryProvider
|
||||
from agent.memory_manager import MemoryManager
|
||||
from agent.builtin_memory_provider import BuiltinMemoryProvider
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SQLite FTS5 memory provider — a real, minimal plugin implementation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SQLiteMemoryProvider(MemoryProvider):
|
||||
"""Minimal SQLite + FTS5 memory provider for testing.
|
||||
|
||||
Demonstrates the full MemoryProvider interface with a real backend.
|
||||
No external dependencies — just stdlib sqlite3.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str = ":memory:"):
|
||||
self._db_path = db_path
|
||||
self._conn = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "sqlite_memory"
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return True # SQLite is always available
|
||||
|
||||
def initialize(self, session_id: str, **kwargs) -> None:
|
||||
self._conn = sqlite3.connect(self._db_path)
|
||||
self._conn.execute("PRAGMA journal_mode=WAL")
|
||||
self._conn.execute("""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS memories
|
||||
USING fts5(content, context, session_id)
|
||||
""")
|
||||
self._session_id = session_id
|
||||
|
||||
def system_prompt_block(self) -> str:
|
||||
if not self._conn:
|
||||
return ""
|
||||
count = self._conn.execute("SELECT COUNT(*) FROM memories").fetchone()[0]
|
||||
if count == 0:
|
||||
return ""
|
||||
return (
|
||||
f"# SQLite Memory Plugin\n"
|
||||
f"Active. {count} memories stored.\n"
|
||||
f"Use sqlite_recall to search, sqlite_retain to store."
|
||||
)
|
||||
|
||||
def prefetch(self, query: str, *, session_id: str = "") -> str:
|
||||
if not self._conn or not query:
|
||||
return ""
|
||||
# FTS5 search
|
||||
try:
|
||||
rows = self._conn.execute(
|
||||
"SELECT content FROM memories WHERE memories MATCH ? LIMIT 5",
|
||||
(query,)
|
||||
).fetchall()
|
||||
if not rows:
|
||||
return ""
|
||||
results = [row[0] for row in rows]
|
||||
return "## SQLite Memory\n" + "\n".join(f"- {r}" for r in results)
|
||||
except sqlite3.OperationalError:
|
||||
return ""
|
||||
|
||||
def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None:
|
||||
if not self._conn:
|
||||
return
|
||||
combined = f"User: {user_content}\nAssistant: {assistant_content}"
|
||||
self._conn.execute(
|
||||
"INSERT INTO memories (content, context, session_id) VALUES (?, ?, ?)",
|
||||
(combined, "conversation", self._session_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def get_tool_schemas(self):
|
||||
return [
|
||||
{
|
||||
"name": "sqlite_retain",
|
||||
"description": "Store a fact to SQLite memory.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {"type": "string", "description": "What to remember"},
|
||||
"context": {"type": "string", "description": "Category/context"},
|
||||
},
|
||||
"required": ["content"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "sqlite_recall",
|
||||
"description": "Search SQLite memory.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
def handle_tool_call(self, tool_name: str, args: dict, **kwargs) -> str:
|
||||
if tool_name == "sqlite_retain":
|
||||
content = args.get("content", "")
|
||||
context = args.get("context", "explicit")
|
||||
if not content:
|
||||
return json.dumps({"error": "content is required"})
|
||||
self._conn.execute(
|
||||
"INSERT INTO memories (content, context, session_id) VALUES (?, ?, ?)",
|
||||
(content, context, self._session_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
return json.dumps({"result": "Stored."})
|
||||
|
||||
elif tool_name == "sqlite_recall":
|
||||
query = args.get("query", "")
|
||||
if not query:
|
||||
return json.dumps({"error": "query is required"})
|
||||
try:
|
||||
rows = self._conn.execute(
|
||||
"SELECT content, context FROM memories WHERE memories MATCH ? LIMIT 10",
|
||||
(query,)
|
||||
).fetchall()
|
||||
results = [{"content": r[0], "context": r[1]} for r in rows]
|
||||
return json.dumps({"results": results})
|
||||
except sqlite3.OperationalError:
|
||||
return json.dumps({"results": []})
|
||||
|
||||
return json.dumps({"error": f"Unknown tool: {tool_name}"})
|
||||
|
||||
def on_memory_write(self, action, target, content):
|
||||
"""Mirror built-in memory writes to SQLite."""
|
||||
if action == "add" and self._conn:
|
||||
self._conn.execute(
|
||||
"INSERT INTO memories (content, context, session_id) VALUES (?, ?, ?)",
|
||||
(content, f"builtin_{target}", self._session_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def shutdown(self):
|
||||
if self._conn:
|
||||
self._conn.close()
|
||||
self._conn = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# End-to-end tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSQLiteMemoryPlugin:
|
||||
"""Full lifecycle test with the SQLite provider."""
|
||||
|
||||
def test_full_lifecycle(self):
|
||||
"""Exercise init → store → recall → sync → prefetch → shutdown."""
|
||||
mgr = MemoryManager()
|
||||
builtin = BuiltinMemoryProvider()
|
||||
sqlite_mem = SQLiteMemoryProvider()
|
||||
|
||||
mgr.add_provider(builtin)
|
||||
mgr.add_provider(sqlite_mem)
|
||||
|
||||
# Initialize
|
||||
mgr.initialize_all(session_id="test-session-1", platform="cli")
|
||||
assert sqlite_mem._conn is not None
|
||||
|
||||
# System prompt — empty at first
|
||||
prompt = mgr.build_system_prompt()
|
||||
assert "SQLite Memory Plugin" not in prompt
|
||||
|
||||
# Store via tool call
|
||||
result = json.loads(mgr.handle_tool_call(
|
||||
"sqlite_retain", {"content": "User prefers dark mode", "context": "preference"}
|
||||
))
|
||||
assert result["result"] == "Stored."
|
||||
|
||||
# System prompt now shows count
|
||||
prompt = mgr.build_system_prompt()
|
||||
assert "1 memories stored" in prompt
|
||||
|
||||
# Recall via tool call
|
||||
result = json.loads(mgr.handle_tool_call(
|
||||
"sqlite_recall", {"query": "dark mode"}
|
||||
))
|
||||
assert len(result["results"]) == 1
|
||||
assert "dark mode" in result["results"][0]["content"]
|
||||
|
||||
# Sync a turn (auto-stores conversation)
|
||||
mgr.sync_all("What's my theme?", "You prefer dark mode.")
|
||||
count = sqlite_mem._conn.execute("SELECT COUNT(*) FROM memories").fetchone()[0]
|
||||
assert count == 2 # 1 explicit + 1 synced
|
||||
|
||||
# Prefetch for next turn
|
||||
prefetched = mgr.prefetch_all("dark mode")
|
||||
assert "dark mode" in prefetched
|
||||
|
||||
# Memory bridge — mirroring builtin writes
|
||||
mgr.on_memory_write("add", "user", "Timezone: US Pacific")
|
||||
count = sqlite_mem._conn.execute("SELECT COUNT(*) FROM memories").fetchone()[0]
|
||||
assert count == 3
|
||||
|
||||
# Shutdown
|
||||
mgr.shutdown_all()
|
||||
assert sqlite_mem._conn is None
|
||||
|
||||
def test_tool_routing_with_builtin(self):
|
||||
"""Verify builtin + plugin tools coexist without conflict."""
|
||||
mgr = MemoryManager()
|
||||
builtin = BuiltinMemoryProvider()
|
||||
sqlite_mem = SQLiteMemoryProvider()
|
||||
mgr.add_provider(builtin)
|
||||
mgr.add_provider(sqlite_mem)
|
||||
mgr.initialize_all(session_id="test-2")
|
||||
|
||||
# Builtin has no tools
|
||||
assert len(builtin.get_tool_schemas()) == 0
|
||||
# SQLite has 2 tools
|
||||
schemas = mgr.get_all_tool_schemas()
|
||||
names = {s["name"] for s in schemas}
|
||||
assert names == {"sqlite_retain", "sqlite_recall"}
|
||||
|
||||
# Routing works
|
||||
assert mgr.has_tool("sqlite_retain")
|
||||
assert mgr.has_tool("sqlite_recall")
|
||||
assert not mgr.has_tool("memory") # builtin doesn't register this
|
||||
|
||||
def test_second_external_plugin_rejected(self):
|
||||
"""Only one external memory provider is allowed at a time."""
|
||||
mgr = MemoryManager()
|
||||
p1 = SQLiteMemoryProvider()
|
||||
p2 = SQLiteMemoryProvider()
|
||||
# Hack name for p2
|
||||
p2._name_override = "sqlite_memory_2"
|
||||
original_name = p2.__class__.name
|
||||
type(p2).name = property(lambda self: getattr(self, '_name_override', 'sqlite_memory'))
|
||||
|
||||
mgr.add_provider(p1)
|
||||
mgr.add_provider(p2) # should be rejected
|
||||
|
||||
# Only p1 was accepted
|
||||
assert len(mgr.providers) == 1
|
||||
assert mgr.provider_names == ["sqlite_memory"]
|
||||
|
||||
# Restore class
|
||||
type(p2).name = original_name
|
||||
mgr.shutdown_all()
|
||||
|
||||
def test_provider_failure_isolation(self):
|
||||
"""Failing external provider doesn't break builtin."""
|
||||
from agent.builtin_memory_provider import BuiltinMemoryProvider
|
||||
|
||||
mgr = MemoryManager()
|
||||
builtin = BuiltinMemoryProvider() # name="builtin", always accepted
|
||||
ext = SQLiteMemoryProvider()
|
||||
|
||||
mgr.add_provider(builtin)
|
||||
mgr.add_provider(ext)
|
||||
mgr.initialize_all(session_id="test-4")
|
||||
|
||||
# Break external provider's connection
|
||||
ext._conn.close()
|
||||
ext._conn = None
|
||||
|
||||
# Sync — external fails silently, builtin (no-op sync) succeeds
|
||||
mgr.sync_all("user", "assistant") # should not raise
|
||||
|
||||
mgr.shutdown_all()
|
||||
|
||||
def test_plugin_registration_flow(self):
|
||||
"""Simulate the full plugin load → agent init path."""
|
||||
# Simulate what AIAgent.__init__ does via plugins/memory/ discovery
|
||||
provider = SQLiteMemoryProvider()
|
||||
|
||||
mem_mgr = MemoryManager()
|
||||
mem_mgr.add_provider(BuiltinMemoryProvider())
|
||||
if provider.is_available():
|
||||
mem_mgr.add_provider(provider)
|
||||
mem_mgr.initialize_all(session_id="agent-session")
|
||||
|
||||
assert len(mem_mgr.providers) == 2
|
||||
assert mem_mgr.provider_names == ["builtin", "sqlite_memory"]
|
||||
assert provider._conn is not None # initialized = connection established
|
||||
|
||||
mem_mgr.shutdown_all()
|
||||
@@ -6,8 +6,6 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
from agent.memory_provider import MemoryProvider
|
||||
from agent.memory_manager import MemoryManager
|
||||
from agent.builtin_memory_provider import BuiltinMemoryProvider
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Concrete test provider
|
||||
@@ -118,7 +116,7 @@ class TestMemoryManager:
|
||||
def test_empty_manager(self):
|
||||
mgr = MemoryManager()
|
||||
assert mgr.providers == []
|
||||
assert mgr.provider_names == []
|
||||
assert [p.name for p in mgr.providers] == []
|
||||
assert mgr.get_all_tool_schemas() == []
|
||||
assert mgr.build_system_prompt() == ""
|
||||
assert mgr.prefetch_all("test") == ""
|
||||
@@ -128,7 +126,7 @@ class TestMemoryManager:
|
||||
p = FakeMemoryProvider("test1")
|
||||
mgr.add_provider(p)
|
||||
assert len(mgr.providers) == 1
|
||||
assert mgr.provider_names == ["test1"]
|
||||
assert [p.name for p in mgr.providers] == ["test1"]
|
||||
|
||||
def test_get_provider_by_name(self):
|
||||
mgr = MemoryManager()
|
||||
@@ -143,7 +141,7 @@ class TestMemoryManager:
|
||||
p2 = FakeMemoryProvider("external")
|
||||
mgr.add_provider(p1)
|
||||
mgr.add_provider(p2)
|
||||
assert mgr.provider_names == ["builtin", "external"]
|
||||
assert [p.name for p in mgr.providers] == ["builtin", "external"]
|
||||
|
||||
def test_second_external_rejected(self):
|
||||
"""Only one non-builtin provider is allowed."""
|
||||
@@ -154,7 +152,7 @@ class TestMemoryManager:
|
||||
mgr.add_provider(builtin)
|
||||
mgr.add_provider(ext1)
|
||||
mgr.add_provider(ext2) # should be rejected
|
||||
assert mgr.provider_names == ["builtin", "mem0"]
|
||||
assert [p.name for p in mgr.providers] == ["builtin", "mem0"]
|
||||
assert len(mgr.providers) == 2
|
||||
|
||||
def test_system_prompt_merges_blocks(self):
|
||||
@@ -321,17 +319,6 @@ class TestMemoryManager:
|
||||
mgr.on_pre_compress([{"role": "user", "content": "old"}])
|
||||
assert p.pre_compress_called
|
||||
|
||||
def test_on_memory_write_skips_builtin(self):
|
||||
"""on_memory_write should skip the builtin provider."""
|
||||
mgr = MemoryManager()
|
||||
builtin = BuiltinMemoryProvider()
|
||||
external = FakeMemoryProvider("external")
|
||||
mgr.add_provider(builtin)
|
||||
mgr.add_provider(external)
|
||||
|
||||
mgr.on_memory_write("add", "memory", "test fact")
|
||||
assert external.memory_writes == [("add", "memory", "test fact")]
|
||||
|
||||
def test_shutdown_all_reverse_order(self):
|
||||
mgr = MemoryManager()
|
||||
order = []
|
||||
@@ -385,146 +372,6 @@ class TestMemoryManager:
|
||||
assert result == "works fine"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BuiltinMemoryProvider tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuiltinMemoryProvider:
|
||||
def test_name(self):
|
||||
p = BuiltinMemoryProvider()
|
||||
assert p.name == "builtin"
|
||||
|
||||
def test_always_available(self):
|
||||
p = BuiltinMemoryProvider()
|
||||
assert p.is_available()
|
||||
|
||||
def test_no_tools(self):
|
||||
"""Builtin provider exposes no tools (memory tool is agent-level)."""
|
||||
p = BuiltinMemoryProvider()
|
||||
assert p.get_tool_schemas() == []
|
||||
|
||||
def test_system_prompt_with_store(self):
|
||||
store = MagicMock()
|
||||
store.format_for_system_prompt.side_effect = lambda t: f"BLOCK_{t}" if t == "memory" else f"BLOCK_{t}"
|
||||
|
||||
p = BuiltinMemoryProvider(
|
||||
memory_store=store,
|
||||
memory_enabled=True,
|
||||
user_profile_enabled=True,
|
||||
)
|
||||
block = p.system_prompt_block()
|
||||
assert "BLOCK_memory" in block
|
||||
assert "BLOCK_user" in block
|
||||
|
||||
def test_system_prompt_memory_disabled(self):
|
||||
store = MagicMock()
|
||||
store.format_for_system_prompt.return_value = "content"
|
||||
|
||||
p = BuiltinMemoryProvider(
|
||||
memory_store=store,
|
||||
memory_enabled=False,
|
||||
user_profile_enabled=False,
|
||||
)
|
||||
assert p.system_prompt_block() == ""
|
||||
|
||||
def test_system_prompt_no_store(self):
|
||||
p = BuiltinMemoryProvider(memory_store=None, memory_enabled=True)
|
||||
assert p.system_prompt_block() == ""
|
||||
|
||||
def test_prefetch_returns_empty(self):
|
||||
p = BuiltinMemoryProvider()
|
||||
assert p.prefetch("anything") == ""
|
||||
|
||||
def test_store_property(self):
|
||||
store = MagicMock()
|
||||
p = BuiltinMemoryProvider(memory_store=store)
|
||||
assert p.store is store
|
||||
|
||||
def test_initialize_loads_from_disk(self):
|
||||
store = MagicMock()
|
||||
p = BuiltinMemoryProvider(memory_store=store)
|
||||
p.initialize(session_id="test")
|
||||
store.load_from_disk.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Plugin registration tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSingleProviderGating:
|
||||
"""Only the configured provider should activate."""
|
||||
|
||||
def test_no_provider_configured_means_builtin_only(self):
|
||||
"""When memory.provider is empty, no plugin providers activate."""
|
||||
mgr = MemoryManager()
|
||||
builtin = BuiltinMemoryProvider()
|
||||
mgr.add_provider(builtin)
|
||||
|
||||
# Simulate what run_agent.py does when provider=""
|
||||
configured = ""
|
||||
available_plugins = [
|
||||
FakeMemoryProvider("holographic"),
|
||||
FakeMemoryProvider("mem0"),
|
||||
]
|
||||
# With empty config, no plugins should be added
|
||||
if configured:
|
||||
for p in available_plugins:
|
||||
if p.name == configured and p.is_available():
|
||||
mgr.add_provider(p)
|
||||
|
||||
assert mgr.provider_names == ["builtin"]
|
||||
|
||||
def test_configured_provider_activates(self):
|
||||
"""Only the named provider should be added."""
|
||||
mgr = MemoryManager()
|
||||
builtin = BuiltinMemoryProvider()
|
||||
mgr.add_provider(builtin)
|
||||
|
||||
configured = "holographic"
|
||||
p1 = FakeMemoryProvider("holographic")
|
||||
p2 = FakeMemoryProvider("mem0")
|
||||
p3 = FakeMemoryProvider("hindsight")
|
||||
|
||||
for p in [p1, p2, p3]:
|
||||
if p.name == configured and p.is_available():
|
||||
mgr.add_provider(p)
|
||||
|
||||
assert mgr.provider_names == ["builtin", "holographic"]
|
||||
assert p1.initialized is False # not initialized by the gating logic itself
|
||||
|
||||
def test_unavailable_provider_skipped(self):
|
||||
"""If the configured provider is unavailable, it should be skipped."""
|
||||
mgr = MemoryManager()
|
||||
builtin = BuiltinMemoryProvider()
|
||||
mgr.add_provider(builtin)
|
||||
|
||||
configured = "holographic"
|
||||
p1 = FakeMemoryProvider("holographic", available=False)
|
||||
|
||||
for p in [p1]:
|
||||
if p.name == configured and p.is_available():
|
||||
mgr.add_provider(p)
|
||||
|
||||
assert mgr.provider_names == ["builtin"]
|
||||
|
||||
def test_nonexistent_provider_results_in_builtin_only(self):
|
||||
"""If the configured name doesn't match any plugin, only builtin remains."""
|
||||
mgr = MemoryManager()
|
||||
builtin = BuiltinMemoryProvider()
|
||||
mgr.add_provider(builtin)
|
||||
|
||||
configured = "nonexistent"
|
||||
plugins = [FakeMemoryProvider("holographic"), FakeMemoryProvider("mem0")]
|
||||
|
||||
for p in plugins:
|
||||
if p.name == configured and p.is_available():
|
||||
mgr.add_provider(p)
|
||||
|
||||
assert mgr.provider_names == ["builtin"]
|
||||
|
||||
|
||||
class TestPluginMemoryDiscovery:
|
||||
"""Memory providers are discovered from plugins/memory/ directory."""
|
||||
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
"""Tests for MiniMax provider hardening — context lengths, thinking guard, catalog."""
|
||||
"""Tests for MiniMax provider hardening — context lengths, thinking guard, catalog, beta headers."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
class TestMinimaxContextLengths:
|
||||
@@ -103,3 +105,100 @@ class TestMinimaxModelCatalog:
|
||||
models = _PROVIDER_MODELS[provider]
|
||||
assert "MiniMax-M2.7-highspeed" not in models
|
||||
assert "MiniMax-M2.5-highspeed" not in models
|
||||
|
||||
|
||||
class TestMinimaxBetaHeaders:
|
||||
"""MiniMax Anthropic-compat endpoints reject fine-grained-tool-streaming beta.
|
||||
|
||||
Verify that build_anthropic_client omits the tool-streaming beta for MiniMax
|
||||
(both global and China domains) while keeping it for native Anthropic and
|
||||
other third-party endpoints. Covers the fix for #6510 / #6555.
|
||||
"""
|
||||
|
||||
_TOOL_BETA = "fine-grained-tool-streaming-2025-05-14"
|
||||
_THINKING_BETA = "interleaved-thinking-2025-05-14"
|
||||
|
||||
# -- helper ----------------------------------------------------------
|
||||
|
||||
def _build_and_get_betas(self, api_key, base_url=None):
|
||||
"""Build client, return the anthropic-beta header string."""
|
||||
from agent.anthropic_adapter import build_anthropic_client
|
||||
with patch("agent.anthropic_adapter._anthropic_sdk") as mock_sdk:
|
||||
build_anthropic_client(api_key, base_url=base_url)
|
||||
kwargs = mock_sdk.Anthropic.call_args[1]
|
||||
headers = kwargs.get("default_headers", {})
|
||||
return headers.get("anthropic-beta", "")
|
||||
|
||||
# -- MiniMax global --------------------------------------------------
|
||||
|
||||
def test_minimax_global_omits_tool_streaming(self):
|
||||
betas = self._build_and_get_betas(
|
||||
"mm-key-123", base_url="https://api.minimax.io/anthropic"
|
||||
)
|
||||
assert self._TOOL_BETA not in betas
|
||||
assert self._THINKING_BETA in betas
|
||||
|
||||
def test_minimax_global_trailing_slash(self):
|
||||
betas = self._build_and_get_betas(
|
||||
"mm-key-123", base_url="https://api.minimax.io/anthropic/"
|
||||
)
|
||||
assert self._TOOL_BETA not in betas
|
||||
|
||||
# -- MiniMax China ---------------------------------------------------
|
||||
|
||||
def test_minimax_cn_omits_tool_streaming(self):
|
||||
betas = self._build_and_get_betas(
|
||||
"mm-cn-key-456", base_url="https://api.minimaxi.com/anthropic"
|
||||
)
|
||||
assert self._TOOL_BETA not in betas
|
||||
assert self._THINKING_BETA in betas
|
||||
|
||||
def test_minimax_cn_trailing_slash(self):
|
||||
betas = self._build_and_get_betas(
|
||||
"mm-cn-key-456", base_url="https://api.minimaxi.com/anthropic/"
|
||||
)
|
||||
assert self._TOOL_BETA not in betas
|
||||
|
||||
# -- Non-MiniMax keeps full betas ------------------------------------
|
||||
|
||||
def test_native_anthropic_keeps_tool_streaming(self):
|
||||
betas = self._build_and_get_betas("sk-ant-api03-real-key-here")
|
||||
assert self._TOOL_BETA in betas
|
||||
assert self._THINKING_BETA in betas
|
||||
|
||||
def test_third_party_proxy_keeps_tool_streaming(self):
|
||||
betas = self._build_and_get_betas(
|
||||
"custom-key", base_url="https://my-proxy.example.com/anthropic"
|
||||
)
|
||||
assert self._TOOL_BETA in betas
|
||||
|
||||
def test_custom_base_url_keeps_tool_streaming(self):
|
||||
betas = self._build_and_get_betas(
|
||||
"custom-key", base_url="https://custom.api.com"
|
||||
)
|
||||
assert self._TOOL_BETA in betas
|
||||
|
||||
# -- _common_betas_for_base_url unit tests ---------------------------
|
||||
|
||||
def test_common_betas_none_url(self):
|
||||
from agent.anthropic_adapter import _common_betas_for_base_url, _COMMON_BETAS
|
||||
assert _common_betas_for_base_url(None) == _COMMON_BETAS
|
||||
|
||||
def test_common_betas_empty_url(self):
|
||||
from agent.anthropic_adapter import _common_betas_for_base_url, _COMMON_BETAS
|
||||
assert _common_betas_for_base_url("") == _COMMON_BETAS
|
||||
|
||||
def test_common_betas_minimax_url(self):
|
||||
from agent.anthropic_adapter import _common_betas_for_base_url, _TOOL_STREAMING_BETA
|
||||
betas = _common_betas_for_base_url("https://api.minimax.io/anthropic")
|
||||
assert _TOOL_STREAMING_BETA not in betas
|
||||
assert len(betas) > 0 # still has other betas
|
||||
|
||||
def test_common_betas_minimax_cn_url(self):
|
||||
from agent.anthropic_adapter import _common_betas_for_base_url, _TOOL_STREAMING_BETA
|
||||
betas = _common_betas_for_base_url("https://api.minimaxi.com/anthropic")
|
||||
assert _TOOL_STREAMING_BETA not in betas
|
||||
|
||||
def test_common_betas_regular_url(self):
|
||||
from agent.anthropic_adapter import _common_betas_for_base_url, _COMMON_BETAS
|
||||
assert _common_betas_for_base_url("https://api.anthropic.com") == _COMMON_BETAS
|
||||
|
||||
@@ -132,6 +132,61 @@ class TestDefaultContextLengths:
|
||||
if "gemini" in key:
|
||||
assert value == 1048576, f"{key} should be 1048576"
|
||||
|
||||
def test_grok_models_context_lengths(self):
|
||||
# xAI /v1/models does not return context_length metadata, so
|
||||
# DEFAULT_CONTEXT_LENGTHS must cover the Grok family explicitly.
|
||||
# Values sourced from models.dev (2026-04).
|
||||
expected = {
|
||||
"grok-4.20": 2000000,
|
||||
"grok-4-1-fast": 2000000,
|
||||
"grok-4-fast": 2000000,
|
||||
"grok-4": 256000,
|
||||
"grok-code-fast": 256000,
|
||||
"grok-3": 131072,
|
||||
"grok-2": 131072,
|
||||
"grok-2-vision": 8192,
|
||||
"grok": 131072,
|
||||
}
|
||||
for key, value in expected.items():
|
||||
assert key in DEFAULT_CONTEXT_LENGTHS, f"{key} missing from DEFAULT_CONTEXT_LENGTHS"
|
||||
assert DEFAULT_CONTEXT_LENGTHS[key] == value, (
|
||||
f"{key} should be {value}, got {DEFAULT_CONTEXT_LENGTHS[key]}"
|
||||
)
|
||||
|
||||
def test_grok_substring_matching(self):
|
||||
# Longest-first substring matching must resolve the real xAI model
|
||||
# IDs to the correct fallback entries without 128k probe-down.
|
||||
from agent.model_metadata import get_model_context_length
|
||||
from unittest.mock import patch as mock_patch
|
||||
|
||||
# Fake the provider/API/cache layers so the lookup falls through
|
||||
# to DEFAULT_CONTEXT_LENGTHS.
|
||||
with mock_patch("agent.model_metadata.fetch_model_metadata", return_value={}), mock_patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), mock_patch("agent.model_metadata.get_cached_context_length", return_value=None):
|
||||
cases = [
|
||||
("grok-4.20-0309-reasoning", 2000000),
|
||||
("grok-4.20-0309-non-reasoning", 2000000),
|
||||
("grok-4.20-multi-agent-0309", 2000000),
|
||||
("grok-4-1-fast-reasoning", 2000000),
|
||||
("grok-4-1-fast-non-reasoning", 2000000),
|
||||
("grok-4-fast-reasoning", 2000000),
|
||||
("grok-4-fast-non-reasoning", 2000000),
|
||||
("grok-4", 256000),
|
||||
("grok-4-0709", 256000),
|
||||
("grok-code-fast-1", 256000),
|
||||
("grok-3", 131072),
|
||||
("grok-3-mini", 131072),
|
||||
("grok-3-mini-fast", 131072),
|
||||
("grok-2", 131072),
|
||||
("grok-2-vision", 8192),
|
||||
("grok-2-vision-1212", 8192),
|
||||
("grok-beta", 131072),
|
||||
]
|
||||
for model_id, expected_ctx in cases:
|
||||
actual = get_model_context_length(model_id)
|
||||
assert actual == expected_ctx, (
|
||||
f"{model_id}: expected {expected_ctx}, got {actual}"
|
||||
)
|
||||
|
||||
def test_all_values_positive(self):
|
||||
for key, value in DEFAULT_CONTEXT_LENGTHS.items():
|
||||
assert value > 0, f"{key} has non-positive context length"
|
||||
|
||||
@@ -11,7 +11,6 @@ from agent.prompt_builder import (
|
||||
_scan_context_content,
|
||||
_truncate_content,
|
||||
_parse_skill_file,
|
||||
_read_skill_conditions,
|
||||
_skill_should_show,
|
||||
_find_hermes_md,
|
||||
_find_git_root,
|
||||
@@ -775,61 +774,6 @@ class TestPromptBuilderConstants:
|
||||
# Conditional skill activation
|
||||
# =========================================================================
|
||||
|
||||
class TestReadSkillConditions:
|
||||
def test_no_conditions_returns_empty_lists(self, tmp_path):
|
||||
skill_file = tmp_path / "SKILL.md"
|
||||
skill_file.write_text("---\nname: test\ndescription: A skill\n---\n")
|
||||
conditions = _read_skill_conditions(skill_file)
|
||||
assert conditions["fallback_for_toolsets"] == []
|
||||
assert conditions["requires_toolsets"] == []
|
||||
assert conditions["fallback_for_tools"] == []
|
||||
assert conditions["requires_tools"] == []
|
||||
|
||||
def test_reads_fallback_for_toolsets(self, tmp_path):
|
||||
skill_file = tmp_path / "SKILL.md"
|
||||
skill_file.write_text(
|
||||
"---\nname: ddg\ndescription: DuckDuckGo\nmetadata:\n hermes:\n fallback_for_toolsets: [web]\n---\n"
|
||||
)
|
||||
conditions = _read_skill_conditions(skill_file)
|
||||
assert conditions["fallback_for_toolsets"] == ["web"]
|
||||
|
||||
def test_reads_requires_toolsets(self, tmp_path):
|
||||
skill_file = tmp_path / "SKILL.md"
|
||||
skill_file.write_text(
|
||||
"---\nname: openhue\ndescription: Hue lights\nmetadata:\n hermes:\n requires_toolsets: [terminal]\n---\n"
|
||||
)
|
||||
conditions = _read_skill_conditions(skill_file)
|
||||
assert conditions["requires_toolsets"] == ["terminal"]
|
||||
|
||||
def test_reads_multiple_conditions(self, tmp_path):
|
||||
skill_file = tmp_path / "SKILL.md"
|
||||
skill_file.write_text(
|
||||
"---\nname: test\ndescription: Test\nmetadata:\n hermes:\n fallback_for_toolsets: [browser]\n requires_tools: [terminal]\n---\n"
|
||||
)
|
||||
conditions = _read_skill_conditions(skill_file)
|
||||
assert conditions["fallback_for_toolsets"] == ["browser"]
|
||||
assert conditions["requires_tools"] == ["terminal"]
|
||||
|
||||
def test_missing_file_returns_empty(self, tmp_path):
|
||||
conditions = _read_skill_conditions(tmp_path / "missing.md")
|
||||
assert conditions == {}
|
||||
|
||||
def test_logs_condition_read_failures_and_returns_empty(self, tmp_path, monkeypatch, caplog):
|
||||
skill_file = tmp_path / "SKILL.md"
|
||||
skill_file.write_text("---\nname: broken\n---\n")
|
||||
|
||||
def boom(*args, **kwargs):
|
||||
raise OSError("read exploded")
|
||||
|
||||
monkeypatch.setattr(type(skill_file), "read_text", boom)
|
||||
with caplog.at_level(logging.DEBUG, logger="agent.prompt_builder"):
|
||||
conditions = _read_skill_conditions(skill_file)
|
||||
|
||||
assert conditions == {}
|
||||
assert "Failed to read skill conditions" in caplog.text
|
||||
assert str(skill_file) in caplog.text
|
||||
|
||||
|
||||
class TestSkillShouldShow:
|
||||
def test_no_filter_info_always_shows(self):
|
||||
assert _skill_should_show({}, None, None) is True
|
||||
|
||||
@@ -6,6 +6,17 @@ from unittest.mock import patch
|
||||
from cli import HermesCLI
|
||||
|
||||
|
||||
def _assert_chrome_debug_cmd(cmd, expected_chrome, expected_port):
|
||||
"""Verify the auto-launch command has all required flags."""
|
||||
assert cmd[0] == expected_chrome
|
||||
assert f"--remote-debugging-port={expected_port}" in cmd
|
||||
assert "--no-first-run" in cmd
|
||||
assert "--no-default-browser-check" in cmd
|
||||
user_data_args = [a for a in cmd if a.startswith("--user-data-dir=")]
|
||||
assert len(user_data_args) == 1, "Expected exactly one --user-data-dir flag"
|
||||
assert "chrome-debug" in user_data_args[0]
|
||||
|
||||
|
||||
class TestChromeDebugLaunch:
|
||||
def test_windows_launch_uses_browser_found_on_path(self):
|
||||
captured = {}
|
||||
@@ -20,7 +31,7 @@ class TestChromeDebugLaunch:
|
||||
patch("subprocess.Popen", side_effect=fake_popen):
|
||||
assert HermesCLI._try_launch_chrome_debug(9333, "Windows") is True
|
||||
|
||||
assert captured["cmd"] == [r"C:\Chrome\chrome.exe", "--remote-debugging-port=9333"]
|
||||
_assert_chrome_debug_cmd(captured["cmd"], r"C:\Chrome\chrome.exe", 9333)
|
||||
assert captured["kwargs"]["start_new_session"] is True
|
||||
|
||||
def test_windows_launch_falls_back_to_common_install_dirs(self, monkeypatch):
|
||||
@@ -43,4 +54,4 @@ class TestChromeDebugLaunch:
|
||||
patch("subprocess.Popen", side_effect=fake_popen):
|
||||
assert HermesCLI._try_launch_chrome_debug(9222, "Windows") is True
|
||||
|
||||
assert captured["cmd"] == [installed, "--remote-debugging-port=9222"]
|
||||
_assert_chrome_debug_cmd(captured["cmd"], installed, 9222)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user