Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f7d57c0108 | |||
| d77783d198 | |||
| 4af69097f2 | |||
| 59471b79e5 | |||
| 0e459f2b7b | |||
| 3befb9389f |
@@ -8,9 +8,6 @@ on:
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: docker-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
@@ -20,29 +17,22 @@ jobs:
|
||||
# Only run on the upstream repository, not on forks
|
||||
if: github.repository == 'NousResearch/hermes-agent'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 60
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
# Build amd64 only so we can `load` the image for smoke testing.
|
||||
# `load: true` cannot export a multi-arch manifest to the local daemon.
|
||||
# The multi-arch build follows on push to main / release.
|
||||
- name: Build image (amd64, smoke test)
|
||||
- name: Build image
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: Dockerfile
|
||||
load: true
|
||||
platforms: linux/amd64
|
||||
tags: nousresearch/hermes-agent:test
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
@@ -61,28 +51,26 @@ jobs:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Push multi-arch image (main branch)
|
||||
- name: Push image (main branch)
|
||||
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: Dockerfile
|
||||
push: true
|
||||
platforms: linux/amd64,linux/arm64
|
||||
tags: |
|
||||
nousresearch/hermes-agent:latest
|
||||
nousresearch/hermes-agent:${{ github.sha }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
- name: Push multi-arch image (release)
|
||||
- name: Push image (release)
|
||||
if: github.event_name == 'release'
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: Dockerfile
|
||||
push: true
|
||||
platforms: linux/amd64,linux/arm64
|
||||
tags: |
|
||||
nousresearch/hermes-agent:latest
|
||||
nousresearch/hermes-agent:${{ github.event.release.tag_name }}
|
||||
|
||||
@@ -27,8 +27,8 @@ jobs:
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Install ascii-guard
|
||||
run: python -m pip install ascii-guard==2.3.0 pyyaml==6.0.3
|
||||
- name: Install Python dependencies
|
||||
run: python -m pip install ascii-guard pyyaml
|
||||
|
||||
- name: Extract skill metadata for dashboard
|
||||
run: python3 website/scripts/extract-skills.py
|
||||
|
||||
@@ -27,8 +27,8 @@ jobs:
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: DeterminateSystems/nix-installer-action@ef8a148080ab6020fd15196c2084a2eea5ff2d25 # v22
|
||||
- uses: DeterminateSystems/magic-nix-cache-action@565684385bcd71bad329742eefe8d12f2e765b39 # v13
|
||||
- uses: DeterminateSystems/nix-installer-action@main
|
||||
- uses: DeterminateSystems/magic-nix-cache-action@main
|
||||
- name: Check flake
|
||||
if: runner.os == 'Linux'
|
||||
run: nix flake check --print-build-logs
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
FROM debian:13.4
|
||||
|
||||
# Disable Python stdout buffering to ensure logs are printed immediately
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
# Install system dependencies in one layer, clear APT cache
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
|
||||
+83
-27
@@ -485,6 +485,35 @@ def _prefer_refreshable_claude_code_token(env_token: str, creds: Optional[Dict[s
|
||||
return None
|
||||
|
||||
|
||||
def get_anthropic_token_source(token: Optional[str] = None) -> str:
|
||||
"""Best-effort source classification for an Anthropic credential token."""
|
||||
token = (token or "").strip()
|
||||
if not token:
|
||||
return "none"
|
||||
|
||||
env_token = os.getenv("ANTHROPIC_TOKEN", "").strip()
|
||||
if env_token and env_token == token:
|
||||
return "anthropic_token_env"
|
||||
|
||||
cc_env_token = os.getenv("CLAUDE_CODE_OAUTH_TOKEN", "").strip()
|
||||
if cc_env_token and cc_env_token == token:
|
||||
return "claude_code_oauth_token_env"
|
||||
|
||||
creds = read_claude_code_credentials()
|
||||
if creds and creds.get("accessToken") == token:
|
||||
return str(creds.get("source") or "claude_code_credentials")
|
||||
|
||||
managed_key = read_claude_managed_key()
|
||||
if managed_key and managed_key == token:
|
||||
return "claude_json_primary_api_key"
|
||||
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY", "").strip()
|
||||
if api_key and api_key == token:
|
||||
return "anthropic_api_key_env"
|
||||
|
||||
return "unknown"
|
||||
|
||||
|
||||
def resolve_anthropic_token() -> Optional[str]:
|
||||
"""Resolve an Anthropic token from all available sources.
|
||||
|
||||
@@ -691,6 +720,21 @@ def run_hermes_oauth_login_pure() -> Optional[Dict[str, Any]]:
|
||||
}
|
||||
|
||||
|
||||
def _save_hermes_oauth_credentials(access_token: str, refresh_token: str, expires_at_ms: int) -> None:
|
||||
"""Save OAuth credentials to ~/.hermes/.anthropic_oauth.json."""
|
||||
data = {
|
||||
"accessToken": access_token,
|
||||
"refreshToken": refresh_token,
|
||||
"expiresAt": expires_at_ms,
|
||||
}
|
||||
try:
|
||||
_HERMES_OAUTH_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
_HERMES_OAUTH_FILE.write_text(json.dumps(data, indent=2), encoding="utf-8")
|
||||
_HERMES_OAUTH_FILE.chmod(0o600)
|
||||
except (OSError, IOError) as e:
|
||||
logger.debug("Failed to save Hermes OAuth credentials: %s", e)
|
||||
|
||||
|
||||
def read_hermes_oauth_credentials() -> Optional[Dict[str, Any]]:
|
||||
"""Read Hermes-managed OAuth credentials from ~/.hermes/.anthropic_oauth.json."""
|
||||
if _HERMES_OAUTH_FILE.exists():
|
||||
@@ -739,6 +783,39 @@ def _sanitize_tool_id(tool_id: str) -> str:
|
||||
return sanitized or "tool_0"
|
||||
|
||||
|
||||
def _convert_openai_image_part_to_anthropic(part: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""Convert an OpenAI-style image block to Anthropic's image source format."""
|
||||
image_data = part.get("image_url", {})
|
||||
url = image_data.get("url", "") if isinstance(image_data, dict) else str(image_data)
|
||||
if not isinstance(url, str) or not url.strip():
|
||||
return None
|
||||
url = url.strip()
|
||||
|
||||
if url.startswith("data:"):
|
||||
header, sep, data = url.partition(",")
|
||||
if sep and ";base64" in header:
|
||||
media_type = header[5:].split(";", 1)[0] or "image/png"
|
||||
return {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": data,
|
||||
},
|
||||
}
|
||||
|
||||
if url.startswith(("http://", "https://")):
|
||||
return {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "url",
|
||||
"url": url,
|
||||
},
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def convert_tools_to_anthropic(tools: List[Dict]) -> List[Dict]:
|
||||
"""Convert OpenAI tool definitions to Anthropic format."""
|
||||
if not tools:
|
||||
@@ -1161,27 +1238,10 @@ def build_anthropic_kwargs(
|
||||
) -> Dict[str, Any]:
|
||||
"""Build kwargs for anthropic.messages.create().
|
||||
|
||||
Naming note — two distinct concepts, easily confused:
|
||||
max_tokens = OUTPUT token cap for a single response.
|
||||
Anthropic's API calls this "max_tokens" but it only
|
||||
limits the *output*. Anthropic's own native SDK
|
||||
renamed it "max_output_tokens" for clarity.
|
||||
context_length = TOTAL context window (input tokens + output tokens).
|
||||
The API enforces: input_tokens + max_tokens ≤ context_length.
|
||||
Stored on the ContextCompressor; reduced on overflow errors.
|
||||
|
||||
When *max_tokens* is None the model's native output ceiling is used
|
||||
(e.g. 128K for Opus 4.6, 64K for Sonnet 4.6).
|
||||
|
||||
When *context_length* is provided and the model's native output ceiling
|
||||
exceeds it (e.g. a local endpoint with an 8K window), the output cap is
|
||||
clamped to context_length − 1. This only kicks in for unusually small
|
||||
context windows; for full-size models the native output cap is always
|
||||
smaller than the context window so no clamping happens.
|
||||
NOTE: this clamping does not account for prompt size — if the prompt is
|
||||
large, Anthropic may still reject the request. The caller must detect
|
||||
"max_tokens too large given prompt" errors and retry with a smaller cap
|
||||
(see parse_available_output_tokens_from_error + _ephemeral_max_output_tokens).
|
||||
When *max_tokens* is None, the model's native output limit is used
|
||||
(e.g. 128K for Opus 4.6, 64K for Sonnet 4.6). If *context_length*
|
||||
is provided, the effective limit is clamped so it doesn't exceed
|
||||
the context window.
|
||||
|
||||
When *is_oauth* is True, applies Claude Code compatibility transforms:
|
||||
system prompt prefix, tool name prefixing, and prompt sanitization.
|
||||
@@ -1196,14 +1256,10 @@ def build_anthropic_kwargs(
|
||||
anthropic_tools = convert_tools_to_anthropic(tools) if tools else []
|
||||
|
||||
model = normalize_model_name(model, preserve_dots=preserve_dots)
|
||||
# effective_max_tokens = output cap for this call (≠ total context window)
|
||||
effective_max_tokens = max_tokens or _get_anthropic_max_output(model)
|
||||
|
||||
# Clamp output cap to fit inside the total context window.
|
||||
# Only matters for small custom endpoints where context_length < native
|
||||
# output ceiling. For standard Anthropic models context_length (e.g.
|
||||
# 200K) is always larger than the output ceiling (e.g. 128K), so this
|
||||
# branch is not taken.
|
||||
# Clamp to context window if the user set a lower context_length
|
||||
# (e.g. custom endpoint with limited capacity).
|
||||
if context_length and effective_max_tokens > context_length:
|
||||
effective_max_tokens = max(context_length - 1, 1)
|
||||
|
||||
|
||||
+74
-67
@@ -629,19 +629,11 @@ def _nous_base_url() -> str:
|
||||
|
||||
|
||||
def _read_codex_access_token() -> Optional[str]:
|
||||
"""Read a valid, non-expired Codex OAuth access token from Hermes auth store.
|
||||
|
||||
If a credential pool exists but currently has no selectable runtime entry
|
||||
(for example all pool slots are marked exhausted), fall back to the
|
||||
profile's auth.json token instead of hard-failing. This keeps explicit
|
||||
fallback-to-Codex working when the pool state is stale but the stored OAuth
|
||||
token is still valid.
|
||||
"""
|
||||
"""Read a valid, non-expired Codex OAuth access token from Hermes auth store."""
|
||||
pool_present, entry = _select_pool_entry("openai-codex")
|
||||
if pool_present:
|
||||
token = _pool_runtime_api_key(entry)
|
||||
if token:
|
||||
return token
|
||||
return token or None
|
||||
|
||||
try:
|
||||
from hermes_cli.auth import _read_codex_tokens
|
||||
@@ -702,7 +694,7 @@ def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
logger.debug("Auxiliary text client: %s (%s) via pool", pconfig.name, model)
|
||||
extra = {}
|
||||
if "api.kimi.com" in base_url.lower():
|
||||
extra["default_headers"] = {"User-Agent": "KimiCLI/1.3"}
|
||||
extra["default_headers"] = {"User-Agent": "KimiCLI/1.0"}
|
||||
elif "api.githubcopilot.com" in base_url.lower():
|
||||
from hermes_cli.models import copilot_default_headers
|
||||
|
||||
@@ -721,7 +713,7 @@ def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
logger.debug("Auxiliary text client: %s (%s)", pconfig.name, model)
|
||||
extra = {}
|
||||
if "api.kimi.com" in base_url.lower():
|
||||
extra["default_headers"] = {"User-Agent": "KimiCLI/1.3"}
|
||||
extra["default_headers"] = {"User-Agent": "KimiCLI/1.0"}
|
||||
elif "api.githubcopilot.com" in base_url.lower():
|
||||
from hermes_cli.models import copilot_default_headers
|
||||
|
||||
@@ -902,13 +894,9 @@ def _try_codex() -> Tuple[Optional[Any], Optional[str]]:
|
||||
pool_present, entry = _select_pool_entry("openai-codex")
|
||||
if pool_present:
|
||||
codex_token = _pool_runtime_api_key(entry)
|
||||
if codex_token:
|
||||
base_url = _pool_runtime_base_url(entry, _CODEX_AUX_BASE_URL) or _CODEX_AUX_BASE_URL
|
||||
else:
|
||||
codex_token = _read_codex_access_token()
|
||||
if not codex_token:
|
||||
return None, None
|
||||
base_url = _CODEX_AUX_BASE_URL
|
||||
if not codex_token:
|
||||
return None, None
|
||||
base_url = _pool_runtime_base_url(entry, _CODEX_AUX_BASE_URL) or _CODEX_AUX_BASE_URL
|
||||
else:
|
||||
codex_token = _read_codex_access_token()
|
||||
if not codex_token:
|
||||
@@ -967,6 +955,40 @@ def _try_anthropic() -> Tuple[Optional[Any], Optional[str]]:
|
||||
return AnthropicAuxiliaryClient(real_client, model, token, base_url, is_oauth=is_oauth), model
|
||||
|
||||
|
||||
def _resolve_forced_provider(forced: str) -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
"""Resolve a specific forced provider. Returns (None, None) if creds missing."""
|
||||
if forced == "openrouter":
|
||||
client, model = _try_openrouter()
|
||||
if client is None:
|
||||
logger.warning("auxiliary.provider=openrouter but OPENROUTER_API_KEY not set")
|
||||
return client, model
|
||||
|
||||
if forced == "nous":
|
||||
client, model = _try_nous()
|
||||
if client is None:
|
||||
logger.warning("auxiliary.provider=nous but Nous Portal not configured (run: hermes auth)")
|
||||
return client, model
|
||||
|
||||
if forced == "codex":
|
||||
client, model = _try_codex()
|
||||
if client is None:
|
||||
logger.warning("auxiliary.provider=codex but no Codex OAuth token found (run: hermes model)")
|
||||
return client, model
|
||||
|
||||
if forced == "main":
|
||||
# "main" = skip OpenRouter/Nous, use the main chat model's credentials.
|
||||
for try_fn in (_try_custom_endpoint, _try_codex, _resolve_api_key_provider):
|
||||
client, model = try_fn()
|
||||
if client is not None:
|
||||
return client, model
|
||||
logger.warning("auxiliary.provider=main but no main endpoint credentials found")
|
||||
return None, None
|
||||
|
||||
# Unknown provider name — fall through to auto
|
||||
logger.warning("Unknown auxiliary.provider=%r, falling back to auto", forced)
|
||||
return None, None
|
||||
|
||||
|
||||
_AUTO_PROVIDER_LABELS = {
|
||||
"_try_openrouter": "openrouter",
|
||||
"_try_nous": "nous",
|
||||
@@ -1013,32 +1035,6 @@ def _is_payment_error(exc: Exception) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _is_connection_error(exc: Exception) -> bool:
|
||||
"""Detect connection/network errors that warrant provider fallback.
|
||||
|
||||
Returns True for errors indicating the provider endpoint is unreachable
|
||||
(DNS failure, connection refused, TLS errors, timeouts). These are
|
||||
distinct from API errors (4xx/5xx) which indicate the provider IS
|
||||
reachable but returned an error.
|
||||
"""
|
||||
from openai import APIConnectionError, APITimeoutError
|
||||
|
||||
if isinstance(exc, (APIConnectionError, APITimeoutError)):
|
||||
return True
|
||||
# urllib3 / httpx / httpcore connection errors
|
||||
err_type = type(exc).__name__
|
||||
if any(kw in err_type for kw in ("Connection", "Timeout", "DNS", "SSL")):
|
||||
return True
|
||||
err_lower = str(exc).lower()
|
||||
if any(kw in err_lower for kw in (
|
||||
"connection refused", "name or service not known",
|
||||
"no route to host", "network is unreachable",
|
||||
"timed out", "connection reset",
|
||||
)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _try_payment_fallback(
|
||||
failed_provider: str,
|
||||
task: str = None,
|
||||
@@ -1103,7 +1099,7 @@ def _resolve_auto() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
main_model = _read_main_model()
|
||||
if (main_provider and main_model
|
||||
and main_provider not in _AGGREGATOR_PROVIDERS
|
||||
and main_provider not in ("auto", "")):
|
||||
and main_provider not in ("auto", "custom", "")):
|
||||
client, resolved = resolve_provider_client(main_provider, main_model)
|
||||
if client is not None:
|
||||
logger.info("Auxiliary auto-detect: using main provider %s (%s)",
|
||||
@@ -1161,7 +1157,7 @@ def _to_async_client(sync_client, model: str):
|
||||
|
||||
async_kwargs["default_headers"] = copilot_default_headers()
|
||||
elif "api.kimi.com" in base_lower:
|
||||
async_kwargs["default_headers"] = {"User-Agent": "KimiCLI/1.3"}
|
||||
async_kwargs["default_headers"] = {"User-Agent": "KimiCLI/1.0"}
|
||||
return AsyncOpenAI(**async_kwargs), model
|
||||
|
||||
|
||||
@@ -1281,13 +1277,7 @@ def resolve_provider_client(
|
||||
)
|
||||
return None, None
|
||||
final_model = model or _read_main_model() or "gpt-4o-mini"
|
||||
extra = {}
|
||||
if "api.kimi.com" in custom_base.lower():
|
||||
extra["default_headers"] = {"User-Agent": "KimiCLI/1.3"}
|
||||
elif "api.githubcopilot.com" in custom_base.lower():
|
||||
from hermes_cli.models import copilot_default_headers
|
||||
extra["default_headers"] = copilot_default_headers()
|
||||
client = OpenAI(api_key=custom_key, base_url=custom_base, **extra)
|
||||
client = OpenAI(api_key=custom_key, base_url=custom_base)
|
||||
return (_to_async_client(client, final_model) if async_mode
|
||||
else (client, final_model))
|
||||
# Try custom first, then codex, then API-key providers
|
||||
@@ -1366,7 +1356,7 @@ def resolve_provider_client(
|
||||
# Provider-specific headers
|
||||
headers = {}
|
||||
if "api.kimi.com" in base_url.lower():
|
||||
headers["User-Agent"] = "KimiCLI/1.3"
|
||||
headers["User-Agent"] = "KimiCLI/1.0"
|
||||
elif "api.githubcopilot.com" in base_url.lower():
|
||||
from hermes_cli.models import copilot_default_headers
|
||||
|
||||
@@ -1461,6 +1451,22 @@ def _strict_vision_backend_available(provider: str) -> bool:
|
||||
return _resolve_strict_vision_backend(provider)[0] is not None
|
||||
|
||||
|
||||
def _preferred_main_vision_provider() -> Optional[str]:
|
||||
"""Return the selected main provider when it is also a supported vision backend."""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
|
||||
config = load_config()
|
||||
model_cfg = config.get("model", {})
|
||||
if isinstance(model_cfg, dict):
|
||||
provider = _normalize_vision_provider(model_cfg.get("provider", ""))
|
||||
if provider in _VISION_AUTO_PROVIDER_ORDER:
|
||||
return provider
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def get_available_vision_backends() -> List[str]:
|
||||
"""Return the currently available vision backends in auto-selection order.
|
||||
|
||||
@@ -1574,6 +1580,18 @@ def resolve_vision_provider_client(
|
||||
return requested, client, final_model
|
||||
|
||||
|
||||
def get_vision_auxiliary_client() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
"""Return (client, default_model_slug) for vision/multimodal auxiliary tasks."""
|
||||
_, client, final_model = resolve_vision_provider_client(async_mode=False)
|
||||
return client, final_model
|
||||
|
||||
|
||||
def get_async_vision_auxiliary_client():
|
||||
"""Return (async_client, model_slug) for async vision consumers."""
|
||||
_, client, final_model = resolve_vision_provider_client(async_mode=True)
|
||||
return client, final_model
|
||||
|
||||
|
||||
def get_auxiliary_extra_body() -> dict:
|
||||
"""Return extra_body kwargs for auxiliary API calls.
|
||||
|
||||
@@ -2063,18 +2081,7 @@ def call_llm(
|
||||
# try alternative providers instead of giving up. This handles the
|
||||
# common case where a user runs out of OpenRouter credits but has
|
||||
# Codex OAuth or another provider available.
|
||||
#
|
||||
# ── Connection error fallback ────────────────────────────────
|
||||
# When a provider endpoint is unreachable (DNS failure, connection
|
||||
# refused, timeout), try alternative providers. This handles stale
|
||||
# Codex/OAuth tokens that authenticate but whose endpoint is down,
|
||||
# and providers the user never configured that got picked up by
|
||||
# the auto-detection chain.
|
||||
should_fallback = _is_payment_error(first_err) or _is_connection_error(first_err)
|
||||
if should_fallback:
|
||||
reason = "payment error" if _is_payment_error(first_err) else "connection error"
|
||||
logger.info("Auxiliary %s: %s on %s (%s), trying fallback",
|
||||
task or "call", reason, resolved_provider, first_err)
|
||||
if _is_payment_error(first_err):
|
||||
fb_client, fb_model, fb_label = _try_payment_fallback(
|
||||
resolved_provider, task)
|
||||
if fb_client is not None:
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
"""BuiltinMemoryProvider — wraps MEMORY.md / USER.md as a MemoryProvider.
|
||||
|
||||
Always registered as the first provider. Cannot be disabled or removed.
|
||||
This is the existing Hermes memory system exposed through the provider
|
||||
interface for compatibility with the MemoryManager.
|
||||
|
||||
The actual storage logic lives in tools/memory_tool.py (MemoryStore).
|
||||
This provider is a thin adapter that delegates to MemoryStore and
|
||||
exposes the memory tool schema.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from agent.memory_provider import MemoryProvider
|
||||
from tools.registry import tool_error
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BuiltinMemoryProvider(MemoryProvider):
|
||||
"""Built-in file-backed memory (MEMORY.md + USER.md).
|
||||
|
||||
Always active, never disabled by other providers. The `memory` tool
|
||||
is handled by run_agent.py's agent-level tool interception (not through
|
||||
the normal registry), so get_tool_schemas() returns an empty list —
|
||||
the memory tool is already wired separately.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
memory_store=None,
|
||||
memory_enabled: bool = False,
|
||||
user_profile_enabled: bool = False,
|
||||
):
|
||||
self._store = memory_store
|
||||
self._memory_enabled = memory_enabled
|
||||
self._user_profile_enabled = user_profile_enabled
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "builtin"
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Built-in memory is always available."""
|
||||
return True
|
||||
|
||||
def initialize(self, session_id: str, **kwargs) -> None:
|
||||
"""Load memory from disk if not already loaded."""
|
||||
if self._store is not None:
|
||||
self._store.load_from_disk()
|
||||
|
||||
def system_prompt_block(self) -> str:
|
||||
"""Return MEMORY.md and USER.md content for the system prompt.
|
||||
|
||||
Uses the frozen snapshot captured at load time. This ensures the
|
||||
system prompt stays stable throughout a session (preserving the
|
||||
prompt cache), even though the live entries may change via tool calls.
|
||||
"""
|
||||
if not self._store:
|
||||
return ""
|
||||
|
||||
parts = []
|
||||
if self._memory_enabled:
|
||||
mem_block = self._store.format_for_system_prompt("memory")
|
||||
if mem_block:
|
||||
parts.append(mem_block)
|
||||
if self._user_profile_enabled:
|
||||
user_block = self._store.format_for_system_prompt("user")
|
||||
if user_block:
|
||||
parts.append(user_block)
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def prefetch(self, query: str, *, session_id: str = "") -> str:
|
||||
"""Built-in memory doesn't do query-based recall — it's injected via system_prompt_block."""
|
||||
return ""
|
||||
|
||||
def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None:
|
||||
"""Built-in memory doesn't auto-sync turns — writes happen via the memory tool."""
|
||||
|
||||
def get_tool_schemas(self) -> List[Dict[str, Any]]:
|
||||
"""Return empty list.
|
||||
|
||||
The `memory` tool is an agent-level intercepted tool, handled
|
||||
specially in run_agent.py before normal tool dispatch. It's not
|
||||
part of the standard tool registry. We don't duplicate it here.
|
||||
"""
|
||||
return []
|
||||
|
||||
def handle_tool_call(self, tool_name: str, args: Dict[str, Any], **kwargs) -> str:
|
||||
"""Not used — the memory tool is intercepted in run_agent.py."""
|
||||
return tool_error("Built-in memory tool is handled by the agent loop")
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""No cleanup needed — files are saved on every write."""
|
||||
|
||||
# -- Property access for backward compatibility --------------------------
|
||||
|
||||
@property
|
||||
def store(self):
|
||||
"""Access the underlying MemoryStore for legacy code paths."""
|
||||
return self._store
|
||||
|
||||
@property
|
||||
def memory_enabled(self) -> bool:
|
||||
return self._memory_enabled
|
||||
|
||||
@property
|
||||
def user_profile_enabled(self) -> bool:
|
||||
return self._user_profile_enabled
|
||||
+70
-112
@@ -114,6 +114,7 @@ class ContextCompressor:
|
||||
|
||||
self.last_prompt_tokens = 0
|
||||
self.last_completion_tokens = 0
|
||||
self.last_total_tokens = 0
|
||||
|
||||
self.summary_model = summary_model_override or ""
|
||||
|
||||
@@ -125,27 +126,40 @@ class ContextCompressor:
|
||||
"""Update tracked token usage from API response."""
|
||||
self.last_prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
self.last_completion_tokens = usage.get("completion_tokens", 0)
|
||||
self.last_total_tokens = usage.get("total_tokens", 0)
|
||||
|
||||
def should_compress(self, prompt_tokens: int = None) -> bool:
|
||||
"""Check if context exceeds the compression threshold."""
|
||||
tokens = prompt_tokens if prompt_tokens is not None else self.last_prompt_tokens
|
||||
return tokens >= self.threshold_tokens
|
||||
|
||||
def should_compress_preflight(self, messages: List[Dict[str, Any]]) -> bool:
|
||||
"""Quick pre-flight check using rough estimate (before API call)."""
|
||||
rough_estimate = estimate_messages_tokens_rough(messages)
|
||||
return rough_estimate >= self.threshold_tokens
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""Get current compression status for display/logging."""
|
||||
return {
|
||||
"last_prompt_tokens": self.last_prompt_tokens,
|
||||
"threshold_tokens": self.threshold_tokens,
|
||||
"context_length": self.context_length,
|
||||
"usage_percent": min(100, (self.last_prompt_tokens / self.context_length * 100)) if self.context_length else 0,
|
||||
"compression_count": self.compression_count,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool output pruning (cheap pre-pass, no LLM call)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _prune_old_tool_results(
|
||||
self, messages: List[Dict[str, Any]], protect_tail_count: int,
|
||||
protect_tail_tokens: int | None = None,
|
||||
) -> tuple[List[Dict[str, Any]], int]:
|
||||
"""Replace old tool result contents with a short placeholder.
|
||||
|
||||
Walks backward from the end, protecting the most recent messages that
|
||||
fall within ``protect_tail_tokens`` (when provided) OR the last
|
||||
``protect_tail_count`` messages (backward-compatible default).
|
||||
When both are given, the token budget takes priority and the message
|
||||
count acts as a hard minimum floor.
|
||||
Walks backward from the end, protecting the most recent
|
||||
``protect_tail_count`` messages. Older tool results get their
|
||||
content replaced with a placeholder string.
|
||||
|
||||
Returns (pruned_messages, pruned_count).
|
||||
"""
|
||||
@@ -154,29 +168,7 @@ class ContextCompressor:
|
||||
|
||||
result = [m.copy() for m in messages]
|
||||
pruned = 0
|
||||
|
||||
# Determine the prune boundary
|
||||
if protect_tail_tokens is not None and protect_tail_tokens > 0:
|
||||
# Token-budget approach: walk backward accumulating tokens
|
||||
accumulated = 0
|
||||
boundary = len(result)
|
||||
min_protect = min(protect_tail_count, len(result) - 1)
|
||||
for i in range(len(result) - 1, -1, -1):
|
||||
msg = result[i]
|
||||
content_len = len(msg.get("content") or "")
|
||||
msg_tokens = content_len // _CHARS_PER_TOKEN + 10
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
if isinstance(tc, dict):
|
||||
args = tc.get("function", {}).get("arguments", "")
|
||||
msg_tokens += len(args) // _CHARS_PER_TOKEN
|
||||
if accumulated + msg_tokens > protect_tail_tokens and (len(result) - i) >= min_protect:
|
||||
boundary = i
|
||||
break
|
||||
accumulated += msg_tokens
|
||||
boundary = i
|
||||
prune_boundary = max(boundary, len(result) - min_protect)
|
||||
else:
|
||||
prune_boundary = len(result) - protect_tail_count
|
||||
prune_boundary = len(result) - protect_tail_count
|
||||
|
||||
for i in range(prune_boundary):
|
||||
msg = result[i]
|
||||
@@ -207,39 +199,30 @@ class ContextCompressor:
|
||||
budget = int(content_tokens * _SUMMARY_RATIO)
|
||||
return max(_MIN_SUMMARY_TOKENS, min(budget, self.max_summary_tokens))
|
||||
|
||||
# Truncation limits for the summarizer input. These bound how much of
|
||||
# each message the summary model sees — the budget is the *summary*
|
||||
# model's context window, not the main model's.
|
||||
_CONTENT_MAX = 6000 # total chars per message body
|
||||
_CONTENT_HEAD = 4000 # chars kept from the start
|
||||
_CONTENT_TAIL = 1500 # chars kept from the end
|
||||
_TOOL_ARGS_MAX = 1500 # tool call argument chars
|
||||
_TOOL_ARGS_HEAD = 1200 # kept from the start of tool args
|
||||
|
||||
def _serialize_for_summary(self, turns: List[Dict[str, Any]]) -> str:
|
||||
"""Serialize conversation turns into labeled text for the summarizer.
|
||||
|
||||
Includes tool call arguments and result content (up to
|
||||
``_CONTENT_MAX`` chars per message) so the summarizer can preserve
|
||||
specific details like file paths, commands, and outputs.
|
||||
Includes tool call arguments and result content (up to 3000 chars
|
||||
per message) so the summarizer can preserve specific details like
|
||||
file paths, commands, and outputs.
|
||||
"""
|
||||
parts = []
|
||||
for msg in turns:
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content") or ""
|
||||
|
||||
# Tool results: keep enough content for the summarizer
|
||||
# Tool results: keep more content than before (3000 chars)
|
||||
if role == "tool":
|
||||
tool_id = msg.get("tool_call_id", "")
|
||||
if len(content) > self._CONTENT_MAX:
|
||||
content = content[:self._CONTENT_HEAD] + "\n...[truncated]...\n" + content[-self._CONTENT_TAIL:]
|
||||
if len(content) > 3000:
|
||||
content = content[:2000] + "\n...[truncated]...\n" + content[-800:]
|
||||
parts.append(f"[TOOL RESULT {tool_id}]: {content}")
|
||||
continue
|
||||
|
||||
# Assistant messages: include tool call names AND arguments
|
||||
if role == "assistant":
|
||||
if len(content) > self._CONTENT_MAX:
|
||||
content = content[:self._CONTENT_HEAD] + "\n...[truncated]...\n" + content[-self._CONTENT_TAIL:]
|
||||
if len(content) > 3000:
|
||||
content = content[:2000] + "\n...[truncated]...\n" + content[-800:]
|
||||
tool_calls = msg.get("tool_calls", [])
|
||||
if tool_calls:
|
||||
tc_parts = []
|
||||
@@ -249,8 +232,8 @@ class ContextCompressor:
|
||||
name = fn.get("name", "?")
|
||||
args = fn.get("arguments", "")
|
||||
# Truncate long arguments but keep enough for context
|
||||
if len(args) > self._TOOL_ARGS_MAX:
|
||||
args = args[:self._TOOL_ARGS_HEAD] + "..."
|
||||
if len(args) > 500:
|
||||
args = args[:400] + "..."
|
||||
tc_parts.append(f" {name}({args})")
|
||||
else:
|
||||
fn = getattr(tc, "function", None)
|
||||
@@ -261,8 +244,8 @@ class ContextCompressor:
|
||||
continue
|
||||
|
||||
# User and other roles
|
||||
if len(content) > self._CONTENT_MAX:
|
||||
content = content[:self._CONTENT_HEAD] + "\n...[truncated]...\n" + content[-self._CONTENT_TAIL:]
|
||||
if len(content) > 3000:
|
||||
content = content[:2000] + "\n...[truncated]...\n" + content[-800:]
|
||||
parts.append(f"[{role.upper()}]: {content}")
|
||||
|
||||
return "\n\n".join(parts)
|
||||
@@ -327,9 +310,6 @@ Update the summary using this exact structure. PRESERVE all existing information
|
||||
## Critical Context
|
||||
[Any specific values, error messages, configuration details, or data that would be lost without explicit preservation]
|
||||
|
||||
## Tools & Patterns
|
||||
[Which tools were used, how they were used effectively, and any tool-specific discoveries. Accumulate across compactions.]
|
||||
|
||||
Target ~{summary_budget} tokens. Be specific — include file paths, command outputs, error messages, and concrete values rather than vague descriptions.
|
||||
|
||||
Write only the summary body. Do not include any preamble or prefix."""
|
||||
@@ -368,9 +348,6 @@ Use this exact structure:
|
||||
## Critical Context
|
||||
[Any specific values, error messages, configuration details, or data that would be lost without explicit preservation]
|
||||
|
||||
## Tools & Patterns
|
||||
[Which tools were used, how they were used effectively, and any tool-specific discoveries (e.g., preferred flags, working invocations, successful command patterns)]
|
||||
|
||||
Target ~{summary_budget} tokens. Be specific — include file paths, command outputs, error messages, and concrete values rather than vague descriptions. The goal is to prevent the next assistant from repeating work or losing important details.
|
||||
|
||||
Write only the summary body. Do not include any preamble or prefix."""
|
||||
@@ -541,20 +518,13 @@ Write only the summary body. Do not include any preamble or prefix."""
|
||||
derived from ``summary_target_ratio * context_length``, so it
|
||||
scales automatically with the model's context window.
|
||||
|
||||
Token budget is the primary criterion. A hard minimum of 3 messages
|
||||
is always protected, but the budget is allowed to exceed by up to
|
||||
1.5x to avoid cutting inside an oversized message (tool output, file
|
||||
read, etc.). If even the minimum 3 messages exceed 1.5x the budget
|
||||
the cut is placed right after the head so compression still runs.
|
||||
|
||||
Never cuts inside a tool_call/result group.
|
||||
Never cuts inside a tool_call/result group. Falls back to the old
|
||||
``protect_last_n`` if the budget would protect fewer messages.
|
||||
"""
|
||||
if token_budget is None:
|
||||
token_budget = self.tail_token_budget
|
||||
n = len(messages)
|
||||
# Hard minimum: always keep at least 3 messages in the tail
|
||||
min_tail = min(3, n - head_end - 1) if n - head_end > 1 else 0
|
||||
soft_ceiling = int(token_budget * 1.5)
|
||||
min_tail = self.protect_last_n
|
||||
accumulated = 0
|
||||
cut_idx = n # start from beyond the end
|
||||
|
||||
@@ -567,21 +537,21 @@ Write only the summary body. Do not include any preamble or prefix."""
|
||||
if isinstance(tc, dict):
|
||||
args = tc.get("function", {}).get("arguments", "")
|
||||
msg_tokens += len(args) // _CHARS_PER_TOKEN
|
||||
# Stop once we exceed the soft ceiling (unless we haven't hit min_tail yet)
|
||||
if accumulated + msg_tokens > soft_ceiling and (n - i) >= min_tail:
|
||||
if accumulated + msg_tokens > token_budget and (n - i) >= min_tail:
|
||||
break
|
||||
accumulated += msg_tokens
|
||||
cut_idx = i
|
||||
|
||||
# Ensure we protect at least min_tail messages
|
||||
# Ensure we protect at least protect_last_n messages
|
||||
fallback_cut = n - min_tail
|
||||
if cut_idx > fallback_cut:
|
||||
cut_idx = fallback_cut
|
||||
|
||||
# If the token budget would protect everything (small conversations),
|
||||
# force a cut after the head so compression can still remove middle turns.
|
||||
# fall back to the fixed protect_last_n approach so compression can
|
||||
# still remove middle turns.
|
||||
if cut_idx <= head_end:
|
||||
cut_idx = max(fallback_cut, head_end + 1)
|
||||
cut_idx = fallback_cut
|
||||
|
||||
# Align to avoid splitting tool groups
|
||||
cut_idx = self._align_boundary_backward(messages, cut_idx)
|
||||
@@ -606,13 +576,12 @@ Write only the summary body. Do not include any preamble or prefix."""
|
||||
up so the API never receives mismatched IDs.
|
||||
"""
|
||||
n_messages = len(messages)
|
||||
# Only need head + 3 tail messages minimum (token budget decides the real tail size)
|
||||
_min_for_compress = self.protect_first_n + 3 + 1
|
||||
if n_messages <= _min_for_compress:
|
||||
if n_messages <= self.protect_first_n + self.protect_last_n + 1:
|
||||
if not self.quiet_mode:
|
||||
logger.warning(
|
||||
"Cannot compress: only %d messages (need > %d)",
|
||||
n_messages, _min_for_compress,
|
||||
n_messages,
|
||||
self.protect_first_n + self.protect_last_n + 1,
|
||||
)
|
||||
return messages
|
||||
|
||||
@@ -620,8 +589,7 @@ Write only the summary body. Do not include any preamble or prefix."""
|
||||
|
||||
# Phase 1: Prune old tool results (cheap, no LLM call)
|
||||
messages, pruned_count = self._prune_old_tool_results(
|
||||
messages, protect_tail_count=self.protect_last_n,
|
||||
protect_tail_tokens=self.tail_token_budget,
|
||||
messages, protect_tail_count=self.protect_last_n * 3,
|
||||
)
|
||||
if pruned_count and not self.quiet_mode:
|
||||
logger.info("Pre-compression: pruned %d old tool result(s)", pruned_count)
|
||||
@@ -674,43 +642,33 @@ Write only the summary body. Do not include any preamble or prefix."""
|
||||
)
|
||||
compressed.append(msg)
|
||||
|
||||
# If LLM summary failed, insert a static fallback so the model
|
||||
# knows context was lost rather than silently dropping everything.
|
||||
if not summary:
|
||||
if not self.quiet_mode:
|
||||
logger.warning("Summary generation failed — inserting static fallback context marker")
|
||||
n_dropped = compress_end - compress_start
|
||||
summary = (
|
||||
f"{SUMMARY_PREFIX}\n"
|
||||
f"Summary generation was unavailable. {n_dropped} conversation turns were "
|
||||
f"removed to free context space but could not be summarized. The removed "
|
||||
f"turns contained earlier work in this session. Continue based on the "
|
||||
f"recent messages below and the current state of any files or resources."
|
||||
)
|
||||
|
||||
_merge_summary_into_tail = False
|
||||
last_head_role = messages[compress_start - 1].get("role", "user") if compress_start > 0 else "user"
|
||||
first_tail_role = messages[compress_end].get("role", "user") if compress_end < n_messages else "user"
|
||||
# Pick a role that avoids consecutive same-role with both neighbors.
|
||||
# Priority: avoid colliding with head (already committed), then tail.
|
||||
if last_head_role in ("assistant", "tool"):
|
||||
summary_role = "user"
|
||||
else:
|
||||
summary_role = "assistant"
|
||||
# If the chosen role collides with the tail AND flipping wouldn't
|
||||
# collide with the head, flip it.
|
||||
if summary_role == first_tail_role:
|
||||
flipped = "assistant" if summary_role == "user" else "user"
|
||||
if flipped != last_head_role:
|
||||
summary_role = flipped
|
||||
if summary:
|
||||
last_head_role = messages[compress_start - 1].get("role", "user") if compress_start > 0 else "user"
|
||||
first_tail_role = messages[compress_end].get("role", "user") if compress_end < n_messages else "user"
|
||||
# Pick a role that avoids consecutive same-role with both neighbors.
|
||||
# Priority: avoid colliding with head (already committed), then tail.
|
||||
if last_head_role in ("assistant", "tool"):
|
||||
summary_role = "user"
|
||||
else:
|
||||
# Both roles would create consecutive same-role messages
|
||||
# (e.g. head=assistant, tail=user — neither role works).
|
||||
# Merge the summary into the first tail message instead
|
||||
# of inserting a standalone message that breaks alternation.
|
||||
_merge_summary_into_tail = True
|
||||
if not _merge_summary_into_tail:
|
||||
compressed.append({"role": summary_role, "content": summary})
|
||||
summary_role = "assistant"
|
||||
# If the chosen role collides with the tail AND flipping wouldn't
|
||||
# collide with the head, flip it.
|
||||
if summary_role == first_tail_role:
|
||||
flipped = "assistant" if summary_role == "user" else "user"
|
||||
if flipped != last_head_role:
|
||||
summary_role = flipped
|
||||
else:
|
||||
# Both roles would create consecutive same-role messages
|
||||
# (e.g. head=assistant, tail=user — neither role works).
|
||||
# Merge the summary into the first tail message instead
|
||||
# of inserting a standalone message that breaks alternation.
|
||||
_merge_summary_into_tail = True
|
||||
if not _merge_summary_into_tail:
|
||||
compressed.append({"role": summary_role, "content": summary})
|
||||
else:
|
||||
if not self.quiet_mode:
|
||||
logger.debug("No summary model available — middle turns dropped without summary")
|
||||
|
||||
for i in range(compress_end, n_messages):
|
||||
msg = messages[i].copy()
|
||||
|
||||
@@ -18,14 +18,12 @@ import hermes_cli.auth as auth_mod
|
||||
from hermes_cli.auth import (
|
||||
CODEX_ACCESS_TOKEN_REFRESH_SKEW_SECONDS,
|
||||
DEFAULT_AGENT_KEY_MIN_TTL_SECONDS,
|
||||
KIMI_CODE_BASE_URL,
|
||||
PROVIDER_REGISTRY,
|
||||
_codex_access_token_is_expiring,
|
||||
_decode_jwt_claims,
|
||||
_import_codex_cli_tokens,
|
||||
_load_auth_store,
|
||||
_load_provider_state,
|
||||
_resolve_kimi_base_url,
|
||||
_resolve_zai_base_url,
|
||||
read_credential_pool,
|
||||
write_credential_pool,
|
||||
@@ -66,10 +64,10 @@ SUPPORTED_POOL_STRATEGIES = {
|
||||
}
|
||||
|
||||
# Cooldown before retrying an exhausted credential.
|
||||
# 429 (rate-limited) and 402 (billing/quota) both cool down after 1 hour.
|
||||
# Provider-supplied reset_at timestamps override these defaults.
|
||||
# 429 (rate-limited) cools down faster since quotas reset frequently.
|
||||
# 402 (billing/quota) and other codes use a longer default.
|
||||
EXHAUSTED_TTL_429_SECONDS = 60 * 60 # 1 hour
|
||||
EXHAUSTED_TTL_DEFAULT_SECONDS = 60 * 60 # 1 hour
|
||||
EXHAUSTED_TTL_DEFAULT_SECONDS = 24 * 60 * 60 # 24 hours
|
||||
|
||||
# Pool key prefix for custom OpenAI-compatible endpoints.
|
||||
# Custom endpoints all share provider='custom' but are keyed by their
|
||||
@@ -633,6 +631,17 @@ class CredentialPool:
|
||||
return False
|
||||
return False
|
||||
|
||||
def mark_used(self, entry_id: Optional[str] = None) -> None:
|
||||
"""Increment request_count for tracking. Used by least_used strategy."""
|
||||
target_id = entry_id or self._current_id
|
||||
if not target_id:
|
||||
return
|
||||
with self._lock:
|
||||
for idx, entry in enumerate(self._entries):
|
||||
if entry.id == target_id:
|
||||
self._entries[idx] = replace(entry, request_count=entry.request_count + 1)
|
||||
return
|
||||
|
||||
def select(self) -> Optional[PooledCredential]:
|
||||
with self._lock:
|
||||
return self._select_unlocked()
|
||||
@@ -794,6 +803,11 @@ class CredentialPool:
|
||||
else:
|
||||
self._active_leases[credential_id] = count - 1
|
||||
|
||||
def active_lease_count(self, credential_id: str) -> int:
|
||||
"""Return the number of active leases for a credential."""
|
||||
with self._lock:
|
||||
return self._active_leases.get(credential_id, 0)
|
||||
|
||||
def try_refresh_current(self) -> Optional[PooledCredential]:
|
||||
with self._lock:
|
||||
return self._try_refresh_current_unlocked()
|
||||
@@ -1070,9 +1084,7 @@ def _seed_from_env(provider: str, entries: List[PooledCredential]) -> Tuple[bool
|
||||
active_sources.add(source)
|
||||
auth_type = AUTH_TYPE_OAUTH if provider == "anthropic" and not token.startswith("sk-ant-api") else AUTH_TYPE_API_KEY
|
||||
base_url = env_url or pconfig.inference_base_url
|
||||
if provider == "kimi-coding":
|
||||
base_url = _resolve_kimi_base_url(token, pconfig.inference_base_url, env_url)
|
||||
elif provider == "zai":
|
||||
if provider == "zai":
|
||||
base_url = _resolve_zai_base_url(token, pconfig.inference_base_url, env_url)
|
||||
changed |= _upsert_entry(
|
||||
entries,
|
||||
|
||||
@@ -67,6 +67,26 @@ def _get_skin():
|
||||
return None
|
||||
|
||||
|
||||
def get_skin_faces(key: str, default: list) -> list:
|
||||
"""Get spinner face list from active skin, falling back to default."""
|
||||
skin = _get_skin()
|
||||
if skin:
|
||||
faces = skin.get_spinner_list(key)
|
||||
if faces:
|
||||
return faces
|
||||
return default
|
||||
|
||||
|
||||
def get_skin_verbs() -> list:
|
||||
"""Get thinking verbs from active skin."""
|
||||
skin = _get_skin()
|
||||
if skin:
|
||||
verbs = skin.get_spinner_list("thinking_verbs")
|
||||
if verbs:
|
||||
return verbs
|
||||
return KawaiiSpinner.THINKING_VERBS
|
||||
|
||||
|
||||
def get_skin_tool_prefix() -> str:
|
||||
"""Get tool output prefix character from active skin."""
|
||||
skin = _get_skin()
|
||||
@@ -703,6 +723,46 @@ class KawaiiSpinner:
|
||||
return False
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Kawaii face arrays (used by AIAgent._execute_tool_calls for spinner text)
|
||||
# =========================================================================
|
||||
|
||||
KAWAII_SEARCH = [
|
||||
"♪(´ε` )", "(。◕‿◕。)", "ヾ(^∇^)", "(◕ᴗ◕✿)", "( ˘▽˘)っ",
|
||||
"٩(◕‿◕。)۶", "(✿◠‿◠)", "♪~(´ε` )", "(ノ´ヮ`)ノ*:・゚✧", "\(◎o◎)/",
|
||||
]
|
||||
KAWAII_READ = [
|
||||
"φ(゜▽゜*)♪", "( ˘▽˘)っ", "(⌐■_■)", "٩(。•́‿•̀。)۶", "(◕‿◕✿)",
|
||||
"ヾ(@⌒ー⌒@)ノ", "(✧ω✧)", "♪(๑ᴖ◡ᴖ๑)♪", "(≧◡≦)", "( ´ ▽ ` )ノ",
|
||||
]
|
||||
KAWAII_TERMINAL = [
|
||||
"ヽ(>∀<☆)ノ", "(ノ°∀°)ノ", "٩(^ᴗ^)۶", "ヾ(⌐■_■)ノ♪", "(•̀ᴗ•́)و",
|
||||
"┗(^0^)┓", "(`・ω・´)", "\( ̄▽ ̄)/", "(ง •̀_•́)ง", "ヽ(´▽`)/",
|
||||
]
|
||||
KAWAII_BROWSER = [
|
||||
"(ノ°∀°)ノ", "(☞゚ヮ゚)☞", "( ͡° ͜ʖ ͡°)", "┌( ಠ_ಠ)┘", "(⊙_⊙)?",
|
||||
"ヾ(•ω•`)o", "( ̄ω ̄)", "( ˇωˇ )", "(ᵔᴥᵔ)", "\(◎o◎)/",
|
||||
]
|
||||
KAWAII_CREATE = [
|
||||
"✧*。٩(ˊᗜˋ*)و✧", "(ノ◕ヮ◕)ノ*:・゚✧", "ヽ(>∀<☆)ノ", "٩(♡ε♡)۶", "(◕‿◕)♡",
|
||||
"✿◕ ‿ ◕✿", "(*≧▽≦)", "ヾ(^-^)ノ", "(☆▽☆)", "°˖✧◝(⁰▿⁰)◜✧˖°",
|
||||
]
|
||||
KAWAII_SKILL = [
|
||||
"ヾ(@⌒ー⌒@)ノ", "(๑˃ᴗ˂)ﻭ", "٩(◕‿◕。)۶", "(✿╹◡╹)", "ヽ(・∀・)ノ",
|
||||
"(ノ´ヮ`)ノ*:・゚✧", "♪(๑ᴖ◡ᴖ๑)♪", "(◠‿◠)", "٩(ˊᗜˋ*)و", "(^▽^)",
|
||||
"ヾ(^∇^)", "(★ω★)/", "٩(。•́‿•̀。)۶", "(◕ᴗ◕✿)", "\(◎o◎)/",
|
||||
"(✧ω✧)", "ヽ(>∀<☆)ノ", "( ˘▽˘)っ", "(≧◡≦) ♡", "ヾ( ̄▽ ̄)",
|
||||
]
|
||||
KAWAII_THINK = [
|
||||
"(っ°Д°;)っ", "(;′⌒`)", "(・_・ヾ", "( ´_ゝ`)", "( ̄ヘ ̄)",
|
||||
"(。-`ω´-)", "( ˘︹˘ )", "(¬_¬)", "ヽ(ー_ー )ノ", "(;一_一)",
|
||||
]
|
||||
KAWAII_GENERIC = [
|
||||
"♪(´ε` )", "(◕‿◕✿)", "ヾ(^∇^)", "٩(◕‿◕。)۶", "(✿◠‿◠)",
|
||||
"(ノ´ヮ`)ノ*:・゚✧", "ヽ(>∀<☆)ノ", "(☆▽☆)", "( ˘▽˘)っ", "(≧◡≦)",
|
||||
]
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Cute tool message (completion line that replaces the spinner)
|
||||
# =========================================================================
|
||||
@@ -910,6 +970,22 @@ _SKY_BLUE = "\033[38;5;117m"
|
||||
_ANSI_RESET = "\033[0m"
|
||||
|
||||
|
||||
def honcho_session_url(workspace: str, session_name: str) -> str:
|
||||
"""Build a Honcho app URL for a session."""
|
||||
from urllib.parse import quote
|
||||
return (
|
||||
f"https://app.honcho.dev/explore"
|
||||
f"?workspace={quote(workspace, safe='')}"
|
||||
f"&view=sessions"
|
||||
f"&session={quote(session_name, safe='')}"
|
||||
)
|
||||
|
||||
|
||||
def _osc8_link(url: str, text: str) -> str:
|
||||
"""OSC 8 terminal hyperlink (clickable in iTerm2, Ghostty, WezTerm, etc.)."""
|
||||
return f"\033]8;;{url}\033\\{text}\033]8;;\033\\"
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Context pressure display (CLI user-facing warnings)
|
||||
# =========================================================================
|
||||
|
||||
@@ -1,782 +0,0 @@
|
||||
"""API error classification for smart failover and recovery.
|
||||
|
||||
Provides a structured taxonomy of API errors and a priority-ordered
|
||||
classification pipeline that determines the correct recovery action
|
||||
(retry, rotate credential, fallback to another provider, compress
|
||||
context, or abort).
|
||||
|
||||
Replaces scattered inline string-matching with a centralized classifier
|
||||
that the main retry loop in run_agent.py consults for every API failure.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── Error taxonomy ──────────────────────────────────────────────────────
|
||||
|
||||
class FailoverReason(enum.Enum):
|
||||
"""Why an API call failed — determines recovery strategy."""
|
||||
|
||||
# Authentication / authorization
|
||||
auth = "auth" # Transient auth (401/403) — refresh/rotate
|
||||
auth_permanent = "auth_permanent" # Auth failed after refresh — abort
|
||||
|
||||
# Billing / quota
|
||||
billing = "billing" # 402 or confirmed credit exhaustion — rotate immediately
|
||||
rate_limit = "rate_limit" # 429 or quota-based throttling — backoff then rotate
|
||||
|
||||
# Server-side
|
||||
overloaded = "overloaded" # 503/529 — provider overloaded, backoff
|
||||
server_error = "server_error" # 500/502 — internal server error, retry
|
||||
|
||||
# Transport
|
||||
timeout = "timeout" # Connection/read timeout — rebuild client + retry
|
||||
|
||||
# Context / payload
|
||||
context_overflow = "context_overflow" # Context too large — compress, not failover
|
||||
payload_too_large = "payload_too_large" # 413 — compress payload
|
||||
|
||||
# Model
|
||||
model_not_found = "model_not_found" # 404 or invalid model — fallback to different model
|
||||
|
||||
# Request format
|
||||
format_error = "format_error" # 400 bad request — abort or strip + retry
|
||||
|
||||
# Provider-specific
|
||||
thinking_signature = "thinking_signature" # Anthropic thinking block sig invalid
|
||||
long_context_tier = "long_context_tier" # Anthropic "extra usage" tier gate
|
||||
|
||||
# Catch-all
|
||||
unknown = "unknown" # Unclassifiable — retry with backoff
|
||||
|
||||
|
||||
# ── Classification result ───────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class ClassifiedError:
|
||||
"""Structured classification of an API error with recovery hints."""
|
||||
|
||||
reason: FailoverReason
|
||||
status_code: Optional[int] = None
|
||||
provider: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
message: str = ""
|
||||
error_context: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Recovery action hints — the retry loop checks these instead of
|
||||
# re-classifying the error itself.
|
||||
retryable: bool = True
|
||||
should_compress: bool = False
|
||||
should_rotate_credential: bool = False
|
||||
should_fallback: bool = False
|
||||
|
||||
@property
|
||||
def is_auth(self) -> bool:
|
||||
return self.reason in (FailoverReason.auth, FailoverReason.auth_permanent)
|
||||
|
||||
|
||||
|
||||
# ── Provider-specific patterns ──────────────────────────────────────────
|
||||
|
||||
# Patterns that indicate billing exhaustion (not transient rate limit)
|
||||
_BILLING_PATTERNS = [
|
||||
"insufficient credits",
|
||||
"insufficient_quota",
|
||||
"credit balance",
|
||||
"credits have been exhausted",
|
||||
"top up your credits",
|
||||
"payment required",
|
||||
"billing hard limit",
|
||||
"exceeded your current quota",
|
||||
"account is deactivated",
|
||||
"plan does not include",
|
||||
]
|
||||
|
||||
# Patterns that indicate rate limiting (transient, will resolve)
|
||||
_RATE_LIMIT_PATTERNS = [
|
||||
"rate limit",
|
||||
"rate_limit",
|
||||
"too many requests",
|
||||
"throttled",
|
||||
"requests per minute",
|
||||
"tokens per minute",
|
||||
"requests per day",
|
||||
"try again in",
|
||||
"please retry after",
|
||||
"resource_exhausted",
|
||||
]
|
||||
|
||||
# Usage-limit patterns that need disambiguation (could be billing OR rate_limit)
|
||||
_USAGE_LIMIT_PATTERNS = [
|
||||
"usage limit",
|
||||
"quota",
|
||||
"limit exceeded",
|
||||
"key limit exceeded",
|
||||
]
|
||||
|
||||
# Patterns confirming usage limit is transient (not billing)
|
||||
_USAGE_LIMIT_TRANSIENT_SIGNALS = [
|
||||
"try again",
|
||||
"retry",
|
||||
"resets at",
|
||||
"reset in",
|
||||
"wait",
|
||||
"requests remaining",
|
||||
"periodic",
|
||||
"window",
|
||||
]
|
||||
|
||||
# Payload-too-large patterns detected from message text (no status_code attr).
|
||||
# Proxies and some backends embed the HTTP status in the error message.
|
||||
_PAYLOAD_TOO_LARGE_PATTERNS = [
|
||||
"request entity too large",
|
||||
"payload too large",
|
||||
"error code: 413",
|
||||
]
|
||||
|
||||
# Context overflow patterns
|
||||
_CONTEXT_OVERFLOW_PATTERNS = [
|
||||
"context length",
|
||||
"context size",
|
||||
"maximum context",
|
||||
"token limit",
|
||||
"too many tokens",
|
||||
"reduce the length",
|
||||
"exceeds the limit",
|
||||
"context window",
|
||||
"prompt is too long",
|
||||
"prompt exceeds max length",
|
||||
"max_tokens",
|
||||
"maximum number of tokens",
|
||||
# Chinese error messages (some providers return these)
|
||||
"超过最大长度",
|
||||
"上下文长度",
|
||||
]
|
||||
|
||||
# Model not found patterns
|
||||
_MODEL_NOT_FOUND_PATTERNS = [
|
||||
"is not a valid model",
|
||||
"invalid model",
|
||||
"model not found",
|
||||
"model_not_found",
|
||||
"does not exist",
|
||||
"no such model",
|
||||
"unknown model",
|
||||
"unsupported model",
|
||||
]
|
||||
|
||||
# Auth patterns (non-status-code signals)
|
||||
_AUTH_PATTERNS = [
|
||||
"invalid api key",
|
||||
"invalid_api_key",
|
||||
"authentication",
|
||||
"unauthorized",
|
||||
"forbidden",
|
||||
"invalid token",
|
||||
"token expired",
|
||||
"token revoked",
|
||||
"access denied",
|
||||
]
|
||||
|
||||
# Anthropic thinking block signature patterns
|
||||
_THINKING_SIG_PATTERNS = [
|
||||
"signature", # Combined with "thinking" check
|
||||
]
|
||||
|
||||
# Transport error type names
|
||||
_TRANSPORT_ERROR_TYPES = frozenset({
|
||||
"ReadTimeout", "ConnectTimeout", "PoolTimeout",
|
||||
"ConnectError", "RemoteProtocolError",
|
||||
"ConnectionError", "ConnectionResetError",
|
||||
"ConnectionAbortedError", "BrokenPipeError",
|
||||
"TimeoutError", "ReadError",
|
||||
"ServerDisconnectedError",
|
||||
# OpenAI SDK errors (not subclasses of Python builtins)
|
||||
"APIConnectionError",
|
||||
"APITimeoutError",
|
||||
})
|
||||
|
||||
# Server disconnect patterns (no status code, but transport-level)
|
||||
_SERVER_DISCONNECT_PATTERNS = [
|
||||
"server disconnected",
|
||||
"peer closed connection",
|
||||
"connection reset by peer",
|
||||
"connection was closed",
|
||||
"network connection lost",
|
||||
"unexpected eof",
|
||||
"incomplete chunked read",
|
||||
]
|
||||
|
||||
|
||||
# ── Classification pipeline ─────────────────────────────────────────────
|
||||
|
||||
def classify_api_error(
|
||||
error: Exception,
|
||||
*,
|
||||
provider: str = "",
|
||||
model: str = "",
|
||||
approx_tokens: int = 0,
|
||||
context_length: int = 200000,
|
||||
num_messages: int = 0,
|
||||
) -> ClassifiedError:
|
||||
"""Classify an API error into a structured recovery recommendation.
|
||||
|
||||
Priority-ordered pipeline:
|
||||
1. Special-case provider-specific patterns (thinking sigs, tier gates)
|
||||
2. HTTP status code + message-aware refinement
|
||||
3. Error code classification (from body)
|
||||
4. Message pattern matching (billing vs rate_limit vs context vs auth)
|
||||
5. Transport error heuristics
|
||||
6. Server disconnect + large session → context overflow
|
||||
7. Fallback: unknown (retryable with backoff)
|
||||
|
||||
Args:
|
||||
error: The exception from the API call.
|
||||
provider: Current provider name (e.g. "openrouter", "anthropic").
|
||||
model: Current model slug.
|
||||
approx_tokens: Approximate token count of the current context.
|
||||
context_length: Maximum context length for the current model.
|
||||
|
||||
Returns:
|
||||
ClassifiedError with reason and recovery action hints.
|
||||
"""
|
||||
status_code = _extract_status_code(error)
|
||||
error_type = type(error).__name__
|
||||
body = _extract_error_body(error)
|
||||
error_code = _extract_error_code(body)
|
||||
|
||||
# Build a comprehensive error message string for pattern matching.
|
||||
# str(error) alone may not include the body message (e.g. OpenAI SDK's
|
||||
# APIStatusError.__str__ returns the first arg, not the body). Append
|
||||
# the body message so patterns like "try again" in 402 disambiguation
|
||||
# are detected even when only present in the structured body.
|
||||
#
|
||||
# Also extract metadata.raw — OpenRouter wraps upstream provider errors
|
||||
# inside {"error": {"message": "Provider returned error", "metadata":
|
||||
# {"raw": "<actual error JSON>"}}} and the real error message (e.g.
|
||||
# "context length exceeded") is only in the inner JSON.
|
||||
_raw_msg = str(error).lower()
|
||||
_body_msg = ""
|
||||
_metadata_msg = ""
|
||||
if isinstance(body, dict):
|
||||
_err_obj = body.get("error", {})
|
||||
if isinstance(_err_obj, dict):
|
||||
_body_msg = (_err_obj.get("message") or "").lower()
|
||||
# Parse metadata.raw for wrapped provider errors
|
||||
_metadata = _err_obj.get("metadata", {})
|
||||
if isinstance(_metadata, dict):
|
||||
_raw_json = _metadata.get("raw") or ""
|
||||
if isinstance(_raw_json, str) and _raw_json.strip():
|
||||
try:
|
||||
import json
|
||||
_inner = json.loads(_raw_json)
|
||||
if isinstance(_inner, dict):
|
||||
_inner_err = _inner.get("error", {})
|
||||
if isinstance(_inner_err, dict):
|
||||
_metadata_msg = (_inner_err.get("message") or "").lower()
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
if not _body_msg:
|
||||
_body_msg = (body.get("message") or "").lower()
|
||||
# Combine all message sources for pattern matching
|
||||
parts = [_raw_msg]
|
||||
if _body_msg and _body_msg not in _raw_msg:
|
||||
parts.append(_body_msg)
|
||||
if _metadata_msg and _metadata_msg not in _raw_msg and _metadata_msg not in _body_msg:
|
||||
parts.append(_metadata_msg)
|
||||
error_msg = " ".join(parts)
|
||||
provider_lower = (provider or "").strip().lower()
|
||||
model_lower = (model or "").strip().lower()
|
||||
|
||||
def _result(reason: FailoverReason, **overrides) -> ClassifiedError:
|
||||
defaults = {
|
||||
"reason": reason,
|
||||
"status_code": status_code,
|
||||
"provider": provider,
|
||||
"model": model,
|
||||
"message": _extract_message(error, body),
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return ClassifiedError(**defaults)
|
||||
|
||||
# ── 1. Provider-specific patterns (highest priority) ────────────
|
||||
|
||||
# Anthropic thinking block signature invalid (400).
|
||||
# Don't gate on provider — OpenRouter proxies Anthropic errors, so the
|
||||
# provider may be "openrouter" even though the error is Anthropic-specific.
|
||||
# The message pattern ("signature" + "thinking") is unique enough.
|
||||
if (
|
||||
status_code == 400
|
||||
and "signature" in error_msg
|
||||
and "thinking" in error_msg
|
||||
):
|
||||
return _result(
|
||||
FailoverReason.thinking_signature,
|
||||
retryable=True,
|
||||
should_compress=False,
|
||||
)
|
||||
|
||||
# Anthropic long-context tier gate (429 "extra usage" + "long context")
|
||||
if (
|
||||
status_code == 429
|
||||
and "extra usage" in error_msg
|
||||
and "long context" in error_msg
|
||||
):
|
||||
return _result(
|
||||
FailoverReason.long_context_tier,
|
||||
retryable=True,
|
||||
should_compress=True,
|
||||
)
|
||||
|
||||
# ── 2. HTTP status code classification ──────────────────────────
|
||||
|
||||
if status_code is not None:
|
||||
classified = _classify_by_status(
|
||||
status_code, error_msg, error_code, body,
|
||||
provider=provider_lower, model=model_lower,
|
||||
approx_tokens=approx_tokens, context_length=context_length,
|
||||
num_messages=num_messages,
|
||||
result_fn=_result,
|
||||
)
|
||||
if classified is not None:
|
||||
return classified
|
||||
|
||||
# ── 3. Error code classification ────────────────────────────────
|
||||
|
||||
if error_code:
|
||||
classified = _classify_by_error_code(error_code, error_msg, _result)
|
||||
if classified is not None:
|
||||
return classified
|
||||
|
||||
# ── 4. Message pattern matching (no status code) ────────────────
|
||||
|
||||
classified = _classify_by_message(
|
||||
error_msg, error_type,
|
||||
approx_tokens=approx_tokens,
|
||||
context_length=context_length,
|
||||
result_fn=_result,
|
||||
)
|
||||
if classified is not None:
|
||||
return classified
|
||||
|
||||
# ── 5. Server disconnect + large session → context overflow ─────
|
||||
# Must come BEFORE generic transport error catch — a disconnect on
|
||||
# a large session is more likely context overflow than a transient
|
||||
# transport hiccup. Without this ordering, RemoteProtocolError
|
||||
# always maps to timeout regardless of session size.
|
||||
|
||||
is_disconnect = any(p in error_msg for p in _SERVER_DISCONNECT_PATTERNS)
|
||||
if is_disconnect and not status_code:
|
||||
is_large = approx_tokens > context_length * 0.6 or approx_tokens > 120000 or num_messages > 200
|
||||
if is_large:
|
||||
return _result(
|
||||
FailoverReason.context_overflow,
|
||||
retryable=True,
|
||||
should_compress=True,
|
||||
)
|
||||
return _result(FailoverReason.timeout, retryable=True)
|
||||
|
||||
# ── 6. Transport / timeout heuristics ───────────────────────────
|
||||
|
||||
if error_type in _TRANSPORT_ERROR_TYPES or isinstance(error, (TimeoutError, ConnectionError, OSError)):
|
||||
return _result(FailoverReason.timeout, retryable=True)
|
||||
|
||||
# ── 7. Fallback: unknown ────────────────────────────────────────
|
||||
|
||||
return _result(FailoverReason.unknown, retryable=True)
|
||||
|
||||
|
||||
# ── Status code classification ──────────────────────────────────────────
|
||||
|
||||
def _classify_by_status(
|
||||
status_code: int,
|
||||
error_msg: str,
|
||||
error_code: str,
|
||||
body: dict,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
approx_tokens: int,
|
||||
context_length: int,
|
||||
num_messages: int = 0,
|
||||
result_fn,
|
||||
) -> Optional[ClassifiedError]:
|
||||
"""Classify based on HTTP status code with message-aware refinement."""
|
||||
|
||||
if status_code == 401:
|
||||
# Not retryable on its own — credential pool rotation and
|
||||
# provider-specific refresh (Codex, Anthropic, Nous) run before
|
||||
# the retryability check in run_agent.py. If those succeed, the
|
||||
# loop `continue`s. If they fail, retryable=False ensures we
|
||||
# hit the client-error abort path (which tries fallback first).
|
||||
return result_fn(
|
||||
FailoverReason.auth,
|
||||
retryable=False,
|
||||
should_rotate_credential=True,
|
||||
should_fallback=True,
|
||||
)
|
||||
|
||||
if status_code == 403:
|
||||
# OpenRouter 403 "key limit exceeded" is actually billing
|
||||
if "key limit exceeded" in error_msg or "spending limit" in error_msg:
|
||||
return result_fn(
|
||||
FailoverReason.billing,
|
||||
retryable=False,
|
||||
should_rotate_credential=True,
|
||||
should_fallback=True,
|
||||
)
|
||||
return result_fn(
|
||||
FailoverReason.auth,
|
||||
retryable=False,
|
||||
should_fallback=True,
|
||||
)
|
||||
|
||||
if status_code == 402:
|
||||
return _classify_402(error_msg, result_fn)
|
||||
|
||||
if status_code == 404:
|
||||
if any(p in error_msg for p in _MODEL_NOT_FOUND_PATTERNS):
|
||||
return result_fn(
|
||||
FailoverReason.model_not_found,
|
||||
retryable=False,
|
||||
should_fallback=True,
|
||||
)
|
||||
# Generic 404 — could be model or endpoint
|
||||
return result_fn(
|
||||
FailoverReason.model_not_found,
|
||||
retryable=False,
|
||||
should_fallback=True,
|
||||
)
|
||||
|
||||
if status_code == 413:
|
||||
return result_fn(
|
||||
FailoverReason.payload_too_large,
|
||||
retryable=True,
|
||||
should_compress=True,
|
||||
)
|
||||
|
||||
if status_code == 429:
|
||||
# Already checked long_context_tier above; this is a normal rate limit
|
||||
return result_fn(
|
||||
FailoverReason.rate_limit,
|
||||
retryable=True,
|
||||
should_rotate_credential=True,
|
||||
should_fallback=True,
|
||||
)
|
||||
|
||||
if status_code == 400:
|
||||
return _classify_400(
|
||||
error_msg, error_code, body,
|
||||
provider=provider, model=model,
|
||||
approx_tokens=approx_tokens,
|
||||
context_length=context_length,
|
||||
num_messages=num_messages,
|
||||
result_fn=result_fn,
|
||||
)
|
||||
|
||||
if status_code in (500, 502):
|
||||
return result_fn(FailoverReason.server_error, retryable=True)
|
||||
|
||||
if status_code in (503, 529):
|
||||
return result_fn(FailoverReason.overloaded, retryable=True)
|
||||
|
||||
# Other 4xx — non-retryable
|
||||
if 400 <= status_code < 500:
|
||||
return result_fn(
|
||||
FailoverReason.format_error,
|
||||
retryable=False,
|
||||
should_fallback=True,
|
||||
)
|
||||
|
||||
# Other 5xx — retryable
|
||||
if 500 <= status_code < 600:
|
||||
return result_fn(FailoverReason.server_error, retryable=True)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _classify_402(error_msg: str, result_fn) -> ClassifiedError:
|
||||
"""Disambiguate 402: billing exhaustion vs transient usage limit.
|
||||
|
||||
The key insight from OpenClaw: some 402s are transient rate limits
|
||||
disguised as payment errors. "Usage limit, try again in 5 minutes"
|
||||
is NOT a billing problem — it's a periodic quota that resets.
|
||||
"""
|
||||
# Check for transient usage-limit signals first
|
||||
has_usage_limit = any(p in error_msg for p in _USAGE_LIMIT_PATTERNS)
|
||||
has_transient_signal = any(p in error_msg for p in _USAGE_LIMIT_TRANSIENT_SIGNALS)
|
||||
|
||||
if has_usage_limit and has_transient_signal:
|
||||
# Transient quota — treat as rate limit, not billing
|
||||
return result_fn(
|
||||
FailoverReason.rate_limit,
|
||||
retryable=True,
|
||||
should_rotate_credential=True,
|
||||
should_fallback=True,
|
||||
)
|
||||
|
||||
# Confirmed billing exhaustion
|
||||
return result_fn(
|
||||
FailoverReason.billing,
|
||||
retryable=False,
|
||||
should_rotate_credential=True,
|
||||
should_fallback=True,
|
||||
)
|
||||
|
||||
|
||||
def _classify_400(
|
||||
error_msg: str,
|
||||
error_code: str,
|
||||
body: dict,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
approx_tokens: int,
|
||||
context_length: int,
|
||||
num_messages: int = 0,
|
||||
result_fn,
|
||||
) -> ClassifiedError:
|
||||
"""Classify 400 Bad Request — context overflow, format error, or generic."""
|
||||
|
||||
# Context overflow from 400
|
||||
if any(p in error_msg for p in _CONTEXT_OVERFLOW_PATTERNS):
|
||||
return result_fn(
|
||||
FailoverReason.context_overflow,
|
||||
retryable=True,
|
||||
should_compress=True,
|
||||
)
|
||||
|
||||
# Some providers return model-not-found as 400 instead of 404 (e.g. OpenRouter).
|
||||
if any(p in error_msg for p in _MODEL_NOT_FOUND_PATTERNS):
|
||||
return result_fn(
|
||||
FailoverReason.model_not_found,
|
||||
retryable=False,
|
||||
should_fallback=True,
|
||||
)
|
||||
|
||||
# Some providers return rate limit / billing errors as 400 instead of 429/402.
|
||||
# Check these patterns before falling through to format_error.
|
||||
if any(p in error_msg for p in _RATE_LIMIT_PATTERNS):
|
||||
return result_fn(
|
||||
FailoverReason.rate_limit,
|
||||
retryable=True,
|
||||
should_rotate_credential=True,
|
||||
should_fallback=True,
|
||||
)
|
||||
if any(p in error_msg for p in _BILLING_PATTERNS):
|
||||
return result_fn(
|
||||
FailoverReason.billing,
|
||||
retryable=False,
|
||||
should_rotate_credential=True,
|
||||
should_fallback=True,
|
||||
)
|
||||
|
||||
# Generic 400 + large session → probable context overflow
|
||||
# Anthropic sometimes returns a bare "Error" message when context is too large
|
||||
err_body_msg = ""
|
||||
if isinstance(body, dict):
|
||||
err_obj = body.get("error", {})
|
||||
if isinstance(err_obj, dict):
|
||||
err_body_msg = (err_obj.get("message") or "").strip().lower()
|
||||
# Responses API (and some providers) use flat body: {"message": "..."}
|
||||
if not err_body_msg:
|
||||
err_body_msg = (body.get("message") or "").strip().lower()
|
||||
is_generic = len(err_body_msg) < 30 or err_body_msg in ("error", "")
|
||||
is_large = approx_tokens > context_length * 0.4 or approx_tokens > 80000 or num_messages > 80
|
||||
|
||||
if is_generic and is_large:
|
||||
return result_fn(
|
||||
FailoverReason.context_overflow,
|
||||
retryable=True,
|
||||
should_compress=True,
|
||||
)
|
||||
|
||||
# Non-retryable format error
|
||||
return result_fn(
|
||||
FailoverReason.format_error,
|
||||
retryable=False,
|
||||
should_fallback=True,
|
||||
)
|
||||
|
||||
|
||||
# ── Error code classification ───────────────────────────────────────────
|
||||
|
||||
def _classify_by_error_code(
|
||||
error_code: str, error_msg: str, result_fn,
|
||||
) -> Optional[ClassifiedError]:
|
||||
"""Classify by structured error codes from the response body."""
|
||||
code_lower = error_code.lower()
|
||||
|
||||
if code_lower in ("resource_exhausted", "throttled", "rate_limit_exceeded"):
|
||||
return result_fn(
|
||||
FailoverReason.rate_limit,
|
||||
retryable=True,
|
||||
should_rotate_credential=True,
|
||||
)
|
||||
|
||||
if code_lower in ("insufficient_quota", "billing_not_active", "payment_required"):
|
||||
return result_fn(
|
||||
FailoverReason.billing,
|
||||
retryable=False,
|
||||
should_rotate_credential=True,
|
||||
should_fallback=True,
|
||||
)
|
||||
|
||||
if code_lower in ("model_not_found", "model_not_available", "invalid_model"):
|
||||
return result_fn(
|
||||
FailoverReason.model_not_found,
|
||||
retryable=False,
|
||||
should_fallback=True,
|
||||
)
|
||||
|
||||
if code_lower in ("context_length_exceeded", "max_tokens_exceeded"):
|
||||
return result_fn(
|
||||
FailoverReason.context_overflow,
|
||||
retryable=True,
|
||||
should_compress=True,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# ── Message pattern classification ──────────────────────────────────────
|
||||
|
||||
def _classify_by_message(
|
||||
error_msg: str,
|
||||
error_type: str,
|
||||
*,
|
||||
approx_tokens: int,
|
||||
context_length: int,
|
||||
result_fn,
|
||||
) -> Optional[ClassifiedError]:
|
||||
"""Classify based on error message patterns when no status code is available."""
|
||||
|
||||
# Payload-too-large patterns (from message text when no status_code)
|
||||
if any(p in error_msg for p in _PAYLOAD_TOO_LARGE_PATTERNS):
|
||||
return result_fn(
|
||||
FailoverReason.payload_too_large,
|
||||
retryable=True,
|
||||
should_compress=True,
|
||||
)
|
||||
|
||||
# Billing patterns
|
||||
if any(p in error_msg for p in _BILLING_PATTERNS):
|
||||
return result_fn(
|
||||
FailoverReason.billing,
|
||||
retryable=False,
|
||||
should_rotate_credential=True,
|
||||
should_fallback=True,
|
||||
)
|
||||
|
||||
# Rate limit patterns
|
||||
if any(p in error_msg for p in _RATE_LIMIT_PATTERNS):
|
||||
return result_fn(
|
||||
FailoverReason.rate_limit,
|
||||
retryable=True,
|
||||
should_rotate_credential=True,
|
||||
should_fallback=True,
|
||||
)
|
||||
|
||||
# Context overflow patterns
|
||||
if any(p in error_msg for p in _CONTEXT_OVERFLOW_PATTERNS):
|
||||
return result_fn(
|
||||
FailoverReason.context_overflow,
|
||||
retryable=True,
|
||||
should_compress=True,
|
||||
)
|
||||
|
||||
# Auth patterns
|
||||
if any(p in error_msg for p in _AUTH_PATTERNS):
|
||||
return result_fn(
|
||||
FailoverReason.auth,
|
||||
retryable=True,
|
||||
should_rotate_credential=True,
|
||||
)
|
||||
|
||||
# Model not found patterns
|
||||
if any(p in error_msg for p in _MODEL_NOT_FOUND_PATTERNS):
|
||||
return result_fn(
|
||||
FailoverReason.model_not_found,
|
||||
retryable=False,
|
||||
should_fallback=True,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# ── Helpers ─────────────────────────────────────────────────────────────
|
||||
|
||||
def _extract_status_code(error: Exception) -> Optional[int]:
|
||||
"""Walk the error and its cause chain to find an HTTP status code."""
|
||||
current = error
|
||||
for _ in range(5): # Max depth to prevent infinite loops
|
||||
code = getattr(current, "status_code", None)
|
||||
if isinstance(code, int):
|
||||
return code
|
||||
# Some SDKs use .status instead of .status_code
|
||||
code = getattr(current, "status", None)
|
||||
if isinstance(code, int) and 100 <= code < 600:
|
||||
return code
|
||||
# Walk cause chain
|
||||
cause = getattr(current, "__cause__", None) or getattr(current, "__context__", None)
|
||||
if cause is None or cause is current:
|
||||
break
|
||||
current = cause
|
||||
return None
|
||||
|
||||
|
||||
def _extract_error_body(error: Exception) -> dict:
|
||||
"""Extract the structured error body from an SDK exception."""
|
||||
body = getattr(error, "body", None)
|
||||
if isinstance(body, dict):
|
||||
return body
|
||||
# Some errors have .response.json()
|
||||
response = getattr(error, "response", None)
|
||||
if response is not None:
|
||||
try:
|
||||
json_body = response.json()
|
||||
if isinstance(json_body, dict):
|
||||
return json_body
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
|
||||
|
||||
def _extract_error_code(body: dict) -> str:
|
||||
"""Extract an error code string from the response body."""
|
||||
if not body:
|
||||
return ""
|
||||
error_obj = body.get("error", {})
|
||||
if isinstance(error_obj, dict):
|
||||
code = error_obj.get("code") or error_obj.get("type") or ""
|
||||
if isinstance(code, str) and code.strip():
|
||||
return code.strip()
|
||||
# Top-level code
|
||||
code = body.get("code") or body.get("error_code") or ""
|
||||
if isinstance(code, (str, int)):
|
||||
return str(code).strip()
|
||||
return ""
|
||||
|
||||
|
||||
def _extract_message(error: Exception, body: dict) -> str:
|
||||
"""Extract the most informative error message."""
|
||||
# Try structured body first
|
||||
if body:
|
||||
error_obj = body.get("error", {})
|
||||
if isinstance(error_obj, dict):
|
||||
msg = error_obj.get("message", "")
|
||||
if isinstance(msg, str) and msg.strip():
|
||||
return msg.strip()[:500]
|
||||
msg = body.get("message", "")
|
||||
if isinstance(msg, str) and msg.strip():
|
||||
return msg.strip()[:500]
|
||||
# Fallback to str(error)
|
||||
return str(error)[:500]
|
||||
@@ -39,6 +39,15 @@ def _has_known_pricing(model_name: str, provider: str = None, base_url: str = No
|
||||
return has_known_pricing(model_name, provider=provider, base_url=base_url)
|
||||
|
||||
|
||||
def _get_pricing(model_name: str) -> Dict[str, float]:
|
||||
"""Look up pricing for a model. Uses fuzzy matching on model name.
|
||||
|
||||
Returns _DEFAULT_PRICING (zero cost) for unknown/custom models —
|
||||
we can't assume costs for self-hosted endpoints, local inference, etc.
|
||||
"""
|
||||
return get_pricing(model_name)
|
||||
|
||||
|
||||
def _estimate_cost(
|
||||
session_or_model: Dict[str, Any] | str,
|
||||
input_tokens: int = 0,
|
||||
|
||||
@@ -134,6 +134,11 @@ class MemoryManager:
|
||||
"""All registered providers in order."""
|
||||
return list(self._providers)
|
||||
|
||||
@property
|
||||
def provider_names(self) -> List[str]:
|
||||
"""Names of all registered providers."""
|
||||
return [p.name for p in self._providers]
|
||||
|
||||
def get_provider(self, name: str) -> Optional[MemoryProvider]:
|
||||
"""Get a provider by name, or None if not registered."""
|
||||
for p in self._providers:
|
||||
|
||||
@@ -197,7 +197,6 @@ _URL_TO_PROVIDER: Dict[str, str] = {
|
||||
"api.githubcopilot.com": "copilot",
|
||||
"models.github.ai": "copilot",
|
||||
"api.fireworks.ai": "fireworks",
|
||||
"opencode.ai": "opencode-go",
|
||||
}
|
||||
|
||||
|
||||
@@ -603,49 +602,6 @@ def parse_context_limit_from_error(error_msg: str) -> Optional[int]:
|
||||
return None
|
||||
|
||||
|
||||
def parse_available_output_tokens_from_error(error_msg: str) -> Optional[int]:
|
||||
"""Detect an "output cap too large" error and return how many output tokens are available.
|
||||
|
||||
Background — two distinct context errors exist:
|
||||
1. "Prompt too long" — the INPUT itself exceeds the context window.
|
||||
Fix: compress history and/or halve context_length.
|
||||
2. "max_tokens too large" — input is fine, but input + requested_output > window.
|
||||
Fix: reduce max_tokens (the output cap) for this call.
|
||||
Do NOT touch context_length — the window hasn't shrunk.
|
||||
|
||||
Anthropic's API returns errors like:
|
||||
"max_tokens: 32768 > context_window: 200000 - input_tokens: 190000 = available_tokens: 10000"
|
||||
|
||||
Returns the number of output tokens that would fit (e.g. 10000 above), or None if
|
||||
the error does not look like a max_tokens-too-large error.
|
||||
"""
|
||||
error_lower = error_msg.lower()
|
||||
|
||||
# Must look like an output-cap error, not a prompt-length error.
|
||||
is_output_cap_error = (
|
||||
"max_tokens" in error_lower
|
||||
and ("available_tokens" in error_lower or "available tokens" in error_lower)
|
||||
)
|
||||
if not is_output_cap_error:
|
||||
return None
|
||||
|
||||
# Extract the available_tokens figure.
|
||||
# Anthropic format: "… = available_tokens: 10000"
|
||||
patterns = [
|
||||
r'available_tokens[:\s]+(\d+)',
|
||||
r'available\s+tokens[:\s]+(\d+)',
|
||||
# fallback: last number after "=" in expressions like "200000 - 190000 = 10000"
|
||||
r'=\s*(\d+)\s*$',
|
||||
]
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, error_lower)
|
||||
if match:
|
||||
tokens = int(match.group(1))
|
||||
if tokens >= 1:
|
||||
return tokens
|
||||
return None
|
||||
|
||||
|
||||
def _model_id_matches(candidate_id: str, lookup_model: str) -> bool:
|
||||
"""Return True if *candidate_id* (from server) matches *lookup_model* (configured).
|
||||
|
||||
|
||||
@@ -135,6 +135,9 @@ class ProviderInfo:
|
||||
doc: str = "" # documentation URL
|
||||
model_count: int = 0
|
||||
|
||||
def has_api_url(self) -> bool:
|
||||
return bool(self.api)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider ID mapping: Hermes ↔ models.dev
|
||||
@@ -631,6 +634,43 @@ def get_provider_info(provider_id: str) -> Optional[ProviderInfo]:
|
||||
return _parse_provider_info(mdev_id, raw)
|
||||
|
||||
|
||||
def list_all_providers() -> Dict[str, ProviderInfo]:
|
||||
"""Return all providers from models.dev as {provider_id: ProviderInfo}.
|
||||
|
||||
Returns the full catalog — 109+ providers. For providers that have
|
||||
a Hermes alias, both the models.dev ID and the Hermes ID are included.
|
||||
"""
|
||||
data = fetch_models_dev()
|
||||
result: Dict[str, ProviderInfo] = {}
|
||||
|
||||
for pid, pdata in data.items():
|
||||
if isinstance(pdata, dict):
|
||||
info = _parse_provider_info(pid, pdata)
|
||||
result[pid] = info
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_providers_for_env_var(env_var: str) -> List[str]:
|
||||
"""Reverse lookup: find all providers that use a given env var.
|
||||
|
||||
Useful for auto-detection: "user has ANTHROPIC_API_KEY set, which
|
||||
providers does that enable?"
|
||||
|
||||
Returns list of models.dev provider IDs.
|
||||
"""
|
||||
data = fetch_models_dev()
|
||||
matches: List[str] = []
|
||||
|
||||
for pid, pdata in data.items():
|
||||
if isinstance(pdata, dict):
|
||||
env = pdata.get("env", [])
|
||||
if isinstance(env, list) and env_var in env:
|
||||
matches.append(pid)
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model-level queries (rich ModelInfo)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -668,3 +708,74 @@ def get_model_info(
|
||||
return None
|
||||
|
||||
|
||||
def get_model_info_any_provider(model_id: str) -> Optional[ModelInfo]:
|
||||
"""Search all providers for a model by ID.
|
||||
|
||||
Useful when you have a full slug like "anthropic/claude-sonnet-4.6" or
|
||||
a bare name and want to find it anywhere. Checks Hermes-mapped providers
|
||||
first, then falls back to all models.dev providers.
|
||||
"""
|
||||
data = fetch_models_dev()
|
||||
|
||||
# Try Hermes-mapped providers first (more likely what the user wants)
|
||||
for hermes_id, mdev_id in PROVIDER_TO_MODELS_DEV.items():
|
||||
pdata = data.get(mdev_id)
|
||||
if not isinstance(pdata, dict):
|
||||
continue
|
||||
models = pdata.get("models", {})
|
||||
if not isinstance(models, dict):
|
||||
continue
|
||||
|
||||
raw = models.get(model_id)
|
||||
if isinstance(raw, dict):
|
||||
return _parse_model_info(model_id, raw, mdev_id)
|
||||
|
||||
# Case-insensitive
|
||||
model_lower = model_id.lower()
|
||||
for mid, mdata in models.items():
|
||||
if mid.lower() == model_lower and isinstance(mdata, dict):
|
||||
return _parse_model_info(mid, mdata, mdev_id)
|
||||
|
||||
# Fall back to ALL providers
|
||||
for pid, pdata in data.items():
|
||||
if pid in _get_reverse_mapping():
|
||||
continue # already checked
|
||||
if not isinstance(pdata, dict):
|
||||
continue
|
||||
models = pdata.get("models", {})
|
||||
if not isinstance(models, dict):
|
||||
continue
|
||||
|
||||
raw = models.get(model_id)
|
||||
if isinstance(raw, dict):
|
||||
return _parse_model_info(model_id, raw, pid)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def list_provider_model_infos(provider_id: str) -> List[ModelInfo]:
|
||||
"""Return all models for a provider as ModelInfo objects.
|
||||
|
||||
Filters out deprecated models by default.
|
||||
"""
|
||||
mdev_id = PROVIDER_TO_MODELS_DEV.get(provider_id, provider_id)
|
||||
|
||||
data = fetch_models_dev()
|
||||
pdata = data.get(mdev_id)
|
||||
if not isinstance(pdata, dict):
|
||||
return []
|
||||
|
||||
models = pdata.get("models", {})
|
||||
if not isinstance(models, dict):
|
||||
return []
|
||||
|
||||
result: List[ModelInfo] = []
|
||||
for mid, mdata in models.items():
|
||||
if not isinstance(mdata, dict):
|
||||
continue
|
||||
status = mdata.get("status", "")
|
||||
if status == "deprecated":
|
||||
continue
|
||||
result.append(_parse_model_info(mid, mdata, mdev_id))
|
||||
|
||||
return result
|
||||
|
||||
+11
-7
@@ -349,13 +349,6 @@ PLATFORM_HINTS = {
|
||||
"only — no markdown, no formatting. SMS messages are limited to ~1600 "
|
||||
"characters, so be brief and direct."
|
||||
),
|
||||
"bluebubbles": (
|
||||
"You are chatting via iMessage (BlueBubbles). iMessage does not render "
|
||||
"markdown formatting — use plain text. Keep responses concise as they "
|
||||
"appear as text messages. You can send media files natively: include "
|
||||
"MEDIA:/absolute/path/to/file in your response. Images (.jpg, .png, "
|
||||
".heic) appear as photos and other files arrive as attachments."
|
||||
),
|
||||
}
|
||||
|
||||
CONTEXT_FILE_MAX_CHARS = 20_000
|
||||
@@ -491,6 +484,17 @@ def _parse_skill_file(skill_file: Path) -> tuple[bool, dict, str]:
|
||||
return True, {}, ""
|
||||
|
||||
|
||||
def _read_skill_conditions(skill_file: Path) -> dict:
|
||||
"""Extract conditional activation fields from SKILL.md frontmatter."""
|
||||
try:
|
||||
raw = skill_file.read_text(encoding="utf-8")[:2000]
|
||||
frontmatter, _ = parse_frontmatter(raw)
|
||||
return extract_skill_conditions(frontmatter)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to read skill conditions from %s: %s", skill_file, e)
|
||||
return {}
|
||||
|
||||
|
||||
def _skill_should_show(
|
||||
conditions: dict,
|
||||
available_tools: "set[str] | None",
|
||||
|
||||
@@ -1,242 +0,0 @@
|
||||
"""Rate limit tracking for inference API responses.
|
||||
|
||||
Captures x-ratelimit-* headers from provider responses and provides
|
||||
formatted display for the /usage slash command. Currently supports
|
||||
the Nous Portal header format (also used by OpenRouter and OpenAI-compatible
|
||||
APIs that follow the same convention).
|
||||
|
||||
Header schema (12 headers total):
|
||||
x-ratelimit-limit-requests RPM cap
|
||||
x-ratelimit-limit-requests-1h RPH cap
|
||||
x-ratelimit-limit-tokens TPM cap
|
||||
x-ratelimit-limit-tokens-1h TPH cap
|
||||
x-ratelimit-remaining-requests requests left in minute window
|
||||
x-ratelimit-remaining-requests-1h requests left in hour window
|
||||
x-ratelimit-remaining-tokens tokens left in minute window
|
||||
x-ratelimit-remaining-tokens-1h tokens left in hour window
|
||||
x-ratelimit-reset-requests seconds until minute request window resets
|
||||
x-ratelimit-reset-requests-1h seconds until hour request window resets
|
||||
x-ratelimit-reset-tokens seconds until minute token window resets
|
||||
x-ratelimit-reset-tokens-1h seconds until hour token window resets
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Mapping, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitBucket:
|
||||
"""One rate-limit window (e.g. requests per minute)."""
|
||||
|
||||
limit: int = 0
|
||||
remaining: int = 0
|
||||
reset_seconds: float = 0.0
|
||||
captured_at: float = 0.0 # time.time() when this was captured
|
||||
|
||||
@property
|
||||
def used(self) -> int:
|
||||
return max(0, self.limit - self.remaining)
|
||||
|
||||
@property
|
||||
def usage_pct(self) -> float:
|
||||
if self.limit <= 0:
|
||||
return 0.0
|
||||
return (self.used / self.limit) * 100.0
|
||||
|
||||
@property
|
||||
def remaining_seconds_now(self) -> float:
|
||||
"""Estimated seconds remaining until reset, adjusted for elapsed time."""
|
||||
elapsed = time.time() - self.captured_at
|
||||
return max(0.0, self.reset_seconds - elapsed)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitState:
|
||||
"""Full rate-limit state parsed from response headers."""
|
||||
|
||||
requests_min: RateLimitBucket = field(default_factory=RateLimitBucket)
|
||||
requests_hour: RateLimitBucket = field(default_factory=RateLimitBucket)
|
||||
tokens_min: RateLimitBucket = field(default_factory=RateLimitBucket)
|
||||
tokens_hour: RateLimitBucket = field(default_factory=RateLimitBucket)
|
||||
captured_at: float = 0.0 # when the headers were captured
|
||||
provider: str = ""
|
||||
|
||||
@property
|
||||
def has_data(self) -> bool:
|
||||
return self.captured_at > 0
|
||||
|
||||
@property
|
||||
def age_seconds(self) -> float:
|
||||
if not self.has_data:
|
||||
return float("inf")
|
||||
return time.time() - self.captured_at
|
||||
|
||||
|
||||
def _safe_int(value: Any, default: int = 0) -> int:
|
||||
try:
|
||||
return int(float(value))
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
def _safe_float(value: Any, default: float = 0.0) -> float:
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
def parse_rate_limit_headers(
|
||||
headers: Mapping[str, str],
|
||||
provider: str = "",
|
||||
) -> Optional[RateLimitState]:
|
||||
"""Parse x-ratelimit-* headers into a RateLimitState.
|
||||
|
||||
Returns None if no rate limit headers are present.
|
||||
"""
|
||||
# Quick check: at least one rate limit header must exist
|
||||
has_any = any(k.lower().startswith("x-ratelimit-") for k in headers)
|
||||
if not has_any:
|
||||
return None
|
||||
|
||||
now = time.time()
|
||||
|
||||
def _bucket(resource: str, suffix: str = "") -> RateLimitBucket:
|
||||
# e.g. resource="requests", suffix="" -> per-minute
|
||||
# resource="tokens", suffix="-1h" -> per-hour
|
||||
tag = f"{resource}{suffix}"
|
||||
return RateLimitBucket(
|
||||
limit=_safe_int(headers.get(f"x-ratelimit-limit-{tag}")),
|
||||
remaining=_safe_int(headers.get(f"x-ratelimit-remaining-{tag}")),
|
||||
reset_seconds=_safe_float(headers.get(f"x-ratelimit-reset-{tag}")),
|
||||
captured_at=now,
|
||||
)
|
||||
|
||||
return RateLimitState(
|
||||
requests_min=_bucket("requests"),
|
||||
requests_hour=_bucket("requests", "-1h"),
|
||||
tokens_min=_bucket("tokens"),
|
||||
tokens_hour=_bucket("tokens", "-1h"),
|
||||
captured_at=now,
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
|
||||
# ── Formatting ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _fmt_count(n: int) -> str:
|
||||
"""Human-friendly number: 7999856 -> '8.0M', 33599 -> '33.6K', 799 -> '799'."""
|
||||
if n >= 1_000_000:
|
||||
return f"{n / 1_000_000:.1f}M"
|
||||
if n >= 10_000:
|
||||
return f"{n / 1_000:.1f}K"
|
||||
if n >= 1_000:
|
||||
return f"{n / 1_000:.1f}K"
|
||||
return str(n)
|
||||
|
||||
|
||||
def _fmt_seconds(seconds: float) -> str:
|
||||
"""Seconds -> human-friendly duration: '58s', '2m 14s', '58m 57s', '1h 2m'."""
|
||||
s = max(0, int(seconds))
|
||||
if s < 60:
|
||||
return f"{s}s"
|
||||
if s < 3600:
|
||||
m, sec = divmod(s, 60)
|
||||
return f"{m}m {sec}s" if sec else f"{m}m"
|
||||
h, remainder = divmod(s, 3600)
|
||||
m = remainder // 60
|
||||
return f"{h}h {m}m" if m else f"{h}h"
|
||||
|
||||
|
||||
def _bar(pct: float, width: int = 20) -> str:
|
||||
"""ASCII progress bar: [████████░░░░░░░░░░░░] 40%."""
|
||||
filled = int(pct / 100.0 * width)
|
||||
filled = max(0, min(width, filled))
|
||||
empty = width - filled
|
||||
return f"[{'█' * filled}{'░' * empty}]"
|
||||
|
||||
|
||||
def _bucket_line(label: str, bucket: RateLimitBucket, label_width: int = 14) -> str:
|
||||
"""Format one bucket as a single line."""
|
||||
if bucket.limit <= 0:
|
||||
return f" {label:<{label_width}} (no data)"
|
||||
|
||||
pct = bucket.usage_pct
|
||||
used = _fmt_count(bucket.used)
|
||||
limit = _fmt_count(bucket.limit)
|
||||
remaining = _fmt_count(bucket.remaining)
|
||||
reset = _fmt_seconds(bucket.remaining_seconds_now)
|
||||
|
||||
bar = _bar(pct)
|
||||
return f" {label:<{label_width}} {bar} {pct:5.1f}% {used}/{limit} used ({remaining} left, resets in {reset})"
|
||||
|
||||
|
||||
def format_rate_limit_display(state: RateLimitState) -> str:
|
||||
"""Format rate limit state for terminal/chat display."""
|
||||
if not state.has_data:
|
||||
return "No rate limit data yet — make an API request first."
|
||||
|
||||
age = state.age_seconds
|
||||
if age < 5:
|
||||
freshness = "just now"
|
||||
elif age < 60:
|
||||
freshness = f"{int(age)}s ago"
|
||||
else:
|
||||
freshness = f"{_fmt_seconds(age)} ago"
|
||||
|
||||
provider_label = state.provider.title() if state.provider else "Provider"
|
||||
|
||||
lines = [
|
||||
f"{provider_label} Rate Limits (captured {freshness}):",
|
||||
"",
|
||||
_bucket_line("Requests/min", state.requests_min),
|
||||
_bucket_line("Requests/hr", state.requests_hour),
|
||||
"",
|
||||
_bucket_line("Tokens/min", state.tokens_min),
|
||||
_bucket_line("Tokens/hr", state.tokens_hour),
|
||||
]
|
||||
|
||||
# Add warnings if any bucket is getting hot
|
||||
warnings = []
|
||||
for label, bucket in [
|
||||
("requests/min", state.requests_min),
|
||||
("requests/hr", state.requests_hour),
|
||||
("tokens/min", state.tokens_min),
|
||||
("tokens/hr", state.tokens_hour),
|
||||
]:
|
||||
if bucket.limit > 0 and bucket.usage_pct >= 80:
|
||||
reset = _fmt_seconds(bucket.remaining_seconds_now)
|
||||
warnings.append(f" ⚠ {label} at {bucket.usage_pct:.0f}% — resets in {reset}")
|
||||
|
||||
if warnings:
|
||||
lines.append("")
|
||||
lines.extend(warnings)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def format_rate_limit_compact(state: RateLimitState) -> str:
|
||||
"""One-line compact summary for status bars / gateway messages."""
|
||||
if not state.has_data:
|
||||
return "No rate limit data."
|
||||
|
||||
rm = state.requests_min
|
||||
tm = state.tokens_min
|
||||
rh = state.requests_hour
|
||||
th = state.tokens_hour
|
||||
|
||||
parts = []
|
||||
if rm.limit > 0:
|
||||
parts.append(f"RPM: {rm.remaining}/{rm.limit}")
|
||||
if rh.limit > 0:
|
||||
parts.append(f"RPH: {_fmt_count(rh.remaining)}/{_fmt_count(rh.limit)} (resets {_fmt_seconds(rh.remaining_seconds_now)})")
|
||||
if tm.limit > 0:
|
||||
parts.append(f"TPM: {_fmt_count(tm.remaining)}/{_fmt_count(tm.limit)}")
|
||||
if th.limit > 0:
|
||||
parts.append(f"TPH: {_fmt_count(th.remaining)}/{_fmt_count(th.limit)} (resets {_fmt_seconds(th.remaining_seconds_now)})")
|
||||
|
||||
return " | ".join(parts)
|
||||
@@ -159,10 +159,7 @@ class SubdirectoryHintTracker:
|
||||
|
||||
def _is_valid_subdir(self, path: Path) -> bool:
|
||||
"""Check if path is a valid directory to scan for hints."""
|
||||
try:
|
||||
if not path.is_dir():
|
||||
return False
|
||||
except OSError:
|
||||
if not path.is_dir():
|
||||
return False
|
||||
if path in self._loaded_dirs:
|
||||
return False
|
||||
@@ -175,10 +172,7 @@ class SubdirectoryHintTracker:
|
||||
found_hints = []
|
||||
for filename in _HINT_FILENAMES:
|
||||
hint_path = directory / filename
|
||||
try:
|
||||
if not hint_path.is_file():
|
||||
continue
|
||||
except OSError:
|
||||
if not hint_path.is_file():
|
||||
continue
|
||||
try:
|
||||
content = hint_path.read_text(encoding="utf-8").strip()
|
||||
|
||||
@@ -595,6 +595,30 @@ def get_pricing(
|
||||
}
|
||||
|
||||
|
||||
def estimate_cost_usd(
|
||||
model: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
*,
|
||||
provider: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
) -> float:
|
||||
"""Backward-compatible helper for legacy callers.
|
||||
|
||||
This uses non-cached input/output only. New code should call
|
||||
`estimate_usage_cost()` with canonical usage buckets.
|
||||
"""
|
||||
result = estimate_usage_cost(
|
||||
model,
|
||||
CanonicalUsage(input_tokens=input_tokens, output_tokens=output_tokens),
|
||||
provider=provider,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
)
|
||||
return float(result.amount_usd or _ZERO)
|
||||
|
||||
|
||||
def format_duration_compact(seconds: float) -> str:
|
||||
if seconds < 60:
|
||||
return f"{seconds:.0f}s"
|
||||
|
||||
+2
-2
@@ -1158,7 +1158,7 @@ def main(
|
||||
providers_order (str): Comma-separated list of OpenRouter providers to try in order (e.g. "anthropic,openai,google")
|
||||
provider_sort (str): Sort providers by "price", "throughput", or "latency" (OpenRouter only)
|
||||
max_tokens (int): Maximum tokens for model responses (optional, uses model default if not set)
|
||||
reasoning_effort (str): OpenRouter reasoning effort level: "none", "minimal", "low", "medium", "high", "xhigh" (default: "medium")
|
||||
reasoning_effort (str): OpenRouter reasoning effort level: "xhigh", "high", "medium", "low", "minimal", "none" (default: "medium")
|
||||
reasoning_disabled (bool): Completely disable reasoning/thinking tokens (default: False)
|
||||
prefill_messages_file (str): Path to JSON file containing prefill messages (list of {role, content} dicts)
|
||||
max_samples (int): Only process the first N samples from the dataset (optional, processes all if not set)
|
||||
@@ -1227,7 +1227,7 @@ def main(
|
||||
print("🧠 Reasoning: DISABLED (effort=none)")
|
||||
elif reasoning_effort:
|
||||
# Use specified effort level
|
||||
valid_efforts = ["none", "minimal", "low", "medium", "high", "xhigh"]
|
||||
valid_efforts = ["xhigh", "high", "medium", "low", "minimal", "none"]
|
||||
if reasoning_effort not in valid_efforts:
|
||||
print(f"❌ Error: --reasoning_effort must be one of: {', '.join(valid_efforts)}")
|
||||
return
|
||||
|
||||
+2
-37
@@ -48,25 +48,6 @@ model:
|
||||
# api_key: "your-key-here" # Uncomment to set here instead of .env
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
|
||||
# ── Token limits — two settings, easy to confuse ──────────────────────────
|
||||
#
|
||||
# context_length: TOTAL context window (input + output tokens combined).
|
||||
# Controls when Hermes compresses history and validates requests.
|
||||
# Leave unset — Hermes auto-detects the correct value from the provider.
|
||||
# Set manually only when auto-detection is wrong (e.g. a local server with
|
||||
# a custom num_ctx, or a proxy that doesn't expose /v1/models).
|
||||
#
|
||||
# context_length: 131072
|
||||
#
|
||||
# max_tokens: OUTPUT cap — maximum tokens the model may generate per response.
|
||||
# Unrelated to how long your conversation history can be.
|
||||
# The OpenAI-standard name "max_tokens" is a misnomer; Anthropic's native
|
||||
# API has since renamed it "max_output_tokens" for clarity.
|
||||
# Leave unset to use the model's native output ceiling (recommended).
|
||||
# Set only if you want to deliberately limit individual response length.
|
||||
#
|
||||
# max_tokens: 8192
|
||||
|
||||
# =============================================================================
|
||||
# OpenRouter Provider Routing (only applies when using OpenRouter)
|
||||
# =============================================================================
|
||||
@@ -136,8 +117,7 @@ terminal:
|
||||
timeout: 180
|
||||
docker_mount_cwd_to_workspace: false # SECURITY: off by default. Opt in to mount the launch cwd into Docker /workspace.
|
||||
lifetime_seconds: 300
|
||||
# sudo_password: "hunter2" # Optional: pipe a sudo password via sudo -S. SECURITY WARNING: plaintext.
|
||||
# sudo_password: "" # Explicit empty password: try empty and never open the interactive sudo prompt.
|
||||
# sudo_password: "" # Enable sudo commands (pipes via sudo -S) - SECURITY WARNING: plaintext!
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 2: SSH remote execution
|
||||
@@ -228,18 +208,13 @@ terminal:
|
||||
#
|
||||
# SECURITY WARNING: Password stored in plaintext!
|
||||
#
|
||||
# INTERACTIVE PROMPT: If sudo_password is unset and the CLI is running,
|
||||
# INTERACTIVE PROMPT: If no sudo_password is set and the CLI is running,
|
||||
# you'll be prompted to enter your password when sudo is needed:
|
||||
# - 45-second timeout (auto-skips if no input)
|
||||
# - Press Enter to skip (command fails gracefully)
|
||||
# - Password is hidden while typing
|
||||
# - Password is cached for the session
|
||||
#
|
||||
# EMPTY PASSWORDS: Setting sudo_password to an explicit empty string is different
|
||||
# from leaving it unset. Hermes will try an empty password via `sudo -S` and
|
||||
# will not open the interactive prompt. This is useful for passwordless sudo,
|
||||
# Touch ID sudo setups, and environments where prompting is just noise.
|
||||
#
|
||||
# ALTERNATIVES:
|
||||
# - SSH backend: Configure passwordless sudo on the remote server
|
||||
# - Containers: Run as root inside the container (no sudo needed)
|
||||
@@ -470,16 +445,6 @@ agent:
|
||||
# Higher = more room for complex tasks, but costs more tokens
|
||||
# Recommended: 20-30 for focused tasks, 50-100 for open exploration
|
||||
max_turns: 60
|
||||
|
||||
# Inactivity timeout for gateway agent runs (seconds, 0 = unlimited).
|
||||
# The agent can run indefinitely when actively calling tools or receiving
|
||||
# API responses. Only fires after the agent has been idle for this duration.
|
||||
# gateway_timeout: 1800
|
||||
|
||||
# Staged warning: send a warning before escalating to full timeout.
|
||||
# Fires once per run when inactivity reaches this threshold (seconds).
|
||||
# Set to 0 to disable the warning.
|
||||
# gateway_timeout_warning: 900
|
||||
|
||||
# Enable verbose logging
|
||||
verbose: false
|
||||
|
||||
@@ -1118,6 +1118,14 @@ HERMES_CADUCEUS = """[#CD7F32]⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⡀⠀⣀⣀
|
||||
[#B8860B]⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠳⠈⣡⠞⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[/]
|
||||
[#B8860B]⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[/]"""
|
||||
|
||||
# Compact banner for smaller terminals (fallback)
|
||||
# Note: built dynamically by _build_compact_banner() to fit terminal width
|
||||
COMPACT_BANNER = """
|
||||
[bold #FFD700]╔══════════════════════════════════════════════════════════════╗[/]
|
||||
[bold #FFD700]║[/] [#FFBF00]⚕ NOUS HERMES[/] [dim #B8860B]- AI Agent Framework[/] [bold #FFD700]║[/]
|
||||
[bold #FFD700]║[/] [#CD7F32]Messenger of the Digital Gods[/] [dim #B8860B]Nous Research[/] [bold #FFD700]║[/]
|
||||
[bold #FFD700]╚══════════════════════════════════════════════════════════════╝[/]
|
||||
"""
|
||||
|
||||
|
||||
def _build_compact_banner() -> str:
|
||||
@@ -1363,6 +1371,7 @@ class HermesCLI:
|
||||
self._stream_buf = "" # Partial line buffer for line-buffered rendering
|
||||
self._stream_started = False # True once first delta arrives
|
||||
self._stream_box_opened = False # True once the response box header is printed
|
||||
self._reasoning_stream_started = False # True once live reasoning starts streaming
|
||||
self._reasoning_preview_buf = "" # Coalesce tiny reasoning chunks for [thinking] output
|
||||
self._pending_edit_snapshots = {}
|
||||
|
||||
@@ -1420,6 +1429,8 @@ class HermesCLI:
|
||||
self.api_key = api_key or os.getenv("OPENROUTER_API_KEY") or os.getenv("OPENAI_API_KEY")
|
||||
else:
|
||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY") or os.getenv("OPENROUTER_API_KEY")
|
||||
self._nous_key_expires_at: Optional[str] = None
|
||||
self._nous_key_source: Optional[str] = None
|
||||
# Max turns priority: CLI arg > config file > env var > default
|
||||
if max_turns is not None: # CLI arg was explicitly set
|
||||
self.max_turns = max_turns
|
||||
@@ -1535,7 +1546,6 @@ class HermesCLI:
|
||||
self._clarify_deadline = 0
|
||||
self._sudo_state = None
|
||||
self._sudo_deadline = 0
|
||||
self._modal_input_snapshot = None
|
||||
self._approval_state = None
|
||||
self._approval_deadline = 0
|
||||
self._approval_lock = threading.Lock()
|
||||
@@ -1592,12 +1602,7 @@ class HermesCLI:
|
||||
return f"[{('█' * filled) + ('░' * max(0, width - filled))}]"
|
||||
|
||||
def _get_status_bar_snapshot(self) -> Dict[str, Any]:
|
||||
# Prefer the agent's model name — it updates on fallback.
|
||||
# self.model reflects the originally configured model and never
|
||||
# changes mid-session, so the TUI would show a stale name after
|
||||
# _try_activate_fallback() switches provider/model.
|
||||
agent = getattr(self, "agent", None)
|
||||
model_name = (getattr(agent, "model", None) or self.model or "unknown")
|
||||
model_name = self.model or "unknown"
|
||||
model_short = model_name.split("/")[-1] if "/" in model_name else model_name
|
||||
if model_short.endswith(".gguf"):
|
||||
model_short = model_short[:-5]
|
||||
@@ -1623,6 +1628,7 @@ class HermesCLI:
|
||||
"compressions": 0,
|
||||
}
|
||||
|
||||
agent = getattr(self, "agent", None)
|
||||
if not agent:
|
||||
return snapshot
|
||||
|
||||
@@ -1995,6 +2001,7 @@ class HermesCLI:
|
||||
"""
|
||||
if not text:
|
||||
return
|
||||
self._reasoning_stream_started = True
|
||||
self._reasoning_shown_this_turn = True
|
||||
if getattr(self, "_stream_box_opened", False):
|
||||
return
|
||||
@@ -2204,6 +2211,7 @@ class HermesCLI:
|
||||
self._stream_buf = ""
|
||||
self._stream_started = False
|
||||
self._stream_box_opened = False
|
||||
self._reasoning_stream_started = False
|
||||
self._stream_text_ansi = ""
|
||||
self._stream_prefilt = ""
|
||||
self._in_reasoning_block = False
|
||||
@@ -3995,7 +4003,59 @@ class HermesCLI:
|
||||
|
||||
print(" To change model or provider, use: hermes model")
|
||||
|
||||
|
||||
def _handle_prompt_command(self, cmd: str):
|
||||
"""Handle the /prompt command to view or set system prompt."""
|
||||
parts = cmd.split(maxsplit=1)
|
||||
|
||||
if len(parts) > 1:
|
||||
# Set new prompt
|
||||
new_prompt = parts[1].strip()
|
||||
|
||||
if new_prompt.lower() == "clear":
|
||||
self.system_prompt = ""
|
||||
self.agent = None # Force re-init
|
||||
if save_config_value("agent.system_prompt", ""):
|
||||
print("(^_^)b System prompt cleared (saved to config)")
|
||||
else:
|
||||
print("(^_^) System prompt cleared (session only)")
|
||||
else:
|
||||
self.system_prompt = new_prompt
|
||||
self.agent = None # Force re-init
|
||||
if save_config_value("agent.system_prompt", new_prompt):
|
||||
print("(^_^)b System prompt set (saved to config)")
|
||||
else:
|
||||
print("(^_^) System prompt set (session only)")
|
||||
print(f" \"{new_prompt[:60]}{'...' if len(new_prompt) > 60 else ''}\"")
|
||||
else:
|
||||
# Show current prompt
|
||||
print()
|
||||
print("+" + "-" * 50 + "+")
|
||||
print("|" + " " * 15 + "(^_^) System Prompt" + " " * 15 + "|")
|
||||
print("+" + "-" * 50 + "+")
|
||||
print()
|
||||
if self.system_prompt:
|
||||
# Word wrap the prompt for display
|
||||
words = self.system_prompt.split()
|
||||
lines = []
|
||||
current_line = ""
|
||||
for word in words:
|
||||
if len(current_line) + len(word) + 1 <= 50:
|
||||
current_line += (" " if current_line else "") + word
|
||||
else:
|
||||
lines.append(current_line)
|
||||
current_line = word
|
||||
if current_line:
|
||||
lines.append(current_line)
|
||||
for line in lines:
|
||||
print(f" {line}")
|
||||
else:
|
||||
print(" (no custom prompt set - using default)")
|
||||
print()
|
||||
print(" Usage:")
|
||||
print(" /prompt <text> - Set a custom system prompt")
|
||||
print(" /prompt clear - Remove custom prompt")
|
||||
print(" /personality - Use a predefined personality")
|
||||
print()
|
||||
|
||||
|
||||
@staticmethod
|
||||
@@ -4495,7 +4555,9 @@ class HermesCLI:
|
||||
self._handle_model_switch(cmd_original)
|
||||
elif canonical == "provider":
|
||||
self._show_model_and_providers()
|
||||
|
||||
elif canonical == "prompt":
|
||||
# Use original case so prompt text isn't lowercased
|
||||
self._handle_prompt_command(cmd_original)
|
||||
elif canonical == "personality":
|
||||
# Use original case (handler lowercases the personality name itself)
|
||||
self._handle_personality_command(cmd_original)
|
||||
@@ -4975,9 +5037,6 @@ class HermesCLI:
|
||||
def _try_launch_chrome_debug(port: int, system: str) -> bool:
|
||||
"""Try to launch Chrome/Chromium with remote debugging enabled.
|
||||
|
||||
Uses a dedicated user-data-dir so the debug instance doesn't conflict
|
||||
with an already-running Chrome using the default profile.
|
||||
|
||||
Returns True if a launch command was executed (doesn't guarantee success).
|
||||
"""
|
||||
import subprocess as _sp
|
||||
@@ -4987,20 +5046,10 @@ class HermesCLI:
|
||||
if not candidates:
|
||||
return False
|
||||
|
||||
# Dedicated profile dir so debug Chrome won't collide with normal Chrome
|
||||
data_dir = str(_hermes_home / "chrome-debug")
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
|
||||
chrome = candidates[0]
|
||||
try:
|
||||
_sp.Popen(
|
||||
[
|
||||
chrome,
|
||||
f"--remote-debugging-port={port}",
|
||||
f"--user-data-dir={data_dir}",
|
||||
"--no-first-run",
|
||||
"--no-default-browser-check",
|
||||
],
|
||||
[chrome, f"--remote-debugging-port={port}"],
|
||||
stdout=_sp.DEVNULL,
|
||||
stderr=_sp.DEVNULL,
|
||||
start_new_session=True, # detach from terminal
|
||||
@@ -5075,33 +5124,18 @@ class HermesCLI:
|
||||
print(f" ✓ Chrome launched and listening on port {_port}")
|
||||
else:
|
||||
print(f" ⚠ Chrome launched but port {_port} isn't responding yet")
|
||||
print(" Try again in a few seconds — the debug instance may still be starting")
|
||||
print(" You may need to close existing Chrome windows first and retry")
|
||||
else:
|
||||
print(" ⚠ Could not auto-launch Chrome")
|
||||
# Show manual instructions as fallback
|
||||
_data_dir = str(_hermes_home / "chrome-debug")
|
||||
sys_name = _plat.system()
|
||||
if sys_name == "Darwin":
|
||||
chrome_cmd = (
|
||||
'open -a "Google Chrome" --args'
|
||||
f" --remote-debugging-port=9222"
|
||||
f' --user-data-dir="{_data_dir}"'
|
||||
" --no-first-run --no-default-browser-check"
|
||||
)
|
||||
chrome_cmd = 'open -a "Google Chrome" --args --remote-debugging-port=9222'
|
||||
elif sys_name == "Windows":
|
||||
chrome_cmd = (
|
||||
f'chrome.exe --remote-debugging-port=9222'
|
||||
f' --user-data-dir="{_data_dir}"'
|
||||
f" --no-first-run --no-default-browser-check"
|
||||
)
|
||||
chrome_cmd = 'chrome.exe --remote-debugging-port=9222'
|
||||
else:
|
||||
chrome_cmd = (
|
||||
f"google-chrome --remote-debugging-port=9222"
|
||||
f' --user-data-dir="{_data_dir}"'
|
||||
f" --no-first-run --no-default-browser-check"
|
||||
)
|
||||
print(f" Launch Chrome manually:")
|
||||
print(f" {chrome_cmd}")
|
||||
chrome_cmd = "google-chrome --remote-debugging-port=9222"
|
||||
print(f" Launch Chrome manually: {chrome_cmd}")
|
||||
else:
|
||||
print(f" ⚠ Port {_port} is not reachable at {cdp_url}")
|
||||
|
||||
@@ -5274,7 +5308,7 @@ class HermesCLI:
|
||||
|
||||
Usage:
|
||||
/reasoning Show current effort level and display state
|
||||
/reasoning <level> Set reasoning effort (none, minimal, low, medium, high, xhigh)
|
||||
/reasoning <level> Set reasoning effort (none, low, medium, high, xhigh)
|
||||
/reasoning show|on Show model thinking/reasoning in output
|
||||
/reasoning hide|off Hide model thinking/reasoning from output
|
||||
"""
|
||||
@@ -5292,7 +5326,7 @@ class HermesCLI:
|
||||
display_state = "on ✓" if self.show_reasoning else "off"
|
||||
_cprint(f" {_GOLD}Reasoning effort: {level}{_RST}")
|
||||
_cprint(f" {_GOLD}Reasoning display: {display_state}{_RST}")
|
||||
_cprint(f" {_DIM}Usage: /reasoning <none|minimal|low|medium|high|xhigh|show|hide>{_RST}")
|
||||
_cprint(f" {_DIM}Usage: /reasoning <none|low|medium|high|xhigh|show|hide>{_RST}")
|
||||
return
|
||||
|
||||
arg = parts[1].strip().lower()
|
||||
@@ -5318,7 +5352,7 @@ class HermesCLI:
|
||||
parsed = _parse_reasoning_config(arg)
|
||||
if parsed is None:
|
||||
_cprint(f" {_DIM}(._.) Unknown argument: {arg}{_RST}")
|
||||
_cprint(f" {_DIM}Valid levels: none, minimal, low, medium, high, xhigh{_RST}")
|
||||
_cprint(f" {_DIM}Valid levels: none, low, minimal, medium, high, xhigh{_RST}")
|
||||
_cprint(f" {_DIM}Display: show, hide{_RST}")
|
||||
return
|
||||
|
||||
@@ -5357,7 +5391,7 @@ class HermesCLI:
|
||||
approx_tokens = estimate_messages_tokens_rough(self.conversation_history)
|
||||
print(f"🗜️ Compressing {original_count} messages (~{approx_tokens:,} tokens)...")
|
||||
|
||||
compressed, _new_system = self.agent._compress_context(
|
||||
compressed, new_system = self.agent._compress_context(
|
||||
self.conversation_history,
|
||||
self.agent._cached_system_prompt or "",
|
||||
approx_tokens=approx_tokens,
|
||||
@@ -5374,27 +5408,12 @@ class HermesCLI:
|
||||
print(f" ❌ Compression failed: {e}")
|
||||
|
||||
def _show_usage(self):
|
||||
"""Show rate limits (if available) and session token usage."""
|
||||
"""Show cumulative token usage for the current session."""
|
||||
if not self.agent:
|
||||
print("(._.) No active agent -- send a message first.")
|
||||
return
|
||||
|
||||
agent = self.agent
|
||||
calls = agent.session_api_calls
|
||||
|
||||
if calls == 0:
|
||||
print("(._.) No API calls made yet in this session.")
|
||||
return
|
||||
|
||||
# ── Rate limits (shown first when available) ────────────────
|
||||
rl_state = agent.get_rate_limit_state()
|
||||
if rl_state and rl_state.has_data:
|
||||
from agent.rate_limit_tracker import format_rate_limit_display
|
||||
print()
|
||||
print(format_rate_limit_display(rl_state))
|
||||
print()
|
||||
|
||||
# ── Session token usage ─────────────────────────────────────
|
||||
input_tokens = getattr(agent, "session_input_tokens", 0) or 0
|
||||
output_tokens = getattr(agent, "session_output_tokens", 0) or 0
|
||||
cache_read_tokens = getattr(agent, "session_cache_read_tokens", 0) or 0
|
||||
@@ -5402,7 +5421,13 @@ class HermesCLI:
|
||||
prompt = agent.session_prompt_tokens
|
||||
completion = agent.session_completion_tokens
|
||||
total = agent.session_total_tokens
|
||||
calls = agent.session_api_calls
|
||||
|
||||
if calls == 0:
|
||||
print("(._.) No API calls made yet in this session.")
|
||||
return
|
||||
|
||||
# Current context window state
|
||||
compressor = agent.context_compressor
|
||||
last_prompt = compressor.last_prompt_tokens
|
||||
ctx_len = compressor.context_length
|
||||
@@ -6180,7 +6205,6 @@ class HermesCLI:
|
||||
timeout = 45
|
||||
response_queue = queue.Queue()
|
||||
|
||||
self._capture_modal_input_snapshot()
|
||||
self._sudo_state = {
|
||||
"response_queue": response_queue,
|
||||
}
|
||||
@@ -6193,7 +6217,6 @@ class HermesCLI:
|
||||
result = response_queue.get(timeout=1)
|
||||
self._sudo_state = None
|
||||
self._sudo_deadline = 0
|
||||
self._restore_modal_input_snapshot()
|
||||
self._invalidate()
|
||||
if result:
|
||||
_cprint(f"\n{_DIM} ✓ Password received (cached for session){_RST}")
|
||||
@@ -6208,7 +6231,6 @@ class HermesCLI:
|
||||
|
||||
self._sudo_state = None
|
||||
self._sudo_deadline = 0
|
||||
self._restore_modal_input_snapshot()
|
||||
self._invalidate()
|
||||
_cprint(f"\n{_DIM} ⏱ Timeout — continuing without sudo{_RST}")
|
||||
return ""
|
||||
@@ -6381,33 +6403,6 @@ class HermesCLI:
|
||||
def _secret_capture_callback(self, var_name: str, prompt: str, metadata=None) -> dict:
|
||||
return prompt_for_secret(self, var_name, prompt, metadata)
|
||||
|
||||
def _capture_modal_input_snapshot(self) -> None:
|
||||
"""Temporarily clear the input buffer and save the user's in-progress draft."""
|
||||
if self._modal_input_snapshot is not None or not getattr(self, "_app", None):
|
||||
return
|
||||
try:
|
||||
buf = self._app.current_buffer
|
||||
self._modal_input_snapshot = {
|
||||
"text": buf.text,
|
||||
"cursor_position": buf.cursor_position,
|
||||
}
|
||||
buf.reset()
|
||||
except Exception:
|
||||
self._modal_input_snapshot = None
|
||||
|
||||
def _restore_modal_input_snapshot(self) -> None:
|
||||
"""Restore any draft text that was present before a modal prompt opened."""
|
||||
snapshot = self._modal_input_snapshot
|
||||
self._modal_input_snapshot = None
|
||||
if not snapshot or not getattr(self, "_app", None):
|
||||
return
|
||||
try:
|
||||
buf = self._app.current_buffer
|
||||
buf.text = snapshot.get("text", "")
|
||||
buf.cursor_position = min(snapshot.get("cursor_position", 0), len(buf.text))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _submit_secret_response(self, value: str) -> None:
|
||||
if not self._secret_state:
|
||||
return
|
||||
@@ -7135,7 +7130,6 @@ class HermesCLI:
|
||||
# Sudo password prompt state (similar mechanism to clarify)
|
||||
self._sudo_state = None # dict with response_queue when active
|
||||
self._sudo_deadline = 0
|
||||
self._modal_input_snapshot = None
|
||||
|
||||
# Dangerous command approval state (similar mechanism to clarify)
|
||||
self._approval_state = None # dict with command, description, choices, selected, response_queue
|
||||
@@ -7207,6 +7201,7 @@ class HermesCLI:
|
||||
text = event.app.current_buffer.text
|
||||
self._sudo_state["response_queue"].put(text)
|
||||
self._sudo_state = None
|
||||
event.app.current_buffer.reset()
|
||||
event.app.invalidate()
|
||||
return
|
||||
|
||||
@@ -7411,6 +7406,7 @@ class HermesCLI:
|
||||
if self._sudo_state:
|
||||
self._sudo_state["response_queue"].put("")
|
||||
self._sudo_state = None
|
||||
event.app.current_buffer.reset()
|
||||
event.app.invalidate()
|
||||
return
|
||||
|
||||
|
||||
+2
-3
@@ -44,7 +44,7 @@ logger = logging.getLogger(__name__)
|
||||
_KNOWN_DELIVERY_PLATFORMS = frozenset({
|
||||
"telegram", "discord", "slack", "whatsapp", "signal",
|
||||
"matrix", "mattermost", "homeassistant", "dingtalk", "feishu",
|
||||
"wecom", "sms", "email", "webhook", "bluebubbles",
|
||||
"wecom", "sms", "email", "webhook",
|
||||
})
|
||||
|
||||
from cron.jobs import get_due_jobs, mark_job_run, save_job_output, advance_next_run
|
||||
@@ -91,7 +91,7 @@ def _resolve_delivery_target(job: dict) -> Optional[dict]:
|
||||
}
|
||||
# Origin missing (e.g. job created via API/script) — try each
|
||||
# platform's home channel as a fallback instead of silently dropping.
|
||||
for platform_name in ("matrix", "telegram", "discord", "slack", "bluebubbles"):
|
||||
for platform_name in ("matrix", "telegram", "discord", "slack"):
|
||||
chat_id = os.getenv(f"{platform_name.upper()}_HOME_CHANNEL", "")
|
||||
if chat_id:
|
||||
logger.info(
|
||||
@@ -236,7 +236,6 @@ def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> Option
|
||||
"wecom": Platform.WECOM,
|
||||
"email": Platform.EMAIL,
|
||||
"sms": Platform.SMS,
|
||||
"bluebubbles": Platform.BLUEBUBBLES,
|
||||
}
|
||||
platform = platform_map.get(platform_name.lower())
|
||||
if not platform:
|
||||
|
||||
+34
-13
@@ -138,6 +138,7 @@ class HermesAgentLoop:
|
||||
max_turns: int = 30,
|
||||
task_id: Optional[str] = None,
|
||||
temperature: float = 1.0,
|
||||
top_p: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
budget_config: Optional["BudgetConfig"] = None,
|
||||
@@ -153,6 +154,7 @@ class HermesAgentLoop:
|
||||
max_turns: Maximum number of LLM calls before stopping
|
||||
task_id: Unique ID for terminal/browser session isolation
|
||||
temperature: Sampling temperature for generation
|
||||
top_p: Nucleus sampling top_p (None = omit, use provider default)
|
||||
max_tokens: Max tokens per generation (None for server default)
|
||||
extra_body: Extra parameters passed to the OpenAI client's create() call.
|
||||
Used for OpenRouter provider preferences, transforms, etc.
|
||||
@@ -168,6 +170,7 @@ class HermesAgentLoop:
|
||||
self.max_turns = max_turns
|
||||
self.task_id = task_id or str(uuid.uuid4())
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.max_tokens = max_tokens
|
||||
self.extra_body = extra_body
|
||||
self.budget_config = budget_config or DEFAULT_BUDGET
|
||||
@@ -211,6 +214,9 @@ class HermesAgentLoop:
|
||||
"temperature": self.temperature,
|
||||
}
|
||||
|
||||
if self.top_p is not None:
|
||||
chat_kwargs["top_p"] = self.top_p
|
||||
|
||||
# Only pass tools if we have them
|
||||
if self.tool_schemas:
|
||||
chat_kwargs["tools"] = self.tool_schemas
|
||||
@@ -225,20 +231,35 @@ class HermesAgentLoop:
|
||||
chat_kwargs["extra_body"] = self.extra_body
|
||||
|
||||
# Make the API call -- standard OpenAI spec
|
||||
# Retry on timeout/connection errors (provider queuing, rate limits)
|
||||
api_start = _time.monotonic()
|
||||
try:
|
||||
response = await self.server.chat_completion(**chat_kwargs)
|
||||
except Exception as e:
|
||||
api_elapsed = _time.monotonic() - api_start
|
||||
logger.error("API call failed on turn %d (%.1fs): %s", turn + 1, api_elapsed, e)
|
||||
return AgentResult(
|
||||
messages=messages,
|
||||
managed_state=self._get_managed_state(),
|
||||
turns_used=turn + 1,
|
||||
finished_naturally=False,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
)
|
||||
response = None
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
response = await self.server.chat_completion(**chat_kwargs)
|
||||
break
|
||||
except Exception as e:
|
||||
api_elapsed = _time.monotonic() - api_start
|
||||
is_retryable = "timeout" in type(e).__name__.lower() or "connection" in type(e).__name__.lower()
|
||||
if is_retryable and attempt < max_retries - 1:
|
||||
wait = 2 ** attempt
|
||||
logger.warning(
|
||||
"[%s] API call timed out on turn %d attempt %d (%.1fs), retrying in %ds: %s",
|
||||
self.task_id[:8], turn + 1, attempt + 1, api_elapsed, wait, e,
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
api_start = _time.monotonic()
|
||||
continue
|
||||
logger.error("API call failed on turn %d (%.1fs): %s", turn + 1, api_elapsed, e)
|
||||
return AgentResult(
|
||||
messages=messages,
|
||||
managed_state=self._get_managed_state(),
|
||||
turns_used=turn + 1,
|
||||
finished_naturally=False,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
)
|
||||
|
||||
api_elapsed = _time.monotonic() - api_start
|
||||
|
||||
|
||||
@@ -15,15 +15,15 @@
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal", "file"]
|
||||
max_agent_turns: 60
|
||||
max_agent_turns: 100
|
||||
max_token_length: 32000
|
||||
agent_temperature: 0.8
|
||||
agent_temperature: 1.0
|
||||
terminal_backend: "modal"
|
||||
terminal_timeout: 300 # 5 min per command (builds, pip install)
|
||||
tool_pool_size: 128 # thread pool for 89 parallel tasks
|
||||
dataset_name: "NousResearch/terminal-bench-2"
|
||||
terminal_timeout: 300 # 5 min per command (builds, pip install)
|
||||
tool_pool_size: 128 # thread pool for 89 parallel tasks
|
||||
dataset_name: "NousResearch/terminal-bench-2-verified-flattened"
|
||||
test_timeout: 600
|
||||
task_timeout: 1800 # 30 min wall-clock per task, auto-FAIL if exceeded
|
||||
task_timeout: 900 # 15 min wall-clock per task, auto-FAIL if exceeded
|
||||
tokenizer_name: "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
use_wandb: true
|
||||
wandb_name: "terminal-bench-2"
|
||||
@@ -33,10 +33,15 @@ env:
|
||||
# Modal's blocking calls (App.lookup, etc.) deadlock when too many sandboxes
|
||||
# are created simultaneously inside thread pool workers via asyncio.run().
|
||||
max_concurrent_tasks: 8
|
||||
extra_body:
|
||||
provider:
|
||||
order: ["DeepInfra"]
|
||||
allow_fallbacks: false
|
||||
|
||||
openai:
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
model_name: "anthropic/claude-opus-4.6"
|
||||
model_name: "nvidia/nemotron-3-super-120b-a12b"
|
||||
server_type: "openai"
|
||||
health_check: false
|
||||
timeout: 300 # 5 min per API call (default 1200s causes 20min stalls)
|
||||
# api_key loaded from OPENROUTER_API_KEY in .env
|
||||
|
||||
@@ -32,8 +32,8 @@ export PYTHONUNBUFFERED=1
|
||||
# These go to the log file; tqdm + [START]/[PASS]/[FAIL] go to terminal
|
||||
export LOGLEVEL=INFO
|
||||
|
||||
python terminalbench2_env.py evaluate \
|
||||
--config default.yaml \
|
||||
uv run python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
||||
--config environments/benchmarks/terminalbench_2/default.yaml \
|
||||
"$@" \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
|
||||
@@ -52,18 +52,18 @@ _repo_root = Path(__file__).resolve().parent.parent.parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import EvalHandlingEnum
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
from pydantic import Field
|
||||
|
||||
from agent.prompt_builder import DEFAULT_AGENT_IDENTITY
|
||||
from environments.agent_loop import AgentResult, HermesAgentLoop
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
from environments.tool_context import ToolContext
|
||||
from tools.terminal_tool import (
|
||||
register_task_env_overrides,
|
||||
clear_task_env_overrides,
|
||||
cleanup_vm,
|
||||
clear_task_env_overrides,
|
||||
register_task_env_overrides,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -73,6 +73,7 @@ logger = logging.getLogger(__name__)
|
||||
# Configuration
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TerminalBench2EvalConfig(HermesAgentEnvConfig):
|
||||
"""
|
||||
Configuration for the Terminal-Bench 2.0 evaluation environment.
|
||||
@@ -138,11 +139,27 @@ class TerminalBench2EvalConfig(HermesAgentEnvConfig):
|
||||
|
||||
# Tasks that cannot run properly on Modal and are excluded from scoring.
|
||||
MODAL_INCOMPATIBLE_TASKS = {
|
||||
"qemu-startup", # Needs KVM/hardware virtualization
|
||||
"qemu-alpine-ssh", # Needs KVM/hardware virtualization
|
||||
"crack-7z-hash", # Password brute-force -- too slow for cloud sandbox timeouts
|
||||
"qemu-startup", # Needs KVM/hardware virtualization
|
||||
"qemu-alpine-ssh", # Needs KVM/hardware virtualization
|
||||
"crack-7z-hash", # Password brute-force -- too slow for cloud sandbox timeouts
|
||||
}
|
||||
|
||||
# Injected as a user message when the model responds with plain text instead of
|
||||
# calling a tool or including a <task_status> tag.
|
||||
_FORMAT_NUDGE_MESSAGE = (
|
||||
"You wrote a plain text response instead of using your tools. "
|
||||
"Plain text responses do not affect the environment — nothing was executed or saved.\n\n"
|
||||
"You MUST use your tools (terminal, read_file, write_file) to actually complete the task. "
|
||||
"Do not describe what you would do — execute it now by making tool calls.\n\n"
|
||||
"If you have already completed all required work using tools in previous turns, "
|
||||
"respond with exactly: <task_status>DONE</task_status>\n"
|
||||
"If you have exhausted all approaches and cannot make further progress, "
|
||||
"respond with exactly: <task_status>UNFINISHED</task_status>"
|
||||
)
|
||||
|
||||
# Maximum number of format nudges before giving up and moving on to scoring.
|
||||
_MAX_FORMAT_NUDGES = 3
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tar extraction helper
|
||||
@@ -203,7 +220,6 @@ def _safe_extract_tar(tar: tarfile.TarFile, target_dir: Path) -> None:
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def _extract_base64_tar(b64_data: str, target_dir: Path):
|
||||
"""Extract a base64-encoded tar.gz archive into target_dir."""
|
||||
if not b64_data:
|
||||
@@ -218,6 +234,7 @@ def _extract_base64_tar(b64_data: str, target_dir: Path):
|
||||
# Main Environment
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
Terminal-Bench 2.0 evaluation environment (eval-only, no training).
|
||||
@@ -262,23 +279,18 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
enabled_toolsets=["terminal", "file"],
|
||||
disabled_toolsets=None,
|
||||
distribution=None,
|
||||
|
||||
# Agent settings -- TB2 tasks are complex, need many turns
|
||||
max_agent_turns=60,
|
||||
max_token_length=16000,
|
||||
agent_temperature=0.6,
|
||||
system_prompt=None,
|
||||
|
||||
system_prompt=DEFAULT_AGENT_IDENTITY,
|
||||
# Modal backend for per-task cloud-isolated sandboxes
|
||||
terminal_backend="modal",
|
||||
terminal_timeout=300, # 5 min per command (builds, pip install, etc.)
|
||||
|
||||
terminal_timeout=300, # 5 min per command (builds, pip install, etc.)
|
||||
# Test execution timeout (TB2 test scripts can install deps like pytest)
|
||||
test_timeout=180,
|
||||
|
||||
# 89 tasks run in parallel, each needs a thread for tool calls
|
||||
tool_pool_size=128,
|
||||
|
||||
# --- Eval-only Atropos settings ---
|
||||
# These settings make the env work as an eval-only environment:
|
||||
# - STOP_TRAIN: pauses training during eval (standard for eval envs)
|
||||
@@ -288,7 +300,6 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
group_size=1,
|
||||
steps_per_eval=1,
|
||||
total_steps=1,
|
||||
|
||||
tokenizer_name="NousResearch/Hermes-3-Llama-3.1-8B",
|
||||
use_wandb=True,
|
||||
wandb_name="terminal-bench-2",
|
||||
@@ -336,7 +347,11 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
|
||||
# Skip tasks incompatible with the current backend (e.g., QEMU on Modal)
|
||||
# plus any user-specified skip_tasks
|
||||
skip = set(MODAL_INCOMPATIBLE_TASKS) if self.config.terminal_backend == "modal" else set()
|
||||
skip = (
|
||||
set(MODAL_INCOMPATIBLE_TASKS)
|
||||
if self.config.terminal_backend == "modal"
|
||||
else set()
|
||||
)
|
||||
if self.config.skip_tasks:
|
||||
skip |= {name.strip() for name in self.config.skip_tasks.split(",")}
|
||||
if skip:
|
||||
@@ -344,7 +359,9 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
tasks = [t for t in tasks if t["task_name"] not in skip]
|
||||
skipped = before - len(tasks)
|
||||
if skipped > 0:
|
||||
print(f" Skipped {skipped} incompatible tasks: {sorted(skip & {t['task_name'] for t in ds})}")
|
||||
print(
|
||||
f" Skipped {skipped} incompatible tasks: {sorted(skip & {t['task_name'] for t in ds})}"
|
||||
)
|
||||
|
||||
self.all_eval_items = tasks
|
||||
self.iter = 0
|
||||
@@ -354,6 +371,16 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
for i, task in enumerate(self.all_eval_items):
|
||||
self.category_index[task.get("category", "unknown")].append(i)
|
||||
|
||||
# Pre-compute which tasks need Modal's add_python (avoids re-decoding
|
||||
# multi-MB environment_tar blobs during per-task rollouts).
|
||||
self._needs_add_python: Dict[str, bool] = {
|
||||
task["task_name"]: self._image_needs_add_python(task)
|
||||
for task in self.all_eval_items
|
||||
}
|
||||
add_py_count = sum(self._needs_add_python.values())
|
||||
if add_py_count:
|
||||
print(f" {add_py_count} tasks need add_python (non-python base image)")
|
||||
|
||||
# Reward tracking for wandb logging
|
||||
self.eval_metrics: List[Tuple[str, float]] = []
|
||||
|
||||
@@ -361,15 +388,30 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
# immediately on completion so data is preserved even on Ctrl+C.
|
||||
# Timestamped filename so each run produces a unique file.
|
||||
import datetime
|
||||
|
||||
log_dir = os.path.join(os.path.dirname(__file__), "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
run_ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
self._streaming_path = os.path.join(log_dir, f"samples_{run_ts}.jsonl")
|
||||
model_name = self.server.servers[0].config.model_name
|
||||
model_slug = model_name.replace("/", "_").replace(":", "_")
|
||||
self._streaming_path = os.path.join(
|
||||
log_dir, f"samples_{run_ts}_{model_slug}.jsonl"
|
||||
)
|
||||
self._streaming_file = open(self._streaming_path, "w")
|
||||
self._streaming_lock = __import__("threading").Lock()
|
||||
self._run_meta = {
|
||||
"model_name": model_name,
|
||||
"temperature": self.config.agent_temperature,
|
||||
"top_p": self.config.agent_top_p,
|
||||
"max_agent_turns": self.config.max_agent_turns,
|
||||
"task_timeout": self.config.task_timeout,
|
||||
"terminal_backend": self.config.terminal_backend,
|
||||
}
|
||||
print(f" Streaming results to: {self._streaming_path}")
|
||||
|
||||
print(f"TB2 ready: {len(self.all_eval_items)} tasks across {len(self.category_index)} categories")
|
||||
print(
|
||||
f"TB2 ready: {len(self.all_eval_items)} tasks across {len(self.category_index)} categories"
|
||||
)
|
||||
for cat, indices in sorted(self.category_index.items()):
|
||||
print(f" {cat}: {len(indices)} tasks")
|
||||
|
||||
@@ -378,7 +420,9 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
if not hasattr(self, "_streaming_file") or self._streaming_file.closed:
|
||||
return
|
||||
with self._streaming_lock:
|
||||
self._streaming_file.write(json.dumps(result, ensure_ascii=False, default=str) + "\n")
|
||||
self._streaming_file.write(
|
||||
json.dumps(result, ensure_ascii=False, default=str) + "\n"
|
||||
)
|
||||
self._streaming_file.flush()
|
||||
|
||||
# =========================================================================
|
||||
@@ -414,6 +458,36 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
# Docker image resolution
|
||||
# =========================================================================
|
||||
|
||||
@staticmethod
|
||||
def _image_needs_add_python(item: Dict[str, Any]) -> bool:
|
||||
"""Check if the task's base image lacks `python` on PATH.
|
||||
|
||||
Parses the Dockerfile FROM line in environment_tar. Returns True
|
||||
for non-python base images (ubuntu, debian, etc.) that need
|
||||
Modal's add_python parameter.
|
||||
"""
|
||||
environment_tar = item.get("environment_tar", "")
|
||||
if not environment_tar:
|
||||
return False
|
||||
try:
|
||||
raw = base64.b64decode(environment_tar)
|
||||
buf = io.BytesIO(raw)
|
||||
with tarfile.open(fileobj=buf, mode="r:gz") as tar:
|
||||
for member in tar:
|
||||
if not member.isfile() or "Dockerfile" not in member.name:
|
||||
continue
|
||||
f = tar.extractfile(member)
|
||||
if not f:
|
||||
continue
|
||||
for line in f.read().decode("utf-8", errors="ignore").splitlines():
|
||||
stripped = line.strip()
|
||||
if stripped.upper().startswith("FROM "):
|
||||
base = stripped.split()[1].lower()
|
||||
return not base.startswith("python:")
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
def _resolve_task_image(
|
||||
self, item: Dict[str, Any], task_name: str
|
||||
) -> Tuple[str, Optional[Path]]:
|
||||
@@ -446,7 +520,9 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
if dockerfile_path.exists():
|
||||
logger.info(
|
||||
"Task %s: building from Dockerfile (force_build=%s, docker_image=%s)",
|
||||
task_name, self.config.force_build, bool(docker_image),
|
||||
task_name,
|
||||
self.config.force_build,
|
||||
bool(docker_image),
|
||||
)
|
||||
return str(dockerfile_path), task_dir
|
||||
|
||||
@@ -454,12 +530,80 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
if docker_image:
|
||||
logger.warning(
|
||||
"Task %s: force_build=True but no environment_tar, "
|
||||
"falling back to docker_image %s", task_name, docker_image,
|
||||
"falling back to docker_image %s",
|
||||
task_name,
|
||||
docker_image,
|
||||
)
|
||||
return docker_image, None
|
||||
|
||||
return "", None
|
||||
|
||||
# =========================================================================
|
||||
# Agent loop with format nudging
|
||||
# =========================================================================
|
||||
|
||||
async def _run_with_nudges(
|
||||
self,
|
||||
server,
|
||||
tools: List[Dict[str, Any]],
|
||||
valid_names: set,
|
||||
messages: List[Dict[str, Any]],
|
||||
task_id: str,
|
||||
task_name: str,
|
||||
) -> Tuple["AgentResult", int]:
|
||||
"""Run the agent loop, nudging if the model returns plain text without task_status tag."""
|
||||
total_turns_used = 0
|
||||
nudge_count = 0
|
||||
result = None
|
||||
|
||||
while total_turns_used < self.config.max_agent_turns:
|
||||
remaining = self.config.max_agent_turns - total_turns_used
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=remaining,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
top_p=self.config.agent_top_p,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
total_turns_used += result.turns_used
|
||||
|
||||
if not result.finished_naturally:
|
||||
break
|
||||
|
||||
last_content = next(
|
||||
(
|
||||
m.get("content", "") or ""
|
||||
for m in reversed(messages)
|
||||
if m.get("role") == "assistant"
|
||||
),
|
||||
"",
|
||||
)
|
||||
if "<task_status>" in last_content:
|
||||
break
|
||||
|
||||
if nudge_count >= _MAX_FORMAT_NUDGES:
|
||||
logger.warning(
|
||||
"Task %s: model ignored %d format nudges; stopping.",
|
||||
task_name,
|
||||
nudge_count,
|
||||
)
|
||||
break
|
||||
nudge_count += 1
|
||||
logger.info(
|
||||
"Task %s: nudging model (nudge %d/%d) — no tool calls and no task_status",
|
||||
task_name,
|
||||
nudge_count,
|
||||
_MAX_FORMAT_NUDGES,
|
||||
)
|
||||
messages.append({"role": "user", "content": _FORMAT_NUDGE_MESSAGE})
|
||||
|
||||
return result, total_turns_used
|
||||
|
||||
# =========================================================================
|
||||
# Per-task evaluation -- agent loop + test verification
|
||||
# =========================================================================
|
||||
@@ -488,6 +632,7 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
task_dir = None # Set if we extract a Dockerfile (needs cleanup)
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
tqdm.write(f" [START] {task_name} (task_id={task_id[:8]})")
|
||||
task_start = time.time()
|
||||
|
||||
@@ -495,24 +640,32 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
# --- 1. Resolve Docker image ---
|
||||
modal_image, task_dir = self._resolve_task_image(eval_item, task_name)
|
||||
if not modal_image:
|
||||
logger.error("Task %s: no docker_image or environment_tar, skipping", task_name)
|
||||
logger.error(
|
||||
"Task %s: no docker_image or environment_tar, skipping", task_name
|
||||
)
|
||||
return {
|
||||
"passed": False, "reward": 0.0,
|
||||
"task_name": task_name, "category": category,
|
||||
"passed": False,
|
||||
"reward": 0.0,
|
||||
"task_name": task_name,
|
||||
"category": category,
|
||||
"error": "no_image",
|
||||
}
|
||||
|
||||
# --- 2. Register per-task image override ---
|
||||
# Set both modal_image and docker_image so the task image is used
|
||||
# regardless of which backend is configured.
|
||||
register_task_env_overrides(task_id, {
|
||||
overrides = {
|
||||
"modal_image": modal_image,
|
||||
"docker_image": modal_image,
|
||||
"cwd": "/app",
|
||||
})
|
||||
}
|
||||
if self._needs_add_python.get(task_name, False):
|
||||
overrides["add_python"] = "3.12"
|
||||
register_task_env_overrides(task_id, overrides)
|
||||
logger.info(
|
||||
"Task %s: registered image override for task_id %s",
|
||||
task_name, task_id[:8],
|
||||
task_name,
|
||||
task_id[:8],
|
||||
)
|
||||
|
||||
# --- 3. Resolve tools and build messages ---
|
||||
@@ -520,53 +673,48 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
|
||||
messages: List[Dict[str, Any]] = []
|
||||
if self.config.system_prompt:
|
||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
messages.append(
|
||||
{"role": "system", "content": self.config.system_prompt}
|
||||
)
|
||||
messages.append({"role": "user", "content": self.format_prompt(eval_item)})
|
||||
|
||||
# --- 4. Run agent loop ---
|
||||
# Use ManagedServer (Phase 2) for vLLM/SGLang backends to get
|
||||
# token-level tracking via /generate. Falls back to direct
|
||||
# ServerManager (Phase 1) for OpenAI endpoints.
|
||||
# --- 4. Run agent loop with format enforcement ---
|
||||
# The model must either call a tool or end with <task_status>DONE/UNFINISHED</task_status>.
|
||||
# If it returns plain text without the tag, inject a nudge user message and
|
||||
# continue with the remaining turn budget (up to _MAX_FORMAT_NUDGES times).
|
||||
if self._use_managed_server():
|
||||
async with self.server.managed_server(
|
||||
tokenizer=self.tokenizer,
|
||||
preserve_think_blocks=bool(self.config.thinking_mode),
|
||||
) as managed:
|
||||
agent = HermesAgentLoop(
|
||||
result, total_turns_used = await self._run_with_nudges(
|
||||
server=managed,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
tools=tools,
|
||||
valid_names=valid_names,
|
||||
messages=messages,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
budget_config=self.config.build_budget_config(),
|
||||
task_name=task_name,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
else:
|
||||
agent = HermesAgentLoop(
|
||||
result, total_turns_used = await self._run_with_nudges(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
tools=tools,
|
||||
valid_names=valid_names,
|
||||
messages=messages,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
budget_config=self.config.build_budget_config(),
|
||||
task_name=task_name,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
# --- 5. Verify -- run test suite in the agent's sandbox ---
|
||||
# Skip verification if the agent produced no meaningful output
|
||||
only_system_and_user = all(
|
||||
msg.get("role") in ("system", "user") for msg in result.messages
|
||||
msg.get("role") in ("system", "user") for msg in messages
|
||||
)
|
||||
if result.turns_used == 0 or only_system_and_user:
|
||||
if total_turns_used == 0 or only_system_and_user:
|
||||
logger.warning(
|
||||
"Task %s: agent produced no output (turns=%d). Reward=0.",
|
||||
task_name, result.turns_used,
|
||||
task_name,
|
||||
total_turns_used,
|
||||
)
|
||||
reward = 0.0
|
||||
else:
|
||||
@@ -578,7 +726,10 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
loop = asyncio.get_event_loop()
|
||||
reward = await loop.run_in_executor(
|
||||
None, # default thread pool
|
||||
self._run_tests, eval_item, ctx, task_name,
|
||||
self._run_tests,
|
||||
eval_item,
|
||||
ctx,
|
||||
task_name,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Task %s: test verification failed: %s", task_name, e)
|
||||
@@ -589,20 +740,26 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
passed = reward == 1.0
|
||||
status = "PASS" if passed else "FAIL"
|
||||
elapsed = time.time() - task_start
|
||||
tqdm.write(f" [{status}] {task_name} (turns={result.turns_used}, {elapsed:.0f}s)")
|
||||
tqdm.write(
|
||||
f" [{status}] {task_name} (turns={total_turns_used}, {elapsed:.0f}s)"
|
||||
)
|
||||
logger.info(
|
||||
"Task %s: reward=%.1f, turns=%d, finished=%s",
|
||||
task_name, reward, result.turns_used, result.finished_naturally,
|
||||
task_name,
|
||||
reward,
|
||||
total_turns_used,
|
||||
result.finished_naturally,
|
||||
)
|
||||
|
||||
out = {
|
||||
**self._run_meta,
|
||||
"passed": passed,
|
||||
"reward": reward,
|
||||
"task_name": task_name,
|
||||
"category": category,
|
||||
"turns_used": result.turns_used,
|
||||
"turns_used": total_turns_used,
|
||||
"finished_naturally": result.finished_naturally,
|
||||
"messages": result.messages,
|
||||
"messages": messages,
|
||||
}
|
||||
self._save_result(out)
|
||||
return out
|
||||
@@ -612,8 +769,11 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
logger.error("Task %s: rollout failed: %s", task_name, e, exc_info=True)
|
||||
tqdm.write(f" [ERROR] {task_name}: {e} ({elapsed:.0f}s)")
|
||||
out = {
|
||||
"passed": False, "reward": 0.0,
|
||||
"task_name": task_name, "category": category,
|
||||
**self._run_meta,
|
||||
"passed": False,
|
||||
"reward": 0.0,
|
||||
"task_name": task_name,
|
||||
"category": category,
|
||||
"error": str(e),
|
||||
}
|
||||
self._save_result(out)
|
||||
@@ -686,7 +846,8 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
# Execute the test suite
|
||||
logger.info(
|
||||
"Task %s: running test suite (timeout=%ds)",
|
||||
task_name, self.config.test_timeout,
|
||||
task_name,
|
||||
self.config.test_timeout,
|
||||
)
|
||||
test_result = ctx.terminal(
|
||||
"bash /tests/test.sh",
|
||||
@@ -719,7 +880,9 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
logger.warning(
|
||||
"Task %s: reward.txt content unexpected (%r), "
|
||||
"falling back to exit_code=%d",
|
||||
task_name, content, exit_code,
|
||||
task_name,
|
||||
content,
|
||||
exit_code,
|
||||
)
|
||||
reward = 1.0 if exit_code == 0 else 0.0
|
||||
else:
|
||||
@@ -727,14 +890,17 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
logger.warning(
|
||||
"Task %s: reward.txt not found after download, "
|
||||
"falling back to exit_code=%d",
|
||||
task_name, exit_code,
|
||||
task_name,
|
||||
exit_code,
|
||||
)
|
||||
reward = 1.0 if exit_code == 0 else 0.0
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Task %s: failed to download verifier dir: %s, "
|
||||
"falling back to exit_code=%d",
|
||||
task_name, e, exit_code,
|
||||
task_name,
|
||||
e,
|
||||
exit_code,
|
||||
)
|
||||
reward = 1.0 if exit_code == 0 else 0.0
|
||||
finally:
|
||||
@@ -745,7 +911,9 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
output_preview = output[-500:] if output else "(no output)"
|
||||
logger.info(
|
||||
"Task %s: FAIL (exit_code=%d)\n%s",
|
||||
task_name, exit_code, output_preview,
|
||||
task_name,
|
||||
exit_code,
|
||||
output_preview,
|
||||
)
|
||||
|
||||
return reward
|
||||
@@ -770,12 +938,18 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
from tqdm import tqdm
|
||||
|
||||
elapsed = self.config.task_timeout
|
||||
tqdm.write(f" [TIMEOUT] {task_name} (exceeded {elapsed}s wall-clock limit)")
|
||||
tqdm.write(
|
||||
f" [TIMEOUT] {task_name} (exceeded {elapsed}s wall-clock limit)"
|
||||
)
|
||||
logger.error("Task %s: wall-clock timeout after %ds", task_name, elapsed)
|
||||
out = {
|
||||
"passed": False, "reward": 0.0,
|
||||
"task_name": task_name, "category": category,
|
||||
**self._run_meta,
|
||||
"passed": False,
|
||||
"reward": 0.0,
|
||||
"task_name": task_name,
|
||||
"category": category,
|
||||
"error": f"timeout ({elapsed}s)",
|
||||
}
|
||||
self._save_result(out)
|
||||
@@ -809,23 +983,25 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
self.handleError(record)
|
||||
|
||||
handler = _TqdmHandler()
|
||||
handler.setFormatter(logging.Formatter(
|
||||
"%(asctime)s [%(name)s] %(levelname)s: %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
))
|
||||
handler.setFormatter(
|
||||
logging.Formatter(
|
||||
"%(asctime)s [%(name)s] %(levelname)s: %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
)
|
||||
root = logging.getLogger()
|
||||
root.handlers = [handler] # Replace any existing handlers
|
||||
root.setLevel(logging.INFO)
|
||||
|
||||
# Silence noisy third-party loggers that flood the output
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING) # Every HTTP request
|
||||
logging.getLogger("openai").setLevel(logging.WARNING) # OpenAI client retries
|
||||
logging.getLogger("rex-deploy").setLevel(logging.WARNING) # Swerex deployment
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING) # Every HTTP request
|
||||
logging.getLogger("openai").setLevel(logging.WARNING) # OpenAI client retries
|
||||
logging.getLogger("rex-deploy").setLevel(logging.WARNING) # Swerex deployment
|
||||
logging.getLogger("rex_image_builder").setLevel(logging.WARNING) # Image builds
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"\n{'=' * 60}")
|
||||
print("Starting Terminal-Bench 2.0 Evaluation")
|
||||
print(f"{'='*60}")
|
||||
print(f"{'=' * 60}")
|
||||
print(f" Dataset: {self.config.dataset_name}")
|
||||
print(f" Total tasks: {len(self.all_eval_items)}")
|
||||
print(f" Max agent turns: {self.config.max_agent_turns}")
|
||||
@@ -833,9 +1009,11 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
print(f" Terminal backend: {self.config.terminal_backend}")
|
||||
print(f" Tool thread pool: {self.config.tool_pool_size}")
|
||||
print(f" Terminal timeout: {self.config.terminal_timeout}s/cmd")
|
||||
print(f" Terminal lifetime: {self.config.terminal_lifetime}s (auto: task_timeout + 120)")
|
||||
print(
|
||||
f" Terminal lifetime: {self.config.terminal_lifetime}s (auto: task_timeout + 120)"
|
||||
)
|
||||
print(f" Max concurrent tasks: {self.config.max_concurrent_tasks}")
|
||||
print(f"{'='*60}\n")
|
||||
print(f"{'=' * 60}\n")
|
||||
|
||||
# Semaphore to limit concurrent Modal sandbox creations.
|
||||
# Without this, all 86 tasks fire simultaneously, each creating a Modal
|
||||
@@ -877,6 +1055,7 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
await asyncio.gather(*eval_tasks, return_exceptions=True)
|
||||
# Belt-and-suspenders: clean up any remaining sandboxes
|
||||
from tools.terminal_tool import cleanup_all_environments
|
||||
|
||||
cleanup_all_environments()
|
||||
print("All sandboxes cleaned up.")
|
||||
return
|
||||
@@ -922,9 +1101,9 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
self.eval_metrics = [(k, v) for k, v in eval_metrics.items()]
|
||||
|
||||
# ---- Print summary ----
|
||||
print(f"\n{'='*60}")
|
||||
print(f"\n{'=' * 60}")
|
||||
print("Terminal-Bench 2.0 Evaluation Results")
|
||||
print(f"{'='*60}")
|
||||
print(f"{'=' * 60}")
|
||||
print(f"Overall Pass Rate: {overall_pass_rate:.4f} ({passed}/{total})")
|
||||
print(f"Evaluation Time: {end_time - start_time:.1f} seconds")
|
||||
|
||||
@@ -944,7 +1123,7 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
extra = f" (error: {error})" if error else ""
|
||||
print(f" [{status}] {r['task_name']} (turns={turns}){extra}")
|
||||
|
||||
print(f"{'='*60}\n")
|
||||
print(f"{'=' * 60}\n")
|
||||
|
||||
# Build sample records for evaluate_log (includes full conversations)
|
||||
samples = [
|
||||
@@ -969,6 +1148,7 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
end_time=end_time,
|
||||
generation_parameters={
|
||||
"temperature": self.config.agent_temperature,
|
||||
"top_p": self.config.agent_top_p,
|
||||
"max_tokens": self.config.max_token_length,
|
||||
"max_agent_turns": self.config.max_agent_turns,
|
||||
"terminal_backend": self.config.terminal_backend,
|
||||
@@ -985,6 +1165,7 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
# Kill all remaining sandboxes. Timed-out tasks leave orphaned thread
|
||||
# pool workers still executing commands -- cleanup_all stops them.
|
||||
from tools.terminal_tool import cleanup_all_environments
|
||||
|
||||
print("\nCleaning up all sandboxes...")
|
||||
cleanup_all_environments()
|
||||
|
||||
@@ -992,6 +1173,7 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
# tasks are killed immediately instead of retrying against dead
|
||||
# sandboxes and spamming the console with TimeoutError warnings.
|
||||
from environments.agent_loop import _tool_executor
|
||||
|
||||
_tool_executor.shutdown(wait=False, cancel_futures=True)
|
||||
print("Done.")
|
||||
|
||||
|
||||
@@ -115,6 +115,10 @@ class HermesAgentEnvConfig(BaseEnvConfig):
|
||||
default=1.0,
|
||||
description="Sampling temperature for agent generation during rollouts.",
|
||||
)
|
||||
agent_top_p: Optional[float] = Field(
|
||||
default=None,
|
||||
description="Nucleus sampling top_p for agent generation. None = provider default.",
|
||||
)
|
||||
|
||||
# --- Terminal backend ---
|
||||
terminal_backend: str = Field(
|
||||
@@ -529,6 +533,7 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
top_p=self.config.agent_top_p,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
budget_config=self.config.build_budget_config(),
|
||||
@@ -547,6 +552,7 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
top_p=self.config.agent_top_p,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
budget_config=self.config.build_budget_config(),
|
||||
@@ -561,6 +567,7 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
top_p=self.config.agent_top_p,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
budget_config=self.config.build_budget_config(),
|
||||
|
||||
Generated
+4
-4
@@ -22,16 +22,16 @@
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1775036866,
|
||||
"narHash": "sha256-ZojAnPuCdy657PbTq5V0Y+AHKhZAIwSIT2cb8UgAz/U=",
|
||||
"lastModified": 1751274312,
|
||||
"narHash": "sha256-/bVBlRpECLVzjV19t5KMdMFWSwKLtb5RyXdjz3LJT+g=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "6201e203d09599479a3b3450ed24fa81537ebc4e",
|
||||
"rev": "50ab793786d9de88ee30ec4e4c24fb4236fc2674",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-unstable",
|
||||
"ref": "nixos-24.11",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
description = "Hermes Agent - AI agent framework by Nous Research";
|
||||
|
||||
inputs = {
|
||||
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
|
||||
nixpkgs.url = "github:NixOS/nixpkgs/nixos-24.11";
|
||||
flake-parts = {
|
||||
url = "github:hercules-ci/flake-parts";
|
||||
inputs.nixpkgs-lib.follows = "nixpkgs";
|
||||
|
||||
@@ -77,7 +77,7 @@ def build_channel_directory(adapters: Dict[Any, Any]) -> Dict[str, Any]:
|
||||
logger.warning("Channel directory: failed to build %s: %s", platform.value, e)
|
||||
|
||||
# Telegram, WhatsApp & Signal can't enumerate chats -- pull from session history
|
||||
for plat_name in ("telegram", "whatsapp", "signal", "email", "sms", "bluebubbles"):
|
||||
for plat_name in ("telegram", "whatsapp", "signal", "email", "sms"):
|
||||
if plat_name not in platforms:
|
||||
platforms[plat_name] = _build_from_sessions(plat_name)
|
||||
|
||||
|
||||
@@ -63,7 +63,6 @@ class Platform(Enum):
|
||||
WEBHOOK = "webhook"
|
||||
FEISHU = "feishu"
|
||||
WECOM = "wecom"
|
||||
BLUEBUBBLES = "bluebubbles"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -288,9 +287,6 @@ class GatewayConfig:
|
||||
# WeCom uses extra dict for bot credentials
|
||||
elif platform == Platform.WECOM and config.extra.get("bot_id"):
|
||||
connected.append(platform)
|
||||
# BlueBubbles uses extra dict for local server config
|
||||
elif platform == Platform.BLUEBUBBLES and config.extra.get("server_url") and config.extra.get("password"):
|
||||
connected.append(platform)
|
||||
return connected
|
||||
|
||||
def get_home_channel(self, platform: Platform) -> Optional[HomeChannel]:
|
||||
@@ -532,8 +528,6 @@ def load_gateway_config() -> GatewayConfig:
|
||||
bridged["reply_prefix"] = platform_cfg["reply_prefix"]
|
||||
if "require_mention" in platform_cfg:
|
||||
bridged["require_mention"] = platform_cfg["require_mention"]
|
||||
if "free_response_channels" in platform_cfg:
|
||||
bridged["free_response_channels"] = platform_cfg["free_response_channels"]
|
||||
if "mention_patterns" in platform_cfg:
|
||||
bridged["mention_patterns"] = platform_cfg["mention_patterns"]
|
||||
if not bridged:
|
||||
@@ -548,19 +542,6 @@ def load_gateway_config() -> GatewayConfig:
|
||||
plat_data["extra"] = extra
|
||||
extra.update(bridged)
|
||||
|
||||
# Slack settings → env vars (env vars take precedence)
|
||||
slack_cfg = yaml_cfg.get("slack", {})
|
||||
if isinstance(slack_cfg, dict):
|
||||
if "require_mention" in slack_cfg and not os.getenv("SLACK_REQUIRE_MENTION"):
|
||||
os.environ["SLACK_REQUIRE_MENTION"] = str(slack_cfg["require_mention"]).lower()
|
||||
if "allow_bots" in slack_cfg and not os.getenv("SLACK_ALLOW_BOTS"):
|
||||
os.environ["SLACK_ALLOW_BOTS"] = str(slack_cfg["allow_bots"]).lower()
|
||||
frc = slack_cfg.get("free_response_channels")
|
||||
if frc is not None and not os.getenv("SLACK_FREE_RESPONSE_CHANNELS"):
|
||||
if isinstance(frc, list):
|
||||
frc = ",".join(str(v) for v in frc)
|
||||
os.environ["SLACK_FREE_RESPONSE_CHANNELS"] = str(frc)
|
||||
|
||||
# Discord settings → env vars (env vars take precedence)
|
||||
discord_cfg = yaml_cfg.get("discord", {})
|
||||
if isinstance(discord_cfg, dict):
|
||||
@@ -967,29 +948,6 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
name=os.getenv("WECOM_HOME_CHANNEL_NAME", "Home"),
|
||||
)
|
||||
|
||||
# BlueBubbles (iMessage)
|
||||
bluebubbles_server_url = os.getenv("BLUEBUBBLES_SERVER_URL")
|
||||
bluebubbles_password = os.getenv("BLUEBUBBLES_PASSWORD")
|
||||
if bluebubbles_server_url and bluebubbles_password:
|
||||
if Platform.BLUEBUBBLES not in config.platforms:
|
||||
config.platforms[Platform.BLUEBUBBLES] = PlatformConfig()
|
||||
config.platforms[Platform.BLUEBUBBLES].enabled = True
|
||||
config.platforms[Platform.BLUEBUBBLES].extra.update({
|
||||
"server_url": bluebubbles_server_url.rstrip("/"),
|
||||
"password": bluebubbles_password,
|
||||
"webhook_host": os.getenv("BLUEBUBBLES_WEBHOOK_HOST", "127.0.0.1"),
|
||||
"webhook_port": int(os.getenv("BLUEBUBBLES_WEBHOOK_PORT", "8645")),
|
||||
"webhook_path": os.getenv("BLUEBUBBLES_WEBHOOK_PATH", "/bluebubbles-webhook"),
|
||||
"send_read_receipts": os.getenv("BLUEBUBBLES_SEND_READ_RECEIPTS", "true").lower() in ("true", "1", "yes"),
|
||||
})
|
||||
bluebubbles_home = os.getenv("BLUEBUBBLES_HOME_CHANNEL")
|
||||
if bluebubbles_home and Platform.BLUEBUBBLES in config.platforms:
|
||||
config.platforms[Platform.BLUEBUBBLES].home_channel = HomeChannel(
|
||||
platform=Platform.BLUEBUBBLES,
|
||||
chat_id=bluebubbles_home,
|
||||
name=os.getenv("BLUEBUBBLES_HOME_CHANNEL_NAME", "Home"),
|
||||
)
|
||||
|
||||
# Session settings
|
||||
idle_minutes = os.getenv("SESSION_IDLE_MINUTES")
|
||||
if idle_minutes:
|
||||
|
||||
@@ -124,6 +124,53 @@ class DeliveryRouter:
|
||||
self.adapters = adapters or {}
|
||||
self.output_dir = get_hermes_home() / "cron" / "output"
|
||||
|
||||
def resolve_targets(
|
||||
self,
|
||||
deliver: Union[str, List[str]],
|
||||
origin: Optional[SessionSource] = None
|
||||
) -> List[DeliveryTarget]:
|
||||
"""
|
||||
Resolve delivery specification to concrete targets.
|
||||
|
||||
Args:
|
||||
deliver: Delivery spec - "origin", "telegram", ["local", "discord"], etc.
|
||||
origin: The source where the request originated (for "origin" target)
|
||||
|
||||
Returns:
|
||||
List of resolved delivery targets
|
||||
"""
|
||||
if isinstance(deliver, str):
|
||||
deliver = [deliver]
|
||||
|
||||
targets = []
|
||||
seen_platforms = set()
|
||||
|
||||
for target_str in deliver:
|
||||
target = DeliveryTarget.parse(target_str, origin)
|
||||
|
||||
# Resolve home channel if needed
|
||||
if target.chat_id is None and target.platform != Platform.LOCAL:
|
||||
home = self.config.get_home_channel(target.platform)
|
||||
if home:
|
||||
target.chat_id = home.chat_id
|
||||
else:
|
||||
# No home channel configured, skip this platform
|
||||
continue
|
||||
|
||||
# Deduplicate
|
||||
key = (target.platform, target.chat_id, target.thread_id)
|
||||
if key not in seen_platforms:
|
||||
seen_platforms.add(key)
|
||||
targets.append(target)
|
||||
|
||||
# Always include local if configured
|
||||
if self.config.always_log_local:
|
||||
local_key = (Platform.LOCAL, None, None)
|
||||
if local_key not in seen_platforms:
|
||||
targets.append(DeliveryTarget(platform=Platform.LOCAL))
|
||||
|
||||
return targets
|
||||
|
||||
async def deliver(
|
||||
self,
|
||||
content: str,
|
||||
@@ -252,5 +299,19 @@ class DeliveryRouter:
|
||||
return await adapter.send(target.chat_id, content, metadata=send_metadata or None)
|
||||
|
||||
|
||||
def parse_deliver_spec(
|
||||
deliver: Optional[Union[str, List[str]]],
|
||||
origin: Optional[SessionSource] = None,
|
||||
default: str = "origin"
|
||||
) -> Union[str, List[str]]:
|
||||
"""
|
||||
Normalize a delivery specification.
|
||||
|
||||
If None or empty, returns the default.
|
||||
"""
|
||||
if not deliver:
|
||||
return default
|
||||
return deliver
|
||||
|
||||
|
||||
|
||||
|
||||
+1
-130
@@ -10,142 +10,18 @@ import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _detect_macos_system_proxy() -> str | None:
|
||||
"""Read the macOS system HTTP(S) proxy via ``scutil --proxy``.
|
||||
|
||||
Returns an ``http://host:port`` URL string if an HTTP or HTTPS proxy is
|
||||
enabled, otherwise *None*. Falls back silently on non-macOS or on any
|
||||
subprocess error.
|
||||
"""
|
||||
if sys.platform != "darwin":
|
||||
return None
|
||||
try:
|
||||
out = subprocess.check_output(
|
||||
["scutil", "--proxy"], timeout=3, text=True, stderr=subprocess.DEVNULL,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
props: dict[str, str] = {}
|
||||
for line in out.splitlines():
|
||||
line = line.strip()
|
||||
if " : " in line:
|
||||
key, _, val = line.partition(" : ")
|
||||
props[key.strip()] = val.strip()
|
||||
|
||||
# Prefer HTTPS, fall back to HTTP
|
||||
for enable_key, host_key, port_key in (
|
||||
("HTTPSEnable", "HTTPSProxy", "HTTPSPort"),
|
||||
("HTTPEnable", "HTTPProxy", "HTTPPort"),
|
||||
):
|
||||
if props.get(enable_key) == "1":
|
||||
host = props.get(host_key)
|
||||
port = props.get(port_key)
|
||||
if host and port:
|
||||
return f"http://{host}:{port}"
|
||||
return None
|
||||
|
||||
|
||||
def resolve_proxy_url(platform_env_var: str | None = None) -> str | None:
|
||||
"""Return a proxy URL from env vars, or macOS system proxy.
|
||||
|
||||
Check order:
|
||||
0. *platform_env_var* (e.g. ``DISCORD_PROXY``) — highest priority
|
||||
1. HTTPS_PROXY / HTTP_PROXY / ALL_PROXY (and lowercase variants)
|
||||
2. macOS system proxy via ``scutil --proxy`` (auto-detect)
|
||||
|
||||
Returns *None* if no proxy is found.
|
||||
"""
|
||||
if platform_env_var:
|
||||
value = (os.environ.get(platform_env_var) or "").strip()
|
||||
if value:
|
||||
return value
|
||||
for key in ("HTTPS_PROXY", "HTTP_PROXY", "ALL_PROXY",
|
||||
"https_proxy", "http_proxy", "all_proxy"):
|
||||
value = (os.environ.get(key) or "").strip()
|
||||
if value:
|
||||
return value
|
||||
return _detect_macos_system_proxy()
|
||||
|
||||
|
||||
def proxy_kwargs_for_bot(proxy_url: str | None) -> dict:
|
||||
"""Build kwargs for ``commands.Bot()`` / ``discord.Client()`` with proxy.
|
||||
|
||||
Returns:
|
||||
- SOCKS URL → ``{"connector": ProxyConnector(..., rdns=True)}``
|
||||
- HTTP URL → ``{"proxy": url}``
|
||||
- *None* → ``{}``
|
||||
|
||||
``rdns=True`` forces remote DNS resolution through the proxy — required
|
||||
by many SOCKS implementations (Shadowrocket, Clash) and essential for
|
||||
bypassing DNS pollution behind the GFW.
|
||||
"""
|
||||
if not proxy_url:
|
||||
return {}
|
||||
if proxy_url.lower().startswith("socks"):
|
||||
try:
|
||||
from aiohttp_socks import ProxyConnector
|
||||
|
||||
connector = ProxyConnector.from_url(proxy_url, rdns=True)
|
||||
return {"connector": connector}
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"aiohttp_socks not installed — SOCKS proxy %s ignored. "
|
||||
"Run: pip install aiohttp-socks",
|
||||
proxy_url,
|
||||
)
|
||||
return {}
|
||||
return {"proxy": proxy_url}
|
||||
|
||||
|
||||
def proxy_kwargs_for_aiohttp(proxy_url: str | None) -> tuple[dict, dict]:
|
||||
"""Build kwargs for standalone ``aiohttp.ClientSession`` with proxy.
|
||||
|
||||
Returns ``(session_kwargs, request_kwargs)`` where:
|
||||
- SOCKS → ``({"connector": ProxyConnector(...)}, {})``
|
||||
- HTTP → ``({}, {"proxy": url})``
|
||||
- None → ``({}, {})``
|
||||
|
||||
Usage::
|
||||
|
||||
sess_kw, req_kw = proxy_kwargs_for_aiohttp(proxy_url)
|
||||
async with aiohttp.ClientSession(**sess_kw) as session:
|
||||
async with session.get(url, **req_kw) as resp:
|
||||
...
|
||||
"""
|
||||
if not proxy_url:
|
||||
return {}, {}
|
||||
if proxy_url.lower().startswith("socks"):
|
||||
try:
|
||||
from aiohttp_socks import ProxyConnector
|
||||
|
||||
connector = ProxyConnector.from_url(proxy_url, rdns=True)
|
||||
return {"connector": connector}, {}
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"aiohttp_socks not installed — SOCKS proxy %s ignored. "
|
||||
"Run: pip install aiohttp-socks",
|
||||
proxy_url,
|
||||
)
|
||||
return {}, {}
|
||||
return {}, {"proxy": proxy_url}
|
||||
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any, Callable, Awaitable, Tuple
|
||||
from enum import Enum
|
||||
|
||||
import sys
|
||||
from pathlib import Path as _Path
|
||||
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
|
||||
|
||||
@@ -422,7 +298,6 @@ SUPPORTED_DOCUMENT_TYPES = {
|
||||
".pdf": "application/pdf",
|
||||
".md": "text/markdown",
|
||||
".txt": "text/plain",
|
||||
".log": "text/plain",
|
||||
".zip": "application/zip",
|
||||
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
@@ -532,10 +407,6 @@ class MessageEvent:
|
||||
# Auto-loaded skill for topic/channel bindings (e.g., Telegram DM Topics)
|
||||
auto_skill: Optional[str] = None
|
||||
|
||||
# Internal flag — set for synthetic events (e.g. background process
|
||||
# completion notifications) that must bypass user authorization checks.
|
||||
internal: bool = False
|
||||
|
||||
# Timestamps
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
@@ -1,828 +0,0 @@
|
||||
"""BlueBubbles iMessage platform adapter.
|
||||
|
||||
Uses the local BlueBubbles macOS server for outbound REST sends and inbound
|
||||
webhooks. Supports text messaging, media attachments (images, voice, video,
|
||||
documents), tapback reactions, typing indicators, and read receipts.
|
||||
|
||||
Architecture based on PR #5869 (benjaminsehl) with inbound attachment
|
||||
downloading from PR #4588 (YuhangLin).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import quote
|
||||
|
||||
import httpx
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
cache_image_from_bytes,
|
||||
cache_audio_from_bytes,
|
||||
cache_document_from_bytes,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DEFAULT_WEBHOOK_HOST = "127.0.0.1"
|
||||
DEFAULT_WEBHOOK_PORT = 8645
|
||||
DEFAULT_WEBHOOK_PATH = "/bluebubbles-webhook"
|
||||
MAX_TEXT_LENGTH = 4000
|
||||
|
||||
# Tapback reaction codes (BlueBubbles associatedMessageType values)
|
||||
_TAPBACK_ADDED = {
|
||||
2000: "love", 2001: "like", 2002: "dislike",
|
||||
2003: "laugh", 2004: "emphasize", 2005: "question",
|
||||
}
|
||||
_TAPBACK_REMOVED = {
|
||||
3000: "love", 3001: "like", 3002: "dislike",
|
||||
3003: "laugh", 3004: "emphasize", 3005: "question",
|
||||
}
|
||||
|
||||
# Webhook event types that carry user messages
|
||||
_MESSAGE_EVENTS = {"new-message", "message", "updated-message"}
|
||||
|
||||
# Log redaction patterns
|
||||
_PHONE_RE = re.compile(r"\+?\d{7,15}")
|
||||
_EMAIL_RE = re.compile(r"[\w.+-]+@[\w-]+\.[\w.]+")
|
||||
|
||||
|
||||
def _redact(text: str) -> str:
|
||||
"""Redact phone numbers and emails from log output."""
|
||||
text = _PHONE_RE.sub("[REDACTED]", text)
|
||||
text = _EMAIL_RE.sub("[REDACTED]", text)
|
||||
return text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def check_bluebubbles_requirements() -> bool:
|
||||
try:
|
||||
import aiohttp # noqa: F401
|
||||
import httpx as _httpx # noqa: F401
|
||||
except ImportError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _normalize_server_url(raw: str) -> str:
|
||||
value = (raw or "").strip()
|
||||
if not value:
|
||||
return ""
|
||||
if not re.match(r"^https?://", value, flags=re.I):
|
||||
value = f"http://{value}"
|
||||
return value.rstrip("/")
|
||||
|
||||
|
||||
def _strip_markdown(text: str) -> str:
|
||||
"""Strip common markdown formatting for iMessage plain-text delivery."""
|
||||
text = re.sub(r"\*\*(.+?)\*\*", r"\1", text, flags=re.DOTALL)
|
||||
text = re.sub(r"\*(.+?)\*", r"\1", text, flags=re.DOTALL)
|
||||
text = re.sub(r"__(.+?)__", r"\1", text, flags=re.DOTALL)
|
||||
text = re.sub(r"_(.+?)_", r"\1", text, flags=re.DOTALL)
|
||||
text = re.sub(r"```[a-zA-Z0-9_+-]*\n?", "", text)
|
||||
text = re.sub(r"`(.+?)`", r"\1", text)
|
||||
text = re.sub(r"^#{1,6}\s+", "", text, flags=re.MULTILINE)
|
||||
text = re.sub(r"\[([^\]]+)\]\(([^\)]+)\)", r"\1", text)
|
||||
text = re.sub(r"\n{3,}", "\n\n", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Adapter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class BlueBubblesAdapter(BasePlatformAdapter):
|
||||
platform = Platform.BLUEBUBBLES
|
||||
MAX_MESSAGE_LENGTH = MAX_TEXT_LENGTH
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.BLUEBUBBLES)
|
||||
extra = config.extra or {}
|
||||
self.server_url = _normalize_server_url(
|
||||
extra.get("server_url") or os.getenv("BLUEBUBBLES_SERVER_URL", "")
|
||||
)
|
||||
self.password = extra.get("password") or os.getenv("BLUEBUBBLES_PASSWORD", "")
|
||||
self.webhook_host = (
|
||||
extra.get("webhook_host")
|
||||
or os.getenv("BLUEBUBBLES_WEBHOOK_HOST", DEFAULT_WEBHOOK_HOST)
|
||||
)
|
||||
self.webhook_port = int(
|
||||
extra.get("webhook_port")
|
||||
or os.getenv("BLUEBUBBLES_WEBHOOK_PORT", str(DEFAULT_WEBHOOK_PORT))
|
||||
)
|
||||
self.webhook_path = (
|
||||
extra.get("webhook_path")
|
||||
or os.getenv("BLUEBUBBLES_WEBHOOK_PATH", DEFAULT_WEBHOOK_PATH)
|
||||
)
|
||||
if not str(self.webhook_path).startswith("/"):
|
||||
self.webhook_path = f"/{self.webhook_path}"
|
||||
self.send_read_receipts = bool(extra.get("send_read_receipts", True))
|
||||
self.client: Optional[httpx.AsyncClient] = None
|
||||
self._runner = None
|
||||
self._private_api_enabled: Optional[bool] = None
|
||||
self._helper_connected: bool = False
|
||||
self._guid_cache: Dict[str, str] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# API helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _api_url(self, path: str) -> str:
|
||||
sep = "&" if "?" in path else "?"
|
||||
return f"{self.server_url}{path}{sep}password={quote(self.password, safe='')}"
|
||||
|
||||
async def _api_get(self, path: str) -> Dict[str, Any]:
|
||||
assert self.client is not None
|
||||
res = await self.client.get(self._api_url(path))
|
||||
res.raise_for_status()
|
||||
return res.json()
|
||||
|
||||
async def _api_post(self, path: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
assert self.client is not None
|
||||
res = await self.client.post(self._api_url(path), json=payload)
|
||||
res.raise_for_status()
|
||||
return res.json()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def connect(self) -> bool:
|
||||
if not self.server_url or not self.password:
|
||||
logger.error(
|
||||
"[bluebubbles] BLUEBUBBLES_SERVER_URL and BLUEBUBBLES_PASSWORD are required"
|
||||
)
|
||||
return False
|
||||
from aiohttp import web
|
||||
|
||||
self.client = httpx.AsyncClient(timeout=30.0)
|
||||
try:
|
||||
await self._api_get("/api/v1/ping")
|
||||
info = await self._api_get("/api/v1/server/info")
|
||||
server_data = (info or {}).get("data", {})
|
||||
self._private_api_enabled = bool(server_data.get("private_api"))
|
||||
self._helper_connected = bool(server_data.get("helper_connected"))
|
||||
logger.info(
|
||||
"[bluebubbles] connected to %s (private_api=%s, helper=%s)",
|
||||
self.server_url,
|
||||
self._private_api_enabled,
|
||||
self._helper_connected,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"[bluebubbles] cannot reach server at %s: %s", self.server_url, exc
|
||||
)
|
||||
if self.client:
|
||||
await self.client.aclose()
|
||||
self.client = None
|
||||
return False
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get("/health", lambda _: web.Response(text="ok"))
|
||||
app.router.add_post(self.webhook_path, self._handle_webhook)
|
||||
self._runner = web.AppRunner(app)
|
||||
await self._runner.setup()
|
||||
site = web.TCPSite(self._runner, self.webhook_host, self.webhook_port)
|
||||
await site.start()
|
||||
self._mark_connected()
|
||||
logger.info(
|
||||
"[bluebubbles] webhook listening on http://%s:%s%s",
|
||||
self.webhook_host,
|
||||
self.webhook_port,
|
||||
self.webhook_path,
|
||||
)
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
if self.client:
|
||||
await self.client.aclose()
|
||||
self.client = None
|
||||
if self._runner:
|
||||
await self._runner.cleanup()
|
||||
self._runner = None
|
||||
self._mark_disconnected()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Chat GUID resolution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _resolve_chat_guid(self, target: str) -> Optional[str]:
|
||||
"""Resolve an email/phone to a BlueBubbles chat GUID.
|
||||
|
||||
If *target* already contains a semicolon (raw GUID format like
|
||||
``iMessage;-;user@example.com``), it is returned as-is. Otherwise
|
||||
the adapter queries the BlueBubbles chat list and matches on
|
||||
``chatIdentifier`` or participant address.
|
||||
"""
|
||||
target = (target or "").strip()
|
||||
if not target:
|
||||
return None
|
||||
# Already a raw GUID
|
||||
if ";" in target:
|
||||
return target
|
||||
if target in self._guid_cache:
|
||||
return self._guid_cache[target]
|
||||
try:
|
||||
payload = await self._api_post(
|
||||
"/api/v1/chat/query",
|
||||
{"limit": 100, "offset": 0, "with": ["participants"]},
|
||||
)
|
||||
for chat in payload.get("data", []) or []:
|
||||
guid = chat.get("guid") or chat.get("chatGuid")
|
||||
identifier = chat.get("chatIdentifier") or chat.get("identifier")
|
||||
if identifier == target:
|
||||
if guid:
|
||||
self._guid_cache[target] = guid
|
||||
return guid
|
||||
for part in chat.get("participants", []) or []:
|
||||
if (part.get("address") or "").strip() == target and guid:
|
||||
self._guid_cache[target] = guid
|
||||
return guid
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
async def _create_chat_for_handle(
|
||||
self, address: str, message: str
|
||||
) -> SendResult:
|
||||
"""Create a new chat by sending the first message to *address*."""
|
||||
payload = {
|
||||
"addresses": [address],
|
||||
"message": message,
|
||||
"tempGuid": f"temp-{datetime.utcnow().timestamp()}",
|
||||
}
|
||||
try:
|
||||
res = await self._api_post("/api/v1/chat/new", payload)
|
||||
data = res.get("data") or {}
|
||||
msg_id = data.get("guid") or data.get("messageGuid") or "ok"
|
||||
return SendResult(success=True, message_id=str(msg_id), raw_response=res)
|
||||
except Exception as exc:
|
||||
return SendResult(success=False, error=str(exc))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Text sending
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def send(
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
text = _strip_markdown(content or "")
|
||||
if not text:
|
||||
return SendResult(success=False, error="BlueBubbles send requires text")
|
||||
chunks = self.truncate_message(text, max_length=self.MAX_MESSAGE_LENGTH)
|
||||
last = SendResult(success=True)
|
||||
for chunk in chunks:
|
||||
guid = await self._resolve_chat_guid(chat_id)
|
||||
if not guid:
|
||||
# If the target looks like an address, try creating a new chat
|
||||
if self._private_api_enabled and (
|
||||
"@" in chat_id or re.match(r"^\+\d+", chat_id)
|
||||
):
|
||||
return await self._create_chat_for_handle(chat_id, chunk)
|
||||
return SendResult(
|
||||
success=False,
|
||||
error=f"BlueBubbles chat not found for target: {chat_id}",
|
||||
)
|
||||
payload: Dict[str, Any] = {
|
||||
"chatGuid": guid,
|
||||
"tempGuid": f"temp-{datetime.utcnow().timestamp()}",
|
||||
"message": chunk,
|
||||
}
|
||||
if reply_to and self._private_api_enabled and self._helper_connected:
|
||||
payload["method"] = "private-api"
|
||||
payload["selectedMessageGuid"] = reply_to
|
||||
payload["partIndex"] = 0
|
||||
try:
|
||||
res = await self._api_post("/api/v1/message/text", payload)
|
||||
data = res.get("data") or {}
|
||||
msg_id = data.get("guid") or data.get("messageGuid") or "ok"
|
||||
last = SendResult(
|
||||
success=True, message_id=str(msg_id), raw_response=res
|
||||
)
|
||||
except Exception as exc:
|
||||
return SendResult(success=False, error=str(exc))
|
||||
return last
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Media sending (outbound)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _send_attachment(
|
||||
self,
|
||||
chat_id: str,
|
||||
file_path: str,
|
||||
filename: Optional[str] = None,
|
||||
caption: Optional[str] = None,
|
||||
is_audio_message: bool = False,
|
||||
) -> SendResult:
|
||||
"""Send a file attachment via BlueBubbles multipart upload."""
|
||||
if not self.client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
if not os.path.isfile(file_path):
|
||||
return SendResult(success=False, error=f"File not found: {file_path}")
|
||||
|
||||
guid = await self._resolve_chat_guid(chat_id)
|
||||
if not guid:
|
||||
return SendResult(success=False, error=f"Chat not found: {chat_id}")
|
||||
|
||||
fname = filename or os.path.basename(file_path)
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
files = {"attachment": (fname, f, "application/octet-stream")}
|
||||
data: Dict[str, str] = {
|
||||
"chatGuid": guid,
|
||||
"name": fname,
|
||||
"tempGuid": uuid.uuid4().hex,
|
||||
}
|
||||
if is_audio_message:
|
||||
data["isAudioMessage"] = "true"
|
||||
res = await self.client.post(
|
||||
self._api_url("/api/v1/message/attachment"),
|
||||
files=files,
|
||||
data=data,
|
||||
timeout=120,
|
||||
)
|
||||
res.raise_for_status()
|
||||
result = res.json()
|
||||
|
||||
if caption:
|
||||
await self.send(chat_id, caption)
|
||||
|
||||
if result.get("status") == 200:
|
||||
rdata = result.get("data") or {}
|
||||
msg_id = rdata.get("guid") if isinstance(rdata, dict) else None
|
||||
return SendResult(
|
||||
success=True, message_id=msg_id, raw_response=result
|
||||
)
|
||||
return SendResult(
|
||||
success=False,
|
||||
error=result.get("message", "Attachment upload failed"),
|
||||
)
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
try:
|
||||
from gateway.platforms.base import cache_image_from_url
|
||||
|
||||
local_path = await cache_image_from_url(image_url)
|
||||
return await self._send_attachment(chat_id, local_path, caption=caption)
|
||||
except Exception:
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to)
|
||||
|
||||
async def send_image_file(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
return await self._send_attachment(chat_id, image_path, caption=caption)
|
||||
|
||||
async def send_voice(
|
||||
self,
|
||||
chat_id: str,
|
||||
audio_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
return await self._send_attachment(
|
||||
chat_id, audio_path, caption=caption, is_audio_message=True
|
||||
)
|
||||
|
||||
async def send_video(
|
||||
self,
|
||||
chat_id: str,
|
||||
video_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
return await self._send_attachment(chat_id, video_path, caption=caption)
|
||||
|
||||
async def send_document(
|
||||
self,
|
||||
chat_id: str,
|
||||
file_path: str,
|
||||
caption: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
return await self._send_attachment(
|
||||
chat_id, file_path, filename=file_name, caption=caption
|
||||
)
|
||||
|
||||
async def send_animation(
|
||||
self,
|
||||
chat_id: str,
|
||||
animation_url: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
return await self.send_image(
|
||||
chat_id, animation_url, caption, reply_to, metadata
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Typing indicators
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def send_typing(self, chat_id: str, metadata=None) -> None:
|
||||
if not self._private_api_enabled or not self._helper_connected or not self.client:
|
||||
return
|
||||
try:
|
||||
guid = await self._resolve_chat_guid(chat_id)
|
||||
if guid:
|
||||
encoded = quote(guid, safe="")
|
||||
await self.client.post(
|
||||
self._api_url(f"/api/v1/chat/{encoded}/typing"), timeout=5
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def stop_typing(self, chat_id: str) -> None:
|
||||
if not self._private_api_enabled or not self._helper_connected or not self.client:
|
||||
return
|
||||
try:
|
||||
guid = await self._resolve_chat_guid(chat_id)
|
||||
if guid:
|
||||
encoded = quote(guid, safe="")
|
||||
await self.client.delete(
|
||||
self._api_url(f"/api/v1/chat/{encoded}/typing"), timeout=5
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Read receipts
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def mark_read(self, chat_id: str) -> bool:
|
||||
if not self._private_api_enabled or not self._helper_connected or not self.client:
|
||||
return False
|
||||
try:
|
||||
guid = await self._resolve_chat_guid(chat_id)
|
||||
if guid:
|
||||
encoded = quote(guid, safe="")
|
||||
await self.client.post(
|
||||
self._api_url(f"/api/v1/chat/{encoded}/read"), timeout=5
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tapback reactions
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def send_reaction(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_guid: str,
|
||||
reaction: str,
|
||||
part_index: int = 0,
|
||||
) -> SendResult:
|
||||
"""Send a tapback reaction (requires Private API helper)."""
|
||||
if not self._private_api_enabled or not self._helper_connected:
|
||||
return SendResult(
|
||||
success=False, error="Private API helper not connected"
|
||||
)
|
||||
guid = await self._resolve_chat_guid(chat_id)
|
||||
if not guid:
|
||||
return SendResult(success=False, error=f"Chat not found: {chat_id}")
|
||||
try:
|
||||
res = await self._api_post(
|
||||
"/api/v1/message/react",
|
||||
{
|
||||
"chatGuid": guid,
|
||||
"selectedMessageGuid": message_guid,
|
||||
"reaction": reaction,
|
||||
"partIndex": part_index,
|
||||
},
|
||||
)
|
||||
return SendResult(success=True, raw_response=res)
|
||||
except Exception as exc:
|
||||
return SendResult(success=False, error=str(exc))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Chat info
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
is_group = ";+;" in (chat_id or "")
|
||||
info: Dict[str, Any] = {
|
||||
"name": chat_id,
|
||||
"type": "group" if is_group else "dm",
|
||||
}
|
||||
try:
|
||||
guid = await self._resolve_chat_guid(chat_id)
|
||||
if guid:
|
||||
encoded = quote(guid, safe="")
|
||||
res = await self._api_get(
|
||||
f"/api/v1/chat/{encoded}?with=participants"
|
||||
)
|
||||
data = (res or {}).get("data", {})
|
||||
display_name = (
|
||||
data.get("displayName")
|
||||
or data.get("chatIdentifier")
|
||||
or chat_id
|
||||
)
|
||||
participants = []
|
||||
for p in data.get("participants", []) or []:
|
||||
addr = (p.get("address") or "").strip()
|
||||
if addr:
|
||||
participants.append(addr)
|
||||
info["name"] = display_name
|
||||
if participants:
|
||||
info["participants"] = participants
|
||||
except Exception:
|
||||
pass
|
||||
return info
|
||||
|
||||
def format_message(self, content: str) -> str:
|
||||
return _strip_markdown(content)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Inbound attachment downloading (from #4588)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _download_attachment(
|
||||
self, att_guid: str, att_meta: Dict[str, Any]
|
||||
) -> Optional[str]:
|
||||
"""Download an attachment from BlueBubbles and cache it locally.
|
||||
|
||||
Returns the local file path on success, None on failure.
|
||||
"""
|
||||
if not self.client:
|
||||
return None
|
||||
try:
|
||||
encoded = quote(att_guid, safe="")
|
||||
resp = await self.client.get(
|
||||
self._api_url(f"/api/v1/attachment/{encoded}/download"),
|
||||
timeout=60,
|
||||
follow_redirects=True,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.content
|
||||
|
||||
mime = (att_meta.get("mimeType") or "").lower()
|
||||
transfer_name = att_meta.get("transferName", "")
|
||||
|
||||
if mime.startswith("image/"):
|
||||
ext_map = {
|
||||
"image/jpeg": ".jpg",
|
||||
"image/png": ".png",
|
||||
"image/gif": ".gif",
|
||||
"image/webp": ".webp",
|
||||
"image/heic": ".jpg",
|
||||
"image/heif": ".jpg",
|
||||
"image/tiff": ".jpg",
|
||||
}
|
||||
ext = ext_map.get(mime, ".jpg")
|
||||
return cache_image_from_bytes(data, ext)
|
||||
|
||||
if mime.startswith("audio/"):
|
||||
ext_map = {
|
||||
"audio/mp3": ".mp3",
|
||||
"audio/mpeg": ".mp3",
|
||||
"audio/ogg": ".ogg",
|
||||
"audio/wav": ".wav",
|
||||
"audio/x-caf": ".mp3",
|
||||
"audio/mp4": ".m4a",
|
||||
"audio/aac": ".m4a",
|
||||
}
|
||||
ext = ext_map.get(mime, ".mp3")
|
||||
return cache_audio_from_bytes(data, ext)
|
||||
|
||||
# Videos, documents, and everything else
|
||||
filename = transfer_name or f"file_{uuid.uuid4().hex[:8]}"
|
||||
return cache_document_from_bytes(data, filename)
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[bluebubbles] failed to download attachment %s: %s",
|
||||
_redact(att_guid),
|
||||
exc,
|
||||
)
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Webhook handling
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _extract_payload_record(
|
||||
self, payload: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
data = payload.get("data")
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
if isinstance(data, list):
|
||||
for item in data:
|
||||
if isinstance(item, dict):
|
||||
return item
|
||||
if isinstance(payload.get("message"), dict):
|
||||
return payload.get("message")
|
||||
return payload if isinstance(payload, dict) else None
|
||||
|
||||
@staticmethod
|
||||
def _value(*candidates: Any) -> Optional[str]:
|
||||
for candidate in candidates:
|
||||
if isinstance(candidate, str) and candidate.strip():
|
||||
return candidate.strip()
|
||||
return None
|
||||
|
||||
async def _handle_webhook(self, request):
|
||||
from aiohttp import web
|
||||
|
||||
token = (
|
||||
request.query.get("password")
|
||||
or request.query.get("guid")
|
||||
or request.headers.get("x-password")
|
||||
or request.headers.get("x-guid")
|
||||
or request.headers.get("x-bluebubbles-guid")
|
||||
)
|
||||
if token != self.password:
|
||||
return web.json_response({"error": "unauthorized"}, status=401)
|
||||
try:
|
||||
raw = await request.read()
|
||||
body = raw.decode("utf-8", errors="replace")
|
||||
try:
|
||||
payload = json.loads(body)
|
||||
except Exception:
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
form = parse_qs(body)
|
||||
payload_str = (
|
||||
form.get("payload")
|
||||
or form.get("data")
|
||||
or form.get("message")
|
||||
or [""]
|
||||
)[0]
|
||||
payload = json.loads(payload_str) if payload_str else {}
|
||||
except Exception as exc:
|
||||
logger.error("[bluebubbles] webhook parse error: %s", exc)
|
||||
return web.json_response({"error": "invalid payload"}, status=400)
|
||||
|
||||
event_type = self._value(payload.get("type"), payload.get("event")) or ""
|
||||
# Only process message events; silently acknowledge everything else
|
||||
if event_type and event_type not in _MESSAGE_EVENTS:
|
||||
return web.Response(text="ok")
|
||||
|
||||
record = self._extract_payload_record(payload) or {}
|
||||
is_from_me = bool(
|
||||
record.get("isFromMe")
|
||||
or record.get("fromMe")
|
||||
or record.get("is_from_me")
|
||||
)
|
||||
if is_from_me:
|
||||
return web.Response(text="ok")
|
||||
|
||||
# Skip tapback reactions delivered as messages
|
||||
assoc_type = record.get("associatedMessageType")
|
||||
if isinstance(assoc_type, int) and assoc_type in {
|
||||
**_TAPBACK_ADDED,
|
||||
**_TAPBACK_REMOVED,
|
||||
}:
|
||||
return web.Response(text="ok")
|
||||
|
||||
text = (
|
||||
self._value(
|
||||
record.get("text"), record.get("message"), record.get("body")
|
||||
)
|
||||
or ""
|
||||
)
|
||||
|
||||
# --- Inbound attachment handling ---
|
||||
attachments = record.get("attachments") or []
|
||||
media_urls: List[str] = []
|
||||
media_types: List[str] = []
|
||||
msg_type = MessageType.TEXT
|
||||
|
||||
for att in attachments:
|
||||
att_guid = att.get("guid", "")
|
||||
if not att_guid:
|
||||
continue
|
||||
cached = await self._download_attachment(att_guid, att)
|
||||
if cached:
|
||||
mime = (att.get("mimeType") or "").lower()
|
||||
media_urls.append(cached)
|
||||
media_types.append(mime)
|
||||
if mime.startswith("image/"):
|
||||
msg_type = MessageType.PHOTO
|
||||
elif mime.startswith("audio/") or (att.get("uti") or "").endswith(
|
||||
"caf"
|
||||
):
|
||||
msg_type = MessageType.VOICE
|
||||
elif mime.startswith("video/"):
|
||||
msg_type = MessageType.VIDEO
|
||||
else:
|
||||
msg_type = MessageType.DOCUMENT
|
||||
|
||||
# With multiple attachments, prefer PHOTO if any images present
|
||||
if len(media_urls) > 1:
|
||||
mime_prefixes = {(m or "").split("/")[0] for m in media_types}
|
||||
if "image" in mime_prefixes:
|
||||
msg_type = MessageType.PHOTO
|
||||
|
||||
if not text and media_urls:
|
||||
text = "(attachment)"
|
||||
# --- End attachment handling ---
|
||||
|
||||
chat_guid = self._value(
|
||||
record.get("chatGuid"),
|
||||
payload.get("chatGuid"),
|
||||
record.get("chat_guid"),
|
||||
payload.get("chat_guid"),
|
||||
payload.get("guid"),
|
||||
)
|
||||
chat_identifier = self._value(
|
||||
record.get("chatIdentifier"),
|
||||
record.get("identifier"),
|
||||
payload.get("chatIdentifier"),
|
||||
payload.get("identifier"),
|
||||
)
|
||||
sender = (
|
||||
self._value(
|
||||
record.get("handle", {}).get("address")
|
||||
if isinstance(record.get("handle"), dict)
|
||||
else None,
|
||||
record.get("sender"),
|
||||
record.get("from"),
|
||||
record.get("address"),
|
||||
)
|
||||
or chat_identifier
|
||||
or chat_guid
|
||||
)
|
||||
if not (chat_guid or chat_identifier) and sender:
|
||||
chat_identifier = sender
|
||||
if not sender or not (chat_guid or chat_identifier) or not text:
|
||||
return web.json_response({"error": "missing message fields"}, status=400)
|
||||
|
||||
session_chat_id = chat_guid or chat_identifier
|
||||
is_group = bool(record.get("isGroup")) or (";+;" in (chat_guid or ""))
|
||||
source = self.build_source(
|
||||
chat_id=session_chat_id,
|
||||
chat_name=chat_identifier or sender,
|
||||
chat_type="group" if is_group else "dm",
|
||||
user_id=sender,
|
||||
user_name=sender,
|
||||
chat_id_alt=chat_identifier,
|
||||
)
|
||||
event = MessageEvent(
|
||||
text=text,
|
||||
message_type=msg_type,
|
||||
source=source,
|
||||
raw_message=payload,
|
||||
message_id=self._value(
|
||||
record.get("guid"),
|
||||
record.get("messageGuid"),
|
||||
record.get("id"),
|
||||
),
|
||||
reply_to_message_id=self._value(
|
||||
record.get("threadOriginatorGuid"),
|
||||
record.get("associatedMessageGuid"),
|
||||
),
|
||||
media_urls=media_urls,
|
||||
media_types=media_types,
|
||||
)
|
||||
task = asyncio.create_task(self.handle_message(event))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
# Fire-and-forget read receipt
|
||||
if self.send_read_receipts and session_chat_id:
|
||||
asyncio.create_task(self.mark_read(session_chat_id))
|
||||
|
||||
return web.Response(text="ok")
|
||||
@@ -529,17 +529,10 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
intents.members = any(not entry.isdigit() for entry in self._allowed_user_ids)
|
||||
intents.voice_states = True
|
||||
|
||||
# Resolve proxy (DISCORD_PROXY > generic env vars > macOS system proxy)
|
||||
from gateway.platforms.base import resolve_proxy_url, proxy_kwargs_for_bot
|
||||
proxy_url = resolve_proxy_url(platform_env_var="DISCORD_PROXY")
|
||||
if proxy_url:
|
||||
logger.info("[%s] Using proxy for Discord: %s", self.name, proxy_url)
|
||||
|
||||
# Create bot — proxy= for HTTP, connector= for SOCKS
|
||||
# Create bot
|
||||
self._client = commands.Bot(
|
||||
command_prefix="!", # Not really used, we handle raw messages
|
||||
intents=intents,
|
||||
**proxy_kwargs_for_bot(proxy_url),
|
||||
)
|
||||
adapter_self = self # capture for closure
|
||||
|
||||
@@ -1314,11 +1307,8 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
|
||||
# Download the image and send as a Discord file attachment
|
||||
# (Discord renders attachments inline, unlike plain URLs)
|
||||
from gateway.platforms.base import resolve_proxy_url, proxy_kwargs_for_aiohttp
|
||||
_proxy = resolve_proxy_url(platform_env_var="DISCORD_PROXY")
|
||||
_sess_kw, _req_kw = proxy_kwargs_for_aiohttp(_proxy)
|
||||
async with aiohttp.ClientSession(**_sess_kw) as session:
|
||||
async with session.get(image_url, timeout=aiohttp.ClientTimeout(total=30), **_req_kw) as resp:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image_url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"Failed to download image: HTTP {resp.status}")
|
||||
|
||||
@@ -1595,7 +1585,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
await self._run_simple_slash(interaction, f"/model {name}".strip())
|
||||
|
||||
@tree.command(name="reasoning", description="Show or change reasoning effort")
|
||||
@discord.app_commands.describe(effort="Reasoning effort: none, minimal, low, medium, high, or xhigh.")
|
||||
@discord.app_commands.describe(effort="Reasoning effort: xhigh, high, medium, low, minimal, or none.")
|
||||
async def slash_reasoning(interaction: discord.Interaction, effort: str = ""):
|
||||
await self._run_simple_slash(interaction, f"/reasoning {effort}".strip())
|
||||
|
||||
@@ -1777,9 +1767,8 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if hasattr(interaction.channel, "guild") and interaction.channel.guild:
|
||||
chat_name = f"{interaction.channel.guild.name} / #{chat_name}"
|
||||
|
||||
# Get channel topic (if available).
|
||||
# For forum threads, inherit the parent forum's topic.
|
||||
chat_topic = self._get_effective_topic(interaction.channel, is_thread=is_thread)
|
||||
# Get channel topic (if available)
|
||||
chat_topic = getattr(interaction.channel, "topic", None)
|
||||
|
||||
source = self.build_source(
|
||||
chat_id=str(interaction.channel_id),
|
||||
@@ -1853,10 +1842,6 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
|
||||
chat_name = f"{guild_name} / {thread_name}" if guild_name else thread_name
|
||||
|
||||
# Inherit forum topic when the thread was created inside a forum channel.
|
||||
_chan = getattr(interaction, "channel", None)
|
||||
chat_topic = self._get_effective_topic(_chan, is_thread=True) if _chan else None
|
||||
|
||||
source = self.build_source(
|
||||
chat_id=thread_id,
|
||||
chat_name=chat_name,
|
||||
@@ -1864,7 +1849,6 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
user_id=str(interaction.user.id),
|
||||
user_name=interaction.user.display_name,
|
||||
thread_id=thread_id,
|
||||
chat_topic=chat_topic,
|
||||
)
|
||||
|
||||
event = MessageEvent(
|
||||
@@ -2150,15 +2134,6 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_effective_topic(self, channel: Any, is_thread: bool = False) -> Optional[str]:
|
||||
"""Return the channel topic, falling back to the parent forum's topic for forum threads."""
|
||||
topic = getattr(channel, "topic", None)
|
||||
if not topic and is_thread:
|
||||
parent = getattr(channel, "parent", None)
|
||||
if parent and self._is_forum_parent(parent):
|
||||
topic = getattr(parent, "topic", None)
|
||||
return topic
|
||||
|
||||
def _format_thread_chat_name(self, thread: Any) -> str:
|
||||
"""Build a readable chat name for thread-like Discord channels, including forum context when available."""
|
||||
thread_name = getattr(thread, "name", None) or str(getattr(thread, "id", "thread"))
|
||||
@@ -2326,10 +2301,8 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if hasattr(message.channel, "guild") and message.channel.guild:
|
||||
chat_name = f"{message.channel.guild.name} / #{chat_name}"
|
||||
|
||||
# Get channel topic (if available - TextChannels have topics, DMs/threads don't).
|
||||
# For threads whose parent is a forum channel, inherit the parent's topic
|
||||
# so forum descriptions (e.g. project instructions) appear in the session context.
|
||||
chat_topic = self._get_effective_topic(message.channel, is_thread=is_thread)
|
||||
# Get channel topic (if available - TextChannels have topics, DMs/threads don't)
|
||||
chat_topic = getattr(message.channel, "topic", None)
|
||||
|
||||
# Build source
|
||||
source = self.build_source(
|
||||
@@ -2392,7 +2365,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
ext or "unknown", content_type,
|
||||
)
|
||||
else:
|
||||
MAX_DOC_BYTES = 32 * 1024 * 1024
|
||||
MAX_DOC_BYTES = 20 * 1024 * 1024
|
||||
if att.size and att.size > MAX_DOC_BYTES:
|
||||
logger.warning(
|
||||
"[Discord] Document too large (%s bytes), skipping: %s",
|
||||
@@ -2401,14 +2374,10 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
else:
|
||||
try:
|
||||
import aiohttp
|
||||
from gateway.platforms.base import resolve_proxy_url, proxy_kwargs_for_aiohttp
|
||||
_proxy = resolve_proxy_url(platform_env_var="DISCORD_PROXY")
|
||||
_sess_kw, _req_kw = proxy_kwargs_for_aiohttp(_proxy)
|
||||
async with aiohttp.ClientSession(**_sess_kw) as session:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
att.url,
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
**_req_kw,
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"HTTP {resp.status}")
|
||||
@@ -2420,9 +2389,9 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
media_urls.append(cached_path)
|
||||
media_types.append(doc_mime)
|
||||
logger.info("[Discord] Cached user document: %s", cached_path)
|
||||
# Inject text content for plain-text documents (capped at 100 KB)
|
||||
# Inject text content for .txt/.md files (capped at 100 KB)
|
||||
MAX_TEXT_INJECT_BYTES = 100 * 1024
|
||||
if ext in (".md", ".txt", ".log") and len(raw_bytes) <= MAX_TEXT_INJECT_BYTES:
|
||||
if ext in (".md", ".txt") and len(raw_bytes) <= MAX_TEXT_INJECT_BYTES:
|
||||
try:
|
||||
text_content = raw_bytes.decode("utf-8")
|
||||
display_name = att.filename or f"document{ext}"
|
||||
|
||||
+69
-379
@@ -14,8 +14,7 @@ import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional, Any, Tuple
|
||||
from typing import Dict, Optional, Any
|
||||
|
||||
try:
|
||||
from slack_bolt.async_app import AsyncApp
|
||||
@@ -46,14 +45,6 @@ from gateway.platforms.base import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ThreadContextCache:
|
||||
"""Cache entry for fetched thread context."""
|
||||
content: str
|
||||
fetched_at: float = field(default_factory=time.monotonic)
|
||||
message_count: int = 0
|
||||
|
||||
|
||||
def check_slack_requirements() -> bool:
|
||||
"""Check if Slack dependencies are available."""
|
||||
return SLACK_AVAILABLE
|
||||
@@ -104,15 +95,6 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
# respond to ALL subsequent messages in that thread automatically.
|
||||
self._mentioned_threads: set = set()
|
||||
self._MENTIONED_THREADS_MAX = 5000
|
||||
# Assistant thread metadata keyed by (channel_id, thread_ts). Slack's
|
||||
# AI Assistant lifecycle events can arrive before/alongside message
|
||||
# events, and they carry the user/thread identity needed for stable
|
||||
# session + memory scoping.
|
||||
self._assistant_threads: Dict[Tuple[str, str], Dict[str, str]] = {}
|
||||
self._ASSISTANT_THREADS_MAX = 5000
|
||||
# Cache for _fetch_thread_context results: cache_key → _ThreadContextCache
|
||||
self._thread_context_cache: Dict[str, _ThreadContextCache] = {}
|
||||
self._THREAD_CACHE_TTL = 60.0
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to Slack via Socket Mode."""
|
||||
@@ -199,14 +181,6 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
async def handle_app_mention(event, say):
|
||||
pass
|
||||
|
||||
@self._app.event("assistant_thread_started")
|
||||
async def handle_assistant_thread_started(event, say):
|
||||
await self._handle_assistant_thread_lifecycle_event(event)
|
||||
|
||||
@self._app.event("assistant_thread_context_changed")
|
||||
async def handle_assistant_thread_context_changed(event, say):
|
||||
await self._handle_assistant_thread_lifecycle_event(event)
|
||||
|
||||
# Register slash command handler
|
||||
@self._app.command("/hermes")
|
||||
async def handle_hermes_command(ack, command):
|
||||
@@ -293,7 +267,6 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
kwargs = {
|
||||
"channel": chat_id,
|
||||
"text": chunk,
|
||||
"mrkdwn": True,
|
||||
}
|
||||
if thread_ts:
|
||||
kwargs["thread_ts"] = thread_ts
|
||||
@@ -336,7 +309,9 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
if not self._app:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
try:
|
||||
# Convert standard markdown → Slack mrkdwn
|
||||
formatted = self.format_message(content)
|
||||
|
||||
await self._get_client(chat_id).chat_update(
|
||||
channel=chat_id,
|
||||
ts=message_id,
|
||||
@@ -468,36 +443,13 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
text = re.sub(r'(`[^`]+`)', lambda m: _ph(m.group(0)), text)
|
||||
|
||||
# 3) Convert markdown links [text](url) → <url|text>
|
||||
def _convert_markdown_link(m):
|
||||
label = m.group(1)
|
||||
url = m.group(2).strip()
|
||||
if url.startswith('<') and url.endswith('>'):
|
||||
url = url[1:-1].strip()
|
||||
return _ph(f'<{url}|{label}>')
|
||||
|
||||
text = re.sub(
|
||||
r'\[([^\]]+)\]\(([^()]*(?:\([^()]*\)[^()]*)*)\)',
|
||||
_convert_markdown_link,
|
||||
r'\[([^\]]+)\]\(([^)]+)\)',
|
||||
lambda m: _ph(f'<{m.group(2)}|{m.group(1)}>'),
|
||||
text,
|
||||
)
|
||||
|
||||
# 4) Protect existing Slack entities/manual links so escaping and later
|
||||
# formatting passes don't break them.
|
||||
text = re.sub(
|
||||
r'(<(?:[@#!]|(?:https?|mailto|tel):)[^>\n]+>)',
|
||||
lambda m: _ph(m.group(1)),
|
||||
text,
|
||||
)
|
||||
|
||||
# 5) Protect blockquote markers before escaping
|
||||
text = re.sub(r'^(>+\s)', lambda m: _ph(m.group(0)), text, flags=re.MULTILINE)
|
||||
|
||||
# 6) Escape Slack control characters in remaining plain text.
|
||||
# Unescape first so already-escaped input doesn't get double-escaped.
|
||||
text = text.replace('&', '&').replace('<', '<').replace('>', '>')
|
||||
text = text.replace('&', '&').replace('<', '<').replace('>', '>')
|
||||
|
||||
# 7) Convert headers (## Title) → *Title* (bold)
|
||||
# 4) Convert headers (## Title) → *Title* (bold)
|
||||
def _convert_header(m):
|
||||
inner = m.group(1).strip()
|
||||
# Strip redundant bold markers inside a header
|
||||
@@ -508,39 +460,34 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
r'^#{1,6}\s+(.+)$', _convert_header, text, flags=re.MULTILINE
|
||||
)
|
||||
|
||||
# 8) Convert bold+italic: ***text*** → *_text_* (Slack bold wrapping italic)
|
||||
text = re.sub(
|
||||
r'\*\*\*(.+?)\*\*\*',
|
||||
lambda m: _ph(f'*_{m.group(1)}_*'),
|
||||
text,
|
||||
)
|
||||
|
||||
# 9) Convert bold: **text** → *text* (Slack bold)
|
||||
# 5) Convert bold: **text** → *text* (Slack bold)
|
||||
text = re.sub(
|
||||
r'\*\*(.+?)\*\*',
|
||||
lambda m: _ph(f'*{m.group(1)}*'),
|
||||
text,
|
||||
)
|
||||
|
||||
# 10) Convert italic: _text_ stays as _text_ (already Slack italic)
|
||||
# Single *text* → _text_ (Slack italic)
|
||||
# 6) Convert italic: _text_ stays as _text_ (already Slack italic)
|
||||
# Single *text* → _text_ (Slack italic)
|
||||
text = re.sub(
|
||||
r'(?<!\*)\*([^*\n]+)\*(?!\*)',
|
||||
lambda m: _ph(f'_{m.group(1)}_'),
|
||||
text,
|
||||
)
|
||||
|
||||
# 11) Convert strikethrough: ~~text~~ → ~text~
|
||||
# 7) Convert strikethrough: ~~text~~ → ~text~
|
||||
text = re.sub(
|
||||
r'~~(.+?)~~',
|
||||
lambda m: _ph(f'~{m.group(1)}~'),
|
||||
text,
|
||||
)
|
||||
|
||||
# 12) Blockquotes: > prefix is already protected by step 5 above.
|
||||
# 8) Convert blockquotes: > text → > text (same syntax, just ensure
|
||||
# no extra escaping happens to the > character)
|
||||
# Slack uses the same > prefix, so this is a no-op for content.
|
||||
|
||||
# 13) Restore placeholders in reverse order
|
||||
for key in reversed(placeholders):
|
||||
# 9) Restore placeholders in reverse order
|
||||
for key in reversed(list(placeholders.keys())):
|
||||
text = text.replace(key, placeholders[key])
|
||||
|
||||
return text
|
||||
@@ -808,135 +755,6 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
|
||||
# ----- Internal handlers -----
|
||||
|
||||
def _assistant_thread_key(self, channel_id: str, thread_ts: str) -> Optional[Tuple[str, str]]:
|
||||
"""Return a stable cache key for Slack assistant thread metadata."""
|
||||
if not channel_id or not thread_ts:
|
||||
return None
|
||||
return (str(channel_id), str(thread_ts))
|
||||
|
||||
def _extract_assistant_thread_metadata(self, event: dict) -> Dict[str, str]:
|
||||
"""Extract Slack Assistant thread identity data from an event payload."""
|
||||
assistant_thread = event.get("assistant_thread") or {}
|
||||
context = assistant_thread.get("context") or event.get("context") or {}
|
||||
|
||||
channel_id = (
|
||||
assistant_thread.get("channel_id")
|
||||
or event.get("channel")
|
||||
or context.get("channel_id")
|
||||
or ""
|
||||
)
|
||||
thread_ts = (
|
||||
assistant_thread.get("thread_ts")
|
||||
or event.get("thread_ts")
|
||||
or event.get("message_ts")
|
||||
or ""
|
||||
)
|
||||
user_id = (
|
||||
assistant_thread.get("user_id")
|
||||
or event.get("user")
|
||||
or context.get("user_id")
|
||||
or ""
|
||||
)
|
||||
team_id = (
|
||||
event.get("team")
|
||||
or event.get("team_id")
|
||||
or assistant_thread.get("team_id")
|
||||
or ""
|
||||
)
|
||||
context_channel_id = context.get("channel_id") or ""
|
||||
|
||||
return {
|
||||
"channel_id": str(channel_id) if channel_id else "",
|
||||
"thread_ts": str(thread_ts) if thread_ts else "",
|
||||
"user_id": str(user_id) if user_id else "",
|
||||
"team_id": str(team_id) if team_id else "",
|
||||
"context_channel_id": str(context_channel_id) if context_channel_id else "",
|
||||
}
|
||||
|
||||
def _cache_assistant_thread_metadata(self, metadata: Dict[str, str]) -> None:
|
||||
"""Remember assistant thread identity data for later message events."""
|
||||
channel_id = metadata.get("channel_id", "")
|
||||
thread_ts = metadata.get("thread_ts", "")
|
||||
key = self._assistant_thread_key(channel_id, thread_ts)
|
||||
if not key:
|
||||
return
|
||||
|
||||
existing = self._assistant_threads.get(key, {})
|
||||
merged = dict(existing)
|
||||
merged.update({k: v for k, v in metadata.items() if v})
|
||||
self._assistant_threads[key] = merged
|
||||
|
||||
# Evict oldest entries when the cache exceeds the limit
|
||||
if len(self._assistant_threads) > self._ASSISTANT_THREADS_MAX:
|
||||
excess = len(self._assistant_threads) - self._ASSISTANT_THREADS_MAX // 2
|
||||
for old_key in list(self._assistant_threads)[:excess]:
|
||||
del self._assistant_threads[old_key]
|
||||
|
||||
team_id = merged.get("team_id", "")
|
||||
if team_id and channel_id:
|
||||
self._channel_team[channel_id] = team_id
|
||||
|
||||
def _lookup_assistant_thread_metadata(
|
||||
self,
|
||||
event: dict,
|
||||
channel_id: str = "",
|
||||
thread_ts: str = "",
|
||||
) -> Dict[str, str]:
|
||||
"""Load cached assistant-thread metadata that matches the current event."""
|
||||
metadata = self._extract_assistant_thread_metadata(event)
|
||||
if channel_id and not metadata.get("channel_id"):
|
||||
metadata["channel_id"] = channel_id
|
||||
if thread_ts and not metadata.get("thread_ts"):
|
||||
metadata["thread_ts"] = thread_ts
|
||||
|
||||
key = self._assistant_thread_key(
|
||||
metadata.get("channel_id", ""),
|
||||
metadata.get("thread_ts", ""),
|
||||
)
|
||||
cached = self._assistant_threads.get(key, {}) if key else {}
|
||||
if cached:
|
||||
merged = dict(cached)
|
||||
merged.update({k: v for k, v in metadata.items() if v})
|
||||
return merged
|
||||
return metadata
|
||||
|
||||
def _seed_assistant_thread_session(self, metadata: Dict[str, str]) -> None:
|
||||
"""Prime the session store so assistant threads get stable user scoping."""
|
||||
session_store = getattr(self, "_session_store", None)
|
||||
if not session_store:
|
||||
return
|
||||
|
||||
channel_id = metadata.get("channel_id", "")
|
||||
thread_ts = metadata.get("thread_ts", "")
|
||||
user_id = metadata.get("user_id", "")
|
||||
if not channel_id or not thread_ts or not user_id:
|
||||
return
|
||||
|
||||
source = self.build_source(
|
||||
chat_id=channel_id,
|
||||
chat_name=channel_id,
|
||||
chat_type="dm",
|
||||
user_id=user_id,
|
||||
thread_id=thread_ts,
|
||||
chat_topic=metadata.get("context_channel_id") or None,
|
||||
)
|
||||
|
||||
try:
|
||||
session_store.get_or_create_session(source)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"[Slack] Failed to seed assistant thread session for %s/%s",
|
||||
channel_id,
|
||||
thread_ts,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def _handle_assistant_thread_lifecycle_event(self, event: dict) -> None:
|
||||
"""Handle Slack Assistant lifecycle events that carry user/thread identity."""
|
||||
metadata = self._extract_assistant_thread_metadata(event)
|
||||
self._cache_assistant_thread_metadata(metadata)
|
||||
self._seed_assistant_thread_session(metadata)
|
||||
|
||||
async def _handle_slack_message(self, event: dict) -> None:
|
||||
"""Handle an incoming Slack message event."""
|
||||
# Dedup: Slack Socket Mode can redeliver events after reconnects (#4777)
|
||||
@@ -953,26 +771,9 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
if v > cutoff
|
||||
}
|
||||
|
||||
# Bot message filtering (SLACK_ALLOW_BOTS / config allow_bots):
|
||||
# "none" — ignore all bot messages (default, backward-compatible)
|
||||
# "mentions" — accept bot messages only when they @mention us
|
||||
# "all" — accept all bot messages (except our own)
|
||||
# Ignore bot messages (including our own)
|
||||
if event.get("bot_id") or event.get("subtype") == "bot_message":
|
||||
allow_bots = self.config.extra.get("allow_bots", "")
|
||||
if not allow_bots:
|
||||
allow_bots = os.getenv("SLACK_ALLOW_BOTS", "none")
|
||||
allow_bots = str(allow_bots).lower().strip()
|
||||
if allow_bots == "none":
|
||||
return
|
||||
elif allow_bots == "mentions":
|
||||
text_check = event.get("text", "")
|
||||
if self._bot_user_id and f"<@{self._bot_user_id}>" not in text_check:
|
||||
return
|
||||
# "all" falls through to process the message
|
||||
# Always ignore our own messages to prevent echo loops
|
||||
msg_user = event.get("user", "")
|
||||
if msg_user and self._bot_user_id and msg_user == self._bot_user_id:
|
||||
return
|
||||
return
|
||||
|
||||
# Ignore message edits and deletions
|
||||
subtype = event.get("subtype")
|
||||
@@ -980,21 +781,10 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
return
|
||||
|
||||
text = event.get("text", "")
|
||||
user_id = event.get("user", "")
|
||||
channel_id = event.get("channel", "")
|
||||
ts = event.get("ts", "")
|
||||
assistant_meta = self._lookup_assistant_thread_metadata(
|
||||
event,
|
||||
channel_id=channel_id,
|
||||
thread_ts=event.get("thread_ts", ""),
|
||||
)
|
||||
user_id = event.get("user") or assistant_meta.get("user_id", "")
|
||||
if not channel_id:
|
||||
channel_id = assistant_meta.get("channel_id", "")
|
||||
team_id = (
|
||||
event.get("team")
|
||||
or event.get("team_id")
|
||||
or assistant_meta.get("team_id", "")
|
||||
)
|
||||
team_id = event.get("team", "")
|
||||
|
||||
# Track which workspace owns this channel
|
||||
if team_id and channel_id:
|
||||
@@ -1002,9 +792,7 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
|
||||
# Determine if this is a DM or channel message
|
||||
channel_type = event.get("channel_type", "")
|
||||
if not channel_type and channel_id.startswith("D"):
|
||||
channel_type = "im"
|
||||
is_dm = channel_type in ("im", "mpim") # Both 1:1 and group DMs
|
||||
is_dm = channel_type == "im"
|
||||
|
||||
# Build thread_ts for session keying.
|
||||
# In channels: fall back to ts so each top-level @mention starts a
|
||||
@@ -1012,13 +800,11 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
# In DMs: only use the real thread_ts — top-level DMs should share
|
||||
# one continuous session, threaded DMs get their own session.
|
||||
if is_dm:
|
||||
thread_ts = event.get("thread_ts") or assistant_meta.get("thread_ts") # None for top-level DMs
|
||||
thread_ts = event.get("thread_ts") # None for top-level DMs
|
||||
else:
|
||||
thread_ts = event.get("thread_ts") or ts # ts fallback for channels
|
||||
|
||||
# In channels, respond if:
|
||||
# 0. Channel is in free_response_channels, OR require_mention is
|
||||
# disabled — always process regardless of mention.
|
||||
# 1. The bot is @mentioned in this message, OR
|
||||
# 2. The message is a reply in a thread the bot started/participated in, OR
|
||||
# 3. The message is in a thread where the bot was previously @mentioned, OR
|
||||
@@ -1028,29 +814,24 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
event_thread_ts = event.get("thread_ts")
|
||||
is_thread_reply = bool(event_thread_ts and event_thread_ts != ts)
|
||||
|
||||
if not is_dm and bot_uid:
|
||||
if channel_id in self._slack_free_response_channels():
|
||||
pass # Free-response channel — always process
|
||||
elif not self._slack_require_mention():
|
||||
pass # Mention requirement disabled globally for Slack
|
||||
elif not is_mentioned:
|
||||
reply_to_bot_thread = (
|
||||
is_thread_reply and event_thread_ts in self._bot_message_ts
|
||||
if not is_dm and bot_uid and not is_mentioned:
|
||||
reply_to_bot_thread = (
|
||||
is_thread_reply and event_thread_ts in self._bot_message_ts
|
||||
)
|
||||
in_mentioned_thread = (
|
||||
event_thread_ts is not None
|
||||
and event_thread_ts in self._mentioned_threads
|
||||
)
|
||||
has_session = (
|
||||
is_thread_reply
|
||||
and self._has_active_session_for_thread(
|
||||
channel_id=channel_id,
|
||||
thread_ts=event_thread_ts,
|
||||
user_id=user_id,
|
||||
)
|
||||
in_mentioned_thread = (
|
||||
event_thread_ts is not None
|
||||
and event_thread_ts in self._mentioned_threads
|
||||
)
|
||||
has_session = (
|
||||
is_thread_reply
|
||||
and self._has_active_session_for_thread(
|
||||
channel_id=channel_id,
|
||||
thread_ts=event_thread_ts,
|
||||
user_id=user_id,
|
||||
)
|
||||
)
|
||||
if not reply_to_bot_thread and not in_mentioned_thread and not has_session:
|
||||
return
|
||||
)
|
||||
if not reply_to_bot_thread and not in_mentioned_thread and not has_session:
|
||||
return
|
||||
|
||||
if is_mentioned:
|
||||
# Strip the bot mention from the text
|
||||
@@ -1191,19 +972,14 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
reply_to_message_id=thread_ts if thread_ts != ts else None,
|
||||
)
|
||||
|
||||
# Only react when bot is directly addressed (DM or @mention).
|
||||
# In listen-all channels (require_mention=false), reacting to every
|
||||
# casual message would be noisy.
|
||||
_should_react = is_dm or is_mentioned
|
||||
|
||||
if _should_react:
|
||||
await self._add_reaction(channel_id, ts, "eyes")
|
||||
# Add 👀 reaction to acknowledge receipt
|
||||
await self._add_reaction(channel_id, ts, "eyes")
|
||||
|
||||
await self.handle_message(msg_event)
|
||||
|
||||
if _should_react:
|
||||
await self._remove_reaction(channel_id, ts, "eyes")
|
||||
await self._add_reaction(channel_id, ts, "white_check_mark")
|
||||
# Replace 👀 with ✅ when done
|
||||
await self._remove_reaction(channel_id, ts, "eyes")
|
||||
await self._add_reaction(channel_id, ts, "white_check_mark")
|
||||
|
||||
# ----- Approval button support (Block Kit) -----
|
||||
|
||||
@@ -1297,20 +1073,6 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
msg_ts = message.get("ts", "")
|
||||
channel_id = body.get("channel", {}).get("id", "")
|
||||
user_name = body.get("user", {}).get("name", "unknown")
|
||||
user_id = body.get("user", {}).get("id", "")
|
||||
|
||||
# Only authorized users may click approval buttons. Button clicks
|
||||
# bypass the normal message auth flow in gateway/run.py, so we must
|
||||
# check here as well.
|
||||
allowed_csv = os.getenv("SLACK_ALLOWED_USERS", "").strip()
|
||||
if allowed_csv:
|
||||
allowed_ids = {uid.strip() for uid in allowed_csv.split(",") if uid.strip()}
|
||||
if "*" not in allowed_ids and user_id not in allowed_ids:
|
||||
logger.warning(
|
||||
"[Slack] Unauthorized approval click by %s (%s) — ignoring",
|
||||
user_name, user_id,
|
||||
)
|
||||
return
|
||||
|
||||
# Map action_id to approval choice
|
||||
choice_map = {
|
||||
@@ -1321,9 +1083,10 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
}
|
||||
choice = choice_map.get(action_id, "deny")
|
||||
|
||||
# Prevent double-clicks — atomic pop; first caller gets False, others get True (default)
|
||||
if self._approval_resolved.pop(msg_ts, True):
|
||||
# Prevent double-clicks
|
||||
if self._approval_resolved.get(msg_ts, False):
|
||||
return
|
||||
self._approval_resolved[msg_ts] = True
|
||||
|
||||
# Update the message to show the decision and remove buttons
|
||||
label_map = {
|
||||
@@ -1378,7 +1141,8 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
except Exception as exc:
|
||||
logger.error("Failed to resolve gateway approval from Slack button: %s", exc)
|
||||
|
||||
# (approval state already consumed by atomic pop above)
|
||||
# Clean up stale approval state
|
||||
self._approval_resolved.pop(msg_ts, None)
|
||||
|
||||
# ----- Thread context fetching -----
|
||||
|
||||
@@ -1389,104 +1153,57 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
"""Fetch recent thread messages to provide context when the bot is
|
||||
mentioned mid-thread for the first time.
|
||||
|
||||
This method is only called when there is NO active session for the
|
||||
thread (guarded at the call site by _has_active_session_for_thread).
|
||||
That guard ensures thread messages are prepended only on the very
|
||||
first turn — after that the session history already holds them, so
|
||||
there is no duplication across subsequent turns.
|
||||
|
||||
Results are cached for _THREAD_CACHE_TTL seconds per thread to avoid
|
||||
hammering conversations.replies (Tier 3, ~50 req/min).
|
||||
|
||||
Returns a formatted string with prior thread history, or empty string
|
||||
on failure or if the thread has no prior messages.
|
||||
Returns a formatted string with thread history, or empty string on
|
||||
failure or if the thread is empty (just the parent message).
|
||||
"""
|
||||
cache_key = f"{channel_id}:{thread_ts}"
|
||||
now = time.monotonic()
|
||||
cached = self._thread_context_cache.get(cache_key)
|
||||
if cached and (now - cached.fetched_at) < self._THREAD_CACHE_TTL:
|
||||
return cached.content
|
||||
|
||||
try:
|
||||
client = self._get_client(channel_id)
|
||||
|
||||
# Retry with exponential backoff for Tier-3 rate limits (429).
|
||||
result = None
|
||||
for attempt in range(3):
|
||||
try:
|
||||
result = await client.conversations_replies(
|
||||
channel=channel_id,
|
||||
ts=thread_ts,
|
||||
limit=limit + 1, # +1 because it includes the current message
|
||||
inclusive=True,
|
||||
)
|
||||
break
|
||||
except Exception as exc:
|
||||
# Check for rate-limit error from slack_sdk
|
||||
err_str = str(exc).lower()
|
||||
is_rate_limit = (
|
||||
"ratelimited" in err_str
|
||||
or "429" in err_str
|
||||
or "rate_limited" in err_str
|
||||
)
|
||||
if is_rate_limit and attempt < 2:
|
||||
retry_after = 1.0 * (2 ** attempt) # 1s, 2s
|
||||
logger.warning(
|
||||
"[Slack] conversations.replies rate limited; retrying in %.1fs (attempt %d/3)",
|
||||
retry_after, attempt + 1,
|
||||
)
|
||||
await asyncio.sleep(retry_after)
|
||||
continue
|
||||
raise
|
||||
|
||||
if result is None:
|
||||
return ""
|
||||
|
||||
result = await client.conversations_replies(
|
||||
channel=channel_id,
|
||||
ts=thread_ts,
|
||||
limit=limit + 1, # +1 because it includes the current message
|
||||
inclusive=True,
|
||||
)
|
||||
messages = result.get("messages", [])
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
bot_uid = self._team_bot_user_ids.get(team_id, self._bot_user_id)
|
||||
context_parts = []
|
||||
for msg in messages:
|
||||
msg_ts = msg.get("ts", "")
|
||||
# Exclude the current triggering message — it will be delivered
|
||||
# as the user message itself, so including it here would duplicate it.
|
||||
# Skip the current message (the one that triggered this fetch)
|
||||
if msg_ts == current_ts:
|
||||
continue
|
||||
# Exclude our own bot messages to avoid circular context.
|
||||
# Skip bot messages from ourselves
|
||||
if msg.get("bot_id") or msg.get("subtype") == "bot_message":
|
||||
continue
|
||||
|
||||
msg_user = msg.get("user", "unknown")
|
||||
msg_text = msg.get("text", "").strip()
|
||||
if not msg_text:
|
||||
continue
|
||||
|
||||
# Strip bot mentions from context messages
|
||||
bot_uid = self._team_bot_user_ids.get(team_id, self._bot_user_id)
|
||||
if bot_uid:
|
||||
msg_text = msg_text.replace(f"<@{bot_uid}>", "").strip()
|
||||
|
||||
msg_user = msg.get("user", "unknown")
|
||||
# Mark the thread parent
|
||||
is_parent = msg_ts == thread_ts
|
||||
prefix = "[thread parent] " if is_parent else ""
|
||||
|
||||
# Resolve user name (cached)
|
||||
name = await self._resolve_user_name(msg_user, chat_id=channel_id)
|
||||
context_parts.append(f"{prefix}{name}: {msg_text}")
|
||||
|
||||
content = ""
|
||||
if context_parts:
|
||||
content = (
|
||||
"[Thread context — prior messages in this thread (not yet in conversation history):]\n"
|
||||
+ "\n".join(context_parts)
|
||||
+ "\n[End of thread context]\n\n"
|
||||
)
|
||||
if not context_parts:
|
||||
return ""
|
||||
|
||||
self._thread_context_cache[cache_key] = _ThreadContextCache(
|
||||
content=content,
|
||||
fetched_at=now,
|
||||
message_count=len(context_parts),
|
||||
return (
|
||||
"[Thread context — previous messages in this thread:]\n"
|
||||
+ "\n".join(context_parts)
|
||||
+ "\n[End of thread context]\n\n"
|
||||
)
|
||||
return content
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("[Slack] Failed to fetch thread context: %s", e)
|
||||
return ""
|
||||
@@ -1642,30 +1359,3 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
continue
|
||||
raise
|
||||
raise last_exc
|
||||
|
||||
# ── Channel mention gating ─────────────────────────────────────────────
|
||||
|
||||
def _slack_require_mention(self) -> bool:
|
||||
"""Return whether channel messages require an explicit bot mention.
|
||||
|
||||
Uses explicit-false parsing (like Discord/Matrix) rather than
|
||||
truthy parsing, since the safe default is True (gating on).
|
||||
Unrecognised or empty values keep gating enabled.
|
||||
"""
|
||||
configured = self.config.extra.get("require_mention")
|
||||
if configured is not None:
|
||||
if isinstance(configured, str):
|
||||
return configured.lower() not in ("false", "0", "no", "off")
|
||||
return bool(configured)
|
||||
return os.getenv("SLACK_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no", "off")
|
||||
|
||||
def _slack_free_response_channels(self) -> set:
|
||||
"""Return channel IDs where no @mention is required."""
|
||||
raw = self.config.extra.get("free_response_channels")
|
||||
if raw is None:
|
||||
raw = os.getenv("SLACK_FREE_RESPONSE_CHANNELS", "")
|
||||
if isinstance(raw, list):
|
||||
return {str(part).strip() for part in raw if str(part).strip()}
|
||||
if isinstance(raw, str) and raw.strip():
|
||||
return {part.strip() for part in raw.split(",") if part.strip()}
|
||||
return set()
|
||||
|
||||
@@ -1398,15 +1398,6 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
await query.answer(text="Invalid approval data.")
|
||||
return
|
||||
|
||||
# Only authorized users may click approval buttons.
|
||||
caller_id = str(getattr(query.from_user, "id", ""))
|
||||
allowed_csv = os.getenv("TELEGRAM_ALLOWED_USERS", "").strip()
|
||||
if allowed_csv:
|
||||
allowed_ids = {uid.strip() for uid in allowed_csv.split(",") if uid.strip()}
|
||||
if "*" not in allowed_ids and caller_id not in allowed_ids:
|
||||
await query.answer(text="⛔ You are not authorized to approve commands.")
|
||||
return
|
||||
|
||||
session_key = self._approval_state.pop(approval_id, None)
|
||||
if not session_key:
|
||||
await query.answer(text="This approval has already been resolved.")
|
||||
|
||||
@@ -45,9 +45,11 @@ _SEED_FALLBACK_IPS: list[str] = ["149.154.167.220"]
|
||||
|
||||
|
||||
def _resolve_proxy_url() -> str | None:
|
||||
# Delegate to shared implementation (env vars + macOS system proxy detection)
|
||||
from gateway.platforms.base import resolve_proxy_url
|
||||
return resolve_proxy_url()
|
||||
for key in ("HTTPS_PROXY", "HTTP_PROXY", "ALL_PROXY", "https_proxy", "http_proxy", "all_proxy"):
|
||||
value = (os.environ.get(key) or "").strip()
|
||||
if value:
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
class TelegramFallbackTransport(httpx.AsyncBaseTransport):
|
||||
|
||||
+29
-77
@@ -184,8 +184,6 @@ if _config_path.exists():
|
||||
# Env var from .env takes precedence (already in os.environ).
|
||||
if "gateway_timeout" in _agent_cfg and "HERMES_AGENT_TIMEOUT" not in os.environ:
|
||||
os.environ["HERMES_AGENT_TIMEOUT"] = str(_agent_cfg["gateway_timeout"])
|
||||
if "gateway_timeout_warning" in _agent_cfg and "HERMES_AGENT_TIMEOUT_WARNING" not in os.environ:
|
||||
os.environ["HERMES_AGENT_TIMEOUT_WARNING"] = str(_agent_cfg["gateway_timeout_warning"])
|
||||
# Timezone: bridge config.yaml → HERMES_TIMEZONE env var.
|
||||
# HERMES_TIMEZONE from .env takes precedence (already in os.environ).
|
||||
_tz_cfg = _cfg.get("timezone", "")
|
||||
@@ -514,6 +512,12 @@ class GatewayRunner:
|
||||
self._agent_cache: Dict[str, tuple] = {}
|
||||
self._agent_cache_lock = _threading.Lock()
|
||||
|
||||
# Track active fallback model/provider when primary is rate-limited.
|
||||
# Set after an agent run where fallback was activated; cleared when
|
||||
# the primary model succeeds again or the user switches via /model.
|
||||
self._effective_model: Optional[str] = None
|
||||
self._effective_provider: Optional[str] = None
|
||||
|
||||
# Per-session model overrides from /model command.
|
||||
# Key: session_key, Value: dict with model/provider/api_key/base_url/api_mode
|
||||
self._session_model_overrides: Dict[str, Dict[str, str]] = {}
|
||||
@@ -919,8 +923,8 @@ class GatewayRunner:
|
||||
def _load_reasoning_config() -> dict | None:
|
||||
"""Load reasoning effort from config.yaml.
|
||||
|
||||
Reads agent.reasoning_effort from config.yaml. Valid: "none",
|
||||
"minimal", "low", "medium", "high", "xhigh". Returns None to use
|
||||
Reads agent.reasoning_effort from config.yaml. Valid: "xhigh",
|
||||
"high", "medium", "low", "minimal", "none". Returns None to use
|
||||
default (medium).
|
||||
"""
|
||||
from hermes_constants import parse_reasoning_effort
|
||||
@@ -1069,7 +1073,6 @@ class GatewayRunner:
|
||||
"MATRIX_ALLOWED_USERS", "DINGTALK_ALLOWED_USERS",
|
||||
"FEISHU_ALLOWED_USERS",
|
||||
"WECOM_ALLOWED_USERS",
|
||||
"BLUEBUBBLES_ALLOWED_USERS",
|
||||
"GATEWAY_ALLOWED_USERS")
|
||||
)
|
||||
_allow_all = os.getenv("GATEWAY_ALLOW_ALL_USERS", "").lower() in ("true", "1", "yes") or any(
|
||||
@@ -1080,8 +1083,7 @@ class GatewayRunner:
|
||||
"SMS_ALLOW_ALL_USERS", "MATTERMOST_ALLOW_ALL_USERS",
|
||||
"MATRIX_ALLOW_ALL_USERS", "DINGTALK_ALLOW_ALL_USERS",
|
||||
"FEISHU_ALLOW_ALL_USERS",
|
||||
"WECOM_ALLOW_ALL_USERS",
|
||||
"BLUEBUBBLES_ALLOW_ALL_USERS")
|
||||
"WECOM_ALLOW_ALL_USERS")
|
||||
)
|
||||
if not _any_allowlist and not _allow_all:
|
||||
logger.warning(
|
||||
@@ -1652,13 +1654,6 @@ class GatewayRunner:
|
||||
adapter.gateway_runner = self # For cross-platform delivery
|
||||
return adapter
|
||||
|
||||
elif platform == Platform.BLUEBUBBLES:
|
||||
from gateway.platforms.bluebubbles import BlueBubblesAdapter, check_bluebubbles_requirements
|
||||
if not check_bluebubbles_requirements():
|
||||
logger.warning("BlueBubbles: aiohttp/httpx missing or BLUEBUBBLES_SERVER_URL/BLUEBUBBLES_PASSWORD not configured")
|
||||
return None
|
||||
return BlueBubblesAdapter(config)
|
||||
|
||||
return None
|
||||
|
||||
def _is_user_authorized(self, source: SessionSource) -> bool:
|
||||
@@ -1697,7 +1692,6 @@ class GatewayRunner:
|
||||
Platform.DINGTALK: "DINGTALK_ALLOWED_USERS",
|
||||
Platform.FEISHU: "FEISHU_ALLOWED_USERS",
|
||||
Platform.WECOM: "WECOM_ALLOWED_USERS",
|
||||
Platform.BLUEBUBBLES: "BLUEBUBBLES_ALLOWED_USERS",
|
||||
}
|
||||
platform_allow_all_map = {
|
||||
Platform.TELEGRAM: "TELEGRAM_ALLOW_ALL_USERS",
|
||||
@@ -1712,7 +1706,6 @@ class GatewayRunner:
|
||||
Platform.DINGTALK: "DINGTALK_ALLOW_ALL_USERS",
|
||||
Platform.FEISHU: "FEISHU_ALLOW_ALL_USERS",
|
||||
Platform.WECOM: "WECOM_ALLOW_ALL_USERS",
|
||||
Platform.BLUEBUBBLES: "BLUEBUBBLES_ALLOW_ALL_USERS",
|
||||
}
|
||||
|
||||
# Per-platform allow-all flag (e.g., DISCORD_ALLOW_ALL_USERS=true)
|
||||
@@ -1786,11 +1779,8 @@ class GatewayRunner:
|
||||
"""
|
||||
source = event.source
|
||||
|
||||
# Internal events (e.g. background-process completion notifications)
|
||||
# are system-generated and must skip user authorization.
|
||||
if getattr(event, "internal", False):
|
||||
pass
|
||||
elif not self._is_user_authorized(source):
|
||||
# Check if user is authorized
|
||||
if not self._is_user_authorized(source):
|
||||
logger.warning("Unauthorized user: %s (%s) on %s", source.user_id, source.user_name, source.platform.value)
|
||||
# In DMs: offer pairing code. In groups: silently ignore.
|
||||
if source.chat_type == "dm" and self._get_unauthorized_dm_behavior(source.platform) == "pair":
|
||||
@@ -4834,7 +4824,7 @@ class GatewayRunner:
|
||||
|
||||
Usage:
|
||||
/reasoning Show current effort level and display state
|
||||
/reasoning <level> Set reasoning effort (none, minimal, low, medium, high, xhigh)
|
||||
/reasoning <level> Set reasoning effort (none, low, medium, high, xhigh)
|
||||
/reasoning show|on Show model reasoning in responses
|
||||
/reasoning hide|off Hide model reasoning from responses
|
||||
"""
|
||||
@@ -4879,7 +4869,7 @@ class GatewayRunner:
|
||||
"🧠 **Reasoning Settings**\n\n"
|
||||
f"**Effort:** `{level}`\n"
|
||||
f"**Display:** {display_state}\n\n"
|
||||
"_Usage:_ `/reasoning <none|minimal|low|medium|high|xhigh|show|hide>`"
|
||||
"_Usage:_ `/reasoning <none|low|medium|high|xhigh|show|hide>`"
|
||||
)
|
||||
|
||||
# Display toggle
|
||||
@@ -4897,12 +4887,12 @@ class GatewayRunner:
|
||||
effort = args.strip()
|
||||
if effort == "none":
|
||||
parsed = {"enabled": False}
|
||||
elif effort in ("minimal", "low", "medium", "high", "xhigh"):
|
||||
elif effort in ("xhigh", "high", "medium", "low", "minimal"):
|
||||
parsed = {"enabled": True, "effort": effort}
|
||||
else:
|
||||
return (
|
||||
f"⚠️ Unknown argument: `{effort}`\n\n"
|
||||
"**Valid levels:** none, minimal, low, medium, high, xhigh\n"
|
||||
"**Valid levels:** none, low, minimal, medium, high, xhigh\n"
|
||||
"**Display:** show, hide"
|
||||
)
|
||||
|
||||
@@ -5274,28 +5264,19 @@ class GatewayRunner:
|
||||
|
||||
agent = self._running_agents.get(session_key)
|
||||
if agent and hasattr(agent, "session_total_tokens") and agent.session_api_calls > 0:
|
||||
lines = []
|
||||
|
||||
# Rate limits first (when available from provider headers)
|
||||
rl_state = agent.get_rate_limit_state()
|
||||
if rl_state and rl_state.has_data:
|
||||
from agent.rate_limit_tracker import format_rate_limit_compact
|
||||
lines.append(f"⏱️ **Rate Limits:** {format_rate_limit_compact(rl_state)}")
|
||||
lines.append("")
|
||||
|
||||
# Session token usage
|
||||
lines.append("📊 **Session Token Usage**")
|
||||
lines.append(f"Prompt (input): {agent.session_prompt_tokens:,}")
|
||||
lines.append(f"Completion (output): {agent.session_completion_tokens:,}")
|
||||
lines.append(f"Total: {agent.session_total_tokens:,}")
|
||||
lines.append(f"API calls: {agent.session_api_calls}")
|
||||
lines = [
|
||||
"📊 **Session Token Usage**",
|
||||
f"Prompt (input): {agent.session_prompt_tokens:,}",
|
||||
f"Completion (output): {agent.session_completion_tokens:,}",
|
||||
f"Total: {agent.session_total_tokens:,}",
|
||||
f"API calls: {agent.session_api_calls}",
|
||||
]
|
||||
ctx = agent.context_compressor
|
||||
if ctx.last_prompt_tokens:
|
||||
pct = min(100, ctx.last_prompt_tokens / ctx.context_length * 100) if ctx.context_length else 0
|
||||
lines.append(f"Context: {ctx.last_prompt_tokens:,} / {ctx.context_length:,} ({pct:.0f}%)")
|
||||
if ctx.compression_count:
|
||||
lines.append(f"Compressions: {ctx.compression_count}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
# No running agent -- check session history for a rough count
|
||||
@@ -5537,7 +5518,7 @@ class GatewayRunner:
|
||||
Platform.TELEGRAM, Platform.DISCORD, Platform.SLACK, Platform.WHATSAPP,
|
||||
Platform.SIGNAL, Platform.MATTERMOST, Platform.MATRIX,
|
||||
Platform.HOMEASSISTANT, Platform.EMAIL, Platform.SMS, Platform.DINGTALK,
|
||||
Platform.FEISHU, Platform.WECOM, Platform.BLUEBUBBLES, Platform.LOCAL,
|
||||
Platform.FEISHU, Platform.WECOM, Platform.LOCAL,
|
||||
})
|
||||
|
||||
async def _handle_update_command(self, event: MessageEvent) -> str:
|
||||
@@ -6177,7 +6158,6 @@ class GatewayRunner:
|
||||
text=synth_text,
|
||||
message_type=MessageType.TEXT,
|
||||
source=_source,
|
||||
internal=True,
|
||||
)
|
||||
logger.info(
|
||||
"Process %s finished — injecting agent notification for session %s",
|
||||
@@ -6440,18 +6420,6 @@ class GatewayRunner:
|
||||
if not adapter:
|
||||
return
|
||||
|
||||
# Skip tool progress for platforms that don't support message
|
||||
# editing (e.g. iMessage/BlueBubbles) — each progress update
|
||||
# would become a separate message bubble, which is noisy.
|
||||
from gateway.platforms.base import BasePlatformAdapter as _BaseAdapter
|
||||
if type(adapter).edit_message is _BaseAdapter.edit_message:
|
||||
while not progress_queue.empty():
|
||||
try:
|
||||
progress_queue.get_nowait()
|
||||
except Exception:
|
||||
break
|
||||
return
|
||||
|
||||
progress_lines = [] # Accumulated tool lines
|
||||
progress_msg_id = None # ID of the progress message to edit
|
||||
can_edit = True # False once an edit fails (platform doesn't support it)
|
||||
@@ -7146,9 +7114,6 @@ class GatewayRunner:
|
||||
# Default 1800s (30 min inactivity). 0 = unlimited.
|
||||
_agent_timeout_raw = float(os.getenv("HERMES_AGENT_TIMEOUT", 1800))
|
||||
_agent_timeout = _agent_timeout_raw if _agent_timeout_raw > 0 else None
|
||||
_agent_warning_raw = float(os.getenv("HERMES_AGENT_TIMEOUT_WARNING", 900))
|
||||
_agent_warning = _agent_warning_raw if _agent_warning_raw > 0 else None
|
||||
_warning_fired = False
|
||||
loop = asyncio.get_event_loop()
|
||||
_executor_task = asyncio.ensure_future(
|
||||
loop.run_in_executor(None, run_sync)
|
||||
@@ -7181,25 +7146,6 @@ class GatewayRunner:
|
||||
_idle_secs = _act.get("seconds_since_activity", 0.0)
|
||||
except Exception:
|
||||
pass
|
||||
# Staged warning: fire once before escalating to full timeout.
|
||||
if (not _warning_fired and _agent_warning is not None
|
||||
and _idle_secs >= _agent_warning):
|
||||
_warning_fired = True
|
||||
_warn_adapter = self.adapters.get(source.platform)
|
||||
if _warn_adapter:
|
||||
_elapsed_warn = int(_agent_warning // 60) or 1
|
||||
_remaining_mins = int((_agent_timeout - _agent_warning) // 60) or 1
|
||||
try:
|
||||
await _warn_adapter.send(
|
||||
source.chat_id,
|
||||
f"⚠️ No activity for {_elapsed_warn} min. "
|
||||
f"If the agent does not respond soon, it will "
|
||||
f"be timed out in {_remaining_mins} min. "
|
||||
f"You can continue waiting or use /reset.",
|
||||
metadata=_status_thread_metadata,
|
||||
)
|
||||
except Exception as _warn_err:
|
||||
logger.debug("Inactivity warning send error: %s", _warn_err)
|
||||
if _idle_secs >= _agent_timeout:
|
||||
_inactivity_timeout = True
|
||||
break
|
||||
@@ -7274,9 +7220,15 @@ class GatewayRunner:
|
||||
if _agent is not None and hasattr(_agent, 'model'):
|
||||
_cfg_model = _resolve_gateway_model()
|
||||
if _agent.model != _cfg_model:
|
||||
self._effective_model = _agent.model
|
||||
self._effective_provider = getattr(_agent, 'provider', None)
|
||||
# Fallback activated — evict cached agent so the next
|
||||
# message starts fresh and retries the primary model.
|
||||
self._evict_cached_agent(session_key)
|
||||
else:
|
||||
# Primary model worked — clear any stale fallback state
|
||||
self._effective_model = None
|
||||
self._effective_provider = None
|
||||
|
||||
# Check if we were interrupted OR have a queued message (/queue).
|
||||
result = result_holder[0]
|
||||
|
||||
+18
-2
@@ -32,6 +32,9 @@ def _now() -> datetime:
|
||||
# PII redaction helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PHONE_RE = re.compile(r"^\+?\d[\d\-\s]{6,}$")
|
||||
|
||||
|
||||
def _hash_id(value: str) -> str:
|
||||
"""Deterministic 12-char hex hash of an identifier."""
|
||||
return hashlib.sha256(value.encode("utf-8")).hexdigest()[:12]
|
||||
@@ -55,6 +58,10 @@ def _hash_chat_id(value: str) -> str:
|
||||
return _hash_id(value)
|
||||
|
||||
|
||||
def _looks_like_phone(value: str) -> bool:
|
||||
"""Return True if *value* looks like a phone number (E.164 or similar)."""
|
||||
return bool(_PHONE_RE.match(value.strip()))
|
||||
|
||||
from .config import (
|
||||
Platform,
|
||||
GatewayConfig,
|
||||
@@ -137,6 +144,15 @@ class SessionSource:
|
||||
chat_id_alt=data.get("chat_id_alt"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def local_cli(cls) -> "SessionSource":
|
||||
"""Create a source representing the local CLI."""
|
||||
return cls(
|
||||
platform=Platform.LOCAL,
|
||||
chat_id="cli",
|
||||
chat_name="CLI terminal",
|
||||
chat_type="dm",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -177,7 +193,6 @@ _PII_SAFE_PLATFORMS = frozenset({
|
||||
Platform.WHATSAPP,
|
||||
Platform.SIGNAL,
|
||||
Platform.TELEGRAM,
|
||||
Platform.BLUEBUBBLES,
|
||||
})
|
||||
"""Platforms where user IDs can be safely redacted (no in-message mention system
|
||||
that requires raw IDs). Discord is excluded because mentions use ``<@user_id>``
|
||||
@@ -494,7 +509,8 @@ class SessionStore:
|
||||
"""
|
||||
|
||||
def __init__(self, sessions_dir: Path, config: GatewayConfig,
|
||||
has_active_processes_fn=None):
|
||||
has_active_processes_fn=None,
|
||||
on_auto_reset=None):
|
||||
self.sessions_dir = sessions_dir
|
||||
self.config = config
|
||||
self._entries: Dict[str, SessionEntry] = {}
|
||||
|
||||
@@ -136,34 +136,7 @@ class GatewayStreamConsumer:
|
||||
|
||||
if should_edit and self._accumulated:
|
||||
# Split overflow: if accumulated text exceeds the platform
|
||||
# limit, split into properly sized chunks.
|
||||
if (
|
||||
len(self._accumulated) > _safe_limit
|
||||
and self._message_id is None
|
||||
):
|
||||
# No existing message to edit (first message or after a
|
||||
# segment break). Use truncate_message — the same
|
||||
# helper the non-streaming path uses — to split with
|
||||
# proper word/code-fence boundaries and chunk
|
||||
# indicators like "(1/2)".
|
||||
chunks = self.adapter.truncate_message(
|
||||
self._accumulated, _safe_limit
|
||||
)
|
||||
for chunk in chunks:
|
||||
await self._send_new_chunk(chunk, self._message_id)
|
||||
self._accumulated = ""
|
||||
self._last_sent_text = ""
|
||||
self._last_edit_time = time.monotonic()
|
||||
if got_done:
|
||||
return
|
||||
if got_segment_break:
|
||||
self._message_id = None
|
||||
self._fallback_final_send = False
|
||||
self._fallback_prefix = ""
|
||||
continue
|
||||
|
||||
# Existing message: edit it with the first chunk, then
|
||||
# start a new message for the overflow remainder.
|
||||
# limit, finalize the current message and start a new one.
|
||||
while (
|
||||
len(self._accumulated) > _safe_limit
|
||||
and self._message_id is not None
|
||||
@@ -253,34 +226,6 @@ class GatewayStreamConsumer:
|
||||
# Strip trailing whitespace/newlines but preserve leading content
|
||||
return cleaned.rstrip()
|
||||
|
||||
async def _send_new_chunk(self, text: str, reply_to_id: Optional[str]) -> Optional[str]:
|
||||
"""Send a new message chunk, optionally threaded to a previous message.
|
||||
|
||||
Returns the message_id so callers can thread subsequent chunks.
|
||||
"""
|
||||
text = self._clean_for_display(text)
|
||||
if not text.strip():
|
||||
return reply_to_id
|
||||
try:
|
||||
meta = dict(self.metadata) if self.metadata else {}
|
||||
result = await self.adapter.send(
|
||||
chat_id=self.chat_id,
|
||||
content=text,
|
||||
reply_to=reply_to_id,
|
||||
metadata=meta,
|
||||
)
|
||||
if result.success and result.message_id:
|
||||
self._message_id = str(result.message_id)
|
||||
self._already_sent = True
|
||||
self._last_sent_text = text
|
||||
return str(result.message_id)
|
||||
else:
|
||||
self._edit_supported = False
|
||||
return reply_to_id
|
||||
except Exception as e:
|
||||
logger.error("Stream send chunk error: %s", e)
|
||||
return reply_to_id
|
||||
|
||||
def _visible_prefix(self) -> str:
|
||||
"""Return the visible text already shown in the streamed message."""
|
||||
prefix = self._last_sent_text or ""
|
||||
|
||||
+34
-15
@@ -70,6 +70,7 @@ DEFAULT_CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex"
|
||||
DEFAULT_QWEN_BASE_URL = "https://portal.qwen.ai/v1"
|
||||
DEFAULT_GITHUB_MODELS_BASE_URL = "https://api.githubcopilot.com"
|
||||
DEFAULT_COPILOT_ACP_BASE_URL = "acp://copilot"
|
||||
DEFAULT_GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai"
|
||||
CODEX_OAUTH_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
CODEX_OAUTH_TOKEN_URL = "https://auth.openai.com/oauth/token"
|
||||
CODEX_ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120
|
||||
@@ -249,7 +250,7 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
|
||||
# Kimi Code Endpoint Detection
|
||||
# =============================================================================
|
||||
|
||||
# Kimi Code (kimi.com/code) issues keys prefixed "sk-kimi-" that only work
|
||||
# Kimi Code (platform.kimi.ai) issues keys prefixed "sk-kimi-" that only work
|
||||
# on api.kimi.com/coding/v1. Legacy keys from platform.moonshot.ai work on
|
||||
# api.moonshot.ai/v1 (the default). Auto-detect when user hasn't set
|
||||
# KIMI_BASE_URL explicitly.
|
||||
@@ -2341,6 +2342,33 @@ def resolve_external_process_provider_credentials(provider_id: str) -> Dict[str,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# External credential detection
|
||||
# =============================================================================
|
||||
|
||||
def detect_external_credentials() -> List[Dict[str, Any]]:
|
||||
"""Scan for credentials from other CLI tools that Hermes can reuse.
|
||||
|
||||
Returns a list of dicts, each with:
|
||||
- provider: str -- Hermes provider id (e.g. "openai-codex")
|
||||
- path: str -- filesystem path where creds were found
|
||||
- label: str -- human-friendly description for the setup UI
|
||||
"""
|
||||
found: List[Dict[str, Any]] = []
|
||||
|
||||
# Codex CLI: ~/.codex/auth.json (importable, not shared)
|
||||
cli_tokens = _import_codex_cli_tokens()
|
||||
if cli_tokens:
|
||||
codex_path = Path.home() / ".codex" / "auth.json"
|
||||
found.append({
|
||||
"provider": "openai-codex",
|
||||
"path": str(codex_path),
|
||||
"label": f"Codex CLI credentials found ({codex_path}) — run `hermes auth` to create a separate session",
|
||||
})
|
||||
|
||||
return found
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CLI Commands — login / logout
|
||||
# =============================================================================
|
||||
@@ -2989,15 +3017,12 @@ def _login_nous(args, pconfig: ProviderConfig) -> None:
|
||||
_save_provider_state(auth_store, "nous", auth_state)
|
||||
saved_to = _save_auth_store(auth_store)
|
||||
|
||||
config_path = _update_config_for_provider("nous", inference_base_url)
|
||||
print()
|
||||
print("Login successful!")
|
||||
print(f" Auth state: {saved_to}")
|
||||
print(f" Config updated: {config_path} (model.provider=nous)")
|
||||
|
||||
# Resolve model BEFORE writing provider to config.yaml so we never
|
||||
# leave the config in a half-updated state (provider=nous but model
|
||||
# still set to the previous provider's model, e.g. opus from
|
||||
# OpenRouter). The auth.json active_provider was already set above.
|
||||
selected_model = None
|
||||
try:
|
||||
runtime_key = auth_state.get("agent_key") or auth_state.get("access_token")
|
||||
if not isinstance(runtime_key, str) or not runtime_key:
|
||||
@@ -3031,6 +3056,9 @@ def _login_nous(args, pconfig: ProviderConfig) -> None:
|
||||
unavailable_models=unavailable_models,
|
||||
portal_url=_portal,
|
||||
)
|
||||
if selected_model:
|
||||
_save_model_choice(selected_model)
|
||||
print(f"Default model set to: {selected_model}")
|
||||
elif unavailable_models:
|
||||
_url = (_portal or DEFAULT_NOUS_PORTAL_URL).rstrip("/")
|
||||
print("No free models currently available.")
|
||||
@@ -3042,15 +3070,6 @@ def _login_nous(args, pconfig: ProviderConfig) -> None:
|
||||
print()
|
||||
print(f"Login succeeded, but could not fetch available models. Reason: {message}")
|
||||
|
||||
# Write provider + model atomically so config is never mismatched.
|
||||
config_path = _update_config_for_provider(
|
||||
"nous", inference_base_url, default_model=selected_model,
|
||||
)
|
||||
if selected_model:
|
||||
_save_model_choice(selected_model)
|
||||
print(f"Default model set to: {selected_model}")
|
||||
print(f" Config updated: {config_path} (model.provider=nous)")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nLogin cancelled.")
|
||||
raise SystemExit(130)
|
||||
|
||||
@@ -90,6 +90,12 @@ HERMES_CADUCEUS = """[#CD7F32]⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⡀⠀⣀⣀
|
||||
[#B8860B]⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠳⠈⣡⠞⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[/]
|
||||
[#B8860B]⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[/]"""
|
||||
|
||||
COMPACT_BANNER = """
|
||||
[bold #FFD700]╔══════════════════════════════════════════════════════════════╗[/]
|
||||
[bold #FFD700]║[/] [#FFBF00]⚕ NOUS HERMES[/] [dim #B8860B]- AI Agent Framework[/] [bold #FFD700]║[/]
|
||||
[bold #FFD700]║[/] [#CD7F32]Messenger of the Digital Gods[/] [dim #B8860B]Nous Research[/] [bold #FFD700]║[/]
|
||||
[bold #FFD700]╚══════════════════════════════════════════════════════════════╝[/]
|
||||
"""
|
||||
|
||||
|
||||
# =========================================================================
|
||||
@@ -289,16 +295,10 @@ def _format_context_length(tokens: int) -> str:
|
||||
"""Format a token count for display (e.g. 128000 → '128K', 1048576 → '1M')."""
|
||||
if tokens >= 1_000_000:
|
||||
val = tokens / 1_000_000
|
||||
rounded = round(val)
|
||||
if abs(val - rounded) < 0.05:
|
||||
return f"{rounded}M"
|
||||
return f"{val:.1f}M"
|
||||
return f"{val:g}M"
|
||||
elif tokens >= 1_000:
|
||||
val = tokens / 1_000
|
||||
rounded = round(val)
|
||||
if abs(val - rounded) < 0.05:
|
||||
return f"{rounded}K"
|
||||
return f"{val:.1f}K"
|
||||
return f"{val:g}K"
|
||||
return str(tokens)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
"""Shared curses-based multi-select checklist for Hermes CLI.
|
||||
|
||||
Used by both ``hermes tools`` and ``hermes skills`` to present a
|
||||
toggleable list of items. Falls back to a numbered text UI when
|
||||
curses is unavailable (Windows without curses, piped stdin, etc.).
|
||||
"""
|
||||
|
||||
import sys
|
||||
from typing import List, Set
|
||||
|
||||
from hermes_cli.colors import Colors, color
|
||||
|
||||
|
||||
def curses_checklist(
|
||||
title: str,
|
||||
items: List[str],
|
||||
pre_selected: Set[int],
|
||||
) -> Set[int]:
|
||||
"""Multi-select checklist. Returns set of **selected** indices.
|
||||
|
||||
Args:
|
||||
title: Header text shown at the top of the checklist.
|
||||
items: Display labels for each row.
|
||||
pre_selected: Indices that start checked.
|
||||
|
||||
Returns:
|
||||
The indices the user confirmed as checked. On cancel (ESC/q),
|
||||
returns ``pre_selected`` unchanged.
|
||||
"""
|
||||
# Safety: return defaults when stdin is not a terminal.
|
||||
if not sys.stdin.isatty():
|
||||
return set(pre_selected)
|
||||
|
||||
try:
|
||||
import curses
|
||||
selected = set(pre_selected)
|
||||
result = [None]
|
||||
|
||||
def _ui(stdscr):
|
||||
curses.curs_set(0)
|
||||
if curses.has_colors():
|
||||
curses.start_color()
|
||||
curses.use_default_colors()
|
||||
curses.init_pair(1, curses.COLOR_GREEN, -1)
|
||||
curses.init_pair(2, curses.COLOR_YELLOW, -1)
|
||||
curses.init_pair(3, 8, -1) # dim gray
|
||||
cursor = 0
|
||||
scroll_offset = 0
|
||||
|
||||
while True:
|
||||
stdscr.clear()
|
||||
max_y, max_x = stdscr.getmaxyx()
|
||||
|
||||
# Header
|
||||
try:
|
||||
hattr = curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0)
|
||||
stdscr.addnstr(0, 0, title, max_x - 1, hattr)
|
||||
stdscr.addnstr(
|
||||
1, 0,
|
||||
" ↑↓ navigate SPACE toggle ENTER confirm ESC cancel",
|
||||
max_x - 1, curses.A_DIM,
|
||||
)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
# Scrollable item list
|
||||
visible_rows = max_y - 3
|
||||
if cursor < scroll_offset:
|
||||
scroll_offset = cursor
|
||||
elif cursor >= scroll_offset + visible_rows:
|
||||
scroll_offset = cursor - visible_rows + 1
|
||||
|
||||
for draw_i, i in enumerate(
|
||||
range(scroll_offset, min(len(items), scroll_offset + visible_rows))
|
||||
):
|
||||
y = draw_i + 3
|
||||
if y >= max_y - 1:
|
||||
break
|
||||
check = "✓" if i in selected else " "
|
||||
arrow = "→" if i == cursor else " "
|
||||
line = f" {arrow} [{check}] {items[i]}"
|
||||
|
||||
attr = curses.A_NORMAL
|
||||
if i == cursor:
|
||||
attr = curses.A_BOLD
|
||||
if curses.has_colors():
|
||||
attr |= curses.color_pair(1)
|
||||
try:
|
||||
stdscr.addnstr(y, 0, line, max_x - 1, attr)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
stdscr.refresh()
|
||||
key = stdscr.getch()
|
||||
|
||||
if key in (curses.KEY_UP, ord("k")):
|
||||
cursor = (cursor - 1) % len(items)
|
||||
elif key in (curses.KEY_DOWN, ord("j")):
|
||||
cursor = (cursor + 1) % len(items)
|
||||
elif key == ord(" "):
|
||||
selected.symmetric_difference_update({cursor})
|
||||
elif key in (curses.KEY_ENTER, 10, 13):
|
||||
result[0] = set(selected)
|
||||
return
|
||||
elif key in (27, ord("q")):
|
||||
result[0] = set(pre_selected)
|
||||
return
|
||||
|
||||
curses.wrapper(_ui)
|
||||
return result[0] if result[0] is not None else set(pre_selected)
|
||||
|
||||
except Exception:
|
||||
pass # fall through to numbered fallback
|
||||
|
||||
# ── Numbered text fallback ────────────────────────────────────────────
|
||||
selected = set(pre_selected)
|
||||
print(color(f"\n {title}", Colors.YELLOW))
|
||||
print(color(" Toggle by number, Enter to confirm.\n", Colors.DIM))
|
||||
|
||||
while True:
|
||||
for i, label in enumerate(items):
|
||||
check = "✓" if i in selected else " "
|
||||
print(f" {i + 1:3}. [{check}] {label}")
|
||||
print()
|
||||
|
||||
try:
|
||||
raw = input(color(" Number to toggle, 's' to save, 'q' to cancel: ", Colors.DIM)).strip()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
return set(pre_selected)
|
||||
|
||||
if raw.lower() == "s" or raw == "":
|
||||
return selected
|
||||
if raw.lower() == "q":
|
||||
return set(pre_selected)
|
||||
try:
|
||||
idx = int(raw) - 1
|
||||
if 0 <= idx < len(items):
|
||||
selected.symmetric_difference_update({idx})
|
||||
except ValueError:
|
||||
print(color(" Invalid input", Colors.DIM))
|
||||
+10
-3
@@ -87,7 +87,8 @@ COMMAND_REGISTRY: list[CommandDef] = [
|
||||
CommandDef("model", "Switch model for this session", "Configuration", args_hint="[model] [--global]"),
|
||||
CommandDef("provider", "Show available providers and current provider",
|
||||
"Configuration"),
|
||||
|
||||
CommandDef("prompt", "View/set custom system prompt", "Configuration",
|
||||
cli_only=True, args_hint="[text]", subcommands=("clear",)),
|
||||
CommandDef("personality", "Set a predefined personality", "Configuration",
|
||||
args_hint="[name]"),
|
||||
CommandDef("statusbar", "Toggle the context/model status bar", "Configuration",
|
||||
@@ -99,7 +100,7 @@ COMMAND_REGISTRY: list[CommandDef] = [
|
||||
"Configuration"),
|
||||
CommandDef("reasoning", "Manage reasoning effort and display", "Configuration",
|
||||
args_hint="[level|show|hide]",
|
||||
subcommands=("none", "minimal", "low", "medium", "high", "xhigh", "show", "hide", "on", "off")),
|
||||
subcommands=("none", "low", "minimal", "medium", "high", "xhigh", "show", "hide", "on", "off")),
|
||||
CommandDef("skin", "Show or change the display skin/theme", "Configuration",
|
||||
cli_only=True, args_hint="[name]"),
|
||||
CommandDef("voice", "Toggle voice mode", "Configuration",
|
||||
@@ -128,7 +129,7 @@ COMMAND_REGISTRY: list[CommandDef] = [
|
||||
CommandDef("commands", "Browse all commands and skills (paginated)", "Info",
|
||||
gateway_only=True, args_hint="[page]"),
|
||||
CommandDef("help", "Show available commands", "Info"),
|
||||
CommandDef("usage", "Show token usage and rate limits for the current session", "Info"),
|
||||
CommandDef("usage", "Show token usage for the current session", "Info"),
|
||||
CommandDef("insights", "Show usage insights and analytics", "Info",
|
||||
args_hint="[days]"),
|
||||
CommandDef("platforms", "Show gateway/messaging platform status", "Info",
|
||||
@@ -169,6 +170,12 @@ def resolve_command(name: str) -> CommandDef | None:
|
||||
return _COMMAND_LOOKUP.get(name.lower().lstrip("/"))
|
||||
|
||||
|
||||
def register_plugin_command(cmd: CommandDef) -> None:
|
||||
"""Append a plugin-defined command to the registry and refresh lookups."""
|
||||
COMMAND_REGISTRY.append(cmd)
|
||||
rebuild_lookups()
|
||||
|
||||
|
||||
def rebuild_lookups() -> None:
|
||||
"""Rebuild all derived lookup dicts from the current COMMAND_REGISTRY.
|
||||
|
||||
|
||||
+7
-78
@@ -39,7 +39,6 @@ _EXTRA_ENV_KEYS = frozenset({
|
||||
"DINGTALK_CLIENT_ID", "DINGTALK_CLIENT_SECRET",
|
||||
"FEISHU_APP_ID", "FEISHU_APP_SECRET", "FEISHU_ENCRYPT_KEY", "FEISHU_VERIFICATION_TOKEN",
|
||||
"WECOM_BOT_ID", "WECOM_SECRET",
|
||||
"BLUEBUBBLES_SERVER_URL", "BLUEBUBBLES_PASSWORD",
|
||||
"TERMINAL_ENV", "TERMINAL_SSH_KEY", "TERMINAL_SSH_PORT",
|
||||
"WHATSAPP_MODE", "WHATSAPP_ENABLED",
|
||||
"MATTERMOST_HOME_CHANNEL", "MATTERMOST_REPLY_MODE",
|
||||
@@ -197,44 +196,14 @@ def _ensure_default_soul_md(home: Path) -> None:
|
||||
|
||||
|
||||
def ensure_hermes_home():
|
||||
"""Ensure ~/.hermes directory structure exists with secure permissions.
|
||||
|
||||
In managed mode (NixOS), dirs are created by the activation script with
|
||||
setgid + group-writable (2770). We skip mkdir and set umask(0o007) so
|
||||
any files created (e.g. SOUL.md) are group-writable (0660).
|
||||
"""
|
||||
"""Ensure ~/.hermes directory structure exists with secure permissions."""
|
||||
home = get_hermes_home()
|
||||
if is_managed():
|
||||
old_umask = os.umask(0o007)
|
||||
try:
|
||||
_ensure_hermes_home_managed(home)
|
||||
finally:
|
||||
os.umask(old_umask)
|
||||
else:
|
||||
home.mkdir(parents=True, exist_ok=True)
|
||||
_secure_dir(home)
|
||||
for subdir in ("cron", "sessions", "logs", "memories"):
|
||||
d = home / subdir
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
_secure_dir(d)
|
||||
_ensure_default_soul_md(home)
|
||||
|
||||
|
||||
def _ensure_hermes_home_managed(home: Path):
|
||||
"""Managed-mode variant: verify dirs exist (activation creates them), seed SOUL.md."""
|
||||
if not home.is_dir():
|
||||
raise RuntimeError(
|
||||
f"HERMES_HOME {home} does not exist. "
|
||||
"Run 'sudo nixos-rebuild switch' first."
|
||||
)
|
||||
home.mkdir(parents=True, exist_ok=True)
|
||||
_secure_dir(home)
|
||||
for subdir in ("cron", "sessions", "logs", "memories"):
|
||||
d = home / subdir
|
||||
if not d.is_dir():
|
||||
raise RuntimeError(
|
||||
f"{d} does not exist. "
|
||||
"Run 'sudo nixos-rebuild switch' first."
|
||||
)
|
||||
# Inside umask(0o007) scope — SOUL.md will be created as 0660
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
_secure_dir(d)
|
||||
_ensure_default_soul_md(home)
|
||||
|
||||
|
||||
@@ -261,10 +230,6 @@ DEFAULT_CONFIG = {
|
||||
# (force on/off for all models), or a list of model-name substrings
|
||||
# to match (e.g. ["gpt", "codex", "gemini", "qwen"]).
|
||||
"tool_use_enforcement": "auto",
|
||||
# Staged inactivity warning: send a warning to the user at this
|
||||
# threshold before escalating to a full timeout. The warning fires
|
||||
# once per run and does not interrupt the agent. 0 = disable warning.
|
||||
"gateway_timeout_warning": 900,
|
||||
},
|
||||
|
||||
"terminal": {
|
||||
@@ -599,7 +564,7 @@ DEFAULT_CONFIG = {
|
||||
},
|
||||
|
||||
# Config schema version - bump this when adding new required fields
|
||||
"_config_version": 13,
|
||||
"_config_version": 12,
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
@@ -1155,27 +1120,6 @@ OPTIONAL_ENV_VARS = {
|
||||
"category": "messaging",
|
||||
"advanced": True,
|
||||
},
|
||||
"BLUEBUBBLES_SERVER_URL": {
|
||||
"description": "BlueBubbles server URL for iMessage integration (e.g. http://192.168.1.10:1234)",
|
||||
"prompt": "BlueBubbles server URL",
|
||||
"url": "https://bluebubbles.app/",
|
||||
"password": False,
|
||||
"category": "messaging",
|
||||
},
|
||||
"BLUEBUBBLES_PASSWORD": {
|
||||
"description": "BlueBubbles server password (from BlueBubbles Server → Settings → API)",
|
||||
"prompt": "BlueBubbles server password",
|
||||
"url": None,
|
||||
"password": True,
|
||||
"category": "messaging",
|
||||
},
|
||||
"BLUEBUBBLES_ALLOWED_USERS": {
|
||||
"description": "Comma-separated iMessage addresses (email or phone) allowed to use the bot",
|
||||
"prompt": "Allowed iMessage addresses (comma-separated)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "messaging",
|
||||
},
|
||||
"GATEWAY_ALLOW_ALL_USERS": {
|
||||
"description": "Allow all users to interact with messaging bots (true/false). Default: false.",
|
||||
"prompt": "Allow all users (true/false)",
|
||||
@@ -1247,7 +1191,7 @@ OPTIONAL_ENV_VARS = {
|
||||
"category": "setting",
|
||||
},
|
||||
"SUDO_PASSWORD": {
|
||||
"description": "Sudo password for terminal commands requiring root access; set to an explicit empty string to try empty without prompting",
|
||||
"description": "Sudo password for terminal commands requiring root access",
|
||||
"prompt": "Sudo password",
|
||||
"url": None,
|
||||
"password": True,
|
||||
@@ -1731,21 +1675,6 @@ def migrate_config(interactive: bool = True, quiet: bool = False) -> Dict[str, A
|
||||
ep = providers_dict[key]
|
||||
print(f" → {key}: {ep.get('api', '')}")
|
||||
|
||||
# ── Version 12 → 13: clear dead LLM_MODEL / OPENAI_MODEL from .env ──
|
||||
# These env vars were written by the old setup wizard but nothing reads
|
||||
# them anymore (config.yaml is the sole source of truth since March 2026).
|
||||
# Stale entries cause user confusion — see issue report.
|
||||
if current_ver < 13:
|
||||
for dead_var in ("LLM_MODEL", "OPENAI_MODEL"):
|
||||
try:
|
||||
old_val = get_env_value(dead_var)
|
||||
if old_val:
|
||||
save_env_value(dead_var, "")
|
||||
if not quiet:
|
||||
print(f" ✓ Cleared {dead_var} from .env (no longer used — config.yaml is source of truth)")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if current_ver < latest_ver and not quiet:
|
||||
print(f"Config version: {current_ver} → {latest_ver}")
|
||||
|
||||
|
||||
@@ -31,6 +31,13 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# OAuth device code flow constants (same client ID as opencode/Copilot CLI)
|
||||
COPILOT_OAUTH_CLIENT_ID = "Ov23li8tweQw6odWQebz"
|
||||
COPILOT_DEVICE_CODE_URL = "https://github.com/login/device/code"
|
||||
COPILOT_ACCESS_TOKEN_URL = "https://github.com/login/oauth/access_token"
|
||||
|
||||
# Copilot API constants
|
||||
COPILOT_TOKEN_EXCHANGE_URL = "https://api.github.com/copilot_internal/v2/token"
|
||||
COPILOT_API_BASE_URL = "https://api.githubcopilot.com"
|
||||
|
||||
# Token type prefixes
|
||||
_CLASSIC_PAT_PREFIX = "ghp_"
|
||||
_SUPPORTED_PREFIXES = ("gho_", "github_pat_", "ghu_")
|
||||
@@ -43,6 +50,11 @@ _DEVICE_CODE_POLL_INTERVAL = 5 # seconds
|
||||
_DEVICE_CODE_POLL_SAFETY_MARGIN = 3 # seconds
|
||||
|
||||
|
||||
def is_classic_pat(token: str) -> bool:
|
||||
"""Check if a token is a classic PAT (ghp_*), which Copilot doesn't support."""
|
||||
return token.strip().startswith(_CLASSIC_PAT_PREFIX)
|
||||
|
||||
|
||||
def validate_copilot_token(token: str) -> tuple[bool, str]:
|
||||
"""Validate that a token is usable with the Copilot API.
|
||||
|
||||
|
||||
@@ -1,332 +0,0 @@
|
||||
"""
|
||||
Dump command for hermes CLI.
|
||||
|
||||
Outputs a compact, plain-text summary of the user's Hermes setup
|
||||
that can be copy-pasted into Discord/GitHub/Telegram for support context.
|
||||
No ANSI colors, no checkmarks — just data.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from hermes_cli.config import get_hermes_home, get_env_path, get_project_root, load_config
|
||||
from hermes_constants import display_hermes_home
|
||||
|
||||
|
||||
def _get_git_commit(project_root: Path) -> str:
|
||||
"""Return short git commit hash, or '(unknown)'."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "rev-parse", "--short=8", "HEAD"],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
cwd=str(project_root),
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return result.stdout.strip()
|
||||
except Exception:
|
||||
pass
|
||||
return "(unknown)"
|
||||
|
||||
|
||||
def _redact(value: str) -> str:
|
||||
"""Redact all but first 4 and last 4 chars."""
|
||||
if not value:
|
||||
return ""
|
||||
if len(value) < 12:
|
||||
return "***"
|
||||
return value[:4] + "..." + value[-4:]
|
||||
|
||||
|
||||
def _gateway_status() -> str:
|
||||
"""Return a short gateway status string."""
|
||||
if sys.platform.startswith("linux"):
|
||||
try:
|
||||
from hermes_cli.gateway import get_service_name
|
||||
svc = get_service_name()
|
||||
except Exception:
|
||||
svc = "hermes-gateway"
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["systemctl", "--user", "is-active", svc],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
)
|
||||
return "running (systemd)" if r.stdout.strip() == "active" else "stopped"
|
||||
except Exception:
|
||||
return "unknown"
|
||||
elif sys.platform == "darwin":
|
||||
try:
|
||||
from hermes_cli.gateway import get_launchd_label
|
||||
r = subprocess.run(
|
||||
["launchctl", "list", get_launchd_label()],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
)
|
||||
return "loaded (launchd)" if r.returncode == 0 else "not loaded"
|
||||
except Exception:
|
||||
return "unknown"
|
||||
return "N/A"
|
||||
|
||||
|
||||
def _count_skills(hermes_home: Path) -> int:
|
||||
"""Count installed skills."""
|
||||
skills_dir = hermes_home / "skills"
|
||||
if not skills_dir.is_dir():
|
||||
return 0
|
||||
count = 0
|
||||
for item in skills_dir.rglob("SKILL.md"):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
def _count_mcp_servers(config: dict) -> int:
|
||||
"""Count configured MCP servers."""
|
||||
mcp = config.get("mcp", {})
|
||||
servers = mcp.get("servers", {})
|
||||
return len(servers)
|
||||
|
||||
|
||||
def _cron_summary(hermes_home: Path) -> str:
|
||||
"""Return cron jobs summary."""
|
||||
jobs_file = hermes_home / "cron" / "jobs.json"
|
||||
if not jobs_file.exists():
|
||||
return "0"
|
||||
try:
|
||||
with open(jobs_file, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
jobs = data.get("jobs", [])
|
||||
active = sum(1 for j in jobs if j.get("enabled", True))
|
||||
return f"{active} active / {len(jobs)} total"
|
||||
except Exception:
|
||||
return "(error reading)"
|
||||
|
||||
|
||||
def _configured_platforms() -> list[str]:
|
||||
"""Return list of configured messaging platform names."""
|
||||
checks = {
|
||||
"telegram": "TELEGRAM_BOT_TOKEN",
|
||||
"discord": "DISCORD_BOT_TOKEN",
|
||||
"slack": "SLACK_BOT_TOKEN",
|
||||
"whatsapp": "WHATSAPP_ENABLED",
|
||||
"signal": "SIGNAL_HTTP_URL",
|
||||
"email": "EMAIL_ADDRESS",
|
||||
"sms": "TWILIO_ACCOUNT_SID",
|
||||
"matrix": "MATRIX_HOMESERVER_URL",
|
||||
"mattermost": "MATTERMOST_URL",
|
||||
"homeassistant": "HASS_TOKEN",
|
||||
"dingtalk": "DINGTALK_CLIENT_ID",
|
||||
"feishu": "FEISHU_APP_ID",
|
||||
"wecom": "WECOM_BOT_ID",
|
||||
}
|
||||
return [name for name, env in checks.items() if os.getenv(env)]
|
||||
|
||||
|
||||
def _memory_provider(config: dict) -> str:
|
||||
"""Return the active memory provider name."""
|
||||
mem = config.get("memory", {})
|
||||
provider = mem.get("provider", "")
|
||||
return provider if provider else "built-in"
|
||||
|
||||
|
||||
def _get_model_and_provider(config: dict) -> tuple[str, str]:
|
||||
"""Extract model and provider from config."""
|
||||
model_cfg = config.get("model", "")
|
||||
if isinstance(model_cfg, dict):
|
||||
model = model_cfg.get("default") or model_cfg.get("model") or model_cfg.get("name") or "(not set)"
|
||||
provider = model_cfg.get("provider") or "(auto)"
|
||||
elif isinstance(model_cfg, str):
|
||||
model = model_cfg or "(not set)"
|
||||
provider = "(auto)"
|
||||
else:
|
||||
model = "(not set)"
|
||||
provider = "(auto)"
|
||||
return model, provider
|
||||
|
||||
|
||||
def _config_overrides(config: dict) -> dict[str, str]:
|
||||
"""Find non-default config values worth reporting.
|
||||
|
||||
Returns a flat dict of dotpath -> value for interesting overrides.
|
||||
"""
|
||||
from hermes_cli.config import DEFAULT_CONFIG
|
||||
|
||||
overrides = {}
|
||||
|
||||
# Sections with interesting user-facing overrides
|
||||
interesting_paths = [
|
||||
("agent", "max_turns"),
|
||||
("agent", "gateway_timeout"),
|
||||
("agent", "tool_use_enforcement"),
|
||||
("terminal", "backend"),
|
||||
("terminal", "docker_image"),
|
||||
("terminal", "persistent_shell"),
|
||||
("browser", "allow_private_urls"),
|
||||
("compression", "enabled"),
|
||||
("compression", "threshold"),
|
||||
("display", "streaming"),
|
||||
("display", "skin"),
|
||||
("display", "show_reasoning"),
|
||||
("smart_model_routing", "enabled"),
|
||||
("privacy", "redact_pii"),
|
||||
("tts", "provider"),
|
||||
]
|
||||
|
||||
for section, key in interesting_paths:
|
||||
default_section = DEFAULT_CONFIG.get(section, {})
|
||||
user_section = config.get(section, {})
|
||||
if not isinstance(default_section, dict) or not isinstance(user_section, dict):
|
||||
continue
|
||||
default_val = default_section.get(key)
|
||||
user_val = user_section.get(key)
|
||||
if user_val is not None and user_val != default_val:
|
||||
overrides[f"{section}.{key}"] = str(user_val)
|
||||
|
||||
# Toolsets (if different from default)
|
||||
default_toolsets = DEFAULT_CONFIG.get("toolsets", [])
|
||||
user_toolsets = config.get("toolsets", [])
|
||||
if user_toolsets != default_toolsets:
|
||||
overrides["toolsets"] = str(user_toolsets)
|
||||
|
||||
# Fallback providers
|
||||
fallbacks = config.get("fallback_providers", [])
|
||||
if fallbacks:
|
||||
overrides["fallback_providers"] = str(fallbacks)
|
||||
|
||||
return overrides
|
||||
|
||||
|
||||
def run_dump(args):
|
||||
"""Output a compact, copy-pasteable setup summary."""
|
||||
show_keys = getattr(args, "show_keys", False)
|
||||
|
||||
# Load env from .env file so key checks work
|
||||
from dotenv import load_dotenv
|
||||
env_path = get_env_path()
|
||||
if env_path.exists():
|
||||
try:
|
||||
load_dotenv(env_path, encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
load_dotenv(env_path, encoding="latin-1")
|
||||
# Also try project .env as dev fallback
|
||||
load_dotenv(get_project_root() / ".env", override=False, encoding="utf-8")
|
||||
|
||||
project_root = get_project_root()
|
||||
hermes_home = get_hermes_home()
|
||||
|
||||
try:
|
||||
from hermes_cli import __version__, __release_date__
|
||||
except ImportError:
|
||||
__version__ = "(unknown)"
|
||||
__release_date__ = ""
|
||||
|
||||
commit = _get_git_commit(project_root)
|
||||
|
||||
try:
|
||||
config = load_config()
|
||||
except Exception:
|
||||
config = {}
|
||||
|
||||
model, provider = _get_model_and_provider(config)
|
||||
|
||||
# Profile
|
||||
try:
|
||||
from hermes_cli.profiles import get_active_profile_name
|
||||
profile = get_active_profile_name() or "(default)"
|
||||
except Exception:
|
||||
profile = "(default)"
|
||||
|
||||
# Terminal backend
|
||||
terminal_cfg = config.get("terminal", {})
|
||||
backend = terminal_cfg.get("backend", "local")
|
||||
|
||||
# OpenAI SDK version
|
||||
try:
|
||||
import openai
|
||||
openai_ver = openai.__version__
|
||||
except ImportError:
|
||||
openai_ver = "not installed"
|
||||
|
||||
# OS info
|
||||
os_info = f"{platform.system()} {platform.release()} {platform.machine()}"
|
||||
|
||||
lines = []
|
||||
lines.append("--- hermes dump ---")
|
||||
ver_str = f"{__version__}"
|
||||
if __release_date__:
|
||||
ver_str += f" ({__release_date__})"
|
||||
ver_str += f" [{commit}]"
|
||||
lines.append(f"version: {ver_str}")
|
||||
lines.append(f"os: {os_info}")
|
||||
lines.append(f"python: {sys.version.split()[0]}")
|
||||
lines.append(f"openai_sdk: {openai_ver}")
|
||||
lines.append(f"profile: {profile}")
|
||||
lines.append(f"hermes_home: {display_hermes_home()}")
|
||||
lines.append(f"model: {model}")
|
||||
lines.append(f"provider: {provider}")
|
||||
lines.append(f"terminal: {backend}")
|
||||
|
||||
# API keys
|
||||
lines.append("")
|
||||
lines.append("api_keys:")
|
||||
api_keys = [
|
||||
("OPENROUTER_API_KEY", "openrouter"),
|
||||
("OPENAI_API_KEY", "openai"),
|
||||
("ANTHROPIC_API_KEY", "anthropic"),
|
||||
("ANTHROPIC_TOKEN", "anthropic_token"),
|
||||
("NOUS_API_KEY", "nous"),
|
||||
("GLM_API_KEY", "glm/zai"),
|
||||
("ZAI_API_KEY", "zai"),
|
||||
("KIMI_API_KEY", "kimi"),
|
||||
("MINIMAX_API_KEY", "minimax"),
|
||||
("DEEPSEEK_API_KEY", "deepseek"),
|
||||
("DASHSCOPE_API_KEY", "dashscope"),
|
||||
("HF_TOKEN", "huggingface"),
|
||||
("AI_GATEWAY_API_KEY", "ai_gateway"),
|
||||
("OPENCODE_ZEN_API_KEY", "opencode_zen"),
|
||||
("OPENCODE_GO_API_KEY", "opencode_go"),
|
||||
("KILOCODE_API_KEY", "kilocode"),
|
||||
("FIRECRAWL_API_KEY", "firecrawl"),
|
||||
("TAVILY_API_KEY", "tavily"),
|
||||
("BROWSERBASE_API_KEY", "browserbase"),
|
||||
("FAL_KEY", "fal"),
|
||||
("ELEVENLABS_API_KEY", "elevenlabs"),
|
||||
("GITHUB_TOKEN", "github"),
|
||||
]
|
||||
|
||||
for env_var, label in api_keys:
|
||||
val = os.getenv(env_var, "")
|
||||
if show_keys and val:
|
||||
display = _redact(val)
|
||||
else:
|
||||
display = "set" if val else "not set"
|
||||
lines.append(f" {label:<20} {display}")
|
||||
|
||||
# Features summary
|
||||
lines.append("")
|
||||
lines.append("features:")
|
||||
|
||||
toolsets = config.get("toolsets", ["hermes-cli"])
|
||||
lines.append(f" toolsets: {', '.join(toolsets) if toolsets else '(default)'}")
|
||||
lines.append(f" mcp_servers: {_count_mcp_servers(config)}")
|
||||
lines.append(f" memory_provider: {_memory_provider(config)}")
|
||||
lines.append(f" gateway: {_gateway_status()}")
|
||||
|
||||
platforms = _configured_platforms()
|
||||
lines.append(f" platforms: {', '.join(platforms) if platforms else 'none'}")
|
||||
lines.append(f" cron_jobs: {_cron_summary(hermes_home)}")
|
||||
lines.append(f" skills: {_count_skills(hermes_home)}")
|
||||
|
||||
# Config overrides (non-default values)
|
||||
overrides = _config_overrides(config)
|
||||
if overrides:
|
||||
lines.append("")
|
||||
lines.append("config_overrides:")
|
||||
for key, val in overrides.items():
|
||||
lines.append(f" {key}: {val}")
|
||||
|
||||
lines.append("--- end dump ---")
|
||||
|
||||
output = "\n".join(lines)
|
||||
print(output)
|
||||
+13
-28
@@ -308,6 +308,8 @@ def get_service_name() -> str:
|
||||
return f"{_SERVICE_BASE}-{suffix}"
|
||||
|
||||
|
||||
SERVICE_NAME = _SERVICE_BASE # backward-compat for external importers; prefer get_service_name()
|
||||
|
||||
|
||||
def get_systemd_unit_path(system: bool = False) -> Path:
|
||||
name = get_service_name()
|
||||
@@ -579,6 +581,17 @@ def get_python_path() -> str:
|
||||
return str(venv_python)
|
||||
return sys.executable
|
||||
|
||||
def get_hermes_cli_path() -> str:
|
||||
"""Get the path to the hermes CLI."""
|
||||
# Check if installed via pip
|
||||
import shutil
|
||||
hermes_bin = shutil.which("hermes")
|
||||
if hermes_bin:
|
||||
return hermes_bin
|
||||
|
||||
# Fallback to direct module execution
|
||||
return f"{get_python_path()} -m hermes_cli.main"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Systemd (Linux)
|
||||
@@ -1575,34 +1588,6 @@ _PLATFORMS = [
|
||||
"help": "Chat ID for scheduled results and notifications."},
|
||||
],
|
||||
},
|
||||
{
|
||||
"key": "bluebubbles",
|
||||
"label": "BlueBubbles (iMessage)",
|
||||
"emoji": "💬",
|
||||
"token_var": "BLUEBUBBLES_SERVER_URL",
|
||||
"setup_instructions": [
|
||||
"1. Install BlueBubbles on a Mac that will act as your iMessage server:",
|
||||
" https://bluebubbles.app/",
|
||||
"2. Complete the BlueBubbles setup wizard — sign in with your Apple ID",
|
||||
"3. In BlueBubbles Settings → API, note the Server URL and password",
|
||||
"4. The server URL is typically http://<your-mac-ip>:1234",
|
||||
"5. Hermes connects via the BlueBubbles REST API and receives",
|
||||
" incoming messages via a local webhook",
|
||||
"6. To authorize users, use DM pairing: hermes pairing generate bluebubbles",
|
||||
" Share the code — the user sends it via iMessage to get approved",
|
||||
],
|
||||
"vars": [
|
||||
{"name": "BLUEBUBBLES_SERVER_URL", "prompt": "BlueBubbles server URL (e.g. http://192.168.1.10:1234)", "password": False,
|
||||
"help": "The URL shown in BlueBubbles Settings → API."},
|
||||
{"name": "BLUEBUBBLES_PASSWORD", "prompt": "BlueBubbles server password", "password": True,
|
||||
"help": "The password shown in BlueBubbles Settings → API."},
|
||||
{"name": "BLUEBUBBLES_ALLOWED_USERS", "prompt": "Pre-authorized phone numbers or iMessage IDs (comma-separated, or leave empty for DM pairing)", "password": False,
|
||||
"is_allowlist": True,
|
||||
"help": "Optional — pre-authorize specific users. Leave empty to use DM pairing instead (recommended)."},
|
||||
{"name": "BLUEBUBBLES_HOME_CHANNEL", "prompt": "Home channel (phone number or iMessage ID for cron/notifications, or empty)", "password": False,
|
||||
"help": "Phone number or Apple ID to deliver cron results and notifications to."},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
|
||||
+2
-31
@@ -1474,11 +1474,7 @@ def _model_flow_custom(config):
|
||||
f"Hermes will still save it."
|
||||
)
|
||||
if probe.get("suggested_base_url"):
|
||||
suggested = probe["suggested_base_url"]
|
||||
if suggested.endswith("/v1"):
|
||||
print(f" If this server expects /v1 in the path, try base URL: {suggested}")
|
||||
else:
|
||||
print(f" If /v1 should not be in the base URL, try: {suggested}")
|
||||
print(f" If this server expects /v1, try base URL: {probe['suggested_base_url']}")
|
||||
|
||||
# Select model — use probe results when available, fall back to manual input
|
||||
model_name = ""
|
||||
@@ -1811,10 +1807,7 @@ def _set_reasoning_effort(config, effort: str) -> None:
|
||||
|
||||
def _prompt_reasoning_effort_selection(efforts, current_effort=""):
|
||||
"""Prompt for a reasoning effort. Returns effort, 'none', or None to keep current."""
|
||||
deduped = list(dict.fromkeys(str(effort).strip().lower() for effort in efforts if str(effort).strip()))
|
||||
canonical_order = ("minimal", "low", "medium", "high", "xhigh")
|
||||
ordered = [effort for effort in canonical_order if effort in deduped]
|
||||
ordered.extend(effort for effort in deduped if effort not in canonical_order)
|
||||
ordered = list(dict.fromkeys(str(effort).strip().lower() for effort in efforts if str(effort).strip()))
|
||||
if not ordered:
|
||||
return None
|
||||
|
||||
@@ -2646,12 +2639,6 @@ def cmd_doctor(args):
|
||||
run_doctor(args)
|
||||
|
||||
|
||||
def cmd_dump(args):
|
||||
"""Dump setup summary for support/debugging."""
|
||||
from hermes_cli.dump import run_dump
|
||||
run_dump(args)
|
||||
|
||||
|
||||
def cmd_config(args):
|
||||
"""Configuration management."""
|
||||
from hermes_cli.config import config_command
|
||||
@@ -4733,22 +4720,6 @@ For more help on a command:
|
||||
help="Attempt to fix issues automatically"
|
||||
)
|
||||
doctor_parser.set_defaults(func=cmd_doctor)
|
||||
|
||||
# =========================================================================
|
||||
# dump command
|
||||
# =========================================================================
|
||||
dump_parser = subparsers.add_parser(
|
||||
"dump",
|
||||
help="Dump setup summary for support/debugging",
|
||||
description="Output a compact, plain-text summary of your Hermes setup "
|
||||
"that can be copy-pasted into Discord/GitHub for support context"
|
||||
)
|
||||
dump_parser.add_argument(
|
||||
"--show-keys",
|
||||
action="store_true",
|
||||
help="Show redacted API key prefixes (first/last 4 chars) instead of just set/not set"
|
||||
)
|
||||
dump_parser.set_defaults(func=cmd_dump)
|
||||
|
||||
# =========================================================================
|
||||
# config command
|
||||
|
||||
@@ -332,3 +332,31 @@ def normalize_model_for_provider(model_input: str, target_provider: str) -> str:
|
||||
# Batch / convenience helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def model_display_name(model_id: str) -> str:
|
||||
"""Return a short, human-readable display name for a model id.
|
||||
|
||||
Strips the vendor prefix (if any) for a cleaner display in menus
|
||||
and status bars, while preserving dots for readability.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> model_display_name("anthropic/claude-sonnet-4.6")
|
||||
'claude-sonnet-4.6'
|
||||
>>> model_display_name("claude-sonnet-4-6")
|
||||
'claude-sonnet-4-6'
|
||||
"""
|
||||
return _strip_vendor_prefix((model_id or "").strip())
|
||||
|
||||
|
||||
def is_aggregator_provider(provider: str) -> bool:
|
||||
"""Check if a provider is an aggregator that needs vendor/model format."""
|
||||
return (provider or "").strip().lower() in _AGGREGATOR_PROVIDERS
|
||||
|
||||
|
||||
def vendor_for_model(model_name: str) -> str:
|
||||
"""Return the vendor slug for a model, or ``""`` if unknown.
|
||||
|
||||
Convenience wrapper around :func:`detect_vendor` that never returns
|
||||
``None``.
|
||||
"""
|
||||
return detect_vendor(model_name) or ""
|
||||
|
||||
+75
-15
@@ -537,11 +537,8 @@ def switch_model(
|
||||
)
|
||||
else:
|
||||
# --- Step c: On aggregator, convert vendor:model to vendor/model ---
|
||||
# Only convert when there's no slash — a slash means the name
|
||||
# is already in vendor/model format and the colon is a variant
|
||||
# tag (:free, :extended, :fast) that must be preserved.
|
||||
colon_pos = raw_input.find(":")
|
||||
if colon_pos > 0 and "/" not in raw_input and is_aggregator(current_provider):
|
||||
if colon_pos > 0 and is_aggregator(current_provider):
|
||||
left = raw_input[:colon_pos].strip().lower()
|
||||
right = raw_input[colon_pos + 1:].strip()
|
||||
if left and right:
|
||||
@@ -733,7 +730,6 @@ def list_authenticated_providers(
|
||||
fetch_models_dev,
|
||||
get_provider_info as _mdev_pinfo,
|
||||
)
|
||||
from hermes_cli.auth import PROVIDER_REGISTRY
|
||||
from hermes_cli.models import OPENROUTER_MODELS, _PROVIDER_MODELS
|
||||
|
||||
results: List[dict] = []
|
||||
@@ -754,16 +750,9 @@ def list_authenticated_providers(
|
||||
if not isinstance(pdata, dict):
|
||||
continue
|
||||
|
||||
# Prefer auth.py PROVIDER_REGISTRY for env var names — it's our
|
||||
# source of truth. models.dev can have wrong mappings (e.g.
|
||||
# minimax-cn → MINIMAX_API_KEY instead of MINIMAX_CN_API_KEY).
|
||||
pconfig = PROVIDER_REGISTRY.get(hermes_id)
|
||||
if pconfig and pconfig.api_key_env_vars:
|
||||
env_vars = list(pconfig.api_key_env_vars)
|
||||
else:
|
||||
env_vars = pdata.get("env", [])
|
||||
if not isinstance(env_vars, list):
|
||||
continue
|
||||
env_vars = pdata.get("env", [])
|
||||
if not isinstance(env_vars, list):
|
||||
continue
|
||||
|
||||
# Check if any env var is set
|
||||
has_creds = any(os.environ.get(ev) for ev in env_vars)
|
||||
@@ -859,3 +848,74 @@ def list_authenticated_providers(
|
||||
return results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fuzzy suggestions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def suggest_models(raw_input: str, limit: int = 3) -> List[str]:
|
||||
"""Return fuzzy model suggestions for a (possibly misspelled) input."""
|
||||
query = raw_input.strip()
|
||||
if not query:
|
||||
return []
|
||||
|
||||
results = search_models_dev(query, limit=limit)
|
||||
suggestions: list[str] = []
|
||||
for r in results:
|
||||
mid = r.get("model_id", "")
|
||||
if mid:
|
||||
suggestions.append(mid)
|
||||
|
||||
return suggestions[:limit]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Custom provider switch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def switch_to_custom_provider() -> CustomAutoResult:
|
||||
"""Handle bare '/model --provider custom' — resolve endpoint and auto-detect model."""
|
||||
from hermes_cli.runtime_provider import (
|
||||
resolve_runtime_provider,
|
||||
_auto_detect_local_model,
|
||||
)
|
||||
|
||||
try:
|
||||
runtime = resolve_runtime_provider(requested="custom")
|
||||
except Exception as e:
|
||||
return CustomAutoResult(
|
||||
success=False,
|
||||
error_message=f"Could not resolve custom endpoint: {e}",
|
||||
)
|
||||
|
||||
cust_base = runtime.get("base_url", "")
|
||||
cust_key = runtime.get("api_key", "")
|
||||
|
||||
if not cust_base or "openrouter.ai" in cust_base:
|
||||
return CustomAutoResult(
|
||||
success=False,
|
||||
error_message=(
|
||||
"No custom endpoint configured. "
|
||||
"Set model.base_url in config.yaml, or set OPENAI_BASE_URL "
|
||||
"in .env, or run: hermes setup -> Custom OpenAI-compatible endpoint"
|
||||
),
|
||||
)
|
||||
|
||||
detected_model = _auto_detect_local_model(cust_base)
|
||||
if not detected_model:
|
||||
return CustomAutoResult(
|
||||
success=False,
|
||||
base_url=cust_base,
|
||||
api_key=cust_key,
|
||||
error_message=(
|
||||
f"Custom endpoint at {cust_base} is reachable but no single "
|
||||
f"model was auto-detected. Specify the model explicitly: "
|
||||
f"/model <model-name> --provider custom"
|
||||
),
|
||||
)
|
||||
|
||||
return CustomAutoResult(
|
||||
success=True,
|
||||
model=detected_model,
|
||||
base_url=cust_base,
|
||||
api_key=cust_key,
|
||||
)
|
||||
|
||||
+44
-1
@@ -20,6 +20,10 @@ COPILOT_EDITOR_VERSION = "vscode/1.104.1"
|
||||
COPILOT_REASONING_EFFORTS_GPT5 = ["minimal", "low", "medium", "high"]
|
||||
COPILOT_REASONING_EFFORTS_O_SERIES = ["low", "medium", "high"]
|
||||
|
||||
# Backward-compatible aliases for the earlier GitHub Models-backed Copilot work.
|
||||
GITHUB_MODELS_BASE_URL = COPILOT_BASE_URL
|
||||
GITHUB_MODELS_CATALOG_URL = COPILOT_MODELS_URL
|
||||
|
||||
# (model_id, display description shown in menus)
|
||||
OPENROUTER_MODELS: list[tuple[str, str]] = [
|
||||
("anthropic/claude-opus-4.6", "recommended"),
|
||||
@@ -412,6 +416,12 @@ _FREE_TIER_CACHE_TTL: int = 180 # seconds (3 minutes)
|
||||
_free_tier_cache: tuple[bool, float] | None = None # (result, timestamp)
|
||||
|
||||
|
||||
def clear_nous_free_tier_cache() -> None:
|
||||
"""Invalidate the cached free-tier result (e.g. after login/logout)."""
|
||||
global _free_tier_cache
|
||||
_free_tier_cache = None
|
||||
|
||||
|
||||
def check_nous_free_tier() -> bool:
|
||||
"""Check if the current Nous Portal user is on a free (unpaid) tier.
|
||||
|
||||
@@ -525,6 +535,14 @@ def model_ids() -> list[str]:
|
||||
return [mid for mid, _ in OPENROUTER_MODELS]
|
||||
|
||||
|
||||
def menu_labels() -> list[str]:
|
||||
"""Return display labels like 'anthropic/claude-opus-4.6 (recommended)'."""
|
||||
labels = []
|
||||
for mid, desc in OPENROUTER_MODELS:
|
||||
labels.append(f"{mid} ({desc})" if desc else mid)
|
||||
return labels
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pricing helpers — fetch live pricing from OpenRouter-compatible /v1/models
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -557,6 +575,31 @@ def _format_price_per_mtok(per_token_str: str) -> str:
|
||||
return f"${per_m:.2f}"
|
||||
|
||||
|
||||
def format_pricing_label(pricing: dict[str, str] | None) -> str:
|
||||
"""Build a compact pricing label like 'in $3 · out $15 · cache $0.30/Mtok'.
|
||||
|
||||
Returns empty string when pricing is unavailable.
|
||||
"""
|
||||
if not pricing:
|
||||
return ""
|
||||
prompt_price = pricing.get("prompt", "")
|
||||
completion_price = pricing.get("completion", "")
|
||||
if not prompt_price and not completion_price:
|
||||
return ""
|
||||
inp = _format_price_per_mtok(prompt_price)
|
||||
out = _format_price_per_mtok(completion_price)
|
||||
if inp == "free" and out == "free":
|
||||
return "free"
|
||||
cache_read = pricing.get("input_cache_read", "")
|
||||
cache_str = _format_price_per_mtok(cache_read) if cache_read else ""
|
||||
if inp == out and not cache_str:
|
||||
return f"{inp}/Mtok"
|
||||
parts = [f"in {inp}", f"out {out}"]
|
||||
if cache_str and cache_str != "?" and cache_str != inp:
|
||||
parts.append(f"cache {cache_str}")
|
||||
return " · ".join(parts) + "/Mtok"
|
||||
|
||||
|
||||
def format_model_pricing_table(
|
||||
models: list[tuple[str, str]],
|
||||
pricing_map: dict[str, dict[str, str]],
|
||||
@@ -1489,7 +1532,7 @@ def probe_api_models(
|
||||
|
||||
return {
|
||||
"models": None,
|
||||
"probed_url": tried[0] if tried else normalized.rstrip("/") + "/models",
|
||||
"probed_url": tried[-1] if tried else normalized.rstrip("/") + "/models",
|
||||
"resolved_base_url": normalized,
|
||||
"suggested_base_url": alternate_base if alternate_base != normalized else None,
|
||||
"used_fallback": False,
|
||||
|
||||
@@ -102,7 +102,7 @@ _RESERVED_NAMES = frozenset({
|
||||
# Hermes subcommands that cannot be used as profile names/aliases
|
||||
_HERMES_SUBCOMMANDS = frozenset({
|
||||
"chat", "model", "gateway", "setup", "whatsapp", "login", "logout",
|
||||
"status", "cron", "doctor", "dump", "config", "pairing", "skills", "tools",
|
||||
"status", "cron", "doctor", "config", "pairing", "skills", "tools",
|
||||
"mcp", "sessions", "insights", "version", "update", "uninstall",
|
||||
"profile", "plugins", "honcho", "acp",
|
||||
})
|
||||
@@ -1007,7 +1007,7 @@ _hermes_completion() {
|
||||
|
||||
# Top-level subcommands
|
||||
if [[ "$COMP_CWORD" == 1 ]]; then
|
||||
local commands="chat model gateway setup status cron doctor dump config skills tools mcp sessions profile update version"
|
||||
local commands="chat model gateway setup status cron doctor config skills tools mcp sessions profile update version"
|
||||
COMPREPLY=($(compgen -W "$commands" -- "$cur"))
|
||||
fi
|
||||
}
|
||||
@@ -1032,7 +1032,7 @@ _hermes() {
|
||||
_arguments \\
|
||||
'-p[Profile name]:profile:($profiles)' \\
|
||||
'--profile[Profile name]:profile:($profiles)' \\
|
||||
'1:command:(chat model gateway setup status cron doctor dump config skills tools mcp sessions profile update version)' \\
|
||||
'1:command:(chat model gateway setup status cron doctor config skills tools mcp sessions profile update version)' \\
|
||||
'*::arg:->args'
|
||||
|
||||
case $words[1] in
|
||||
|
||||
@@ -148,6 +148,10 @@ class ProviderDef:
|
||||
doc: str = ""
|
||||
source: str = "" # "models.dev", "hermes", "user-config"
|
||||
|
||||
@property
|
||||
def is_user_defined(self) -> bool:
|
||||
return self.source == "user-config"
|
||||
|
||||
|
||||
# -- Aliases ------------------------------------------------------------------
|
||||
# Maps human-friendly / legacy names to canonical provider IDs.
|
||||
@@ -258,6 +262,12 @@ def normalize_provider(name: str) -> str:
|
||||
return ALIASES.get(key, key)
|
||||
|
||||
|
||||
def get_overlay(provider_id: str) -> Optional[HermesOverlay]:
|
||||
"""Get Hermes overlay for a provider, if one exists."""
|
||||
canonical = normalize_provider(provider_id)
|
||||
return HERMES_OVERLAYS.get(canonical)
|
||||
|
||||
|
||||
def get_provider(name: str) -> Optional[ProviderDef]:
|
||||
"""Look up a provider by id or alias, merging all data sources.
|
||||
|
||||
@@ -340,6 +350,37 @@ def get_label(provider_id: str) -> str:
|
||||
return canonical
|
||||
|
||||
|
||||
# For direct import compat, expose as module-level dict
|
||||
# Built on demand by get_label() calls
|
||||
LABELS: Dict[str, str] = {
|
||||
# Static entries for backward compat — get_label() is the proper API
|
||||
"openrouter": "OpenRouter",
|
||||
"nous": "Nous Portal",
|
||||
"openai-codex": "OpenAI Codex",
|
||||
"copilot-acp": "GitHub Copilot ACP",
|
||||
"github-copilot": "GitHub Copilot",
|
||||
"anthropic": "Anthropic",
|
||||
"zai": "Z.AI / GLM",
|
||||
"kimi-for-coding": "Kimi / Moonshot",
|
||||
"minimax": "MiniMax",
|
||||
"minimax-cn": "MiniMax (China)",
|
||||
"deepseek": "DeepSeek",
|
||||
"alibaba": "Alibaba Cloud (DashScope)",
|
||||
"vercel": "Vercel AI Gateway",
|
||||
"opencode": "OpenCode Zen",
|
||||
"opencode-go": "OpenCode Go",
|
||||
"kilo": "Kilo Gateway",
|
||||
"huggingface": "Hugging Face",
|
||||
"local": "Local endpoint",
|
||||
"custom": "Custom endpoint",
|
||||
# Legacy Hermes IDs (point to same providers)
|
||||
"ai-gateway": "Vercel AI Gateway",
|
||||
"kilocode": "Kilo Gateway",
|
||||
"copilot": "GitHub Copilot",
|
||||
"kimi-coding": "Kimi / Moonshot",
|
||||
"opencode-zen": "OpenCode Zen",
|
||||
}
|
||||
|
||||
|
||||
def is_aggregator(provider: str) -> bool:
|
||||
"""Return True when the provider is a multi-model aggregator."""
|
||||
|
||||
+164
-242
@@ -172,6 +172,147 @@ def _setup_copilot_reasoning_selection(
|
||||
_set_reasoning_effort(config, "none")
|
||||
|
||||
|
||||
def _setup_provider_model_selection(config, provider_id, current_model, prompt_choice, prompt_fn):
|
||||
"""Model selection for API-key providers with live /models detection.
|
||||
|
||||
Tries the provider's /models endpoint first. Falls back to a
|
||||
hardcoded default list with a warning if the endpoint is unreachable.
|
||||
Always offers a 'Custom model' escape hatch.
|
||||
"""
|
||||
from hermes_cli.auth import PROVIDER_REGISTRY, resolve_api_key_provider_credentials
|
||||
from hermes_cli.config import get_env_value
|
||||
from hermes_cli.models import (
|
||||
copilot_model_api_mode,
|
||||
fetch_api_models,
|
||||
fetch_github_model_catalog,
|
||||
normalize_copilot_model_id,
|
||||
normalize_opencode_model_id,
|
||||
opencode_model_api_mode,
|
||||
)
|
||||
|
||||
pconfig = PROVIDER_REGISTRY[provider_id]
|
||||
is_copilot_catalog_provider = provider_id in {"copilot", "copilot-acp"}
|
||||
|
||||
# Resolve API key and base URL for the probe
|
||||
if is_copilot_catalog_provider:
|
||||
api_key = ""
|
||||
if provider_id == "copilot":
|
||||
creds = resolve_api_key_provider_credentials(provider_id)
|
||||
api_key = creds.get("api_key", "")
|
||||
base_url = creds.get("base_url", "") or pconfig.inference_base_url
|
||||
else:
|
||||
try:
|
||||
creds = resolve_api_key_provider_credentials("copilot")
|
||||
api_key = creds.get("api_key", "")
|
||||
except Exception:
|
||||
pass
|
||||
base_url = pconfig.inference_base_url
|
||||
catalog = fetch_github_model_catalog(api_key)
|
||||
current_model = normalize_copilot_model_id(
|
||||
current_model,
|
||||
catalog=catalog,
|
||||
api_key=api_key,
|
||||
) or current_model
|
||||
else:
|
||||
api_key = ""
|
||||
for ev in pconfig.api_key_env_vars:
|
||||
api_key = get_env_value(ev) or os.getenv(ev, "")
|
||||
if api_key:
|
||||
break
|
||||
base_url_env = pconfig.base_url_env_var or ""
|
||||
base_url = (get_env_value(base_url_env) if base_url_env else "") or pconfig.inference_base_url
|
||||
catalog = None
|
||||
|
||||
# Try live /models endpoint
|
||||
if is_copilot_catalog_provider and catalog:
|
||||
live_models = [item.get("id", "") for item in catalog if item.get("id")]
|
||||
else:
|
||||
live_models = fetch_api_models(api_key, base_url)
|
||||
|
||||
if live_models:
|
||||
provider_models = live_models
|
||||
print_info(f"Found {len(live_models)} model(s) from {pconfig.name} API")
|
||||
else:
|
||||
fallback_provider_id = "copilot" if provider_id == "copilot-acp" else provider_id
|
||||
provider_models = _DEFAULT_PROVIDER_MODELS.get(fallback_provider_id, [])
|
||||
if provider_models:
|
||||
print_warning(
|
||||
f"Could not auto-detect models from {pconfig.name} API — showing defaults.\n"
|
||||
f" Use \"Custom model\" if the model you expect isn't listed."
|
||||
)
|
||||
|
||||
if provider_id in {"opencode-zen", "opencode-go"}:
|
||||
provider_models = [normalize_opencode_model_id(provider_id, mid) for mid in provider_models]
|
||||
current_model = normalize_opencode_model_id(provider_id, current_model)
|
||||
provider_models = list(dict.fromkeys(mid for mid in provider_models if mid))
|
||||
|
||||
model_choices = list(provider_models)
|
||||
model_choices.append("Custom model")
|
||||
model_choices.append(f"Keep current ({current_model})")
|
||||
|
||||
keep_idx = len(model_choices) - 1
|
||||
model_idx = prompt_choice("Select default model:", model_choices, keep_idx)
|
||||
|
||||
selected_model = current_model
|
||||
|
||||
if model_idx < len(provider_models):
|
||||
selected_model = provider_models[model_idx]
|
||||
if is_copilot_catalog_provider:
|
||||
selected_model = normalize_copilot_model_id(
|
||||
selected_model,
|
||||
catalog=catalog,
|
||||
api_key=api_key,
|
||||
) or selected_model
|
||||
elif provider_id in {"opencode-zen", "opencode-go"}:
|
||||
selected_model = normalize_opencode_model_id(provider_id, selected_model)
|
||||
_set_default_model(config, selected_model)
|
||||
elif model_idx == len(provider_models):
|
||||
custom = prompt_fn("Enter model name")
|
||||
if custom:
|
||||
if is_copilot_catalog_provider:
|
||||
selected_model = normalize_copilot_model_id(
|
||||
custom,
|
||||
catalog=catalog,
|
||||
api_key=api_key,
|
||||
) or custom
|
||||
elif provider_id in {"opencode-zen", "opencode-go"}:
|
||||
selected_model = normalize_opencode_model_id(provider_id, custom)
|
||||
else:
|
||||
selected_model = custom
|
||||
_set_default_model(config, selected_model)
|
||||
else:
|
||||
# "Keep current" selected — validate it's compatible with the new
|
||||
# provider. OpenRouter-formatted names (containing "/") won't work
|
||||
# on direct-API providers and would silently break the gateway.
|
||||
if "/" in (current_model or "") and provider_models:
|
||||
print_warning(
|
||||
f"Current model \"{current_model}\" looks like an OpenRouter model "
|
||||
f"and won't work with {pconfig.name}. "
|
||||
f"Switching to {provider_models[0]}."
|
||||
)
|
||||
selected_model = provider_models[0]
|
||||
_set_default_model(config, provider_models[0])
|
||||
|
||||
if provider_id == "copilot" and selected_model:
|
||||
model_cfg = _model_config_dict(config)
|
||||
model_cfg["api_mode"] = copilot_model_api_mode(
|
||||
selected_model,
|
||||
catalog=catalog,
|
||||
api_key=api_key,
|
||||
)
|
||||
config["model"] = model_cfg
|
||||
_setup_copilot_reasoning_selection(
|
||||
config,
|
||||
selected_model,
|
||||
prompt_choice,
|
||||
catalog=catalog,
|
||||
api_key=api_key,
|
||||
)
|
||||
elif provider_id in {"opencode-zen", "opencode-go"} and selected_model:
|
||||
model_cfg = _model_config_dict(config)
|
||||
model_cfg["api_mode"] = opencode_model_api_mode(provider_id, selected_model)
|
||||
config["model"] = model_cfg
|
||||
|
||||
|
||||
# Import config helpers
|
||||
from hermes_cli.config import (
|
||||
@@ -2026,71 +2167,6 @@ def _setup_whatsapp():
|
||||
print_info("or personal self-chat) and pair via QR code.")
|
||||
|
||||
|
||||
def _setup_bluebubbles():
|
||||
"""Configure BlueBubbles iMessage gateway."""
|
||||
print_header("BlueBubbles (iMessage)")
|
||||
existing = get_env_value("BLUEBUBBLES_SERVER_URL")
|
||||
if existing:
|
||||
print_info("BlueBubbles: already configured")
|
||||
if not prompt_yes_no("Reconfigure BlueBubbles?", False):
|
||||
return
|
||||
|
||||
print_info("Connects Hermes to iMessage via BlueBubbles — a free, open-source")
|
||||
print_info("macOS server that bridges iMessage to any device.")
|
||||
print_info(" Requires a Mac running BlueBubbles Server v1.0.0+")
|
||||
print_info(" Download: https://bluebubbles.app/")
|
||||
print()
|
||||
print_info("In BlueBubbles Server → Settings → API, note your Server URL and Password.")
|
||||
print()
|
||||
|
||||
server_url = prompt("BlueBubbles server URL (e.g. http://192.168.1.10:1234)")
|
||||
if not server_url:
|
||||
print_warning("Server URL is required — skipping BlueBubbles setup")
|
||||
return
|
||||
save_env_value("BLUEBUBBLES_SERVER_URL", server_url.rstrip("/"))
|
||||
|
||||
password = prompt("BlueBubbles server password", password=True)
|
||||
if not password:
|
||||
print_warning("Password is required — skipping BlueBubbles setup")
|
||||
return
|
||||
save_env_value("BLUEBUBBLES_PASSWORD", password)
|
||||
print_success("BlueBubbles credentials saved")
|
||||
|
||||
print()
|
||||
print_info("🔒 Security: Restrict who can message your bot")
|
||||
print_info(" Use iMessage addresses: email (user@icloud.com) or phone (+15551234567)")
|
||||
print()
|
||||
allowed_users = prompt("Allowed iMessage addresses (comma-separated, leave empty for open access)")
|
||||
if allowed_users:
|
||||
save_env_value("BLUEBUBBLES_ALLOWED_USERS", allowed_users.replace(" ", ""))
|
||||
print_success("BlueBubbles allowlist configured")
|
||||
else:
|
||||
print_info("⚠️ No allowlist set — anyone who can iMessage you can use the bot!")
|
||||
|
||||
print()
|
||||
print_info("📬 Home Channel: phone or email for cron job delivery and notifications.")
|
||||
print_info(" You can also set this later with /set-home in your iMessage chat.")
|
||||
home_channel = prompt("Home channel address (leave empty to set later)")
|
||||
if home_channel:
|
||||
save_env_value("BLUEBUBBLES_HOME_CHANNEL", home_channel)
|
||||
|
||||
print()
|
||||
print_info("Advanced settings (defaults are fine for most setups):")
|
||||
if prompt_yes_no("Configure webhook listener settings?", False):
|
||||
webhook_port = prompt("Webhook listener port (default: 8645)")
|
||||
if webhook_port:
|
||||
try:
|
||||
save_env_value("BLUEBUBBLES_WEBHOOK_PORT", str(int(webhook_port)))
|
||||
print_success(f"Webhook port set to {webhook_port}")
|
||||
except ValueError:
|
||||
print_warning("Invalid port number, using default 8645")
|
||||
|
||||
print()
|
||||
print_info("Requires the BlueBubbles Private API helper for typing indicators,")
|
||||
print_info("read receipts, and tapback reactions. Basic messaging works without it.")
|
||||
print_info(" Install: https://docs.bluebubbles.app/helper-bundle/installation")
|
||||
|
||||
|
||||
def _setup_webhooks():
|
||||
"""Configure webhook integration."""
|
||||
print_header("Webhooks")
|
||||
@@ -2145,7 +2221,6 @@ _GATEWAY_PLATFORMS = [
|
||||
("Matrix", "MATRIX_ACCESS_TOKEN", _setup_matrix),
|
||||
("Mattermost", "MATTERMOST_TOKEN", _setup_mattermost),
|
||||
("WhatsApp", "WHATSAPP_ENABLED", _setup_whatsapp),
|
||||
("BlueBubbles (iMessage)", "BLUEBUBBLES_SERVER_URL", _setup_bluebubbles),
|
||||
("Webhooks (GitHub, GitLab, etc.)", "WEBHOOK_ENABLED", _setup_webhooks),
|
||||
]
|
||||
|
||||
@@ -2189,7 +2264,6 @@ def setup_gateway(config: dict):
|
||||
or get_env_value("MATRIX_ACCESS_TOKEN")
|
||||
or get_env_value("MATRIX_PASSWORD")
|
||||
or get_env_value("WHATSAPP_ENABLED")
|
||||
or get_env_value("BLUEBUBBLES_SERVER_URL")
|
||||
or get_env_value("WEBHOOK_ENABLED")
|
||||
)
|
||||
if any_messaging:
|
||||
@@ -2209,8 +2283,6 @@ def setup_gateway(config: dict):
|
||||
missing_home.append("Discord")
|
||||
if get_env_value("SLACK_BOT_TOKEN") and not get_env_value("SLACK_HOME_CHANNEL"):
|
||||
missing_home.append("Slack")
|
||||
if get_env_value("BLUEBUBBLES_SERVER_URL") and not get_env_value("BLUEBUBBLES_HOME_CHANNEL"):
|
||||
missing_home.append("BlueBubbles")
|
||||
|
||||
if missing_home:
|
||||
print()
|
||||
@@ -2381,8 +2453,6 @@ def _get_section_config_summary(config: dict, section_key: str) -> Optional[str]
|
||||
platforms.append("WhatsApp")
|
||||
if get_env_value("SIGNAL_ACCOUNT"):
|
||||
platforms.append("Signal")
|
||||
if get_env_value("BLUEBUBBLES_SERVER_URL"):
|
||||
platforms.append("BlueBubbles")
|
||||
if platforms:
|
||||
return ", ".join(platforms)
|
||||
return None # No platforms configured — section must run
|
||||
@@ -2431,120 +2501,9 @@ _OPENCLAW_SCRIPT = (
|
||||
)
|
||||
|
||||
|
||||
def _load_openclaw_migration_module():
|
||||
"""Load the openclaw_to_hermes migration script as a module.
|
||||
|
||||
Returns the loaded module, or None if the script can't be loaded.
|
||||
"""
|
||||
if not _OPENCLAW_SCRIPT.exists():
|
||||
return None
|
||||
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"openclaw_to_hermes", _OPENCLAW_SCRIPT
|
||||
)
|
||||
if spec is None or spec.loader is None:
|
||||
return None
|
||||
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
# Register in sys.modules so @dataclass can resolve the module
|
||||
# (Python 3.11+ requires this for dynamically loaded modules)
|
||||
import sys as _sys
|
||||
_sys.modules[spec.name] = mod
|
||||
try:
|
||||
spec.loader.exec_module(mod)
|
||||
except Exception:
|
||||
_sys.modules.pop(spec.name, None)
|
||||
raise
|
||||
return mod
|
||||
|
||||
|
||||
# Item kinds that represent high-impact changes warranting explicit warnings.
|
||||
# Gateway tokens/channels can hijack messaging platforms from the old agent.
|
||||
# Config values may have different semantics between OpenClaw and Hermes.
|
||||
# Instruction/context files (.md) can contain incompatible setup procedures.
|
||||
_HIGH_IMPACT_KIND_KEYWORDS = {
|
||||
"gateway": "⚠ Gateway/messaging — this will configure Hermes to use your OpenClaw messaging channels",
|
||||
"telegram": "⚠ Telegram — this will point Hermes at your OpenClaw Telegram bot",
|
||||
"slack": "⚠ Slack — this will point Hermes at your OpenClaw Slack workspace",
|
||||
"discord": "⚠ Discord — this will point Hermes at your OpenClaw Discord bot",
|
||||
"whatsapp": "⚠ WhatsApp — this will point Hermes at your OpenClaw WhatsApp connection",
|
||||
"config": "⚠ Config values — OpenClaw settings may not map 1:1 to Hermes equivalents",
|
||||
"soul": "⚠ Instruction file — may contain OpenClaw-specific setup/restart procedures",
|
||||
"memory": "⚠ Memory/context file — may reference OpenClaw-specific infrastructure",
|
||||
"context": "⚠ Context file — may contain OpenClaw-specific instructions",
|
||||
}
|
||||
|
||||
|
||||
def _print_migration_preview(report: dict):
|
||||
"""Print a detailed dry-run preview of what migration would do.
|
||||
|
||||
Groups items by category and adds explicit warnings for high-impact
|
||||
changes like gateway token takeover and config value differences.
|
||||
"""
|
||||
items = report.get("items", [])
|
||||
if not items:
|
||||
print_info("Nothing to migrate.")
|
||||
return
|
||||
|
||||
migrated_items = [i for i in items if i.get("status") == "migrated"]
|
||||
conflict_items = [i for i in items if i.get("status") == "conflict"]
|
||||
skipped_items = [i for i in items if i.get("status") == "skipped"]
|
||||
|
||||
warnings_shown = set()
|
||||
|
||||
if migrated_items:
|
||||
print(color(" Would import:", Colors.GREEN))
|
||||
for item in migrated_items:
|
||||
kind = item.get("kind", "unknown")
|
||||
dest = item.get("destination", "")
|
||||
if dest:
|
||||
dest_short = str(dest).replace(str(Path.home()), "~")
|
||||
print(f" {kind:<22s} → {dest_short}")
|
||||
else:
|
||||
print(f" {kind}")
|
||||
|
||||
# Check for high-impact items and collect warnings
|
||||
kind_lower = kind.lower()
|
||||
dest_lower = str(dest).lower()
|
||||
for keyword, warning in _HIGH_IMPACT_KIND_KEYWORDS.items():
|
||||
if keyword in kind_lower or keyword in dest_lower:
|
||||
warnings_shown.add(warning)
|
||||
print()
|
||||
|
||||
if conflict_items:
|
||||
print(color(" Would overwrite (conflicts with existing Hermes config):", Colors.YELLOW))
|
||||
for item in conflict_items:
|
||||
kind = item.get("kind", "unknown")
|
||||
reason = item.get("reason", "already exists")
|
||||
print(f" {kind:<22s} {reason}")
|
||||
print()
|
||||
|
||||
if skipped_items:
|
||||
print(color(" Would skip:", Colors.DIM))
|
||||
for item in skipped_items:
|
||||
kind = item.get("kind", "unknown")
|
||||
reason = item.get("reason", "")
|
||||
print(f" {kind:<22s} {reason}")
|
||||
print()
|
||||
|
||||
# Print collected warnings
|
||||
if warnings_shown:
|
||||
print(color(" ── Warnings ──", Colors.YELLOW))
|
||||
for warning in sorted(warnings_shown):
|
||||
print(color(f" {warning}", Colors.YELLOW))
|
||||
print()
|
||||
print(color(" Note: OpenClaw config values may have different semantics in Hermes.", Colors.YELLOW))
|
||||
print(color(" For example, OpenClaw's tool_call_execution: \"auto\" ≠ Hermes's yolo mode.", Colors.YELLOW))
|
||||
print(color(" Instruction files (.md) from OpenClaw may contain incompatible procedures.", Colors.YELLOW))
|
||||
print()
|
||||
|
||||
|
||||
def _offer_openclaw_migration(hermes_home: Path) -> bool:
|
||||
"""Detect ~/.openclaw and offer to migrate during first-time setup.
|
||||
|
||||
Runs a dry-run first to show the user exactly what would be imported,
|
||||
overwritten, or taken over. Only executes after explicit confirmation.
|
||||
|
||||
Returns True if migration ran successfully, False otherwise.
|
||||
"""
|
||||
openclaw_dir = Path.home() / ".openclaw"
|
||||
@@ -2557,12 +2516,12 @@ def _offer_openclaw_migration(hermes_home: Path) -> bool:
|
||||
print()
|
||||
print_header("OpenClaw Installation Detected")
|
||||
print_info(f"Found OpenClaw data at {openclaw_dir}")
|
||||
print_info("Hermes can preview what would be imported before making any changes.")
|
||||
print_info("Hermes can import your settings, memories, skills, and API keys.")
|
||||
print()
|
||||
|
||||
if not prompt_yes_no("Would you like to see what can be imported?", default=True):
|
||||
if not prompt_yes_no("Would you like to import from OpenClaw?", default=True):
|
||||
print_info(
|
||||
"Skipping migration. You can run it later with: hermes claw migrate --dry-run"
|
||||
"Skipping migration. You can run it later via the openclaw-migration skill."
|
||||
)
|
||||
return False
|
||||
|
||||
@@ -2571,71 +2530,34 @@ def _offer_openclaw_migration(hermes_home: Path) -> bool:
|
||||
if not config_path.exists():
|
||||
save_config(load_config())
|
||||
|
||||
# Load the migration module
|
||||
# Dynamically load the migration script
|
||||
try:
|
||||
mod = _load_openclaw_migration_module()
|
||||
if mod is None:
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"openclaw_to_hermes", _OPENCLAW_SCRIPT
|
||||
)
|
||||
if spec is None or spec.loader is None:
|
||||
print_warning("Could not load migration script.")
|
||||
return False
|
||||
except Exception as e:
|
||||
print_warning(f"Could not load migration script: {e}")
|
||||
logger.debug("OpenClaw migration module load error", exc_info=True)
|
||||
return False
|
||||
|
||||
# ── Phase 1: Dry-run preview ──
|
||||
try:
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
# Register in sys.modules so @dataclass can resolve the module
|
||||
# (Python 3.11+ requires this for dynamically loaded modules)
|
||||
import sys as _sys
|
||||
_sys.modules[spec.name] = mod
|
||||
try:
|
||||
spec.loader.exec_module(mod)
|
||||
except Exception:
|
||||
_sys.modules.pop(spec.name, None)
|
||||
raise
|
||||
|
||||
# Run migration with the "full" preset, execute mode, no overwrite
|
||||
selected = mod.resolve_selected_options(None, None, preset="full")
|
||||
dry_migrator = mod.Migrator(
|
||||
source_root=openclaw_dir.resolve(),
|
||||
target_root=hermes_home.resolve(),
|
||||
execute=False, # dry-run — no files modified
|
||||
workspace_target=None,
|
||||
overwrite=True, # show everything including conflicts
|
||||
migrate_secrets=True,
|
||||
output_dir=None,
|
||||
selected_options=selected,
|
||||
preset_name="full",
|
||||
)
|
||||
preview_report = dry_migrator.migrate()
|
||||
except Exception as e:
|
||||
print_warning(f"Migration preview failed: {e}")
|
||||
logger.debug("OpenClaw migration preview error", exc_info=True)
|
||||
return False
|
||||
|
||||
# Display the full preview
|
||||
preview_summary = preview_report.get("summary", {})
|
||||
preview_count = preview_summary.get("migrated", 0)
|
||||
|
||||
if preview_count == 0:
|
||||
print()
|
||||
print_info("Nothing to import from OpenClaw.")
|
||||
return False
|
||||
|
||||
print()
|
||||
print_header(f"Migration Preview — {preview_count} item(s) would be imported")
|
||||
print_info("No changes have been made yet. Review the list below:")
|
||||
print()
|
||||
_print_migration_preview(preview_report)
|
||||
|
||||
# ── Phase 2: Confirm and execute ──
|
||||
if not prompt_yes_no("Proceed with migration?", default=False):
|
||||
print_info(
|
||||
"Migration cancelled. You can run it later with: hermes claw migrate"
|
||||
)
|
||||
print_info(
|
||||
"Use --dry-run to preview again, or --preset minimal for a lighter import."
|
||||
)
|
||||
return False
|
||||
|
||||
# Execute the migration — overwrite=False so existing Hermes configs are
|
||||
# preserved. The user saw the preview; conflicts are skipped by default.
|
||||
try:
|
||||
migrator = mod.Migrator(
|
||||
source_root=openclaw_dir.resolve(),
|
||||
target_root=hermes_home.resolve(),
|
||||
execute=True,
|
||||
workspace_target=None,
|
||||
overwrite=False, # preserve existing Hermes config
|
||||
overwrite=True,
|
||||
migrate_secrets=True,
|
||||
output_dir=None,
|
||||
selected_options=selected,
|
||||
@@ -2647,7 +2569,7 @@ def _offer_openclaw_migration(hermes_home: Path) -> bool:
|
||||
logger.debug("OpenClaw migration error", exc_info=True)
|
||||
return False
|
||||
|
||||
# Print final summary
|
||||
# Print summary
|
||||
summary = report.get("summary", {})
|
||||
migrated = summary.get("migrated", 0)
|
||||
skipped = summary.get("skipped", 0)
|
||||
@@ -2658,7 +2580,7 @@ def _offer_openclaw_migration(hermes_home: Path) -> bool:
|
||||
if migrated:
|
||||
print_success(f"Imported {migrated} item(s) from OpenClaw.")
|
||||
if conflicts:
|
||||
print_info(f"Skipped {conflicts} item(s) that already exist in Hermes (use hermes claw migrate --overwrite to force).")
|
||||
print_info(f"Skipped {conflicts} item(s) that already exist in Hermes.")
|
||||
if skipped:
|
||||
print_info(f"Skipped {skipped} item(s) (not found or unchanged).")
|
||||
if errors:
|
||||
|
||||
@@ -23,7 +23,6 @@ PLATFORMS = {
|
||||
"slack": "💼 Slack",
|
||||
"whatsapp": "📱 WhatsApp",
|
||||
"signal": "📡 Signal",
|
||||
"bluebubbles": "💬 BlueBubbles",
|
||||
"email": "📧 Email",
|
||||
"homeassistant": "🏠 Home Assistant",
|
||||
"mattermost": "💬 Mattermost",
|
||||
|
||||
@@ -302,7 +302,6 @@ def show_status(args):
|
||||
"DingTalk": ("DINGTALK_CLIENT_ID", None),
|
||||
"Feishu": ("FEISHU_APP_ID", "FEISHU_HOME_CHANNEL"),
|
||||
"WeCom": ("WECOM_BOT_ID", "WECOM_HOME_CHANNEL"),
|
||||
"BlueBubbles": ("BLUEBUBBLES_SERVER_URL", "BLUEBUBBLES_HOME_CHANNEL"),
|
||||
}
|
||||
|
||||
for name, (token_var, home_var) in platforms.items():
|
||||
|
||||
@@ -126,7 +126,6 @@ PLATFORMS = {
|
||||
"slack": {"label": "💼 Slack", "default_toolset": "hermes-slack"},
|
||||
"whatsapp": {"label": "📱 WhatsApp", "default_toolset": "hermes-whatsapp"},
|
||||
"signal": {"label": "📡 Signal", "default_toolset": "hermes-signal"},
|
||||
"bluebubbles": {"label": "💙 BlueBubbles", "default_toolset": "hermes-bluebubbles"},
|
||||
"homeassistant": {"label": "🏠 Home Assistant", "default_toolset": "hermes-homeassistant"},
|
||||
"email": {"label": "📧 Email", "default_toolset": "hermes-email"},
|
||||
"matrix": {"label": "💬 Matrix", "default_toolset": "hermes-matrix"},
|
||||
|
||||
+6
-2
@@ -72,13 +72,13 @@ def display_hermes_home() -> str:
|
||||
return str(home)
|
||||
|
||||
|
||||
VALID_REASONING_EFFORTS = ("minimal", "low", "medium", "high", "xhigh")
|
||||
VALID_REASONING_EFFORTS = ("xhigh", "high", "medium", "low", "minimal")
|
||||
|
||||
|
||||
def parse_reasoning_effort(effort: str) -> dict | None:
|
||||
"""Parse a reasoning effort level into a config dict.
|
||||
|
||||
Valid levels: "none", "minimal", "low", "medium", "high", "xhigh".
|
||||
Valid levels: "xhigh", "high", "medium", "low", "minimal", "none".
|
||||
Returns None when the input is empty or unrecognized (caller uses default).
|
||||
Returns {"enabled": False} for "none".
|
||||
Returns {"enabled": True, "effort": <level>} for valid effort levels.
|
||||
@@ -95,7 +95,11 @@ def parse_reasoning_effort(effort: str) -> dict | None:
|
||||
|
||||
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
|
||||
OPENROUTER_MODELS_URL = f"{OPENROUTER_BASE_URL}/models"
|
||||
OPENROUTER_CHAT_URL = f"{OPENROUTER_BASE_URL}/chat/completions"
|
||||
|
||||
AI_GATEWAY_BASE_URL = "https://ai-gateway.vercel.sh/v1"
|
||||
AI_GATEWAY_MODELS_URL = f"{AI_GATEWAY_BASE_URL}/models"
|
||||
AI_GATEWAY_CHAT_URL = f"{AI_GATEWAY_BASE_URL}/chat/completions"
|
||||
|
||||
NOUS_API_BASE_URL = "https://inference-api.nousresearch.com/v1"
|
||||
NOUS_API_CHAT_URL = f"{NOUS_API_BASE_URL}/chat/completions"
|
||||
|
||||
+1
-34
@@ -13,7 +13,6 @@ secrets are never written to disk.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
@@ -178,38 +177,6 @@ def setup_verbose_logging() -> None:
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _ManagedRotatingFileHandler(RotatingFileHandler):
|
||||
"""RotatingFileHandler that ensures group-writable perms in managed mode.
|
||||
|
||||
In managed mode (NixOS), the stateDir uses setgid (2770) so new files
|
||||
inherit the hermes group. However, both _open() (initial creation) and
|
||||
doRollover() create files via open(), which uses the process umask —
|
||||
typically 0022, producing 0644. This subclass applies chmod 0660 after
|
||||
both operations so the gateway and interactive users can share log files.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
from hermes_cli.config import is_managed
|
||||
self._managed = is_managed()
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _chmod_if_managed(self):
|
||||
if self._managed:
|
||||
try:
|
||||
os.chmod(self.baseFilename, 0o660)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _open(self):
|
||||
stream = super()._open()
|
||||
self._chmod_if_managed()
|
||||
return stream
|
||||
|
||||
def doRollover(self):
|
||||
super().doRollover()
|
||||
self._chmod_if_managed()
|
||||
|
||||
|
||||
def _add_rotating_handler(
|
||||
logger: logging.Logger,
|
||||
path: Path,
|
||||
@@ -231,7 +198,7 @@ def _add_rotating_handler(
|
||||
return # already attached
|
||||
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
handler = _ManagedRotatingFileHandler(
|
||||
handler = RotatingFileHandler(
|
||||
str(path), maxBytes=max_bytes, backupCount=backup_count,
|
||||
)
|
||||
handler.setLevel(level)
|
||||
|
||||
+95
-29
@@ -520,6 +520,72 @@ class SessionDB:
|
||||
)
|
||||
self._execute_write(_do)
|
||||
|
||||
def set_token_counts(
|
||||
self,
|
||||
session_id: str,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
model: str = None,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_write_tokens: int = 0,
|
||||
reasoning_tokens: int = 0,
|
||||
estimated_cost_usd: Optional[float] = None,
|
||||
actual_cost_usd: Optional[float] = None,
|
||||
cost_status: Optional[str] = None,
|
||||
cost_source: Optional[str] = None,
|
||||
pricing_version: Optional[str] = None,
|
||||
billing_provider: Optional[str] = None,
|
||||
billing_base_url: Optional[str] = None,
|
||||
billing_mode: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Set token counters to absolute values (not increment).
|
||||
|
||||
Use this when the caller provides cumulative totals from a completed
|
||||
conversation run (e.g. the gateway, where the cached agent's
|
||||
session_prompt_tokens already reflects the running total).
|
||||
"""
|
||||
def _do(conn):
|
||||
conn.execute(
|
||||
"""UPDATE sessions SET
|
||||
input_tokens = ?,
|
||||
output_tokens = ?,
|
||||
cache_read_tokens = ?,
|
||||
cache_write_tokens = ?,
|
||||
reasoning_tokens = ?,
|
||||
estimated_cost_usd = ?,
|
||||
actual_cost_usd = CASE
|
||||
WHEN ? IS NULL THEN actual_cost_usd
|
||||
ELSE ?
|
||||
END,
|
||||
cost_status = COALESCE(?, cost_status),
|
||||
cost_source = COALESCE(?, cost_source),
|
||||
pricing_version = COALESCE(?, pricing_version),
|
||||
billing_provider = COALESCE(billing_provider, ?),
|
||||
billing_base_url = COALESCE(billing_base_url, ?),
|
||||
billing_mode = COALESCE(billing_mode, ?),
|
||||
model = COALESCE(model, ?)
|
||||
WHERE id = ?""",
|
||||
(
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens,
|
||||
reasoning_tokens,
|
||||
estimated_cost_usd,
|
||||
actual_cost_usd,
|
||||
actual_cost_usd,
|
||||
cost_status,
|
||||
cost_source,
|
||||
pricing_version,
|
||||
billing_provider,
|
||||
billing_base_url,
|
||||
billing_mode,
|
||||
model,
|
||||
session_id,
|
||||
),
|
||||
)
|
||||
self._execute_write(_do)
|
||||
|
||||
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a session by ID."""
|
||||
with self._lock:
|
||||
@@ -878,8 +944,7 @@ class SessionDB:
|
||||
try:
|
||||
msg["tool_calls"] = json.loads(msg["tool_calls"])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning("Failed to deserialize tool_calls in get_messages, falling back to []")
|
||||
msg["tool_calls"] = []
|
||||
pass
|
||||
result.append(msg)
|
||||
return result
|
||||
|
||||
@@ -907,8 +972,7 @@ class SessionDB:
|
||||
try:
|
||||
msg["tool_calls"] = json.loads(row["tool_calls"])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning("Failed to deserialize tool_calls in conversation replay, falling back to []")
|
||||
msg["tool_calls"] = []
|
||||
pass
|
||||
# Restore reasoning fields on assistant messages so providers
|
||||
# that replay reasoning (OpenRouter, OpenAI, Nous) receive
|
||||
# coherent multi-turn reasoning context.
|
||||
@@ -919,14 +983,12 @@ class SessionDB:
|
||||
try:
|
||||
msg["reasoning_details"] = json.loads(row["reasoning_details"])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning("Failed to deserialize reasoning_details, falling back to None")
|
||||
msg["reasoning_details"] = None
|
||||
pass
|
||||
if row["codex_reasoning_items"]:
|
||||
try:
|
||||
msg["codex_reasoning_items"] = json.loads(row["codex_reasoning_items"])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning("Failed to deserialize codex_reasoning_items, falling back to None")
|
||||
msg["codex_reasoning_items"] = None
|
||||
pass
|
||||
messages.append(msg)
|
||||
return messages
|
||||
|
||||
@@ -1173,10 +1235,10 @@ class SessionDB:
|
||||
self._execute_write(_do)
|
||||
|
||||
def delete_session(self, session_id: str) -> bool:
|
||||
"""Delete a session and all its messages.
|
||||
"""Delete a session, its child sessions, and all their messages.
|
||||
|
||||
Child sessions are orphaned (parent_session_id set to NULL) rather
|
||||
than cascade-deleted, so they remain accessible independently.
|
||||
Child sessions (subagent runs, compression continuations) are deleted
|
||||
first to satisfy the ``parent_session_id`` foreign key constraint.
|
||||
Returns True if the session was found and deleted.
|
||||
"""
|
||||
def _do(conn):
|
||||
@@ -1185,12 +1247,15 @@ class SessionDB:
|
||||
)
|
||||
if cursor.fetchone()[0] == 0:
|
||||
return False
|
||||
# Orphan child sessions so FK constraint is satisfied
|
||||
conn.execute(
|
||||
"UPDATE sessions SET parent_session_id = NULL "
|
||||
"WHERE parent_session_id = ?",
|
||||
# Delete child sessions first (FK constraint)
|
||||
child_ids = [r[0] for r in conn.execute(
|
||||
"SELECT id FROM sessions WHERE parent_session_id = ?",
|
||||
(session_id,),
|
||||
)
|
||||
).fetchall()]
|
||||
for cid in child_ids:
|
||||
conn.execute("DELETE FROM messages WHERE session_id = ?", (cid,))
|
||||
conn.execute("DELETE FROM sessions WHERE id = ?", (cid,))
|
||||
# Delete the session itself
|
||||
conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
|
||||
conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,))
|
||||
return True
|
||||
@@ -1199,9 +1264,9 @@ class SessionDB:
|
||||
def prune_sessions(self, older_than_days: int = 90, source: str = None) -> int:
|
||||
"""Delete sessions older than N days. Returns count of deleted sessions.
|
||||
|
||||
Only prunes ended sessions (not active ones). Child sessions outside
|
||||
the prune window are orphaned (parent_session_id set to NULL) rather
|
||||
than cascade-deleted.
|
||||
Only prunes ended sessions (not active ones). Child sessions whose
|
||||
parents are being pruned are deleted first to satisfy the
|
||||
``parent_session_id`` foreign key constraint.
|
||||
"""
|
||||
cutoff = time.time() - (older_than_days * 86400)
|
||||
|
||||
@@ -1219,16 +1284,17 @@ class SessionDB:
|
||||
)
|
||||
session_ids = set(row["id"] for row in cursor.fetchall())
|
||||
|
||||
if not session_ids:
|
||||
return 0
|
||||
|
||||
# Orphan any sessions whose parent is about to be deleted
|
||||
placeholders = ",".join("?" * len(session_ids))
|
||||
conn.execute(
|
||||
f"UPDATE sessions SET parent_session_id = NULL "
|
||||
f"WHERE parent_session_id IN ({placeholders})",
|
||||
list(session_ids),
|
||||
)
|
||||
# Delete children first whose parents are in the prune set
|
||||
# (avoids FK constraint errors)
|
||||
for sid in list(session_ids):
|
||||
child_ids = [r[0] for r in conn.execute(
|
||||
"SELECT id FROM sessions WHERE parent_session_id = ?",
|
||||
(sid,),
|
||||
).fetchall()]
|
||||
for cid in child_ids:
|
||||
conn.execute("DELETE FROM messages WHERE session_id = ?", (cid,))
|
||||
conn.execute("DELETE FROM sessions WHERE id = ?", (cid,))
|
||||
session_ids.discard(cid) # don't double-delete
|
||||
|
||||
for sid in session_ids:
|
||||
conn.execute("DELETE FROM messages WHERE session_id = ?", (sid,))
|
||||
|
||||
@@ -89,6 +89,13 @@ def get_timezone() -> Optional[ZoneInfo]:
|
||||
return _cached_tz
|
||||
|
||||
|
||||
def get_timezone_name() -> str:
|
||||
"""Return the IANA name of the configured timezone, or empty string."""
|
||||
if not _cache_resolved:
|
||||
get_timezone() # populates cache
|
||||
return _cached_tz_name or ""
|
||||
|
||||
|
||||
def now() -> datetime:
|
||||
"""
|
||||
Return the current time as a timezone-aware datetime.
|
||||
@@ -103,3 +110,9 @@ def now() -> datetime:
|
||||
return datetime.now().astimezone()
|
||||
|
||||
|
||||
def reset_cache() -> None:
|
||||
"""Clear the cached timezone. Used by tests and after config changes."""
|
||||
global _cached_tz, _cached_tz_name, _cache_resolved
|
||||
_cached_tz = None
|
||||
_cached_tz_name = None
|
||||
_cache_resolved = False
|
||||
|
||||
+5
-27
@@ -560,40 +560,22 @@
|
||||
# ── Directories ───────────────────────────────────────────────────
|
||||
{
|
||||
systemd.tmpfiles.rules = [
|
||||
"d ${cfg.stateDir} 2770 ${cfg.user} ${cfg.group} - -"
|
||||
"d ${cfg.stateDir}/.hermes 2770 ${cfg.user} ${cfg.group} - -"
|
||||
"d ${cfg.stateDir}/.hermes/cron 2770 ${cfg.user} ${cfg.group} - -"
|
||||
"d ${cfg.stateDir}/.hermes/sessions 2770 ${cfg.user} ${cfg.group} - -"
|
||||
"d ${cfg.stateDir}/.hermes/logs 2770 ${cfg.user} ${cfg.group} - -"
|
||||
"d ${cfg.stateDir}/.hermes/memories 2770 ${cfg.user} ${cfg.group} - -"
|
||||
"d ${cfg.stateDir} 0750 ${cfg.user} ${cfg.group} - -"
|
||||
"d ${cfg.stateDir}/.hermes 0750 ${cfg.user} ${cfg.group} - -"
|
||||
"d ${cfg.stateDir}/home 0750 ${cfg.user} ${cfg.group} - -"
|
||||
"d ${cfg.workingDirectory} 2770 ${cfg.user} ${cfg.group} - -"
|
||||
"d ${cfg.workingDirectory} 0750 ${cfg.user} ${cfg.group} - -"
|
||||
];
|
||||
}
|
||||
|
||||
# ── Activation: link config + auth + documents ────────────────────
|
||||
{
|
||||
system.activationScripts."hermes-agent-setup" = lib.stringAfter ([ "users" ] ++ lib.optional (config.system.activationScripts ? setupSecrets) "setupSecrets") ''
|
||||
system.activationScripts."hermes-agent-setup" = lib.stringAfter [ "users" "setupSecrets" ] ''
|
||||
# Ensure directories exist (activation runs before tmpfiles)
|
||||
mkdir -p ${cfg.stateDir}/.hermes
|
||||
mkdir -p ${cfg.stateDir}/home
|
||||
mkdir -p ${cfg.workingDirectory}
|
||||
chown ${cfg.user}:${cfg.group} ${cfg.stateDir} ${cfg.stateDir}/.hermes ${cfg.stateDir}/home ${cfg.workingDirectory}
|
||||
chmod 2770 ${cfg.stateDir} ${cfg.stateDir}/.hermes ${cfg.workingDirectory}
|
||||
chmod 0750 ${cfg.stateDir}/home
|
||||
|
||||
# Create subdirs, set setgid + group-writable, migrate existing files.
|
||||
# Nix-managed files (config.yaml, .env, .managed) stay 0640/0644.
|
||||
find ${cfg.stateDir}/.hermes -maxdepth 1 \
|
||||
\( -name "*.db" -o -name "*.db-wal" -o -name "*.db-shm" -o -name "SOUL.md" \) \
|
||||
-exec chmod g+rw {} + 2>/dev/null || true
|
||||
for _subdir in cron sessions logs memories; do
|
||||
mkdir -p "${cfg.stateDir}/.hermes/$_subdir"
|
||||
chown ${cfg.user}:${cfg.group} "${cfg.stateDir}/.hermes/$_subdir"
|
||||
chmod 2770 "${cfg.stateDir}/.hermes/$_subdir"
|
||||
find "${cfg.stateDir}/.hermes/$_subdir" -type f \
|
||||
-exec chmod g+rw {} + 2>/dev/null || true
|
||||
done
|
||||
chmod 0750 ${cfg.stateDir} ${cfg.stateDir}/.hermes ${cfg.stateDir}/home ${cfg.workingDirectory}
|
||||
|
||||
# Merge Nix settings into existing config.yaml.
|
||||
# Preserves user-added keys (skills, streaming, etc.); Nix keys win.
|
||||
@@ -680,10 +662,6 @@ HERMES_NIX_ENV_EOF
|
||||
Restart = cfg.restart;
|
||||
RestartSec = cfg.restartSec;
|
||||
|
||||
# Shared-state: files created by the gateway should be group-writable
|
||||
# so interactive users in the hermes group can read/write them.
|
||||
UMask = "0007";
|
||||
|
||||
# Hardening
|
||||
NoNewPrivileges = true;
|
||||
ProtectSystem = "strict";
|
||||
|
||||
+1
-1
@@ -14,7 +14,7 @@
|
||||
};
|
||||
|
||||
runtimeDeps = with pkgs; [
|
||||
nodejs_20 ripgrep git openssh ffmpeg tirith
|
||||
nodejs_20 ripgrep git openssh ffmpeg
|
||||
];
|
||||
|
||||
runtimePath = pkgs.lib.makeBinPath runtimeDeps;
|
||||
|
||||
@@ -1803,34 +1803,30 @@ class Migrator:
|
||||
def migrate_cron_jobs(self, config: Optional[Dict[str, Any]] = None) -> None:
|
||||
config = config or self.load_openclaw_config()
|
||||
cron = config.get("cron") or {}
|
||||
if not cron:
|
||||
self.record("cron-jobs", None, None, "skipped", "No cron configuration found")
|
||||
return
|
||||
|
||||
# Archive the full cron config
|
||||
if self.archive_dir and self.execute:
|
||||
self.archive_dir.mkdir(parents=True, exist_ok=True)
|
||||
dest = self.archive_dir / "cron-config.json"
|
||||
dest.write_text(json.dumps(cron, indent=2, ensure_ascii=False) + "\n", encoding="utf-8")
|
||||
self.record("cron-jobs", "openclaw.json cron.*", str(dest), "archived",
|
||||
"Cron config archived. Use 'hermes cron' to recreate jobs manually.")
|
||||
else:
|
||||
self.record("cron-jobs", "openclaw.json cron.*", "archive/cron-config.json",
|
||||
"archived", "Would archive cron config")
|
||||
|
||||
# Also check for cron store files
|
||||
cron_store = self.source_root / "cron"
|
||||
found_any = False
|
||||
|
||||
# Archive the full cron config when present
|
||||
if cron:
|
||||
found_any = True
|
||||
if self.archive_dir and self.execute:
|
||||
self.archive_dir.mkdir(parents=True, exist_ok=True)
|
||||
dest = self.archive_dir / "cron-config.json"
|
||||
dest.write_text(json.dumps(cron, indent=2, ensure_ascii=False) + "\n", encoding="utf-8")
|
||||
self.record("cron-jobs", "openclaw.json cron.*", str(dest), "archived",
|
||||
"Cron config archived. Use 'hermes cron' to recreate jobs manually.")
|
||||
else:
|
||||
self.record("cron-jobs", "openclaw.json cron.*", "archive/cron-config.json",
|
||||
"archived", "Would archive cron config")
|
||||
|
||||
# Also check for cron store files even when config.cron is missing
|
||||
if cron_store.is_dir() and self.archive_dir:
|
||||
found_any = True
|
||||
dest_cron = self.archive_dir / "cron-store"
|
||||
if self.execute:
|
||||
shutil.copytree(cron_store, dest_cron, dirs_exist_ok=True)
|
||||
self.record("cron-jobs", str(cron_store), str(dest_cron), "archived",
|
||||
"Cron job store archived")
|
||||
|
||||
if not found_any:
|
||||
self.record("cron-jobs", None, None, "skipped", "No cron configuration found")
|
||||
|
||||
# ── Hooks ─────────────────────────────────────────────────
|
||||
def migrate_hooks_config(self, config: Optional[Dict[str, Any]] = None) -> None:
|
||||
config = config or self.load_openclaw_config()
|
||||
@@ -2458,15 +2454,6 @@ class Migrator:
|
||||
notes.append(f"- **{item.kind}**: {item.reason}")
|
||||
notes.append("")
|
||||
|
||||
has_cron_config_archive = any(
|
||||
i.kind == "cron-jobs" and i.status == "archived" and i.destination and i.destination.endswith("cron-config.json")
|
||||
for i in self.items
|
||||
)
|
||||
has_cron_store_archive = any(
|
||||
i.kind == "cron-jobs" and i.status == "archived" and i.destination and i.destination.endswith("cron-store")
|
||||
for i in self.items
|
||||
)
|
||||
|
||||
notes.extend([
|
||||
"## IMPORTANT: Archive the OpenClaw Directory",
|
||||
"",
|
||||
@@ -2488,14 +2475,7 @@ class Migrator:
|
||||
"- Run `hermes claw cleanup` to archive the OpenClaw directory (prevents state confusion)",
|
||||
"- Run `hermes setup` to configure any remaining settings",
|
||||
"- Run `hermes mcp list` to verify MCP servers were imported correctly",
|
||||
])
|
||||
|
||||
if has_cron_config_archive:
|
||||
notes.append("- Run `hermes cron` to recreate scheduled tasks (see archive/cron-config.json)")
|
||||
elif has_cron_store_archive:
|
||||
notes.append("- Run `hermes cron` to recreate scheduled tasks (see archived cron-store)")
|
||||
|
||||
notes.extend([
|
||||
"- Run `hermes cron` to recreate scheduled tasks (see archive/cron-config.json)",
|
||||
"- Run `hermes gateway install` if you need the gateway service",
|
||||
"- Review `~/.hermes/config.yaml` for any adjustments",
|
||||
"",
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
# Hindsight Memory Provider
|
||||
|
||||
Long-term memory with knowledge graph, entity resolution, and multi-strategy retrieval. Supports cloud, local embedded, and local external modes.
|
||||
Long-term memory with knowledge graph, entity resolution, and multi-strategy retrieval. Supports cloud and local (embedded) modes.
|
||||
|
||||
## Requirements
|
||||
|
||||
- **Cloud:** API key from [ui.hindsight.vectorize.io](https://ui.hindsight.vectorize.io)
|
||||
- **Local Embedded:** API key for a supported LLM provider (OpenAI, Anthropic, Gemini, Groq, OpenRouter, MiniMax, Ollama, or any OpenAI-compatible endpoint). Embeddings and reranking run locally — no additional API keys needed.
|
||||
- **Local External:** A running Hindsight instance (Docker or self-hosted) reachable over HTTP.
|
||||
- **Local:** API key for a supported LLM provider (OpenAI, Anthropic, Gemini, Groq, MiniMax, or Ollama). Embeddings and reranking run locally — no additional API keys needed.
|
||||
|
||||
## Setup
|
||||
|
||||
@@ -22,28 +21,17 @@ hermes config set memory.provider hindsight
|
||||
echo "HINDSIGHT_API_KEY=your-key" >> ~/.hermes/.env
|
||||
```
|
||||
|
||||
### Cloud
|
||||
### Cloud Mode
|
||||
|
||||
Connects to the Hindsight Cloud API. Requires an API key from [ui.hindsight.vectorize.io](https://ui.hindsight.vectorize.io).
|
||||
|
||||
### Local Embedded
|
||||
### Local Mode
|
||||
|
||||
Hermes spins up a local Hindsight daemon with built-in PostgreSQL. Requires an LLM API key for memory extraction and synthesis. The daemon starts automatically in the background on first use and stops after 5 minutes of inactivity.
|
||||
|
||||
Supports any OpenAI-compatible LLM endpoint (llama.cpp, vLLM, LM Studio, etc.) — pick `openai_compatible` as the provider and enter the base URL.
|
||||
Runs an embedded Hindsight server with built-in PostgreSQL. Requires an LLM API key (e.g. Groq, OpenAI, Anthropic) for memory extraction and synthesis. The daemon starts automatically in the background on first use and stops after 5 minutes of inactivity.
|
||||
|
||||
Daemon startup logs: `~/.hermes/logs/hindsight-embed.log`
|
||||
Daemon runtime logs: `~/.hindsight/profiles/<profile>.log`
|
||||
|
||||
To open the Hindsight web UI (local embedded mode only):
|
||||
```bash
|
||||
hindsight-embed -p hermes ui start
|
||||
```
|
||||
|
||||
### Local External
|
||||
|
||||
Points the plugin at an existing Hindsight instance you're already running (Docker, self-hosted, etc.). No daemon management — just a URL and an optional API key.
|
||||
|
||||
## Config
|
||||
|
||||
Config file: `~/.hermes/hindsight/config.json`
|
||||
@@ -52,58 +40,40 @@ Config file: `~/.hermes/hindsight/config.json`
|
||||
|
||||
| Key | Default | Description |
|
||||
|-----|---------|-------------|
|
||||
| `mode` | `cloud` | `cloud`, `local_embedded`, or `local_external` |
|
||||
| `api_url` | `https://api.hindsight.vectorize.io` | API URL (cloud and local_external modes) |
|
||||
| `mode` | `cloud` | `cloud` or `local` |
|
||||
| `api_url` | `https://api.hindsight.vectorize.io` | API URL (cloud mode) |
|
||||
| `api_url` | `http://localhost:8888` | API URL (local mode, unused — daemon manages its own port) |
|
||||
|
||||
### Memory Bank
|
||||
### Memory
|
||||
|
||||
| Key | Default | Description |
|
||||
|-----|---------|-------------|
|
||||
| `bank_id` | `hermes` | Memory bank name |
|
||||
| `bank_mission` | — | Reflect mission (identity/framing for reflect reasoning). Applied via Banks API. |
|
||||
| `bank_retain_mission` | — | Retain mission (steers what gets extracted). Applied via Banks API. |
|
||||
|
||||
### Recall
|
||||
|
||||
| Key | Default | Description |
|
||||
|-----|---------|-------------|
|
||||
| `recall_budget` | `mid` | Recall thoroughness: `low` / `mid` / `high` |
|
||||
| `recall_prefetch_method` | `recall` | Auto-recall method: `recall` (raw facts) or `reflect` (LLM synthesis) |
|
||||
| `recall_max_tokens` | `4096` | Maximum tokens for recall results |
|
||||
| `recall_max_input_chars` | `800` | Maximum input query length for auto-recall |
|
||||
| `recall_prompt_preamble` | — | Custom preamble for recalled memories in context |
|
||||
| `recall_tags` | — | Tags to filter when searching memories |
|
||||
| `recall_tags_match` | `any` | Tag matching mode: `any` / `all` / `any_strict` / `all_strict` |
|
||||
| `auto_recall` | `true` | Automatically recall memories before each turn |
|
||||
|
||||
### Retain
|
||||
|
||||
| Key | Default | Description |
|
||||
|-----|---------|-------------|
|
||||
| `auto_retain` | `true` | Automatically retain conversation turns |
|
||||
| `retain_async` | `true` | Process retain asynchronously on the Hindsight server |
|
||||
| `retain_every_n_turns` | `1` | Retain every N turns (1 = every turn) |
|
||||
| `retain_context` | `conversation between Hermes Agent and the User` | Context label for retained memories |
|
||||
| `tags` | — | Tags applied when storing memories |
|
||||
| `budget` | `mid` | Recall thoroughness: `low` / `mid` / `high` |
|
||||
|
||||
### Integration
|
||||
|
||||
| Key | Default | Description |
|
||||
|-----|---------|-------------|
|
||||
| `memory_mode` | `hybrid` | How memories are integrated into the agent |
|
||||
| `prefetch_method` | `recall` | Method for automatic context injection |
|
||||
|
||||
**memory_mode:**
|
||||
- `hybrid` — automatic context injection + tools available to the LLM
|
||||
- `context` — automatic injection only, no tools exposed
|
||||
- `tools` — tools only, no automatic injection
|
||||
|
||||
### Local Embedded LLM
|
||||
**prefetch_method:**
|
||||
- `recall` — injects raw memory facts (fast)
|
||||
- `reflect` — injects LLM-synthesized summary (slower, more coherent)
|
||||
|
||||
### Local Mode LLM
|
||||
|
||||
| Key | Default | Description |
|
||||
|-----|---------|-------------|
|
||||
| `llm_provider` | `openai` | `openai`, `anthropic`, `gemini`, `groq`, `openrouter`, `minimax`, `ollama`, `lmstudio`, `openai_compatible` |
|
||||
| `llm_model` | per-provider | Model name (e.g. `gpt-4o-mini`, `qwen/qwen3.5-9b`) |
|
||||
| `llm_base_url` | — | Endpoint URL for `openai_compatible` (e.g. `http://192.168.1.10:8080/v1`) |
|
||||
| `llm_provider` | `openai` | LLM provider: `openai`, `anthropic`, `gemini`, `groq`, `minimax`, `ollama` |
|
||||
| `llm_model` | per-provider | Model name (e.g. `gpt-4o-mini`, `openai/gpt-oss-120b`) |
|
||||
| `llm_base_url` | — | LLM Base URL override (e.g. `https://openrouter.ai/api/v1`) |
|
||||
|
||||
The LLM API key is stored in `~/.hermes/.env` as `HINDSIGHT_LLM_API_KEY`.
|
||||
|
||||
@@ -127,8 +97,4 @@ Available in `hybrid` and `tools` memory modes:
|
||||
| `HINDSIGHT_API_URL` | Override API endpoint |
|
||||
| `HINDSIGHT_BANK_ID` | Override bank name |
|
||||
| `HINDSIGHT_BUDGET` | Override recall budget |
|
||||
| `HINDSIGHT_MODE` | Override mode (`cloud`, `local_embedded`, `local_external`) |
|
||||
|
||||
## Client Version
|
||||
|
||||
Requires `hindsight-client >= 0.4.22`. The plugin auto-upgrades on session start if an older version is detected.
|
||||
| `HINDSIGHT_MODE` | Override mode (`cloud` / `local`) |
|
||||
|
||||
@@ -28,25 +28,21 @@ from hermes_constants import get_hermes_home
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from agent.memory_provider import MemoryProvider
|
||||
from hermes_constants import get_hermes_home
|
||||
from tools.registry import tool_error
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_API_URL = "https://api.hindsight.vectorize.io"
|
||||
_DEFAULT_LOCAL_URL = "http://localhost:8888"
|
||||
_MIN_CLIENT_VERSION = "0.4.22"
|
||||
_VALID_BUDGETS = {"low", "mid", "high"}
|
||||
_PROVIDER_DEFAULT_MODELS = {
|
||||
"openai": "gpt-4o-mini",
|
||||
"anthropic": "claude-haiku-4-5",
|
||||
"gemini": "gemini-2.5-flash",
|
||||
"groq": "openai/gpt-oss-120b",
|
||||
"openrouter": "qwen/qwen3.5-9b",
|
||||
"minimax": "MiniMax-M2.7",
|
||||
"ollama": "gemma3:12b",
|
||||
"lmstudio": "local-model",
|
||||
"openai_compatible": "your-model-name",
|
||||
}
|
||||
|
||||
|
||||
@@ -192,7 +188,6 @@ class HindsightMemoryProvider(MemoryProvider):
|
||||
self._bank_id = "hermes"
|
||||
self._budget = "mid"
|
||||
self._mode = "cloud"
|
||||
self._llm_base_url = ""
|
||||
self._memory_mode = "hybrid" # "context", "tools", or "hybrid"
|
||||
self._prefetch_method = "recall" # "recall" or "reflect"
|
||||
self._client = None
|
||||
@@ -200,31 +195,6 @@ class HindsightMemoryProvider(MemoryProvider):
|
||||
self._prefetch_lock = threading.Lock()
|
||||
self._prefetch_thread = None
|
||||
self._sync_thread = None
|
||||
self._session_id = ""
|
||||
|
||||
# Tags
|
||||
self._tags: list[str] | None = None
|
||||
self._recall_tags: list[str] | None = None
|
||||
self._recall_tags_match = "any"
|
||||
|
||||
# Retain controls
|
||||
self._auto_retain = True
|
||||
self._retain_every_n_turns = 1
|
||||
self._retain_context = "conversation between Hermes Agent and the User"
|
||||
self._turn_counter = 0
|
||||
self._session_turns: list[str] = [] # accumulates ALL turns for the session
|
||||
|
||||
# Recall controls
|
||||
self._auto_recall = True
|
||||
self._recall_max_tokens = 4096
|
||||
self._recall_types: list[str] | None = None
|
||||
self._recall_prompt_preamble = ""
|
||||
self._recall_max_input_chars = 800
|
||||
|
||||
# Bank
|
||||
self._bank_mission = ""
|
||||
self._bank_retain_mission: str | None = None
|
||||
self._retain_async = True
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -234,7 +204,7 @@ class HindsightMemoryProvider(MemoryProvider):
|
||||
try:
|
||||
cfg = _load_config()
|
||||
mode = cfg.get("mode", "cloud")
|
||||
if mode in ("local", "local_embedded", "local_external"):
|
||||
if mode == "local":
|
||||
return True
|
||||
has_key = bool(cfg.get("apiKey") or os.environ.get("HINDSIGHT_API_KEY", ""))
|
||||
has_url = bool(cfg.get("api_url") or os.environ.get("HINDSIGHT_API_URL", ""))
|
||||
@@ -258,306 +228,73 @@ class HindsightMemoryProvider(MemoryProvider):
|
||||
existing.update(values)
|
||||
config_path.write_text(json.dumps(existing, indent=2))
|
||||
|
||||
def post_setup(self, hermes_home: str, config: dict) -> None:
|
||||
"""Custom setup wizard — installs only the deps needed for the selected mode."""
|
||||
import getpass
|
||||
import subprocess
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from hermes_cli.config import save_config
|
||||
|
||||
from hermes_cli.memory_setup import _curses_select
|
||||
|
||||
print("\n Configuring Hindsight memory:\n")
|
||||
|
||||
# Step 1: Mode selection
|
||||
mode_items = [
|
||||
("Cloud", "Hindsight Cloud API (lightweight, just needs an API key)"),
|
||||
("Local Embedded", "Run Hindsight locally (downloads ~200MB, needs LLM key)"),
|
||||
("Local External", "Connect to an existing Hindsight instance"),
|
||||
]
|
||||
mode_idx = _curses_select(" Select mode", mode_items, default=0)
|
||||
mode = ["cloud", "local_embedded", "local_external"][mode_idx]
|
||||
|
||||
provider_config: dict = {"mode": mode}
|
||||
env_writes: dict = {}
|
||||
|
||||
# Step 2: Install/upgrade deps for selected mode
|
||||
_MIN_CLIENT_VERSION = "0.4.22"
|
||||
cloud_dep = f"hindsight-client>={_MIN_CLIENT_VERSION}"
|
||||
local_dep = "hindsight-all"
|
||||
if mode == "local_embedded":
|
||||
deps_to_install = [local_dep]
|
||||
elif mode == "local_external":
|
||||
deps_to_install = [cloud_dep]
|
||||
else:
|
||||
deps_to_install = [cloud_dep]
|
||||
|
||||
print(f"\n Checking dependencies...")
|
||||
uv_path = shutil.which("uv")
|
||||
if not uv_path:
|
||||
print(" ⚠ uv not found — install it: curl -LsSf https://astral.sh/uv/install.sh | sh")
|
||||
print(f" Then run manually: uv pip install --python {sys.executable} {' '.join(deps_to_install)}")
|
||||
else:
|
||||
try:
|
||||
subprocess.run(
|
||||
[uv_path, "pip", "install", "--python", sys.executable, "--quiet", "--upgrade"] + deps_to_install,
|
||||
check=True, timeout=120, capture_output=True,
|
||||
)
|
||||
print(f" ✓ Dependencies up to date")
|
||||
except Exception as e:
|
||||
print(f" ⚠ Install failed: {e}")
|
||||
print(f" Run manually: uv pip install --python {sys.executable} {' '.join(deps_to_install)}")
|
||||
|
||||
# Step 3: Mode-specific config
|
||||
if mode == "cloud":
|
||||
print(f"\n Get your API key at https://ui.hindsight.vectorize.io\n")
|
||||
existing_key = os.environ.get("HINDSIGHT_API_KEY", "")
|
||||
if existing_key:
|
||||
masked = f"...{existing_key[-4:]}" if len(existing_key) > 4 else "set"
|
||||
sys.stdout.write(f" API key (current: {masked}, blank to keep): ")
|
||||
sys.stdout.flush()
|
||||
api_key = getpass.getpass(prompt="") if sys.stdin.isatty() else sys.stdin.readline().strip()
|
||||
else:
|
||||
sys.stdout.write(" API key: ")
|
||||
sys.stdout.flush()
|
||||
api_key = getpass.getpass(prompt="") if sys.stdin.isatty() else sys.stdin.readline().strip()
|
||||
if api_key:
|
||||
env_writes["HINDSIGHT_API_KEY"] = api_key
|
||||
|
||||
val = input(f" API URL [{_DEFAULT_API_URL}]: ").strip()
|
||||
if val:
|
||||
provider_config["api_url"] = val
|
||||
|
||||
elif mode == "local_external":
|
||||
val = input(f" Hindsight API URL [{_DEFAULT_LOCAL_URL}]: ").strip()
|
||||
provider_config["api_url"] = val or _DEFAULT_LOCAL_URL
|
||||
|
||||
sys.stdout.write(" API key (optional, blank to skip): ")
|
||||
sys.stdout.flush()
|
||||
api_key = getpass.getpass(prompt="") if sys.stdin.isatty() else sys.stdin.readline().strip()
|
||||
if api_key:
|
||||
env_writes["HINDSIGHT_API_KEY"] = api_key
|
||||
|
||||
else: # local_embedded
|
||||
providers_list = list(_PROVIDER_DEFAULT_MODELS.keys())
|
||||
llm_items = [
|
||||
(p, f"default model: {_PROVIDER_DEFAULT_MODELS[p]}")
|
||||
for p in providers_list
|
||||
]
|
||||
llm_idx = _curses_select(" Select LLM provider", llm_items, default=0)
|
||||
llm_provider = providers_list[llm_idx]
|
||||
|
||||
provider_config["llm_provider"] = llm_provider
|
||||
|
||||
if llm_provider == "openai_compatible":
|
||||
val = input(" LLM endpoint URL (e.g. http://192.168.1.10:8080/v1): ").strip()
|
||||
if val:
|
||||
provider_config["llm_base_url"] = val
|
||||
elif llm_provider == "openrouter":
|
||||
provider_config["llm_base_url"] = "https://openrouter.ai/api/v1"
|
||||
|
||||
default_model = _PROVIDER_DEFAULT_MODELS.get(llm_provider, "gpt-4o-mini")
|
||||
val = input(f" LLM model [{default_model}]: ").strip()
|
||||
provider_config["llm_model"] = val or default_model
|
||||
|
||||
sys.stdout.write(" LLM API key: ")
|
||||
sys.stdout.flush()
|
||||
llm_key = getpass.getpass(prompt="") if sys.stdin.isatty() else sys.stdin.readline().strip()
|
||||
if llm_key:
|
||||
env_writes["HINDSIGHT_LLM_API_KEY"] = llm_key
|
||||
|
||||
# Step 4: Save everything
|
||||
provider_config["bank_id"] = "hermes"
|
||||
provider_config["recall_budget"] = "mid"
|
||||
bank_id = "hermes"
|
||||
config["memory"]["provider"] = "hindsight"
|
||||
save_config(config)
|
||||
|
||||
self.save_config(provider_config, hermes_home)
|
||||
|
||||
if env_writes:
|
||||
env_path = Path(hermes_home) / ".env"
|
||||
env_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
existing_lines = []
|
||||
if env_path.exists():
|
||||
existing_lines = env_path.read_text().splitlines()
|
||||
updated_keys = set()
|
||||
new_lines = []
|
||||
for line in existing_lines:
|
||||
key_match = line.split("=", 1)[0].strip() if "=" in line and not line.startswith("#") else None
|
||||
if key_match and key_match in env_writes:
|
||||
new_lines.append(f"{key_match}={env_writes[key_match]}")
|
||||
updated_keys.add(key_match)
|
||||
else:
|
||||
new_lines.append(line)
|
||||
for k, v in env_writes.items():
|
||||
if k not in updated_keys:
|
||||
new_lines.append(f"{k}={v}")
|
||||
env_path.write_text("\n".join(new_lines) + "\n")
|
||||
|
||||
print(f"\n ✓ Hindsight memory configured ({mode} mode)")
|
||||
if env_writes:
|
||||
print(f" API keys saved to .env")
|
||||
print(f"\n Start a new session to activate.\n")
|
||||
|
||||
def get_config_schema(self):
|
||||
return [
|
||||
{"key": "mode", "description": "Connection mode", "default": "cloud", "choices": ["cloud", "local_embedded", "local_external"]},
|
||||
# Cloud mode
|
||||
{"key": "api_url", "description": "Hindsight Cloud API URL", "default": _DEFAULT_API_URL, "when": {"mode": "cloud"}},
|
||||
{"key": "mode", "description": "Cloud API or local embedded mode", "default": "cloud", "choices": ["cloud", "local"]},
|
||||
{"key": "api_url", "description": "Hindsight API URL", "default": _DEFAULT_API_URL, "when": {"mode": "cloud"}},
|
||||
{"key": "api_key", "description": "Hindsight Cloud API key", "secret": True, "env_var": "HINDSIGHT_API_KEY", "url": "https://ui.hindsight.vectorize.io", "when": {"mode": "cloud"}},
|
||||
# Local external mode
|
||||
{"key": "api_url", "description": "Hindsight API URL", "default": _DEFAULT_LOCAL_URL, "when": {"mode": "local_external"}},
|
||||
{"key": "api_key", "description": "API key (optional)", "secret": True, "env_var": "HINDSIGHT_API_KEY", "when": {"mode": "local_external"}},
|
||||
# Local embedded mode
|
||||
{"key": "llm_provider", "description": "LLM provider", "default": "openai", "choices": ["openai", "anthropic", "gemini", "groq", "openrouter", "minimax", "ollama", "lmstudio", "openai_compatible"], "when": {"mode": "local_embedded"}},
|
||||
{"key": "llm_base_url", "description": "Endpoint URL (e.g. http://192.168.1.10:8080/v1)", "default": "", "when": {"mode": "local_embedded", "llm_provider": "openai_compatible"}},
|
||||
{"key": "llm_api_key", "description": "LLM API key (optional for openai_compatible)", "secret": True, "env_var": "HINDSIGHT_LLM_API_KEY", "when": {"mode": "local_embedded"}},
|
||||
{"key": "llm_model", "description": "LLM model", "default": "gpt-4o-mini", "default_from": {"field": "llm_provider", "map": _PROVIDER_DEFAULT_MODELS}, "when": {"mode": "local_embedded"}},
|
||||
{"key": "llm_provider", "description": "LLM provider for local mode", "default": "openai", "choices": ["openai", "anthropic", "gemini", "groq", "minimax", "ollama"], "when": {"mode": "local"}},
|
||||
{"key": "llm_api_key", "description": "LLM API key for local Hindsight", "secret": True, "env_var": "HINDSIGHT_LLM_API_KEY", "when": {"mode": "local"}},
|
||||
{"key": "llm_base_url", "description": "LLM Base URL (e.g. for OpenRouter)", "default": "", "env_var": "HINDSIGHT_API_LLM_BASE_URL", "when": {"mode": "local"}},
|
||||
{"key": "llm_model", "description": "LLM model for local mode", "default": "gpt-4o-mini", "default_from": {"field": "llm_provider", "map": _PROVIDER_DEFAULT_MODELS}, "when": {"mode": "local"}},
|
||||
{"key": "bank_id", "description": "Memory bank name", "default": "hermes"},
|
||||
{"key": "bank_mission", "description": "Mission/purpose description for the memory bank"},
|
||||
{"key": "bank_retain_mission", "description": "Custom extraction prompt for memory retention"},
|
||||
{"key": "recall_budget", "description": "Recall thoroughness", "default": "mid", "choices": ["low", "mid", "high"]},
|
||||
{"key": "budget", "description": "Recall thoroughness", "default": "mid", "choices": ["low", "mid", "high"]},
|
||||
{"key": "memory_mode", "description": "Memory integration mode", "default": "hybrid", "choices": ["hybrid", "context", "tools"]},
|
||||
{"key": "recall_prefetch_method", "description": "Auto-recall method", "default": "recall", "choices": ["recall", "reflect"]},
|
||||
{"key": "tags", "description": "Tags applied when storing memories (comma-separated)", "default": ""},
|
||||
{"key": "recall_tags", "description": "Tags to filter when searching memories (comma-separated)", "default": ""},
|
||||
{"key": "recall_tags_match", "description": "Tag matching mode for recall", "default": "any", "choices": ["any", "all", "any_strict", "all_strict"]},
|
||||
{"key": "auto_recall", "description": "Automatically recall memories before each turn", "default": True},
|
||||
{"key": "auto_retain", "description": "Automatically retain conversation turns", "default": True},
|
||||
{"key": "retain_every_n_turns", "description": "Retain every N turns (1 = every turn)", "default": 1},
|
||||
{"key": "retain_async","description": "Process retain asynchronously on the Hindsight server", "default": True},
|
||||
{"key": "retain_context", "description": "Context label for retained memories", "default": "conversation between Hermes Agent and the User"},
|
||||
{"key": "recall_max_tokens", "description": "Maximum tokens for recall results", "default": 4096},
|
||||
{"key": "recall_max_input_chars", "description": "Maximum input query length for auto-recall", "default": 800},
|
||||
{"key": "recall_prompt_preamble", "description": "Custom preamble for recalled memories in context"},
|
||||
{"key": "prefetch_method", "description": "Auto-recall method", "default": "recall", "choices": ["recall", "reflect"]},
|
||||
]
|
||||
|
||||
def _get_client(self):
|
||||
"""Return the cached Hindsight client (created once, reused)."""
|
||||
if self._client is None:
|
||||
if self._mode == "local_embedded":
|
||||
if self._mode == "local":
|
||||
from hindsight import HindsightEmbedded
|
||||
# Disable __del__ on the class to prevent "attached to a
|
||||
# different loop" errors during GC — we handle cleanup in
|
||||
# shutdown() instead.
|
||||
HindsightEmbedded.__del__ = lambda self: None
|
||||
llm_provider = self._config.get("llm_provider", "")
|
||||
if llm_provider in ("openai_compatible", "openrouter"):
|
||||
llm_provider = "openai"
|
||||
logger.debug("Creating HindsightEmbedded client (profile=%s, provider=%s)",
|
||||
self._config.get("profile", "hermes"), llm_provider)
|
||||
kwargs = dict(
|
||||
profile=self._config.get("profile", "hermes"),
|
||||
llm_provider=llm_provider,
|
||||
llm_api_key=self._config.get("llmApiKey") or self._config.get("llm_api_key") or os.environ.get("HINDSIGHT_LLM_API_KEY", ""),
|
||||
llm_provider=self._config.get("llm_provider", ""),
|
||||
llm_api_key=self._config.get("llm_api_key") or os.environ.get("HINDSIGHT_LLM_API_KEY", ""),
|
||||
llm_model=self._config.get("llm_model", ""),
|
||||
)
|
||||
if self._llm_base_url:
|
||||
kwargs["llm_base_url"] = self._llm_base_url
|
||||
base_url = self._config.get("llm_base_url") or os.environ.get("HINDSIGHT_API_LLM_BASE_URL", "")
|
||||
if base_url:
|
||||
kwargs["llm_base_url"] = base_url
|
||||
self._client = HindsightEmbedded(**kwargs)
|
||||
else:
|
||||
from hindsight_client import Hindsight
|
||||
kwargs = {"base_url": self._api_url, "timeout": 30.0}
|
||||
if self._api_key:
|
||||
kwargs["api_key"] = self._api_key
|
||||
logger.debug("Creating Hindsight cloud client (url=%s, has_key=%s)",
|
||||
self._api_url, bool(self._api_key))
|
||||
self._client = Hindsight(**kwargs)
|
||||
return self._client
|
||||
|
||||
def initialize(self, session_id: str, **kwargs) -> None:
|
||||
self._session_id = session_id
|
||||
|
||||
# Check client version and auto-upgrade if needed
|
||||
try:
|
||||
from importlib.metadata import version as pkg_version
|
||||
from packaging.version import Version
|
||||
installed = pkg_version("hindsight-client")
|
||||
if Version(installed) < Version(_MIN_CLIENT_VERSION):
|
||||
logger.warning("hindsight-client %s is outdated (need >=%s), attempting upgrade...",
|
||||
installed, _MIN_CLIENT_VERSION)
|
||||
import shutil, subprocess, sys
|
||||
uv_path = shutil.which("uv")
|
||||
if uv_path:
|
||||
try:
|
||||
subprocess.run(
|
||||
[uv_path, "pip", "install", "--python", sys.executable,
|
||||
"--quiet", "--upgrade", f"hindsight-client>={_MIN_CLIENT_VERSION}"],
|
||||
check=True, timeout=120, capture_output=True,
|
||||
)
|
||||
logger.info("hindsight-client upgraded to >=%s", _MIN_CLIENT_VERSION)
|
||||
except Exception as e:
|
||||
logger.warning("Auto-upgrade failed: %s. Run: uv pip install 'hindsight-client>=%s'",
|
||||
e, _MIN_CLIENT_VERSION)
|
||||
else:
|
||||
logger.warning("uv not found. Run: pip install 'hindsight-client>=%s'", _MIN_CLIENT_VERSION)
|
||||
except Exception:
|
||||
pass # packaging not available or other issue — proceed anyway
|
||||
|
||||
self._config = _load_config()
|
||||
self._mode = self._config.get("mode", "cloud")
|
||||
# "local" is a legacy alias for "local_embedded"
|
||||
if self._mode == "local":
|
||||
self._mode = "local_embedded"
|
||||
self._api_key = self._config.get("apiKey") or self._config.get("api_key") or os.environ.get("HINDSIGHT_API_KEY", "")
|
||||
default_url = _DEFAULT_LOCAL_URL if self._mode in ("local_embedded", "local_external") else _DEFAULT_API_URL
|
||||
self._api_key = self._config.get("apiKey") or os.environ.get("HINDSIGHT_API_KEY", "")
|
||||
default_url = _DEFAULT_LOCAL_URL if self._mode == "local" else _DEFAULT_API_URL
|
||||
self._api_url = self._config.get("api_url") or os.environ.get("HINDSIGHT_API_URL", default_url)
|
||||
self._llm_base_url = self._config.get("llm_base_url", "")
|
||||
|
||||
banks = self._config.get("banks", {}).get("hermes", {})
|
||||
self._bank_id = self._config.get("bank_id") or banks.get("bankId", "hermes")
|
||||
budget = self._config.get("recall_budget") or self._config.get("budget") or banks.get("budget", "mid")
|
||||
budget = self._config.get("budget") or banks.get("budget", "mid")
|
||||
self._budget = budget if budget in _VALID_BUDGETS else "mid"
|
||||
|
||||
memory_mode = self._config.get("memory_mode", "hybrid")
|
||||
self._memory_mode = memory_mode if memory_mode in ("context", "tools", "hybrid") else "hybrid"
|
||||
|
||||
prefetch_method = self._config.get("recall_prefetch_method", "recall")
|
||||
prefetch_method = self._config.get("prefetch_method", "recall")
|
||||
self._prefetch_method = prefetch_method if prefetch_method in ("recall", "reflect") else "recall"
|
||||
|
||||
# Bank options
|
||||
self._bank_mission = self._config.get("bank_mission", "")
|
||||
self._bank_retain_mission = self._config.get("bank_retain_mission") or None
|
||||
|
||||
# Tags
|
||||
self._tags = self._config.get("tags") or None
|
||||
self._recall_tags = self._config.get("recall_tags") or None
|
||||
self._recall_tags_match = self._config.get("recall_tags_match", "any")
|
||||
|
||||
# Retain controls
|
||||
self._auto_retain = self._config.get("auto_retain", True)
|
||||
self._retain_every_n_turns = max(1, int(self._config.get("retain_every_n_turns", 1)))
|
||||
self._retain_context = self._config.get("retain_context", "conversation between Hermes Agent and the User")
|
||||
|
||||
# Recall controls
|
||||
self._auto_recall = self._config.get("auto_recall", True)
|
||||
self._recall_max_tokens = int(self._config.get("recall_max_tokens", 4096))
|
||||
self._recall_types = self._config.get("recall_types") or None
|
||||
self._recall_prompt_preamble = self._config.get("recall_prompt_preamble", "")
|
||||
self._recall_max_input_chars = int(self._config.get("recall_max_input_chars", 800))
|
||||
self._retain_async = self._config.get("retain_async", True)
|
||||
|
||||
_client_version = "unknown"
|
||||
try:
|
||||
from importlib.metadata import version as pkg_version
|
||||
_client_version = pkg_version("hindsight-client")
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("Hindsight initialized: mode=%s, api_url=%s, bank=%s, budget=%s, memory_mode=%s, prefetch_method=%s, client=%s",
|
||||
self._mode, self._api_url, self._bank_id, self._budget, self._memory_mode, self._prefetch_method, _client_version)
|
||||
logger.debug("Hindsight config: auto_retain=%s, auto_recall=%s, retain_every_n=%d, "
|
||||
"retain_async=%s, retain_context=%s, "
|
||||
"recall_max_tokens=%d, recall_max_input_chars=%d, tags=%s, recall_tags=%s",
|
||||
self._auto_retain, self._auto_recall, self._retain_every_n_turns,
|
||||
self._retain_async, self._retain_context,
|
||||
self._recall_max_tokens, self._recall_max_input_chars,
|
||||
self._tags, self._recall_tags)
|
||||
logger.info("Hindsight initialized: mode=%s, api_url=%s, bank=%s, budget=%s, memory_mode=%s, prefetch_method=%s",
|
||||
self._mode, self._api_url, self._bank_id, self._budget, self._memory_mode, self._prefetch_method)
|
||||
|
||||
# For local mode, start the embedded daemon in the background so it
|
||||
# doesn't block the chat. Redirect stdout/stderr to a log file to
|
||||
# prevent rich startup output from spamming the terminal.
|
||||
if self._mode == "local_embedded":
|
||||
if self._mode == "local":
|
||||
def _start_daemon():
|
||||
import traceback
|
||||
log_dir = get_hermes_home() / "logs"
|
||||
@@ -583,8 +320,6 @@ class HindsightMemoryProvider(MemoryProvider):
|
||||
current_provider = self._config.get("llm_provider", "")
|
||||
current_model = self._config.get("llm_model", "")
|
||||
current_base_url = self._config.get("llm_base_url") or os.environ.get("HINDSIGHT_API_LLM_BASE_URL", "")
|
||||
# Map openai_compatible/openrouter → openai for the daemon (OpenAI wire format)
|
||||
daemon_provider = "openai" if current_provider in ("openai_compatible", "openrouter") else current_provider
|
||||
|
||||
# Read saved profile config
|
||||
saved = {}
|
||||
@@ -595,7 +330,7 @@ class HindsightMemoryProvider(MemoryProvider):
|
||||
saved[k.strip()] = v.strip()
|
||||
|
||||
config_changed = (
|
||||
saved.get("HINDSIGHT_API_LLM_PROVIDER") != daemon_provider or
|
||||
saved.get("HINDSIGHT_API_LLM_PROVIDER") != current_provider or
|
||||
saved.get("HINDSIGHT_API_LLM_MODEL") != current_model or
|
||||
saved.get("HINDSIGHT_API_LLM_API_KEY") != current_key or
|
||||
saved.get("HINDSIGHT_API_LLM_BASE_URL", "") != current_base_url
|
||||
@@ -605,7 +340,7 @@ class HindsightMemoryProvider(MemoryProvider):
|
||||
# Write updated profile .env
|
||||
profile_env.parent.mkdir(parents=True, exist_ok=True)
|
||||
env_lines = (
|
||||
f"HINDSIGHT_API_LLM_PROVIDER={daemon_provider}\n"
|
||||
f"HINDSIGHT_API_LLM_PROVIDER={current_provider}\n"
|
||||
f"HINDSIGHT_API_LLM_API_KEY={current_key}\n"
|
||||
f"HINDSIGHT_API_LLM_MODEL={current_model}\n"
|
||||
f"HINDSIGHT_API_LOG_LEVEL=info\n"
|
||||
@@ -653,118 +388,47 @@ class HindsightMemoryProvider(MemoryProvider):
|
||||
|
||||
def prefetch(self, query: str, *, session_id: str = "") -> str:
|
||||
if self._prefetch_thread and self._prefetch_thread.is_alive():
|
||||
logger.debug("Prefetch: waiting for background thread to complete")
|
||||
self._prefetch_thread.join(timeout=3.0)
|
||||
with self._prefetch_lock:
|
||||
result = self._prefetch_result
|
||||
self._prefetch_result = ""
|
||||
if not result:
|
||||
logger.debug("Prefetch: no results available")
|
||||
return ""
|
||||
logger.debug("Prefetch: returning %d chars of context", len(result))
|
||||
header = self._recall_prompt_preamble or (
|
||||
"# Hindsight Memory (persistent cross-session context)\n"
|
||||
"Use this to answer questions about the user and prior sessions. "
|
||||
"Do not call tools to look up information that is already present here."
|
||||
)
|
||||
return f"{header}\n\n{result}"
|
||||
return f"## Hindsight Memory\n{result}"
|
||||
|
||||
def queue_prefetch(self, query: str, *, session_id: str = "") -> None:
|
||||
if self._memory_mode == "tools":
|
||||
logger.debug("Prefetch: skipped (tools-only mode)")
|
||||
return
|
||||
if not self._auto_recall:
|
||||
logger.debug("Prefetch: skipped (auto_recall disabled)")
|
||||
return
|
||||
# Truncate query to max chars
|
||||
if self._recall_max_input_chars and len(query) > self._recall_max_input_chars:
|
||||
query = query[:self._recall_max_input_chars]
|
||||
|
||||
def _run():
|
||||
try:
|
||||
client = self._get_client()
|
||||
if self._prefetch_method == "reflect":
|
||||
logger.debug("Prefetch: calling reflect (bank=%s, query_len=%d)", self._bank_id, len(query))
|
||||
resp = _run_sync(client.areflect(bank_id=self._bank_id, query=query, budget=self._budget))
|
||||
text = resp.text or ""
|
||||
else:
|
||||
recall_kwargs: dict = {
|
||||
"bank_id": self._bank_id, "query": query,
|
||||
"budget": self._budget, "max_tokens": self._recall_max_tokens,
|
||||
}
|
||||
if self._recall_tags:
|
||||
recall_kwargs["tags"] = self._recall_tags
|
||||
recall_kwargs["tags_match"] = self._recall_tags_match
|
||||
if self._recall_types:
|
||||
recall_kwargs["types"] = self._recall_types
|
||||
logger.debug("Prefetch: calling recall (bank=%s, query_len=%d, budget=%s)",
|
||||
self._bank_id, len(query), self._budget)
|
||||
resp = _run_sync(client.arecall(**recall_kwargs))
|
||||
num_results = len(resp.results) if resp.results else 0
|
||||
logger.debug("Prefetch: recall returned %d results", num_results)
|
||||
text = "\n".join(f"- {r.text}" for r in resp.results if r.text) if resp.results else ""
|
||||
resp = _run_sync(client.arecall(bank_id=self._bank_id, query=query, budget=self._budget))
|
||||
text = "\n".join(r.text for r in resp.results if r.text) if resp.results else ""
|
||||
if text:
|
||||
with self._prefetch_lock:
|
||||
self._prefetch_result = text
|
||||
except Exception as e:
|
||||
logger.debug("Hindsight prefetch failed: %s", e, exc_info=True)
|
||||
logger.debug("Hindsight prefetch failed: %s", e)
|
||||
|
||||
self._prefetch_thread = threading.Thread(target=_run, daemon=True, name="hindsight-prefetch")
|
||||
self._prefetch_thread.start()
|
||||
|
||||
def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None:
|
||||
"""Retain conversation turn in background (non-blocking).
|
||||
|
||||
Respects retain_every_n_turns for batching.
|
||||
"""
|
||||
if not self._auto_retain:
|
||||
logger.debug("sync_turn: skipped (auto_retain disabled)")
|
||||
return
|
||||
|
||||
from datetime import datetime, timezone
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": user_content, "timestamp": now},
|
||||
{"role": "assistant", "content": assistant_content, "timestamp": now},
|
||||
]
|
||||
|
||||
turn = json.dumps(messages)
|
||||
self._session_turns.append(turn)
|
||||
self._turn_counter += 1
|
||||
|
||||
# Only retain every N turns
|
||||
if self._turn_counter % self._retain_every_n_turns != 0:
|
||||
logger.debug("sync_turn: buffered turn %d (will retain at turn %d)",
|
||||
self._turn_counter, self._turn_counter + (self._retain_every_n_turns - self._turn_counter % self._retain_every_n_turns))
|
||||
return
|
||||
|
||||
logger.debug("sync_turn: retaining %d turns, total session content %d chars",
|
||||
len(self._session_turns), sum(len(t) for t in self._session_turns))
|
||||
# Send the ENTIRE session as a single JSON array (document_id deduplicates).
|
||||
# Each element in _session_turns is a JSON string of that turn's messages.
|
||||
content = "[" + ",".join(self._session_turns) + "]"
|
||||
"""Retain conversation turn in background (non-blocking)."""
|
||||
combined = f"User: {user_content}\nAssistant: {assistant_content}"
|
||||
|
||||
def _sync():
|
||||
try:
|
||||
client = self._get_client()
|
||||
item: dict = {
|
||||
"content": content,
|
||||
"context": self._retain_context,
|
||||
}
|
||||
if self._tags:
|
||||
item["tags"] = self._tags
|
||||
logger.debug("Hindsight retain: bank=%s, doc=%s, async=%s, content_len=%d, num_turns=%d",
|
||||
self._bank_id, self._session_id, self._retain_async, len(content), len(self._session_turns))
|
||||
_run_sync(client.aretain_batch(
|
||||
bank_id=self._bank_id,
|
||||
items=[item],
|
||||
document_id=self._session_id,
|
||||
retain_async=self._retain_async,
|
||||
_run_sync(client.aretain(
|
||||
bank_id=self._bank_id, content=combined, context="conversation"
|
||||
))
|
||||
logger.debug("Hindsight retain succeeded")
|
||||
except Exception as e:
|
||||
logger.warning("Hindsight sync failed: %s", e, exc_info=True)
|
||||
logger.warning("Hindsight sync failed: %s", e)
|
||||
|
||||
if self._sync_thread and self._sync_thread.is_alive():
|
||||
self._sync_thread.join(timeout=5.0)
|
||||
@@ -789,18 +453,12 @@ class HindsightMemoryProvider(MemoryProvider):
|
||||
return tool_error("Missing required parameter: content")
|
||||
context = args.get("context")
|
||||
try:
|
||||
retain_kwargs: dict = {
|
||||
"bank_id": self._bank_id, "content": content, "context": context,
|
||||
}
|
||||
if self._tags:
|
||||
retain_kwargs["tags"] = self._tags
|
||||
logger.debug("Tool hindsight_retain: bank=%s, content_len=%d, context=%s",
|
||||
self._bank_id, len(content), context)
|
||||
_run_sync(client.aretain(**retain_kwargs))
|
||||
logger.debug("Tool hindsight_retain: success")
|
||||
_run_sync(client.aretain(
|
||||
bank_id=self._bank_id, content=content, context=context
|
||||
))
|
||||
return json.dumps({"result": "Memory stored successfully."})
|
||||
except Exception as e:
|
||||
logger.warning("hindsight_retain failed: %s", e, exc_info=True)
|
||||
logger.warning("hindsight_retain failed: %s", e)
|
||||
return tool_error(f"Failed to store memory: {e}")
|
||||
|
||||
elif tool_name == "hindsight_recall":
|
||||
@@ -808,26 +466,15 @@ class HindsightMemoryProvider(MemoryProvider):
|
||||
if not query:
|
||||
return tool_error("Missing required parameter: query")
|
||||
try:
|
||||
recall_kwargs: dict = {
|
||||
"bank_id": self._bank_id, "query": query, "budget": self._budget,
|
||||
"max_tokens": self._recall_max_tokens,
|
||||
}
|
||||
if self._recall_tags:
|
||||
recall_kwargs["tags"] = self._recall_tags
|
||||
recall_kwargs["tags_match"] = self._recall_tags_match
|
||||
if self._recall_types:
|
||||
recall_kwargs["types"] = self._recall_types
|
||||
logger.debug("Tool hindsight_recall: bank=%s, query_len=%d, budget=%s",
|
||||
self._bank_id, len(query), self._budget)
|
||||
resp = _run_sync(client.arecall(**recall_kwargs))
|
||||
num_results = len(resp.results) if resp.results else 0
|
||||
logger.debug("Tool hindsight_recall: %d results", num_results)
|
||||
resp = _run_sync(client.arecall(
|
||||
bank_id=self._bank_id, query=query, budget=self._budget
|
||||
))
|
||||
if not resp.results:
|
||||
return json.dumps({"result": "No relevant memories found."})
|
||||
lines = [f"{i}. {r.text}" for i, r in enumerate(resp.results, 1)]
|
||||
return json.dumps({"result": "\n".join(lines)})
|
||||
except Exception as e:
|
||||
logger.warning("hindsight_recall failed: %s", e, exc_info=True)
|
||||
logger.warning("hindsight_recall failed: %s", e)
|
||||
return tool_error(f"Failed to search memory: {e}")
|
||||
|
||||
elif tool_name == "hindsight_reflect":
|
||||
@@ -835,28 +482,24 @@ class HindsightMemoryProvider(MemoryProvider):
|
||||
if not query:
|
||||
return tool_error("Missing required parameter: query")
|
||||
try:
|
||||
logger.debug("Tool hindsight_reflect: bank=%s, query_len=%d, budget=%s",
|
||||
self._bank_id, len(query), self._budget)
|
||||
resp = _run_sync(client.areflect(
|
||||
bank_id=self._bank_id, query=query, budget=self._budget
|
||||
))
|
||||
logger.debug("Tool hindsight_reflect: response_len=%d", len(resp.text or ""))
|
||||
return json.dumps({"result": resp.text or "No relevant memories found."})
|
||||
except Exception as e:
|
||||
logger.warning("hindsight_reflect failed: %s", e, exc_info=True)
|
||||
logger.warning("hindsight_reflect failed: %s", e)
|
||||
return tool_error(f"Failed to reflect: {e}")
|
||||
|
||||
return tool_error(f"Unknown tool: {tool_name}")
|
||||
|
||||
def shutdown(self) -> None:
|
||||
logger.debug("Hindsight shutdown: waiting for background threads")
|
||||
global _loop, _loop_thread
|
||||
for t in (self._prefetch_thread, self._sync_thread):
|
||||
if t and t.is_alive():
|
||||
t.join(timeout=5.0)
|
||||
if self._client is not None:
|
||||
try:
|
||||
if self._mode == "local_embedded":
|
||||
if self._mode == "local":
|
||||
# Use the public close() API. The RuntimeError from
|
||||
# aiohttp's "attached to a different loop" is expected
|
||||
# and harmless — the daemon keeps running independently.
|
||||
|
||||
@@ -2,7 +2,9 @@ name: hindsight
|
||||
version: 1.0.0
|
||||
description: "Hindsight — long-term memory with knowledge graph, entity resolution, and multi-strategy retrieval."
|
||||
pip_dependencies:
|
||||
- "hindsight-client>=0.4.22"
|
||||
requires_env: []
|
||||
- hindsight-client
|
||||
- hindsight-all
|
||||
requires_env:
|
||||
- HINDSIGHT_API_KEY
|
||||
hooks:
|
||||
- on_session_end
|
||||
|
||||
+161
-330
@@ -66,7 +66,7 @@ from model_tools import (
|
||||
handle_function_call,
|
||||
check_toolset_requirements,
|
||||
)
|
||||
from tools.terminal_tool import cleanup_vm, get_active_env, is_persistent_env
|
||||
from tools.terminal_tool import cleanup_vm, get_active_env
|
||||
from tools.tool_result_storage import maybe_persist_tool_result, enforce_turn_budget
|
||||
from tools.interrupt import set_interrupt as _set_interrupt
|
||||
from tools.browser_tool import cleanup_browser
|
||||
@@ -77,7 +77,6 @@ from hermes_constants import OPENROUTER_BASE_URL
|
||||
# Agent internals extracted to agent/ package for modularity
|
||||
from agent.memory_manager import build_memory_context_block
|
||||
from agent.retry_utils import jittered_backoff
|
||||
from agent.error_classifier import classify_api_error, FailoverReason
|
||||
from agent.prompt_builder import (
|
||||
DEFAULT_AGENT_IDENTITY, PLATFORM_HINTS,
|
||||
MEMORY_GUIDANCE, SESSION_SEARCH_GUIDANCE, SKILLS_GUIDANCE,
|
||||
@@ -87,7 +86,6 @@ from agent.model_metadata import (
|
||||
fetch_model_metadata,
|
||||
estimate_tokens_rough, estimate_messages_tokens_rough, estimate_request_tokens_rough,
|
||||
get_next_probe_tier, parse_context_limit_from_error,
|
||||
parse_available_output_tokens_from_error,
|
||||
save_context_length, is_local_endpoint,
|
||||
query_ollama_num_ctx,
|
||||
)
|
||||
@@ -444,13 +442,6 @@ class AIAgent:
|
||||
for AI models that support function calling.
|
||||
"""
|
||||
|
||||
# ── Class-level context pressure dedup (survives across instances) ──
|
||||
# The gateway creates a new AIAgent per message, so instance-level flags
|
||||
# reset every time. This dict tracks {session_id: (warn_level, timestamp)}
|
||||
# to suppress duplicate warnings within a cooldown window.
|
||||
_context_pressure_last_warned: dict = {}
|
||||
_CONTEXT_PRESSURE_COOLDOWN = 300 # seconds between re-warning same session
|
||||
|
||||
@property
|
||||
def base_url(self) -> str:
|
||||
return self._base_url
|
||||
@@ -624,6 +615,7 @@ class AIAgent:
|
||||
self.tool_complete_callback = tool_complete_callback
|
||||
self.thinking_callback = thinking_callback
|
||||
self.reasoning_callback = reasoning_callback
|
||||
self._reasoning_deltas_fired = False # Set by _fire_reasoning_delta, reset per API call
|
||||
self.clarify_callback = clarify_callback
|
||||
self.step_callback = step_callback
|
||||
self.stream_delta_callback = stream_delta_callback
|
||||
@@ -681,8 +673,7 @@ class AIAgent:
|
||||
# Context pressure warnings: notify the USER (not the LLM) as context
|
||||
# fills up. Purely informational — displayed in CLI output and sent via
|
||||
# status_callback for gateway platforms. Does NOT inject into messages.
|
||||
# Tiered: fires at 85% and again at 95% of compaction threshold.
|
||||
self._context_pressure_warned_at = 0.0 # highest tier already shown
|
||||
self._context_pressure_warned = False
|
||||
|
||||
# Activity tracking — updated on each API call, tool execution, and
|
||||
# stream chunk. Used by the gateway timeout handler to report what the
|
||||
@@ -693,10 +684,6 @@ class AIAgent:
|
||||
self._current_tool: str | None = None
|
||||
self._api_call_count: int = 0
|
||||
|
||||
# Rate limit tracking — updated from x-ratelimit-* response headers
|
||||
# after each API call. Accessed by /usage slash command.
|
||||
self._rate_limit_state: Optional["RateLimitState"] = None
|
||||
|
||||
# Centralized logging — agent.log (INFO+) and errors.log (WARNING+)
|
||||
# both live under ~/.hermes/logs/. Idempotent, so gateway mode
|
||||
# (which creates a new AIAgent per message) won't duplicate handlers.
|
||||
@@ -1298,6 +1285,7 @@ class AIAgent:
|
||||
if hasattr(self, "context_compressor") and self.context_compressor:
|
||||
self.context_compressor.last_prompt_tokens = 0
|
||||
self.context_compressor.last_completion_tokens = 0
|
||||
self.context_compressor.last_total_tokens = 0
|
||||
self.context_compressor.compression_count = 0
|
||||
self.context_compressor._context_probed = False
|
||||
self.context_compressor._context_probe_persistable = False
|
||||
@@ -1699,25 +1687,9 @@ class AIAgent:
|
||||
return None
|
||||
|
||||
def _cleanup_task_resources(self, task_id: str) -> None:
|
||||
"""Clean up VM and browser resources for a given task.
|
||||
|
||||
Skips ``cleanup_vm`` when the active terminal environment is marked
|
||||
persistent (``persistent_filesystem=True``) so that long-lived sandbox
|
||||
containers survive between turns. The idle reaper in
|
||||
``terminal_tool._cleanup_inactive_envs`` still tears them down once
|
||||
``terminal.lifetime_seconds`` is exceeded. Non-persistent backends are
|
||||
torn down per-turn as before to prevent resource leakage (the original
|
||||
intent of this hook for the Morph backend, see commit fbd3a2fd).
|
||||
"""
|
||||
"""Clean up VM and browser resources for a given task."""
|
||||
try:
|
||||
if is_persistent_env(task_id):
|
||||
if self.verbose_logging:
|
||||
logging.debug(
|
||||
f"Skipping per-turn cleanup_vm for persistent env {task_id}; "
|
||||
f"idle reaper will handle it."
|
||||
)
|
||||
else:
|
||||
cleanup_vm(task_id)
|
||||
cleanup_vm(task_id)
|
||||
except Exception as e:
|
||||
if self.verbose_logging:
|
||||
logging.warning(f"Failed to cleanup VM for task {task_id}: {e}")
|
||||
@@ -2549,29 +2521,6 @@ class AIAgent:
|
||||
self._last_activity_ts = time.time()
|
||||
self._last_activity_desc = desc
|
||||
|
||||
def _capture_rate_limits(self, http_response: Any) -> None:
|
||||
"""Parse x-ratelimit-* headers from an HTTP response and cache the state.
|
||||
|
||||
Called after each streaming API call. The httpx Response object is
|
||||
available on the OpenAI SDK Stream via ``stream.response``.
|
||||
"""
|
||||
if http_response is None:
|
||||
return
|
||||
headers = getattr(http_response, "headers", None)
|
||||
if not headers:
|
||||
return
|
||||
try:
|
||||
from agent.rate_limit_tracker import parse_rate_limit_headers
|
||||
state = parse_rate_limit_headers(headers, provider=self.provider)
|
||||
if state is not None:
|
||||
self._rate_limit_state = state
|
||||
except Exception:
|
||||
pass # Never let header parsing break the agent loop
|
||||
|
||||
def get_rate_limit_state(self):
|
||||
"""Return the last captured RateLimitState, or None."""
|
||||
return self._rate_limit_state
|
||||
|
||||
def get_activity_summary(self) -> dict:
|
||||
"""Return a snapshot of the agent's current activity for diagnostics.
|
||||
|
||||
@@ -3847,6 +3796,7 @@ class AIAgent:
|
||||
max_stream_retries = 1
|
||||
has_tool_calls = False
|
||||
first_delta_fired = False
|
||||
self._reasoning_deltas_fired = False
|
||||
# Accumulate streamed text so we can recover if get_final_response()
|
||||
# returns empty output (e.g. chatgpt.com backend-api sends
|
||||
# response.incomplete instead of response.completed).
|
||||
@@ -4324,6 +4274,7 @@ class AIAgent:
|
||||
|
||||
def _fire_reasoning_delta(self, text: str) -> None:
|
||||
"""Fire reasoning callback if registered."""
|
||||
self._reasoning_deltas_fired = True
|
||||
cb = self.reasoning_callback
|
||||
if cb is not None:
|
||||
try:
|
||||
@@ -4424,11 +4375,6 @@ class AIAgent:
|
||||
self._touch_activity("waiting for provider response (streaming)")
|
||||
stream = request_client_holder["client"].chat.completions.create(**stream_kwargs)
|
||||
|
||||
# Capture rate limit headers from the initial HTTP response.
|
||||
# The OpenAI SDK Stream object exposes the underlying httpx
|
||||
# response via .response before any chunks are consumed.
|
||||
self._capture_rate_limits(getattr(stream, "response", None))
|
||||
|
||||
content_parts: list = []
|
||||
tool_calls_acc: dict = {}
|
||||
tool_gen_notified: set = set()
|
||||
@@ -4443,6 +4389,10 @@ class AIAgent:
|
||||
role = "assistant"
|
||||
reasoning_parts: list = []
|
||||
usage_obj = None
|
||||
# Reset per-call reasoning tracking so _build_assistant_message
|
||||
# knows whether reasoning was already displayed during streaming.
|
||||
self._reasoning_deltas_fired = False
|
||||
|
||||
_first_chunk_seen = False
|
||||
for chunk in stream:
|
||||
last_chunk_time["t"] = time.time()
|
||||
@@ -4599,6 +4549,7 @@ class AIAgent:
|
||||
works unchanged.
|
||||
"""
|
||||
has_tool_use = False
|
||||
self._reasoning_deltas_fired = False
|
||||
|
||||
# Reset stale-stream timer for this attempt
|
||||
last_chunk_time["t"] = time.time()
|
||||
@@ -4777,25 +4728,18 @@ class AIAgent:
|
||||
self._close_request_openai_client(request_client, reason="stream_request_complete")
|
||||
|
||||
_stream_stale_timeout_base = float(os.getenv("HERMES_STREAM_STALE_TIMEOUT", 180.0))
|
||||
# Local providers (Ollama, oMLX, llama-cpp) can take 300+ seconds
|
||||
# for prefill on large contexts. Disable the stale detector unless
|
||||
# the user explicitly set HERMES_STREAM_STALE_TIMEOUT.
|
||||
if _stream_stale_timeout_base == 180.0 and self.base_url and is_local_endpoint(self.base_url):
|
||||
_stream_stale_timeout = float("inf")
|
||||
logger.debug("Local provider detected (%s) — stale stream timeout disabled", self.base_url)
|
||||
# Scale the stale timeout for large contexts: slow models (like Opus)
|
||||
# can legitimately think for minutes before producing the first token
|
||||
# when the context is large. Without this, the stale detector kills
|
||||
# healthy connections during the model's thinking phase, producing
|
||||
# spurious RemoteProtocolError ("peer closed connection").
|
||||
_est_tokens = sum(len(str(v)) for v in api_kwargs.get("messages", [])) // 4
|
||||
if _est_tokens > 100_000:
|
||||
_stream_stale_timeout = max(_stream_stale_timeout_base, 300.0)
|
||||
elif _est_tokens > 50_000:
|
||||
_stream_stale_timeout = max(_stream_stale_timeout_base, 240.0)
|
||||
else:
|
||||
# Scale the stale timeout for large contexts: slow models (like Opus)
|
||||
# can legitimately think for minutes before producing the first token
|
||||
# when the context is large. Without this, the stale detector kills
|
||||
# healthy connections during the model's thinking phase, producing
|
||||
# spurious RemoteProtocolError ("peer closed connection").
|
||||
_est_tokens = sum(len(str(v)) for v in api_kwargs.get("messages", [])) // 4
|
||||
if _est_tokens > 100_000:
|
||||
_stream_stale_timeout = max(_stream_stale_timeout_base, 300.0)
|
||||
elif _est_tokens > 50_000:
|
||||
_stream_stale_timeout = max(_stream_stale_timeout_base, 240.0)
|
||||
else:
|
||||
_stream_stale_timeout = _stream_stale_timeout_base
|
||||
_stream_stale_timeout = _stream_stale_timeout_base
|
||||
|
||||
t = threading.Thread(target=_call, daemon=True)
|
||||
t.start()
|
||||
@@ -4960,21 +4904,9 @@ class AIAgent:
|
||||
# Swap OpenAI client and config in-place
|
||||
self.api_key = fb_client.api_key
|
||||
self.client = fb_client
|
||||
# Preserve provider-specific headers that
|
||||
# resolve_provider_client() may have baked into
|
||||
# fb_client via the default_headers kwarg. The OpenAI
|
||||
# SDK stores these in _custom_headers. Without this,
|
||||
# subsequent request-client rebuilds (via
|
||||
# _create_request_openai_client) drop the headers,
|
||||
# causing 403s from providers like Kimi Coding that
|
||||
# require a User-Agent sentinel.
|
||||
fb_headers = getattr(fb_client, "_custom_headers", None)
|
||||
if not fb_headers:
|
||||
fb_headers = getattr(fb_client, "default_headers", None)
|
||||
self._client_kwargs = {
|
||||
"api_key": fb_client.api_key,
|
||||
"base_url": fb_base_url,
|
||||
**({"default_headers": dict(fb_headers)} if fb_headers else {}),
|
||||
}
|
||||
|
||||
# Re-evaluate prompt caching for the new provider/model
|
||||
@@ -5389,22 +5321,15 @@ class AIAgent:
|
||||
if self.api_mode == "anthropic_messages":
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs
|
||||
anthropic_messages = self._prepare_anthropic_messages_for_api(api_messages)
|
||||
# Pass context_length (total input+output window) so the adapter can
|
||||
# clamp max_tokens (output cap) when the user configured a smaller
|
||||
# context window than the model's native output limit.
|
||||
# Pass context_length so the adapter can clamp max_tokens if the
|
||||
# user configured a smaller context window than the model's output limit.
|
||||
ctx_len = getattr(self, "context_compressor", None)
|
||||
ctx_len = ctx_len.context_length if ctx_len else None
|
||||
# _ephemeral_max_output_tokens is set for one call when the API
|
||||
# returns "max_tokens too large given prompt" — it caps output to
|
||||
# the available window space without touching context_length.
|
||||
ephemeral_out = getattr(self, "_ephemeral_max_output_tokens", None)
|
||||
if ephemeral_out is not None:
|
||||
self._ephemeral_max_output_tokens = None # consume immediately
|
||||
return build_anthropic_kwargs(
|
||||
model=self.model,
|
||||
messages=anthropic_messages,
|
||||
tools=self.tools,
|
||||
max_tokens=ephemeral_out if ephemeral_out is not None else self.max_tokens,
|
||||
max_tokens=self.max_tokens,
|
||||
reasoning_config=self.reasoning_config,
|
||||
is_oauth=self._is_anthropic_oauth,
|
||||
preserve_dots=self._anthropic_preserve_dots(),
|
||||
@@ -5939,7 +5864,7 @@ class AIAgent:
|
||||
tools=[memory_tool_def],
|
||||
temperature=0.3,
|
||||
max_tokens=5120,
|
||||
# timeout resolved from auxiliary.flush_memories.timeout config
|
||||
timeout=30.0,
|
||||
)
|
||||
except RuntimeError:
|
||||
_aux_available = False
|
||||
@@ -5971,10 +5896,7 @@ class AIAgent:
|
||||
"temperature": 0.3,
|
||||
**self._max_tokens_param(5120),
|
||||
}
|
||||
from agent.auxiliary_client import _get_task_timeout
|
||||
response = self._ensure_primary_openai_client(reason="flush_memories").chat.completions.create(
|
||||
**api_kwargs, timeout=_get_task_timeout("flush_memories")
|
||||
)
|
||||
response = self._ensure_primary_openai_client(reason="flush_memories").chat.completions.create(**api_kwargs, timeout=30.0)
|
||||
|
||||
# Extract tool calls from the response, handling all API formats
|
||||
tool_calls = []
|
||||
@@ -6081,15 +6003,6 @@ class AIAgent:
|
||||
except Exception as e:
|
||||
logger.warning("Session DB compression split failed — new session will NOT be indexed: %s", e)
|
||||
|
||||
# Warn on repeated compressions (quality degrades with each pass)
|
||||
_cc = self.context_compressor.compression_count
|
||||
if _cc >= 2:
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ Session compressed {_cc} times — "
|
||||
f"accuracy may degrade. Consider /new to start fresh.",
|
||||
force=True,
|
||||
)
|
||||
|
||||
# Update token estimate after compaction so pressure calculations
|
||||
# use the post-compression count, not the stale pre-compression one.
|
||||
_compressed_est = (
|
||||
@@ -6102,16 +6015,12 @@ class AIAgent:
|
||||
# Only reset the pressure warning if compression actually brought
|
||||
# us below the warning level (85% of threshold). When compression
|
||||
# can't reduce enough (e.g. threshold is very low, or system prompt
|
||||
# alone exceeds the warning level), keep the tier set to prevent
|
||||
# alone exceeds the warning level), keep the flag set to prevent
|
||||
# spamming the user with repeated warnings every loop iteration.
|
||||
if self.context_compressor.threshold_tokens > 0:
|
||||
_post_progress = _compressed_est / self.context_compressor.threshold_tokens
|
||||
if _post_progress < 0.85:
|
||||
self._context_pressure_warned_at = 0.0
|
||||
# Clear class-level dedup for this session so a fresh
|
||||
# warning cycle can start if context grows again.
|
||||
_sid = self.session_id or "default"
|
||||
AIAgent._context_pressure_last_warned.pop(_sid, None)
|
||||
self._context_pressure_warned = False
|
||||
|
||||
# Clear the file-read dedup cache. After compression the original
|
||||
# read content is summarised away — if the model re-reads the same
|
||||
@@ -7293,7 +7202,6 @@ class AIAgent:
|
||||
length_continue_retries = 0
|
||||
truncated_response_prefix = ""
|
||||
compression_attempts = 0
|
||||
_turn_exit_reason = "unknown" # Diagnostic: why the loop ended
|
||||
|
||||
# Clear any stale interrupt state at start
|
||||
self.clear_interrupt()
|
||||
@@ -7318,7 +7226,6 @@ class AIAgent:
|
||||
# Check for interrupt request (e.g., user sent new message)
|
||||
if self._interrupt_requested:
|
||||
interrupted = True
|
||||
_turn_exit_reason = "interrupted_by_user"
|
||||
if not self.quiet_mode:
|
||||
self._safe_print("\n⚡ Breaking out of tool loop due to interrupt...")
|
||||
break
|
||||
@@ -7327,7 +7234,6 @@ class AIAgent:
|
||||
self._api_call_count = api_call_count
|
||||
self._touch_activity(f"starting API call #{api_call_count}")
|
||||
if not self.iteration_budget.consume():
|
||||
_turn_exit_reason = "budget_exhausted"
|
||||
if not self.quiet_mode:
|
||||
self._safe_print(f"\n⚠️ Iteration budget exhausted ({self.iteration_budget.used}/{self.iteration_budget.max_total} iterations used)")
|
||||
break
|
||||
@@ -8032,25 +7938,6 @@ class AIAgent:
|
||||
|
||||
status_code = getattr(api_error, "status_code", None)
|
||||
error_context = self._extract_api_error_context(api_error)
|
||||
|
||||
# ── Classify the error for structured recovery decisions ──
|
||||
_compressor = getattr(self, "context_compressor", None)
|
||||
_ctx_len = getattr(_compressor, "context_length", 200000) if _compressor else 200000
|
||||
classified = classify_api_error(
|
||||
api_error,
|
||||
provider=getattr(self, "provider", "") or "",
|
||||
model=getattr(self, "model", "") or "",
|
||||
approx_tokens=approx_tokens,
|
||||
context_length=_ctx_len,
|
||||
num_messages=len(api_messages) if api_messages else 0,
|
||||
)
|
||||
logger.debug(
|
||||
"Error classified: reason=%s status=%s retryable=%s compress=%s rotate=%s fallback=%s",
|
||||
classified.reason.value, classified.status_code,
|
||||
classified.retryable, classified.should_compress,
|
||||
classified.should_rotate_credential, classified.should_fallback,
|
||||
)
|
||||
|
||||
recovered_with_pool, has_retried_429 = self._recover_with_credential_pool(
|
||||
status_code=status_code,
|
||||
has_retried_429=has_retried_429,
|
||||
@@ -8113,24 +8000,27 @@ class AIAgent:
|
||||
# from all messages so the next retry sends no thinking
|
||||
# blocks at all. One-shot — don't retry infinitely.
|
||||
if (
|
||||
classified.reason == FailoverReason.thinking_signature
|
||||
self.api_mode == "anthropic_messages"
|
||||
and status_code == 400
|
||||
and not thinking_sig_retry_attempted
|
||||
):
|
||||
thinking_sig_retry_attempted = True
|
||||
for _m in messages:
|
||||
if isinstance(_m, dict):
|
||||
_m.pop("reasoning_details", None)
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ Thinking block signature invalid — "
|
||||
f"stripped all thinking blocks, retrying...",
|
||||
force=True,
|
||||
)
|
||||
logging.warning(
|
||||
"%sThinking block signature recovery: stripped "
|
||||
"reasoning_details from %d messages",
|
||||
self.log_prefix, len(messages),
|
||||
)
|
||||
continue
|
||||
_err_msg_lower = str(api_error).lower()
|
||||
if "signature" in _err_msg_lower and "thinking" in _err_msg_lower:
|
||||
thinking_sig_retry_attempted = True
|
||||
for _m in messages:
|
||||
if isinstance(_m, dict):
|
||||
_m.pop("reasoning_details", None)
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ Thinking block signature invalid — "
|
||||
f"stripped all thinking blocks, retrying...",
|
||||
force=True,
|
||||
)
|
||||
logging.warning(
|
||||
"%sThinking block signature recovery: stripped "
|
||||
"reasoning_details from %d messages",
|
||||
self.log_prefix, len(messages),
|
||||
)
|
||||
continue
|
||||
|
||||
retry_count += 1
|
||||
elapsed_time = time.time() - api_start_time
|
||||
@@ -8187,7 +8077,14 @@ class AIAgent:
|
||||
# is NOT a transient rate limit — retrying or switching
|
||||
# credentials won't help. Reduce context to 200k (the
|
||||
# standard tier) and compress.
|
||||
if classified.reason == FailoverReason.long_context_tier:
|
||||
# Only applies to Sonnet — Opus 1M is general access.
|
||||
_is_long_context_tier_error = (
|
||||
status_code == 429
|
||||
and "extra usage" in error_msg
|
||||
and "long context" in error_msg
|
||||
and "sonnet" in self.model.lower()
|
||||
)
|
||||
if _is_long_context_tier_error:
|
||||
_reduced_ctx = 200000
|
||||
compressor = self.context_compressor
|
||||
old_ctx = compressor.context_length
|
||||
@@ -8232,9 +8129,13 @@ class AIAgent:
|
||||
# When a fallback model is configured, switch immediately instead
|
||||
# of burning through retries with exponential backoff -- the
|
||||
# primary provider won't recover within the retry window.
|
||||
is_rate_limited = classified.reason in (
|
||||
FailoverReason.rate_limit,
|
||||
FailoverReason.billing,
|
||||
is_rate_limited = (
|
||||
status_code == 429
|
||||
or "rate limit" in error_msg
|
||||
or "too many requests" in error_msg
|
||||
or "rate_limit" in error_msg
|
||||
or "usage limit" in error_msg
|
||||
or "quota" in error_msg
|
||||
)
|
||||
if is_rate_limited and self._fallback_index < len(self._fallback_chain):
|
||||
# Don't eagerly fallback if credential pool rotation may
|
||||
@@ -8250,7 +8151,10 @@ class AIAgent:
|
||||
continue
|
||||
|
||||
is_payload_too_large = (
|
||||
classified.reason == FailoverReason.payload_too_large
|
||||
status_code == 413
|
||||
or 'request entity too large' in error_msg
|
||||
or 'payload too large' in error_msg
|
||||
or 'error code: 413' in error_msg
|
||||
)
|
||||
|
||||
if is_payload_too_large:
|
||||
@@ -8294,59 +8198,69 @@ class AIAgent:
|
||||
}
|
||||
|
||||
# Check for context-length errors BEFORE generic 4xx handler.
|
||||
# The classifier detects context overflow from: explicit error
|
||||
# messages, generic 400 + large session heuristic (#1630), and
|
||||
# server disconnect + large session pattern (#2153).
|
||||
is_context_length_error = (
|
||||
classified.reason == FailoverReason.context_overflow
|
||||
)
|
||||
# Local backends (LM Studio, Ollama, llama.cpp) often return
|
||||
# HTTP 400 with messages like "Context size has been exceeded"
|
||||
# which must trigger compression, not an immediate abort.
|
||||
is_context_length_error = any(phrase in error_msg for phrase in [
|
||||
'context length', 'context size', 'maximum context',
|
||||
'token limit', 'too many tokens', 'reduce the length',
|
||||
'exceeds the limit', 'context window',
|
||||
'request entity too large', # OpenRouter/Nous 413 safety net
|
||||
'prompt is too long', # Anthropic: "prompt is too long: N tokens > M maximum"
|
||||
'prompt exceeds max length', # Z.AI / GLM: generic 400 overflow wording
|
||||
])
|
||||
|
||||
# Fallback heuristic: Anthropic sometimes returns a generic
|
||||
# 400 invalid_request_error with just "Error" as the message
|
||||
# when the context is too large. If the error message is very
|
||||
# short/generic AND the session is large, treat it as a
|
||||
# probable context-length error and attempt compression rather
|
||||
# than aborting. This prevents an infinite failure loop where
|
||||
# each failed message gets persisted, making the session even
|
||||
# larger. (#1630)
|
||||
if not is_context_length_error and status_code == 400:
|
||||
ctx_len = getattr(getattr(self, 'context_compressor', None), 'context_length', 200000)
|
||||
is_large_session = approx_tokens > ctx_len * 0.4 or len(api_messages) > 80
|
||||
is_generic_error = len(error_msg.strip()) < 30 # e.g. just "error"
|
||||
if is_large_session and is_generic_error:
|
||||
is_context_length_error = True
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ Generic 400 with large session "
|
||||
f"(~{approx_tokens:,} tokens, {len(api_messages)} msgs) — "
|
||||
f"treating as probable context overflow.",
|
||||
force=True,
|
||||
)
|
||||
|
||||
# Server disconnects on large sessions are often caused by
|
||||
# the request exceeding the provider's context/payload limit
|
||||
# without a proper HTTP error response. Treat these as
|
||||
# context-length errors to trigger compression rather than
|
||||
# burning through retries that will all fail the same way.
|
||||
# This breaks the death spiral: disconnect → no token data
|
||||
# → no compression → bigger session → more disconnects.
|
||||
# (#2153)
|
||||
if not is_context_length_error and not status_code:
|
||||
_is_server_disconnect = (
|
||||
'server disconnected' in error_msg
|
||||
or 'peer closed connection' in error_msg
|
||||
or error_type in ('ReadError', 'RemoteProtocolError', 'ServerDisconnectedError')
|
||||
)
|
||||
if _is_server_disconnect:
|
||||
ctx_len = getattr(getattr(self, 'context_compressor', None), 'context_length', 200000)
|
||||
_is_large = approx_tokens > ctx_len * 0.6 or len(api_messages) > 200
|
||||
if _is_large:
|
||||
is_context_length_error = True
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ Server disconnected with large session "
|
||||
f"(~{approx_tokens:,} tokens, {len(api_messages)} msgs) — "
|
||||
f"treating as context-length error, attempting compression.",
|
||||
force=True,
|
||||
)
|
||||
|
||||
if is_context_length_error:
|
||||
compressor = self.context_compressor
|
||||
old_ctx = compressor.context_length
|
||||
|
||||
# ── Distinguish two very different errors ───────────
|
||||
# 1. "Prompt too long": the INPUT exceeds the context window.
|
||||
# Fix: reduce context_length + compress history.
|
||||
# 2. "max_tokens too large": input is fine, but
|
||||
# input_tokens + requested max_tokens > context_window.
|
||||
# Fix: reduce max_tokens (the OUTPUT cap) for this call.
|
||||
# Do NOT shrink context_length — the window is unchanged.
|
||||
#
|
||||
# Note: max_tokens = output token cap (one response).
|
||||
# context_length = total window (input + output combined).
|
||||
available_out = parse_available_output_tokens_from_error(error_msg)
|
||||
if available_out is not None:
|
||||
# Error is purely about the output cap being too large.
|
||||
# Cap output to the available space and retry without
|
||||
# touching context_length or triggering compression.
|
||||
safe_out = max(1, available_out - 64) # small safety margin
|
||||
self._ephemeral_max_output_tokens = safe_out
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ Output cap too large for current prompt — "
|
||||
f"retrying with max_tokens={safe_out:,} "
|
||||
f"(available_tokens={available_out:,}; context_length unchanged at {old_ctx:,})",
|
||||
force=True,
|
||||
)
|
||||
# Still count against compression_attempts so we don't
|
||||
# loop forever if the error keeps recurring.
|
||||
compression_attempts += 1
|
||||
if compression_attempts > max_compression_attempts:
|
||||
self._vprint(f"{self.log_prefix}❌ Max compression attempts ({max_compression_attempts}) reached.", force=True)
|
||||
self._vprint(f"{self.log_prefix} 💡 Try /new to start a fresh conversation, or /compress to retry compression.", force=True)
|
||||
logging.error(f"{self.log_prefix}Context compression failed after {max_compression_attempts} attempts.")
|
||||
self._persist_session(messages, conversation_history)
|
||||
return {
|
||||
"messages": messages,
|
||||
"completed": False,
|
||||
"api_calls": api_call_count,
|
||||
"error": f"Context length exceeded: max compression attempts ({max_compression_attempts}) reached.",
|
||||
"partial": True
|
||||
}
|
||||
restart_with_compressed_messages = True
|
||||
break
|
||||
|
||||
# Error is about the INPUT being too large — reduce context_length.
|
||||
# Try to parse the actual limit from the error message
|
||||
parsed_limit = parse_context_limit_from_error(error_msg)
|
||||
if parsed_limit and parsed_limit < old_ctx:
|
||||
@@ -8413,30 +8327,35 @@ class AIAgent:
|
||||
"partial": True
|
||||
}
|
||||
|
||||
# Check for non-retryable client errors. The classifier
|
||||
# already accounts for 413, 429, 529 (transient), context
|
||||
# overflow, and generic-400 heuristics. Local validation
|
||||
# errors (ValueError, TypeError) are programming bugs.
|
||||
# Check for non-retryable client errors (4xx HTTP status codes).
|
||||
# These indicate a problem with the request itself (bad model ID,
|
||||
# invalid API key, forbidden, etc.) and will never succeed on retry.
|
||||
# Note: 413 and context-length errors are excluded — handled above.
|
||||
# 429 (rate limit) is transient and MUST be retried with backoff.
|
||||
# 529 (Anthropic overloaded) is also transient.
|
||||
# Also catch local validation errors (ValueError, TypeError) — these
|
||||
# are programming bugs, not transient failures.
|
||||
# Exclude UnicodeEncodeError — it's a ValueError subclass but is
|
||||
# handled separately by the surrogate sanitization path above.
|
||||
_RETRYABLE_STATUS_CODES = {413, 429, 529}
|
||||
is_local_validation_error = (
|
||||
isinstance(api_error, (ValueError, TypeError))
|
||||
and not isinstance(api_error, UnicodeEncodeError)
|
||||
)
|
||||
is_client_error = (
|
||||
is_local_validation_error
|
||||
or (
|
||||
not classified.retryable
|
||||
and not classified.should_compress
|
||||
and classified.reason not in (
|
||||
FailoverReason.rate_limit,
|
||||
FailoverReason.billing,
|
||||
FailoverReason.overloaded,
|
||||
FailoverReason.context_overflow,
|
||||
FailoverReason.payload_too_large,
|
||||
FailoverReason.long_context_tier,
|
||||
FailoverReason.thinking_signature,
|
||||
)
|
||||
)
|
||||
) and not is_context_length_error
|
||||
# Detect generic 400s from Anthropic OAuth (transient server-side failures).
|
||||
# Real invalid_request_error responses include a descriptive message;
|
||||
# transient ones contain only "Error" or are empty. (ref: issue #1608)
|
||||
_err_body = getattr(api_error, "body", None) or {}
|
||||
_err_message = (_err_body.get("error", {}).get("message", "") if isinstance(_err_body, dict) else "")
|
||||
_is_generic_400 = (status_code == 400 and _err_message.strip().lower() in ("error", ""))
|
||||
is_client_status_error = isinstance(status_code, int) and 400 <= status_code < 500 and status_code not in _RETRYABLE_STATUS_CODES and not _is_generic_400
|
||||
is_client_error = (is_local_validation_error or is_client_status_error or any(phrase in error_msg for phrase in [
|
||||
'error code: 401', 'error code: 403',
|
||||
'error code: 404', 'error code: 422',
|
||||
'is not a valid model', 'invalid model', 'model not found',
|
||||
'invalid api key', 'invalid_api_key', 'authentication',
|
||||
'unauthorized', 'forbidden', 'not found',
|
||||
])) and not is_context_length_error
|
||||
|
||||
if is_client_error:
|
||||
# Try fallback before aborting — a different provider
|
||||
@@ -8456,7 +8375,7 @@ class AIAgent:
|
||||
self._vprint(f"{self.log_prefix} 🔌 Provider: {_provider} Model: {_model}", force=True)
|
||||
self._vprint(f"{self.log_prefix} 🌐 Endpoint: {_base}", force=True)
|
||||
# Actionable guidance for common auth errors
|
||||
if classified.is_auth or classified.reason == FailoverReason.billing:
|
||||
if status_code in (401, 403) or "unauthorized" in error_msg or "forbidden" in error_msg or "permission" in error_msg:
|
||||
if _provider == "openai-codex" and status_code == 401:
|
||||
self._vprint(f"{self.log_prefix} 💡 Codex OAuth token was rejected (HTTP 401). Your token may have been", force=True)
|
||||
self._vprint(f"{self.log_prefix} refreshed by another client (Codex CLI, VS Code). To fix:", force=True)
|
||||
@@ -8616,7 +8535,6 @@ class AIAgent:
|
||||
|
||||
# If the API call was interrupted, skip response processing
|
||||
if interrupted:
|
||||
_turn_exit_reason = "interrupted_during_api_call"
|
||||
break
|
||||
|
||||
if restart_with_compressed_messages:
|
||||
@@ -8636,7 +8554,6 @@ class AIAgent:
|
||||
# (e.g. repeated context-length errors that exhausted retry_count),
|
||||
# the `response` variable is still None. Break out cleanly.
|
||||
if response is None:
|
||||
_turn_exit_reason = "all_retries_exhausted_no_response"
|
||||
print(f"{self.log_prefix}❌ All API retries exhausted with no successful response.")
|
||||
self._persist_session(messages, conversation_history)
|
||||
break
|
||||
@@ -9043,34 +8960,13 @@ class AIAgent:
|
||||
# compaction fires, not the raw context window.
|
||||
# Does not inject into messages — just prints to CLI output
|
||||
# and fires status_callback for gateway platforms.
|
||||
# Tiered: 85% (orange) and 95% (red/critical).
|
||||
if _compressor.threshold_tokens > 0:
|
||||
_compaction_progress = _real_tokens / _compressor.threshold_tokens
|
||||
# Determine the warning tier for this progress level
|
||||
_warn_tier = 0.0
|
||||
if _compaction_progress >= 0.95:
|
||||
_warn_tier = 0.95
|
||||
elif _compaction_progress >= 0.85:
|
||||
_warn_tier = 0.85
|
||||
if _warn_tier > self._context_pressure_warned_at:
|
||||
# Class-level dedup: check if this session was already
|
||||
# warned at this tier within the cooldown window.
|
||||
_sid = self.session_id or "default"
|
||||
_last = AIAgent._context_pressure_last_warned.get(_sid)
|
||||
_now = time.time()
|
||||
if _last is None or _last[0] < _warn_tier or (_now - _last[1]) >= self._CONTEXT_PRESSURE_COOLDOWN:
|
||||
self._context_pressure_warned_at = _warn_tier
|
||||
AIAgent._context_pressure_last_warned[_sid] = (_warn_tier, _now)
|
||||
self._emit_context_pressure(_compaction_progress, _compressor)
|
||||
# Evict stale entries (older than 2x cooldown)
|
||||
_cutoff = _now - self._CONTEXT_PRESSURE_COOLDOWN * 2
|
||||
AIAgent._context_pressure_last_warned = {
|
||||
k: v for k, v in AIAgent._context_pressure_last_warned.items()
|
||||
if v[1] > _cutoff
|
||||
}
|
||||
if _compaction_progress >= 0.85 and not self._context_pressure_warned:
|
||||
self._context_pressure_warned = True
|
||||
self._emit_context_pressure(_compaction_progress, _compressor)
|
||||
|
||||
if self.compression_enabled and _compressor.should_compress(_real_tokens):
|
||||
self._safe_print(" ⟳ compacting context…")
|
||||
messages, active_system_prompt = self._compress_context(
|
||||
messages, system_message,
|
||||
approx_tokens=self.context_compressor.last_prompt_tokens,
|
||||
@@ -9100,7 +8996,6 @@ class AIAgent:
|
||||
# instead of wasting API calls on retries that won't help.
|
||||
fallback = getattr(self, '_last_content_with_tools', None)
|
||||
if fallback:
|
||||
_turn_exit_reason = "fallback_prior_turn_content"
|
||||
logger.debug("Empty follow-up after tool calls — using prior turn content as final response")
|
||||
self._last_content_with_tools = None
|
||||
self._empty_content_retries = 0
|
||||
@@ -9146,28 +9041,8 @@ class AIAgent:
|
||||
self._save_session_log(messages)
|
||||
continue
|
||||
|
||||
# ── Empty response retry (no reasoning) ──────
|
||||
# Model returned nothing — no content, no
|
||||
# structured reasoning, no tool calls. Common
|
||||
# with open models (transient provider issues,
|
||||
# rate limits, sampling flukes). Silently retry
|
||||
# up to 3 times before giving up. Skip when
|
||||
# content has inline <think> tags (model chose
|
||||
# to reason, just no visible text).
|
||||
_truly_empty = not final_response.strip()
|
||||
if _truly_empty and not _has_structured and self._empty_content_retries < 3:
|
||||
self._empty_content_retries += 1
|
||||
self._vprint(
|
||||
f"{self.log_prefix}↻ Empty response (no content or reasoning) "
|
||||
f"— retrying ({self._empty_content_retries}/3)",
|
||||
force=True,
|
||||
)
|
||||
continue
|
||||
|
||||
# Exhausted prefill attempts, empty retries, or
|
||||
# structured reasoning with no content —
|
||||
# fall through to "(empty)" terminal.
|
||||
_turn_exit_reason = "empty_response_exhausted"
|
||||
# Exhausted prefill attempts or no structured
|
||||
# reasoning — fall through to "(empty)" terminal.
|
||||
reasoning_text = self._extract_reasoning(assistant_message)
|
||||
assistant_msg = self._build_assistant_message(assistant_message, finish_reason)
|
||||
assistant_msg["content"] = "(empty)"
|
||||
@@ -9177,7 +9052,7 @@ class AIAgent:
|
||||
reasoning_preview = reasoning_text[:500] + "..." if len(reasoning_text) > 500 else reasoning_text
|
||||
self._vprint(f"{self.log_prefix}ℹ️ Reasoning-only response (no visible content). Reasoning: {reasoning_preview}")
|
||||
else:
|
||||
self._vprint(f"{self.log_prefix}ℹ️ Empty response (no content or reasoning) after 3 retries.")
|
||||
self._vprint(f"{self.log_prefix}ℹ️ Empty response (no content or reasoning).")
|
||||
|
||||
final_response = "(empty)"
|
||||
break
|
||||
@@ -9185,6 +9060,7 @@ class AIAgent:
|
||||
# Reset retry counter/signature on successful content
|
||||
if hasattr(self, '_empty_content_retries'):
|
||||
self._empty_content_retries = 0
|
||||
self._last_empty_content_signature = None
|
||||
self._thinking_prefill_retries = 0
|
||||
|
||||
if (
|
||||
@@ -9238,7 +9114,6 @@ class AIAgent:
|
||||
|
||||
messages.append(final_msg)
|
||||
|
||||
_turn_exit_reason = f"text_response(finish_reason={finish_reason})"
|
||||
if not self.quiet_mode:
|
||||
self._safe_print(f"🎉 Conversation completed after {api_call_count} OpenAI-compatible API call(s)")
|
||||
break
|
||||
@@ -9256,6 +9131,7 @@ class AIAgent:
|
||||
# If an assistant message with tool_calls was already appended,
|
||||
# the API expects a role="tool" result for every tool_call_id.
|
||||
# Fill in error results for any that weren't answered yet.
|
||||
pending_handled = False
|
||||
for idx in range(len(messages) - 1, -1, -1):
|
||||
msg = messages[idx]
|
||||
if not isinstance(msg, dict):
|
||||
@@ -9287,7 +9163,6 @@ class AIAgent:
|
||||
|
||||
# If we're near the limit, break to avoid infinite loops
|
||||
if api_call_count >= self.max_iterations - 1:
|
||||
_turn_exit_reason = f"error_near_max_iterations({error_msg[:80]})"
|
||||
final_response = f"I apologize, but I encountered repeated errors: {error_msg}"
|
||||
# Append as assistant so the history stays valid for
|
||||
# session resume (avoids consecutive user messages).
|
||||
@@ -9298,7 +9173,6 @@ class AIAgent:
|
||||
api_call_count >= self.max_iterations
|
||||
or self.iteration_budget.remaining <= 0
|
||||
):
|
||||
_turn_exit_reason = f"max_iterations_reached({api_call_count}/{self.max_iterations})"
|
||||
if self.iteration_budget.remaining <= 0 and not self.quiet_mode:
|
||||
print(f"\n⚠️ Iteration budget exhausted ({self.iteration_budget.used}/{self.iteration_budget.max_total} iterations used)")
|
||||
final_response = self._handle_max_iterations(messages, api_call_count)
|
||||
@@ -9315,49 +9189,6 @@ class AIAgent:
|
||||
# Persist session to both JSON log and SQLite
|
||||
self._persist_session(messages, conversation_history)
|
||||
|
||||
# ── Turn-exit diagnostic log ─────────────────────────────────────
|
||||
# Always logged at INFO so agent.log captures WHY every turn ended.
|
||||
# When the last message is a tool result (agent was mid-work), log
|
||||
# at WARNING — this is the "just stops" scenario users report.
|
||||
_last_msg_role = messages[-1].get("role") if messages else None
|
||||
_last_tool_name = None
|
||||
if _last_msg_role == "tool":
|
||||
# Walk back to find the assistant message with the tool call
|
||||
for _m in reversed(messages):
|
||||
if _m.get("role") == "assistant" and _m.get("tool_calls"):
|
||||
_tcs = _m["tool_calls"]
|
||||
if _tcs and isinstance(_tcs[0], dict):
|
||||
_last_tool_name = _tcs[-1].get("function", {}).get("name")
|
||||
break
|
||||
|
||||
_turn_tool_count = sum(
|
||||
1 for m in messages
|
||||
if isinstance(m, dict) and m.get("role") == "assistant" and m.get("tool_calls")
|
||||
)
|
||||
_resp_len = len(final_response) if final_response else 0
|
||||
_budget_used = self.iteration_budget.used if self.iteration_budget else 0
|
||||
_budget_max = self.iteration_budget.max_total if self.iteration_budget else 0
|
||||
|
||||
_diag_msg = (
|
||||
"Turn ended: reason=%s model=%s api_calls=%d/%d budget=%d/%d "
|
||||
"tool_turns=%d last_msg_role=%s response_len=%d session=%s"
|
||||
)
|
||||
_diag_args = (
|
||||
_turn_exit_reason, self.model, api_call_count, self.max_iterations,
|
||||
_budget_used, _budget_max,
|
||||
_turn_tool_count, _last_msg_role, _resp_len,
|
||||
self.session_id or "none",
|
||||
)
|
||||
|
||||
if _last_msg_role == "tool" and not interrupted:
|
||||
# Agent was mid-work — this is the "just stops" case.
|
||||
logger.warning(
|
||||
"Turn ended with pending tool result (agent may appear stuck). "
|
||||
+ _diag_msg + " last_tool=%s",
|
||||
*_diag_args, _last_tool_name,
|
||||
)
|
||||
else:
|
||||
logger.info(_diag_msg, *_diag_args)
|
||||
|
||||
# Plugin hook: post_llm_call
|
||||
# Fired once per turn after the tool-calling loop completes.
|
||||
|
||||
@@ -249,8 +249,9 @@ Type these during an interactive chat session.
|
||||
/config Show config (CLI)
|
||||
/model [name] Show or change model
|
||||
/provider Show provider info
|
||||
/prompt [text] View/set system prompt (CLI)
|
||||
/personality [name] Set personality
|
||||
/reasoning [level] Set reasoning (none|minimal|low|medium|high|xhigh|show|hide)
|
||||
/reasoning [level] Set reasoning (none|low|medium|high|xhigh|show|hide)
|
||||
/verbose Cycle: off → new → all → verbose
|
||||
/voice [on|off|tts] Voice mode
|
||||
/yolo Toggle approval bypass
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
---
|
||||
name: google-workspace
|
||||
description: Gmail, Calendar, Drive, Contacts, Sheets, and Docs integration via gws CLI (googleworkspace/cli). Uses OAuth2 with automatic token refresh via bridge script. Requires gws binary.
|
||||
version: 2.0.0
|
||||
description: Gmail, Calendar, Drive, Contacts, Sheets, and Docs integration via Python. Uses OAuth2 with automatic token refresh. No external binaries needed — runs entirely with Google's Python client libraries in the Hermes venv.
|
||||
version: 1.0.0
|
||||
author: Nous Research
|
||||
license: MIT
|
||||
required_credential_files:
|
||||
@@ -11,25 +11,14 @@ required_credential_files:
|
||||
description: Google OAuth2 client credentials (downloaded from Google Cloud Console)
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Google, Gmail, Calendar, Drive, Sheets, Docs, Contacts, Email, OAuth, gws]
|
||||
tags: [Google, Gmail, Calendar, Drive, Sheets, Docs, Contacts, Email, OAuth]
|
||||
homepage: https://github.com/NousResearch/hermes-agent
|
||||
related_skills: [himalaya]
|
||||
---
|
||||
|
||||
# Google Workspace
|
||||
|
||||
Gmail, Calendar, Drive, Contacts, Sheets, and Docs — powered by `gws` (Google's official Rust CLI). The skill provides a backward-compatible Python wrapper that handles OAuth token refresh and delegates to `gws`.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
google_api.py → gws_bridge.py → gws CLI
|
||||
(argparse compat) (token refresh) (Google APIs)
|
||||
```
|
||||
|
||||
- `setup.py` handles OAuth2 (headless-compatible, works on CLI/Telegram/Discord)
|
||||
- `gws_bridge.py` refreshes the Hermes token and injects it into `gws` via `GOOGLE_WORKSPACE_CLI_TOKEN`
|
||||
- `google_api.py` provides the same CLI interface as v1 but delegates to `gws`
|
||||
Gmail, Calendar, Drive, Contacts, Sheets, and Docs — all through Python scripts in this skill. No external binaries to install.
|
||||
|
||||
## References
|
||||
|
||||
@@ -38,22 +27,7 @@ google_api.py → gws_bridge.py → gws CLI
|
||||
## Scripts
|
||||
|
||||
- `scripts/setup.py` — OAuth2 setup (run once to authorize)
|
||||
- `scripts/gws_bridge.py` — Token refresh bridge to gws CLI
|
||||
- `scripts/google_api.py` — Backward-compatible API wrapper (delegates to gws)
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Install `gws`:
|
||||
|
||||
```bash
|
||||
cargo install google-workspace-cli
|
||||
# or via npm (recommended, downloads prebuilt binary):
|
||||
npm install -g @googleworkspace/cli
|
||||
# or via Homebrew:
|
||||
brew install googleworkspace-cli
|
||||
```
|
||||
|
||||
Verify: `gws --version`
|
||||
- `scripts/google_api.py` — API wrapper CLI (agent uses this for all operations)
|
||||
|
||||
## First-Time Setup
|
||||
|
||||
@@ -82,29 +56,42 @@ If it prints `AUTHENTICATED`, skip to Usage — setup is already done.
|
||||
|
||||
### Step 1: Triage — ask the user what they need
|
||||
|
||||
Before starting OAuth setup, ask the user TWO questions:
|
||||
|
||||
**Question 1: "What Google services do you need? Just email, or also
|
||||
Calendar/Drive/Sheets/Docs?"**
|
||||
|
||||
- **Email only** → Use the `himalaya` skill instead — simpler setup.
|
||||
- **Calendar, Drive, Sheets, Docs (or email + these)** → Continue below.
|
||||
- **Email only** → They don't need this skill at all. Use the `himalaya` skill
|
||||
instead — it works with a Gmail App Password (Settings → Security → App
|
||||
Passwords) and takes 2 minutes to set up. No Google Cloud project needed.
|
||||
Load the himalaya skill and follow its setup instructions.
|
||||
|
||||
**Partial scopes**: Users can authorize only a subset of services. The setup
|
||||
script accepts partial scopes and warns about missing ones.
|
||||
- **Calendar, Drive, Sheets, Docs (or email + these)** → Continue with this
|
||||
skill's OAuth setup below.
|
||||
|
||||
**Question 2: "Does your Google account use Advanced Protection?"**
|
||||
**Question 2: "Does your Google account use Advanced Protection (hardware
|
||||
security keys required to sign in)? If you're not sure, you probably don't
|
||||
— it's something you would have explicitly enrolled in."**
|
||||
|
||||
- **No / Not sure** → Normal setup.
|
||||
- **Yes** → Workspace admin must add the OAuth client ID to allowed apps first.
|
||||
- **No / Not sure** → Normal setup. Continue below.
|
||||
- **Yes** → Their Workspace admin must add the OAuth client ID to the org's
|
||||
allowed apps list before Step 4 will work. Let them know upfront.
|
||||
|
||||
### Step 2: Create OAuth credentials (one-time, ~5 minutes)
|
||||
|
||||
Tell the user:
|
||||
|
||||
> You need a Google Cloud OAuth client. This is a one-time setup:
|
||||
>
|
||||
> 1. Go to https://console.cloud.google.com/apis/credentials
|
||||
> 2. Create a project (or use an existing one)
|
||||
> 3. Enable the APIs you need (Gmail, Calendar, Drive, Sheets, Docs, People)
|
||||
> 4. Credentials → Create Credentials → OAuth 2.0 Client ID → Desktop app
|
||||
> 5. Download JSON and tell me the file path
|
||||
> 3. Click "Enable APIs" and enable: Gmail API, Google Calendar API,
|
||||
> Google Drive API, Google Sheets API, Google Docs API, People API
|
||||
> 4. Go to Credentials → Create Credentials → OAuth 2.0 Client ID
|
||||
> 5. Application type: "Desktop app" → Create
|
||||
> 6. Click "Download JSON" and tell me the file path
|
||||
|
||||
Once they provide the path:
|
||||
|
||||
```bash
|
||||
$GSETUP --client-secret /path/to/client_secret.json
|
||||
@@ -116,10 +103,20 @@ $GSETUP --client-secret /path/to/client_secret.json
|
||||
$GSETUP --auth-url
|
||||
```
|
||||
|
||||
Send the URL to the user. After authorizing, they paste back the redirect URL or code.
|
||||
This prints a URL. **Send the URL to the user** and tell them:
|
||||
|
||||
> Open this link in your browser, sign in with your Google account, and
|
||||
> authorize access. After authorizing, you'll be redirected to a page that
|
||||
> may show an error — that's expected. Copy the ENTIRE URL from your
|
||||
> browser's address bar and paste it back to me.
|
||||
|
||||
### Step 4: Exchange the code
|
||||
|
||||
The user will paste back either a URL like `http://localhost:1/?code=4/0A...&scope=...`
|
||||
or just the code string. Either works. The `--auth-url` step stores a temporary
|
||||
pending OAuth session locally so `--auth-code` can complete the PKCE exchange
|
||||
later, even on headless systems:
|
||||
|
||||
```bash
|
||||
$GSETUP --auth-code "THE_URL_OR_CODE_THE_USER_PASTED"
|
||||
```
|
||||
@@ -130,11 +127,18 @@ $GSETUP --auth-code "THE_URL_OR_CODE_THE_USER_PASTED"
|
||||
$GSETUP --check
|
||||
```
|
||||
|
||||
Should print `AUTHENTICATED`. Token refreshes automatically from now on.
|
||||
Should print `AUTHENTICATED`. Setup is complete — token refreshes automatically from now on.
|
||||
|
||||
### Notes
|
||||
|
||||
- Token is stored at `google_token.json` under the active profile's `HERMES_HOME` and auto-refreshes.
|
||||
- Pending OAuth session state/verifier are stored temporarily at `google_oauth_pending.json` under the active profile's `HERMES_HOME` until exchange completes.
|
||||
- Hermes now refuses to overwrite a full Google Workspace token with a narrower re-auth token missing Gmail scopes, so one profile's partial consent cannot silently break email actions later.
|
||||
- To revoke: `$GSETUP --revoke`
|
||||
|
||||
## Usage
|
||||
|
||||
All commands go through the API script:
|
||||
All commands go through the API script. Set `GAPI` as a shorthand:
|
||||
|
||||
```bash
|
||||
HERMES_HOME="${HERMES_HOME:-$HOME/.hermes}"
|
||||
@@ -149,21 +153,40 @@ GAPI="$PYTHON_BIN $GWORKSPACE_SKILL_DIR/scripts/google_api.py"
|
||||
### Gmail
|
||||
|
||||
```bash
|
||||
# Search (returns JSON array with id, from, subject, date, snippet)
|
||||
$GAPI gmail search "is:unread" --max 10
|
||||
$GAPI gmail search "from:boss@company.com newer_than:1d"
|
||||
$GAPI gmail search "has:attachment filename:pdf newer_than:7d"
|
||||
|
||||
# Read full message (returns JSON with body text)
|
||||
$GAPI gmail get MESSAGE_ID
|
||||
|
||||
# Send
|
||||
$GAPI gmail send --to user@example.com --subject "Hello" --body "Message text"
|
||||
$GAPI gmail send --to user@example.com --subject "Report" --body "<h1>Q4</h1>" --html
|
||||
$GAPI gmail send --to user@example.com --subject "Report" --body "<h1>Q4</h1><p>Details...</p>" --html
|
||||
|
||||
# Reply (automatically threads and sets In-Reply-To)
|
||||
$GAPI gmail reply MESSAGE_ID --body "Thanks, that works for me."
|
||||
|
||||
# Labels
|
||||
$GAPI gmail labels
|
||||
$GAPI gmail modify MESSAGE_ID --add-labels LABEL_ID
|
||||
$GAPI gmail modify MESSAGE_ID --remove-labels UNREAD
|
||||
```
|
||||
|
||||
### Calendar
|
||||
|
||||
```bash
|
||||
# List events (defaults to next 7 days)
|
||||
$GAPI calendar list
|
||||
$GAPI calendar create --summary "Standup" --start 2026-03-01T10:00:00+01:00 --end 2026-03-01T10:30:00+01:00
|
||||
$GAPI calendar create --summary "Review" --start ... --end ... --attendees "alice@co.com,bob@co.com"
|
||||
$GAPI calendar list --start 2026-03-01T00:00:00Z --end 2026-03-07T23:59:59Z
|
||||
|
||||
# Create event (ISO 8601 with timezone required)
|
||||
$GAPI calendar create --summary "Team Standup" --start 2026-03-01T10:00:00-06:00 --end 2026-03-01T10:30:00-06:00
|
||||
$GAPI calendar create --summary "Lunch" --start 2026-03-01T12:00:00Z --end 2026-03-01T13:00:00Z --location "Cafe"
|
||||
$GAPI calendar create --summary "Review" --start 2026-03-01T14:00:00Z --end 2026-03-01T15:00:00Z --attendees "alice@co.com,bob@co.com"
|
||||
|
||||
# Delete event
|
||||
$GAPI calendar delete EVENT_ID
|
||||
```
|
||||
|
||||
@@ -183,8 +206,13 @@ $GAPI contacts list --max 20
|
||||
### Sheets
|
||||
|
||||
```bash
|
||||
# Read
|
||||
$GAPI sheets get SHEET_ID "Sheet1!A1:D10"
|
||||
|
||||
# Write
|
||||
$GAPI sheets update SHEET_ID "Sheet1!A1:B2" --values '[["Name","Score"],["Alice","95"]]'
|
||||
|
||||
# Append rows
|
||||
$GAPI sheets append SHEET_ID "Sheet1!A:C" --values '[["new","row","data"]]'
|
||||
```
|
||||
|
||||
@@ -194,52 +222,37 @@ $GAPI sheets append SHEET_ID "Sheet1!A:C" --values '[["new","row","data"]]'
|
||||
$GAPI docs get DOC_ID
|
||||
```
|
||||
|
||||
### Direct gws access (advanced)
|
||||
|
||||
For operations not covered by the wrapper, use `gws_bridge.py` directly:
|
||||
|
||||
```bash
|
||||
GBRIDGE="$PYTHON_BIN $GWORKSPACE_SKILL_DIR/scripts/gws_bridge.py"
|
||||
$GBRIDGE calendar +agenda --today --format table
|
||||
$GBRIDGE gmail +triage --labels --format json
|
||||
$GBRIDGE drive +upload ./report.pdf
|
||||
$GBRIDGE sheets +read --spreadsheet SHEET_ID --range "Sheet1!A1:D10"
|
||||
```
|
||||
|
||||
## Output Format
|
||||
|
||||
All commands return JSON via `gws --format json`. Key output shapes:
|
||||
All commands return JSON. Parse with `jq` or read directly. Key fields:
|
||||
|
||||
- **Gmail search/triage**: Array of message summaries (sender, subject, date, snippet)
|
||||
- **Gmail get/read**: Message object with headers and body text
|
||||
- **Gmail send/reply**: Confirmation with message ID
|
||||
- **Calendar list/agenda**: Array of event objects (summary, start, end, location)
|
||||
- **Calendar create**: Confirmation with event ID and htmlLink
|
||||
- **Drive search**: Array of file objects (id, name, mimeType, webViewLink)
|
||||
- **Sheets get/read**: 2D array of cell values
|
||||
- **Docs get**: Full document JSON (use `body.content` for text extraction)
|
||||
- **Contacts list**: Array of person objects with names, emails, phones
|
||||
|
||||
Parse output with `jq` or read JSON directly.
|
||||
- **Gmail search**: `[{id, threadId, from, to, subject, date, snippet, labels}]`
|
||||
- **Gmail get**: `{id, threadId, from, to, subject, date, labels, body}`
|
||||
- **Gmail send/reply**: `{status: "sent", id, threadId}`
|
||||
- **Calendar list**: `[{id, summary, start, end, location, description, htmlLink}]`
|
||||
- **Calendar create**: `{status: "created", id, summary, htmlLink}`
|
||||
- **Drive search**: `[{id, name, mimeType, modifiedTime, webViewLink}]`
|
||||
- **Contacts list**: `[{name, emails: [...], phones: [...]}]`
|
||||
- **Sheets get**: `[[cell, cell, ...], ...]`
|
||||
|
||||
## Rules
|
||||
|
||||
1. **Never send email or create/delete events without confirming with the user first.**
|
||||
2. **Check auth before first use** — run `setup.py --check`.
|
||||
3. **Use the Gmail search syntax reference** for complex queries.
|
||||
4. **Calendar times must include timezone** — ISO 8601 with offset or UTC.
|
||||
5. **Respect rate limits** — avoid rapid-fire sequential API calls.
|
||||
1. **Never send email or create/delete events without confirming with the user first.** Show the draft content and ask for approval.
|
||||
2. **Check auth before first use** — run `setup.py --check`. If it fails, guide the user through setup.
|
||||
3. **Use the Gmail search syntax reference** for complex queries — load it with `skill_view("google-workspace", file_path="references/gmail-search-syntax.md")`.
|
||||
4. **Calendar times must include timezone** — always use ISO 8601 with offset (e.g., `2026-03-01T10:00:00-06:00`) or UTC (`Z`).
|
||||
5. **Respect rate limits** — avoid rapid-fire sequential API calls. Batch reads when possible.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
| Problem | Fix |
|
||||
|---------|-----|
|
||||
| `NOT_AUTHENTICATED` | Run setup Steps 2-5 |
|
||||
| `REFRESH_FAILED` | Token revoked — redo Steps 3-5 |
|
||||
| `gws: command not found` | Install: `npm install -g @googleworkspace/cli` |
|
||||
| `HttpError 403` | Missing scope — `$GSETUP --revoke` then redo Steps 3-5 |
|
||||
| `HttpError 403: Access Not Configured` | Enable API in Google Cloud Console |
|
||||
| Advanced Protection blocks auth | Admin must allowlist the OAuth client ID |
|
||||
| `NOT_AUTHENTICATED` | Run setup Steps 2-5 above |
|
||||
| `REFRESH_FAILED` | Token revoked or expired — redo Steps 3-5 |
|
||||
| `HttpError 403: Insufficient Permission` | Missing API scope — `$GSETUP --revoke` then redo Steps 3-5 |
|
||||
| `HttpError 403: Access Not Configured` | API not enabled — user needs to enable it in Google Cloud Console |
|
||||
| `ModuleNotFoundError` | Run `$GSETUP --install-deps` |
|
||||
| Advanced Protection blocks auth | Workspace admin must allowlist the OAuth client ID |
|
||||
|
||||
## Revoking Access
|
||||
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Google Workspace API CLI for Hermes Agent.
|
||||
|
||||
Thin wrapper that delegates to gws (googleworkspace/cli) via gws_bridge.py.
|
||||
Maintains the same CLI interface for backward compatibility with Hermes skills.
|
||||
A thin CLI wrapper around Google's Python client libraries.
|
||||
Authenticates using the token stored by setup.py.
|
||||
|
||||
Usage:
|
||||
python google_api.py gmail search "is:unread" [--max 10]
|
||||
python google_api.py gmail get MESSAGE_ID
|
||||
python google_api.py gmail send --to user@example.com --subject "Hi" --body "Hello"
|
||||
python google_api.py gmail reply MESSAGE_ID --body "Thanks"
|
||||
python google_api.py calendar list [--start DATE] [--end DATE] [--calendar primary]
|
||||
python google_api.py calendar list [--from DATE] [--to DATE] [--calendar primary]
|
||||
python google_api.py calendar create --summary "Meeting" --start DATETIME --end DATETIME
|
||||
python google_api.py calendar delete EVENT_ID
|
||||
python google_api.py drive search "budget report" [--max 10]
|
||||
python google_api.py contacts list [--max 20]
|
||||
python google_api.py sheets get SHEET_ID RANGE
|
||||
@@ -21,193 +20,386 @@ Usage:
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from email.mime.text import MIMEText
|
||||
from pathlib import Path
|
||||
|
||||
BRIDGE = Path(__file__).parent / "gws_bridge.py"
|
||||
PYTHON = sys.executable
|
||||
try:
|
||||
from hermes_constants import display_hermes_home, get_hermes_home
|
||||
except ModuleNotFoundError:
|
||||
HERMES_AGENT_ROOT = Path(__file__).resolve().parents[4]
|
||||
if HERMES_AGENT_ROOT.exists():
|
||||
sys.path.insert(0, str(HERMES_AGENT_ROOT))
|
||||
from hermes_constants import display_hermes_home, get_hermes_home
|
||||
|
||||
HERMES_HOME = get_hermes_home()
|
||||
TOKEN_PATH = HERMES_HOME / "google_token.json"
|
||||
|
||||
SCOPES = [
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
"https://www.googleapis.com/auth/gmail.send",
|
||||
"https://www.googleapis.com/auth/gmail.modify",
|
||||
"https://www.googleapis.com/auth/calendar",
|
||||
"https://www.googleapis.com/auth/drive.readonly",
|
||||
"https://www.googleapis.com/auth/contacts.readonly",
|
||||
"https://www.googleapis.com/auth/spreadsheets",
|
||||
"https://www.googleapis.com/auth/documents.readonly",
|
||||
]
|
||||
|
||||
|
||||
def gws(*args: str) -> None:
|
||||
"""Call gws via the bridge and exit with its return code."""
|
||||
result = subprocess.run(
|
||||
[PYTHON, str(BRIDGE)] + list(args),
|
||||
env={**os.environ, "HERMES_HOME": os.environ.get("HERMES_HOME", str(Path.home() / ".hermes"))},
|
||||
)
|
||||
sys.exit(result.returncode)
|
||||
def _missing_scopes() -> list[str]:
|
||||
try:
|
||||
payload = json.loads(TOKEN_PATH.read_text())
|
||||
except Exception:
|
||||
return []
|
||||
raw = payload.get("scopes") or payload.get("scope")
|
||||
if not raw:
|
||||
return []
|
||||
granted = {s.strip() for s in (raw.split() if isinstance(raw, str) else raw) if s.strip()}
|
||||
return sorted(scope for scope in SCOPES if scope not in granted)
|
||||
|
||||
|
||||
# -- Gmail --
|
||||
def get_credentials():
|
||||
"""Load and refresh credentials from token file."""
|
||||
if not TOKEN_PATH.exists():
|
||||
print("Not authenticated. Run the setup script first:", file=sys.stderr)
|
||||
print(f" python {Path(__file__).parent / 'setup.py'}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from google.auth.transport.requests import Request
|
||||
|
||||
creds = Credentials.from_authorized_user_file(str(TOKEN_PATH), SCOPES)
|
||||
if creds.expired and creds.refresh_token:
|
||||
creds.refresh(Request())
|
||||
TOKEN_PATH.write_text(creds.to_json())
|
||||
if not creds.valid:
|
||||
print("Token is invalid. Re-run setup.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
missing_scopes = _missing_scopes()
|
||||
if missing_scopes:
|
||||
print(
|
||||
"Token is valid but missing Google Workspace scopes required by this skill.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
for scope in missing_scopes:
|
||||
print(f" - {scope}", file=sys.stderr)
|
||||
print(
|
||||
f"Re-run setup.py from the active Hermes profile ({display_hermes_home()}) to restore full access.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
return creds
|
||||
|
||||
|
||||
def build_service(api, version):
|
||||
from googleapiclient.discovery import build
|
||||
return build(api, version, credentials=get_credentials())
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Gmail
|
||||
# =========================================================================
|
||||
|
||||
def gmail_search(args):
|
||||
cmd = ["gmail", "+triage", "--query", args.query, "--max", str(args.max), "--format", "json"]
|
||||
gws(*cmd)
|
||||
service = build_service("gmail", "v1")
|
||||
results = service.users().messages().list(
|
||||
userId="me", q=args.query, maxResults=args.max
|
||||
).execute()
|
||||
messages = results.get("messages", [])
|
||||
if not messages:
|
||||
print("No messages found.")
|
||||
return
|
||||
|
||||
output = []
|
||||
for msg_meta in messages:
|
||||
msg = service.users().messages().get(
|
||||
userId="me", id=msg_meta["id"], format="metadata",
|
||||
metadataHeaders=["From", "To", "Subject", "Date"],
|
||||
).execute()
|
||||
headers = {h["name"]: h["value"] for h in msg.get("payload", {}).get("headers", [])}
|
||||
output.append({
|
||||
"id": msg["id"],
|
||||
"threadId": msg["threadId"],
|
||||
"from": headers.get("From", ""),
|
||||
"to": headers.get("To", ""),
|
||||
"subject": headers.get("Subject", ""),
|
||||
"date": headers.get("Date", ""),
|
||||
"snippet": msg.get("snippet", ""),
|
||||
"labels": msg.get("labelIds", []),
|
||||
})
|
||||
print(json.dumps(output, indent=2, ensure_ascii=False))
|
||||
|
||||
|
||||
def gmail_get(args):
|
||||
gws("gmail", "+read", "--id", args.message_id, "--headers", "--format", "json")
|
||||
service = build_service("gmail", "v1")
|
||||
msg = service.users().messages().get(
|
||||
userId="me", id=args.message_id, format="full"
|
||||
).execute()
|
||||
|
||||
headers = {h["name"]: h["value"] for h in msg.get("payload", {}).get("headers", [])}
|
||||
|
||||
# Extract body text
|
||||
body = ""
|
||||
payload = msg.get("payload", {})
|
||||
if payload.get("body", {}).get("data"):
|
||||
body = base64.urlsafe_b64decode(payload["body"]["data"]).decode("utf-8", errors="replace")
|
||||
elif payload.get("parts"):
|
||||
for part in payload["parts"]:
|
||||
if part.get("mimeType") == "text/plain" and part.get("body", {}).get("data"):
|
||||
body = base64.urlsafe_b64decode(part["body"]["data"]).decode("utf-8", errors="replace")
|
||||
break
|
||||
if not body:
|
||||
for part in payload["parts"]:
|
||||
if part.get("mimeType") == "text/html" and part.get("body", {}).get("data"):
|
||||
body = base64.urlsafe_b64decode(part["body"]["data"]).decode("utf-8", errors="replace")
|
||||
break
|
||||
|
||||
result = {
|
||||
"id": msg["id"],
|
||||
"threadId": msg["threadId"],
|
||||
"from": headers.get("From", ""),
|
||||
"to": headers.get("To", ""),
|
||||
"subject": headers.get("Subject", ""),
|
||||
"date": headers.get("Date", ""),
|
||||
"labels": msg.get("labelIds", []),
|
||||
"body": body,
|
||||
}
|
||||
print(json.dumps(result, indent=2, ensure_ascii=False))
|
||||
|
||||
|
||||
def gmail_send(args):
|
||||
cmd = ["gmail", "+send", "--to", args.to, "--subject", args.subject, "--body", args.body, "--format", "json"]
|
||||
service = build_service("gmail", "v1")
|
||||
message = MIMEText(args.body, "html" if args.html else "plain")
|
||||
message["to"] = args.to
|
||||
message["subject"] = args.subject
|
||||
if args.cc:
|
||||
cmd += ["--cc", args.cc]
|
||||
if args.html:
|
||||
cmd.append("--html")
|
||||
gws(*cmd)
|
||||
message["cc"] = args.cc
|
||||
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
body = {"raw": raw}
|
||||
|
||||
if args.thread_id:
|
||||
body["threadId"] = args.thread_id
|
||||
|
||||
result = service.users().messages().send(userId="me", body=body).execute()
|
||||
print(json.dumps({"status": "sent", "id": result["id"], "threadId": result.get("threadId", "")}, indent=2))
|
||||
|
||||
|
||||
def gmail_reply(args):
|
||||
gws("gmail", "+reply", "--message-id", args.message_id, "--body", args.body, "--format", "json")
|
||||
service = build_service("gmail", "v1")
|
||||
# Fetch original to get thread ID and headers
|
||||
original = service.users().messages().get(
|
||||
userId="me", id=args.message_id, format="metadata",
|
||||
metadataHeaders=["From", "Subject", "Message-ID"],
|
||||
).execute()
|
||||
headers = {h["name"]: h["value"] for h in original.get("payload", {}).get("headers", [])}
|
||||
|
||||
subject = headers.get("Subject", "")
|
||||
if not subject.startswith("Re:"):
|
||||
subject = f"Re: {subject}"
|
||||
|
||||
message = MIMEText(args.body)
|
||||
message["to"] = headers.get("From", "")
|
||||
message["subject"] = subject
|
||||
if headers.get("Message-ID"):
|
||||
message["In-Reply-To"] = headers["Message-ID"]
|
||||
message["References"] = headers["Message-ID"]
|
||||
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
body = {"raw": raw, "threadId": original["threadId"]}
|
||||
|
||||
result = service.users().messages().send(userId="me", body=body).execute()
|
||||
print(json.dumps({"status": "sent", "id": result["id"], "threadId": result.get("threadId", "")}, indent=2))
|
||||
|
||||
|
||||
def gmail_labels(args):
|
||||
gws("gmail", "users", "labels", "list", "--params", json.dumps({"userId": "me"}), "--format", "json")
|
||||
service = build_service("gmail", "v1")
|
||||
results = service.users().labels().list(userId="me").execute()
|
||||
labels = [{"id": l["id"], "name": l["name"], "type": l.get("type", "")} for l in results.get("labels", [])]
|
||||
print(json.dumps(labels, indent=2))
|
||||
|
||||
|
||||
def gmail_modify(args):
|
||||
service = build_service("gmail", "v1")
|
||||
body = {}
|
||||
if args.add_labels:
|
||||
body["addLabelIds"] = args.add_labels.split(",")
|
||||
if args.remove_labels:
|
||||
body["removeLabelIds"] = args.remove_labels.split(",")
|
||||
gws(
|
||||
"gmail", "users", "messages", "modify",
|
||||
"--params", json.dumps({"userId": "me", "id": args.message_id}),
|
||||
"--json", json.dumps(body),
|
||||
"--format", "json",
|
||||
)
|
||||
result = service.users().messages().modify(userId="me", id=args.message_id, body=body).execute()
|
||||
print(json.dumps({"id": result["id"], "labels": result.get("labelIds", [])}, indent=2))
|
||||
|
||||
|
||||
# -- Calendar --
|
||||
# =========================================================================
|
||||
# Calendar
|
||||
# =========================================================================
|
||||
|
||||
def calendar_list(args):
|
||||
if args.start or args.end:
|
||||
# Specific date range — use raw Calendar API for precise timeMin/timeMax
|
||||
from datetime import datetime, timedelta, timezone as tz
|
||||
now = datetime.now(tz.utc)
|
||||
time_min = args.start or now.isoformat()
|
||||
time_max = args.end or (now + timedelta(days=7)).isoformat()
|
||||
gws(
|
||||
"calendar", "events", "list",
|
||||
"--params", json.dumps({
|
||||
"calendarId": args.calendar,
|
||||
"timeMin": time_min,
|
||||
"timeMax": time_max,
|
||||
"maxResults": args.max,
|
||||
"singleEvents": True,
|
||||
"orderBy": "startTime",
|
||||
}),
|
||||
"--format", "json",
|
||||
)
|
||||
else:
|
||||
# No date range — use +agenda helper (defaults to 7 days)
|
||||
cmd = ["calendar", "+agenda", "--days", "7", "--format", "json"]
|
||||
if args.calendar != "primary":
|
||||
cmd += ["--calendar", args.calendar]
|
||||
gws(*cmd)
|
||||
service = build_service("calendar", "v3")
|
||||
now = datetime.now(timezone.utc)
|
||||
time_min = args.start or now.isoformat()
|
||||
time_max = args.end or (now + timedelta(days=7)).isoformat()
|
||||
|
||||
# Ensure timezone info
|
||||
for val in [time_min, time_max]:
|
||||
if "T" in val and "Z" not in val and "+" not in val and "-" not in val[11:]:
|
||||
val += "Z"
|
||||
|
||||
results = service.events().list(
|
||||
calendarId=args.calendar, timeMin=time_min, timeMax=time_max,
|
||||
maxResults=args.max, singleEvents=True, orderBy="startTime",
|
||||
).execute()
|
||||
|
||||
events = []
|
||||
for e in results.get("items", []):
|
||||
events.append({
|
||||
"id": e["id"],
|
||||
"summary": e.get("summary", "(no title)"),
|
||||
"start": e.get("start", {}).get("dateTime", e.get("start", {}).get("date", "")),
|
||||
"end": e.get("end", {}).get("dateTime", e.get("end", {}).get("date", "")),
|
||||
"location": e.get("location", ""),
|
||||
"description": e.get("description", ""),
|
||||
"status": e.get("status", ""),
|
||||
"htmlLink": e.get("htmlLink", ""),
|
||||
})
|
||||
print(json.dumps(events, indent=2, ensure_ascii=False))
|
||||
|
||||
|
||||
def calendar_create(args):
|
||||
cmd = [
|
||||
"calendar", "+insert",
|
||||
"--summary", args.summary,
|
||||
"--start", args.start,
|
||||
"--end", args.end,
|
||||
"--format", "json",
|
||||
]
|
||||
service = build_service("calendar", "v3")
|
||||
event = {
|
||||
"summary": args.summary,
|
||||
"start": {"dateTime": args.start},
|
||||
"end": {"dateTime": args.end},
|
||||
}
|
||||
if args.location:
|
||||
cmd += ["--location", args.location]
|
||||
event["location"] = args.location
|
||||
if args.description:
|
||||
cmd += ["--description", args.description]
|
||||
event["description"] = args.description
|
||||
if args.attendees:
|
||||
for email in args.attendees.split(","):
|
||||
cmd += ["--attendee", email.strip()]
|
||||
if args.calendar != "primary":
|
||||
cmd += ["--calendar", args.calendar]
|
||||
gws(*cmd)
|
||||
event["attendees"] = [{"email": e.strip()} for e in args.attendees.split(",")]
|
||||
|
||||
result = service.events().insert(calendarId=args.calendar, body=event).execute()
|
||||
print(json.dumps({
|
||||
"status": "created",
|
||||
"id": result["id"],
|
||||
"summary": result.get("summary", ""),
|
||||
"htmlLink": result.get("htmlLink", ""),
|
||||
}, indent=2))
|
||||
|
||||
|
||||
def calendar_delete(args):
|
||||
gws(
|
||||
"calendar", "events", "delete",
|
||||
"--params", json.dumps({"calendarId": args.calendar, "eventId": args.event_id}),
|
||||
"--format", "json",
|
||||
)
|
||||
service = build_service("calendar", "v3")
|
||||
service.events().delete(calendarId=args.calendar, eventId=args.event_id).execute()
|
||||
print(json.dumps({"status": "deleted", "eventId": args.event_id}))
|
||||
|
||||
|
||||
# -- Drive --
|
||||
# =========================================================================
|
||||
# Drive
|
||||
# =========================================================================
|
||||
|
||||
def drive_search(args):
|
||||
query = args.query if args.raw_query else f"fullText contains '{args.query}'"
|
||||
gws(
|
||||
"drive", "files", "list",
|
||||
"--params", json.dumps({
|
||||
"q": query,
|
||||
"pageSize": args.max,
|
||||
"fields": "files(id,name,mimeType,modifiedTime,webViewLink)",
|
||||
}),
|
||||
"--format", "json",
|
||||
)
|
||||
service = build_service("drive", "v3")
|
||||
query = f"fullText contains '{args.query}'" if not args.raw_query else args.query
|
||||
results = service.files().list(
|
||||
q=query, pageSize=args.max, fields="files(id, name, mimeType, modifiedTime, webViewLink)",
|
||||
).execute()
|
||||
files = results.get("files", [])
|
||||
print(json.dumps(files, indent=2, ensure_ascii=False))
|
||||
|
||||
|
||||
# -- Contacts --
|
||||
# =========================================================================
|
||||
# Contacts
|
||||
# =========================================================================
|
||||
|
||||
def contacts_list(args):
|
||||
gws(
|
||||
"people", "people", "connections", "list",
|
||||
"--params", json.dumps({
|
||||
"resourceName": "people/me",
|
||||
"pageSize": args.max,
|
||||
"personFields": "names,emailAddresses,phoneNumbers",
|
||||
}),
|
||||
"--format", "json",
|
||||
)
|
||||
service = build_service("people", "v1")
|
||||
results = service.people().connections().list(
|
||||
resourceName="people/me",
|
||||
pageSize=args.max,
|
||||
personFields="names,emailAddresses,phoneNumbers",
|
||||
).execute()
|
||||
contacts = []
|
||||
for person in results.get("connections", []):
|
||||
names = person.get("names", [{}])
|
||||
emails = person.get("emailAddresses", [])
|
||||
phones = person.get("phoneNumbers", [])
|
||||
contacts.append({
|
||||
"name": names[0].get("displayName", "") if names else "",
|
||||
"emails": [e.get("value", "") for e in emails],
|
||||
"phones": [p.get("value", "") for p in phones],
|
||||
})
|
||||
print(json.dumps(contacts, indent=2, ensure_ascii=False))
|
||||
|
||||
|
||||
# -- Sheets --
|
||||
# =========================================================================
|
||||
# Sheets
|
||||
# =========================================================================
|
||||
|
||||
def sheets_get(args):
|
||||
gws(
|
||||
"sheets", "+read",
|
||||
"--spreadsheet", args.sheet_id,
|
||||
"--range", args.range,
|
||||
"--format", "json",
|
||||
)
|
||||
service = build_service("sheets", "v4")
|
||||
result = service.spreadsheets().values().get(
|
||||
spreadsheetId=args.sheet_id, range=args.range,
|
||||
).execute()
|
||||
print(json.dumps(result.get("values", []), indent=2, ensure_ascii=False))
|
||||
|
||||
|
||||
def sheets_update(args):
|
||||
service = build_service("sheets", "v4")
|
||||
values = json.loads(args.values)
|
||||
gws(
|
||||
"sheets", "spreadsheets", "values", "update",
|
||||
"--params", json.dumps({
|
||||
"spreadsheetId": args.sheet_id,
|
||||
"range": args.range,
|
||||
"valueInputOption": "USER_ENTERED",
|
||||
}),
|
||||
"--json", json.dumps({"values": values}),
|
||||
"--format", "json",
|
||||
)
|
||||
body = {"values": values}
|
||||
result = service.spreadsheets().values().update(
|
||||
spreadsheetId=args.sheet_id, range=args.range,
|
||||
valueInputOption="USER_ENTERED", body=body,
|
||||
).execute()
|
||||
print(json.dumps({"updatedCells": result.get("updatedCells", 0), "updatedRange": result.get("updatedRange", "")}, indent=2))
|
||||
|
||||
|
||||
def sheets_append(args):
|
||||
service = build_service("sheets", "v4")
|
||||
values = json.loads(args.values)
|
||||
gws(
|
||||
"sheets", "+append",
|
||||
"--spreadsheet", args.sheet_id,
|
||||
"--json-values", json.dumps(values),
|
||||
"--format", "json",
|
||||
)
|
||||
body = {"values": values}
|
||||
result = service.spreadsheets().values().append(
|
||||
spreadsheetId=args.sheet_id, range=args.range,
|
||||
valueInputOption="USER_ENTERED", insertDataOption="INSERT_ROWS", body=body,
|
||||
).execute()
|
||||
print(json.dumps({"updatedCells": result.get("updates", {}).get("updatedCells", 0)}, indent=2))
|
||||
|
||||
|
||||
# -- Docs --
|
||||
# =========================================================================
|
||||
# Docs
|
||||
# =========================================================================
|
||||
|
||||
def docs_get(args):
|
||||
gws(
|
||||
"docs", "documents", "get",
|
||||
"--params", json.dumps({"documentId": args.doc_id}),
|
||||
"--format", "json",
|
||||
)
|
||||
service = build_service("docs", "v1")
|
||||
doc = service.documents().get(documentId=args.doc_id).execute()
|
||||
# Extract plain text from the document structure
|
||||
text_parts = []
|
||||
for element in doc.get("body", {}).get("content", []):
|
||||
paragraph = element.get("paragraph", {})
|
||||
for pe in paragraph.get("elements", []):
|
||||
text_run = pe.get("textRun", {})
|
||||
if text_run.get("content"):
|
||||
text_parts.append(text_run["content"])
|
||||
result = {
|
||||
"title": doc.get("title", ""),
|
||||
"documentId": doc.get("documentId", ""),
|
||||
"body": "".join(text_parts),
|
||||
}
|
||||
print(json.dumps(result, indent=2, ensure_ascii=False))
|
||||
|
||||
|
||||
# -- CLI parser (backward-compatible interface) --
|
||||
# =========================================================================
|
||||
# CLI parser
|
||||
# =========================================================================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Google Workspace API for Hermes Agent (gws backend)")
|
||||
parser = argparse.ArgumentParser(description="Google Workspace API for Hermes Agent")
|
||||
sub = parser.add_subparsers(dest="service", required=True)
|
||||
|
||||
# --- Gmail ---
|
||||
@@ -229,7 +421,7 @@ def main():
|
||||
p.add_argument("--body", required=True)
|
||||
p.add_argument("--cc", default="")
|
||||
p.add_argument("--html", action="store_true", help="Send body as HTML")
|
||||
p.add_argument("--thread-id", default="", help="Thread ID (unused with gws, kept for compat)")
|
||||
p.add_argument("--thread-id", default="", help="Thread ID for threading")
|
||||
p.set_defaults(func=gmail_send)
|
||||
|
||||
p = gmail_sub.add_parser("reply")
|
||||
|
||||
@@ -1,89 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Bridge between Hermes OAuth token and gws CLI.
|
||||
|
||||
Refreshes the token if expired, then executes gws with the valid access token.
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def get_hermes_home() -> Path:
|
||||
return Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
|
||||
|
||||
|
||||
def get_token_path() -> Path:
|
||||
return get_hermes_home() / "google_token.json"
|
||||
|
||||
|
||||
def refresh_token(token_data: dict) -> dict:
|
||||
"""Refresh the access token using the refresh token."""
|
||||
import urllib.error
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
|
||||
params = urllib.parse.urlencode({
|
||||
"client_id": token_data["client_id"],
|
||||
"client_secret": token_data["client_secret"],
|
||||
"refresh_token": token_data["refresh_token"],
|
||||
"grant_type": "refresh_token",
|
||||
}).encode()
|
||||
|
||||
req = urllib.request.Request(token_data["token_uri"], data=params)
|
||||
try:
|
||||
with urllib.request.urlopen(req) as resp:
|
||||
result = json.loads(resp.read())
|
||||
except urllib.error.HTTPError as e:
|
||||
body = e.read().decode("utf-8", errors="replace")
|
||||
print(f"ERROR: Token refresh failed (HTTP {e.code}): {body}", file=sys.stderr)
|
||||
print("Re-run setup.py to re-authenticate.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
token_data["token"] = result["access_token"]
|
||||
token_data["expiry"] = datetime.fromtimestamp(
|
||||
datetime.now(timezone.utc).timestamp() + result["expires_in"],
|
||||
tz=timezone.utc,
|
||||
).isoformat()
|
||||
|
||||
get_token_path().write_text(json.dumps(token_data, indent=2))
|
||||
return token_data
|
||||
|
||||
|
||||
def get_valid_token() -> str:
|
||||
"""Return a valid access token, refreshing if needed."""
|
||||
token_path = get_token_path()
|
||||
if not token_path.exists():
|
||||
print("ERROR: No Google token found. Run setup.py --auth-url first.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
token_data = json.loads(token_path.read_text())
|
||||
|
||||
expiry = token_data.get("expiry", "")
|
||||
if expiry:
|
||||
exp_dt = datetime.fromisoformat(expiry.replace("Z", "+00:00"))
|
||||
now = datetime.now(timezone.utc)
|
||||
if now >= exp_dt:
|
||||
token_data = refresh_token(token_data)
|
||||
|
||||
return token_data["token"]
|
||||
|
||||
|
||||
def main():
|
||||
"""Refresh token if needed, then exec gws with remaining args."""
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: gws_bridge.py <gws args...>", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
access_token = get_valid_token()
|
||||
env = os.environ.copy()
|
||||
env["GOOGLE_WORKSPACE_CLI_TOKEN"] = access_token
|
||||
|
||||
result = subprocess.run(["gws"] + sys.argv[1:], env=env)
|
||||
sys.exit(result.returncode)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -23,7 +23,6 @@ Agent workflow:
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
@@ -129,11 +128,7 @@ def check_auth():
|
||||
from google.auth.transport.requests import Request
|
||||
|
||||
try:
|
||||
# Don't pass scopes — user may have authorized only a subset.
|
||||
# Passing scopes forces google-auth to validate them on refresh,
|
||||
# which fails with invalid_scope if the token has fewer scopes
|
||||
# than requested.
|
||||
creds = Credentials.from_authorized_user_file(str(TOKEN_PATH))
|
||||
creds = Credentials.from_authorized_user_file(str(TOKEN_PATH), SCOPES)
|
||||
except Exception as e:
|
||||
print(f"TOKEN_CORRUPT: {e}")
|
||||
return False
|
||||
@@ -142,9 +137,8 @@ def check_auth():
|
||||
if creds.valid:
|
||||
missing_scopes = _missing_scopes_from_payload(payload)
|
||||
if missing_scopes:
|
||||
print(f"AUTHENTICATED (partial): Token valid but missing {len(missing_scopes)} scopes:")
|
||||
for s in missing_scopes:
|
||||
print(f" - {s}")
|
||||
print(f"AUTH_SCOPE_MISMATCH: {_format_missing_scopes(missing_scopes)}")
|
||||
return False
|
||||
print(f"AUTHENTICATED: Token valid at {TOKEN_PATH}")
|
||||
return True
|
||||
|
||||
@@ -154,9 +148,8 @@ def check_auth():
|
||||
TOKEN_PATH.write_text(creds.to_json())
|
||||
missing_scopes = _missing_scopes_from_payload(_load_token_payload(TOKEN_PATH))
|
||||
if missing_scopes:
|
||||
print(f"AUTHENTICATED (partial): Token refreshed but missing {len(missing_scopes)} scopes:")
|
||||
for s in missing_scopes:
|
||||
print(f" - {s}")
|
||||
print(f"AUTH_SCOPE_MISMATCH: {_format_missing_scopes(missing_scopes)}")
|
||||
return False
|
||||
print(f"AUTHENTICATED: Token refreshed at {TOKEN_PATH}")
|
||||
return True
|
||||
except Exception as e:
|
||||
@@ -279,33 +272,16 @@ def exchange_auth_code(code: str):
|
||||
|
||||
_ensure_deps()
|
||||
from google_auth_oauthlib.flow import Flow
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
# Extract granted scopes from the callback URL if present
|
||||
if returned_state and "scope" in parse_qs(urlparse(code).query if isinstance(code, str) and code.startswith("http") else {}):
|
||||
granted_scopes = parse_qs(urlparse(code).query)["scope"][0].split()
|
||||
else:
|
||||
# Try to extract from code_or_url parameter
|
||||
if isinstance(code, str) and code.startswith("http"):
|
||||
params = parse_qs(urlparse(code).query)
|
||||
if "scope" in params:
|
||||
granted_scopes = params["scope"][0].split()
|
||||
else:
|
||||
granted_scopes = SCOPES
|
||||
else:
|
||||
granted_scopes = SCOPES
|
||||
|
||||
flow = Flow.from_client_secrets_file(
|
||||
str(CLIENT_SECRET_PATH),
|
||||
scopes=granted_scopes,
|
||||
scopes=SCOPES,
|
||||
redirect_uri=pending_auth.get("redirect_uri", REDIRECT_URI),
|
||||
state=pending_auth["state"],
|
||||
code_verifier=pending_auth["code_verifier"],
|
||||
)
|
||||
|
||||
try:
|
||||
# Accept partial scopes — user may deselect some permissions in the consent screen
|
||||
os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1"
|
||||
flow.fetch_token(code=code)
|
||||
except Exception as e:
|
||||
print(f"ERROR: Token exchange failed: {e}")
|
||||
@@ -314,21 +290,11 @@ def exchange_auth_code(code: str):
|
||||
|
||||
creds = flow.credentials
|
||||
token_payload = json.loads(creds.to_json())
|
||||
|
||||
# Store only the scopes actually granted by the user, not what was requested.
|
||||
# creds.to_json() writes the requested scopes, which causes refresh to fail
|
||||
# with invalid_scope if the user only authorized a subset.
|
||||
actually_granted = list(creds.granted_scopes or []) if hasattr(creds, "granted_scopes") and creds.granted_scopes else []
|
||||
if actually_granted:
|
||||
token_payload["scopes"] = actually_granted
|
||||
elif granted_scopes != SCOPES:
|
||||
# granted_scopes was extracted from the callback URL
|
||||
token_payload["scopes"] = granted_scopes
|
||||
|
||||
missing_scopes = _missing_scopes_from_payload(token_payload)
|
||||
if missing_scopes:
|
||||
print(f"WARNING: Token missing some Google Workspace scopes: {', '.join(missing_scopes)}")
|
||||
print("Some services may not be available.")
|
||||
print(f"ERROR: Refusing to save incomplete Google Workspace token. {_format_missing_scopes(missing_scopes)}")
|
||||
print(f"Existing token at {TOKEN_PATH} was left unchanged.")
|
||||
sys.exit(1)
|
||||
|
||||
TOKEN_PATH.write_text(json.dumps(token_payload, indent=2))
|
||||
PENDING_AUTH_PATH.unlink(missing_ok=True)
|
||||
|
||||
@@ -17,6 +17,7 @@ from agent.anthropic_adapter import (
|
||||
build_anthropic_kwargs,
|
||||
convert_messages_to_anthropic,
|
||||
convert_tools_to_anthropic,
|
||||
get_anthropic_token_source,
|
||||
is_claude_code_token_valid,
|
||||
normalize_anthropic_response,
|
||||
normalize_model_name,
|
||||
@@ -164,6 +165,15 @@ class TestResolveAnthropicToken:
|
||||
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
|
||||
assert resolve_anthropic_token() == "sk-ant-oat01-mytoken"
|
||||
|
||||
def test_reports_claude_json_primary_key_source(self, monkeypatch, tmp_path):
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
|
||||
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
|
||||
(tmp_path / ".claude.json").write_text(json.dumps({"primaryApiKey": "sk-ant-api03-primary"}))
|
||||
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
|
||||
|
||||
assert get_anthropic_token_source("sk-ant-api03-primary") == "claude_json_primary_api_key"
|
||||
|
||||
def test_does_not_resolve_primary_api_key_as_native_anthropic_token(self, monkeypatch, tmp_path):
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
|
||||
|
||||
@@ -9,6 +9,7 @@ import pytest
|
||||
|
||||
from agent.auxiliary_client import (
|
||||
get_text_auxiliary_client,
|
||||
get_vision_auxiliary_client,
|
||||
get_available_vision_backends,
|
||||
resolve_vision_provider_client,
|
||||
resolve_provider_client,
|
||||
@@ -19,6 +20,7 @@ from agent.auxiliary_client import (
|
||||
_get_provider_chain,
|
||||
_is_payment_error,
|
||||
_try_payment_fallback,
|
||||
_resolve_forced_provider,
|
||||
_resolve_auto,
|
||||
)
|
||||
|
||||
@@ -75,20 +77,6 @@ class TestReadCodexAccessToken:
|
||||
result = _read_codex_access_token()
|
||||
assert result == "tok-123"
|
||||
|
||||
def test_pool_without_selected_entry_falls_back_to_auth_store(self, tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
valid_jwt = "eyJhbGciOiJSUzI1NiJ9.eyJleHAiOjk5OTk5OTk5OTl9.sig"
|
||||
with patch("agent.auxiliary_client._select_pool_entry", return_value=(True, None)), \
|
||||
patch("hermes_cli.auth._read_codex_tokens", return_value={
|
||||
"tokens": {"access_token": valid_jwt, "refresh_token": "refresh"}
|
||||
}):
|
||||
result = _read_codex_access_token()
|
||||
|
||||
assert result == valid_jwt
|
||||
|
||||
def test_missing_returns_none(self, tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
@@ -250,24 +238,6 @@ class TestAnthropicOAuthFlag:
|
||||
assert mock_build.call_args.args[0] == "sk-ant-oat01-pooled"
|
||||
|
||||
|
||||
class TestTryCodex:
|
||||
def test_pool_without_selected_entry_falls_back_to_auth_store(self):
|
||||
with (
|
||||
patch("agent.auxiliary_client._select_pool_entry", return_value=(True, None)),
|
||||
patch("agent.auxiliary_client._read_codex_access_token", return_value="codex-auth-token"),
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai,
|
||||
):
|
||||
mock_openai.return_value = MagicMock()
|
||||
from agent.auxiliary_client import _try_codex
|
||||
|
||||
client, model = _try_codex()
|
||||
|
||||
assert client is not None
|
||||
assert model == "gpt-5.2-codex"
|
||||
assert mock_openai.call_args.kwargs["api_key"] == "codex-auth-token"
|
||||
assert mock_openai.call_args.kwargs["base_url"] == "https://chatgpt.com/backend-api/codex"
|
||||
|
||||
|
||||
class TestExpiredCodexFallback:
|
||||
"""Test that expired Codex tokens don't block the auto chain."""
|
||||
|
||||
@@ -662,6 +632,15 @@ class TestGetTextAuxiliaryClient:
|
||||
class TestVisionClientFallback:
|
||||
"""Vision client auto mode resolves known-good multimodal backends."""
|
||||
|
||||
def test_vision_returns_none_without_any_credentials(self):
|
||||
with (
|
||||
patch("agent.auxiliary_client._read_nous_auth", return_value=None),
|
||||
patch("agent.auxiliary_client._try_anthropic", return_value=(None, None)),
|
||||
):
|
||||
client, model = get_vision_auxiliary_client()
|
||||
assert client is None
|
||||
assert model is None
|
||||
|
||||
def test_vision_auto_includes_active_provider_when_configured(self, monkeypatch):
|
||||
"""Active provider appears in available backends when credentials exist."""
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "***")
|
||||
@@ -743,6 +722,21 @@ class TestAuxiliaryPoolAwareness:
|
||||
assert call_kwargs["base_url"] == "https://api.githubcopilot.com"
|
||||
assert call_kwargs["default_headers"]["Editor-Version"]
|
||||
|
||||
def test_vision_auto_uses_active_provider_as_fallback(self, monkeypatch):
|
||||
"""When no OpenRouter/Nous available, vision auto falls back to active provider."""
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "***")
|
||||
with (
|
||||
patch("agent.auxiliary_client._read_nous_auth", return_value=None),
|
||||
patch("agent.auxiliary_client._read_main_provider", return_value="anthropic"),
|
||||
patch("agent.auxiliary_client._read_main_model", return_value="claude-sonnet-4"),
|
||||
patch("agent.anthropic_adapter.build_anthropic_client", return_value=MagicMock()),
|
||||
patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="***"),
|
||||
):
|
||||
client, model = get_vision_auxiliary_client()
|
||||
|
||||
assert client is not None
|
||||
assert client.__class__.__name__ == "AnthropicAuxiliaryClient"
|
||||
|
||||
def test_vision_auto_prefers_active_provider_over_openrouter(self, monkeypatch):
|
||||
"""Active provider is tried before OpenRouter in vision auto."""
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
@@ -774,6 +768,43 @@ class TestAuxiliaryPoolAwareness:
|
||||
assert client is not None
|
||||
assert provider == "custom:local"
|
||||
|
||||
def test_vision_direct_endpoint_override(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
monkeypatch.setenv("AUXILIARY_VISION_BASE_URL", "http://localhost:4567/v1")
|
||||
monkeypatch.setenv("AUXILIARY_VISION_API_KEY", "vision-key")
|
||||
monkeypatch.setenv("AUXILIARY_VISION_MODEL", "vision-model")
|
||||
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = get_vision_auxiliary_client()
|
||||
assert model == "vision-model"
|
||||
assert mock_openai.call_args.kwargs["base_url"] == "http://localhost:4567/v1"
|
||||
assert mock_openai.call_args.kwargs["api_key"] == "vision-key"
|
||||
|
||||
def test_vision_direct_endpoint_without_key_uses_placeholder(self, monkeypatch):
|
||||
"""Vision endpoint without API key should use 'no-key-required' placeholder."""
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
monkeypatch.setenv("AUXILIARY_VISION_BASE_URL", "http://localhost:4567/v1")
|
||||
monkeypatch.setenv("AUXILIARY_VISION_MODEL", "vision-model")
|
||||
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = get_vision_auxiliary_client()
|
||||
assert client is not None
|
||||
assert model == "vision-model"
|
||||
assert mock_openai.call_args.kwargs["api_key"] == "no-key-required"
|
||||
|
||||
def test_vision_uses_openrouter_when_available(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = get_vision_auxiliary_client()
|
||||
assert model == "google/gemini-3-flash-preview"
|
||||
assert client is not None
|
||||
|
||||
def test_vision_uses_nous_when_available(self, monkeypatch):
|
||||
with patch("agent.auxiliary_client._read_nous_auth") as mock_nous, \
|
||||
patch("agent.auxiliary_client.OpenAI"):
|
||||
mock_nous.return_value = {"access_token": "nous-tok"}
|
||||
client, model = get_vision_auxiliary_client()
|
||||
assert model == "google/gemini-3-flash-preview"
|
||||
assert client is not None
|
||||
|
||||
def test_vision_config_google_provider_uses_gemini_credentials(self, monkeypatch):
|
||||
config = {
|
||||
"auxiliary": {
|
||||
@@ -799,6 +830,53 @@ class TestAuxiliaryPoolAwareness:
|
||||
assert mock_openai.call_args.kwargs["api_key"] == "gemini-key"
|
||||
assert mock_openai.call_args.kwargs["base_url"] == "https://generativelanguage.googleapis.com/v1beta/openai"
|
||||
|
||||
def test_vision_forced_main_uses_custom_endpoint(self, monkeypatch):
|
||||
"""When explicitly forced to 'main', vision CAN use custom endpoint."""
|
||||
config = {
|
||||
"model": {
|
||||
"provider": "custom",
|
||||
"base_url": "http://localhost:1234/v1",
|
||||
"default": "my-local-model",
|
||||
}
|
||||
}
|
||||
monkeypatch.setenv("AUXILIARY_VISION_PROVIDER", "main")
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "local-key")
|
||||
monkeypatch.setattr("hermes_cli.config.load_config", lambda: config)
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.load_config", lambda: config)
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = get_vision_auxiliary_client()
|
||||
assert client is not None
|
||||
assert model == "my-local-model"
|
||||
|
||||
def test_vision_forced_main_returns_none_without_creds(self, monkeypatch):
|
||||
"""Forced main with no credentials still returns None."""
|
||||
monkeypatch.setenv("AUXILIARY_VISION_PROVIDER", "main")
|
||||
monkeypatch.delenv("OPENAI_BASE_URL", raising=False)
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
# Clear client cache to avoid stale entries from previous tests
|
||||
from agent.auxiliary_client import _client_cache
|
||||
_client_cache.clear()
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client._read_main_provider", return_value=""), \
|
||||
patch("agent.auxiliary_client._read_main_model", return_value=""), \
|
||||
patch("agent.auxiliary_client._select_pool_entry", return_value=(False, None)), \
|
||||
patch("agent.auxiliary_client._resolve_custom_runtime", return_value=(None, None)), \
|
||||
patch("agent.auxiliary_client._read_codex_access_token", return_value=None), \
|
||||
patch("agent.auxiliary_client._resolve_api_key_provider", return_value=(None, None)):
|
||||
client, model = get_vision_auxiliary_client()
|
||||
assert client is None
|
||||
assert model is None
|
||||
|
||||
def test_vision_forced_codex(self, monkeypatch, codex_auth_dir):
|
||||
"""When forced to 'codex', vision uses Codex OAuth."""
|
||||
monkeypatch.setenv("AUXILIARY_VISION_PROVIDER", "codex")
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client.OpenAI"):
|
||||
client, model = get_vision_auxiliary_client()
|
||||
from agent.auxiliary_client import CodexAuxiliaryClient
|
||||
assert isinstance(client, CodexAuxiliaryClient)
|
||||
assert model == "gpt-5.2-codex"
|
||||
|
||||
|
||||
class TestGetAuxiliaryProvider:
|
||||
@@ -838,6 +916,122 @@ class TestGetAuxiliaryProvider:
|
||||
assert _get_auxiliary_provider("web_extract") == "main"
|
||||
|
||||
|
||||
class TestResolveForcedProvider:
|
||||
"""Tests for _resolve_forced_provider with explicit provider selection."""
|
||||
|
||||
def test_forced_openrouter(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = _resolve_forced_provider("openrouter")
|
||||
assert model == "google/gemini-3-flash-preview"
|
||||
assert client is not None
|
||||
|
||||
def test_forced_openrouter_no_key(self, monkeypatch):
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None):
|
||||
client, model = _resolve_forced_provider("openrouter")
|
||||
assert client is None
|
||||
assert model is None
|
||||
|
||||
def test_forced_nous(self, monkeypatch):
|
||||
with patch("agent.auxiliary_client._read_nous_auth") as mock_nous, \
|
||||
patch("agent.auxiliary_client.OpenAI"):
|
||||
mock_nous.return_value = {"access_token": "nous-tok"}
|
||||
client, model = _resolve_forced_provider("nous")
|
||||
assert model == "google/gemini-3-flash-preview"
|
||||
assert client is not None
|
||||
|
||||
def test_forced_nous_not_configured(self, monkeypatch):
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None):
|
||||
client, model = _resolve_forced_provider("nous")
|
||||
assert client is None
|
||||
assert model is None
|
||||
|
||||
def test_forced_main_uses_custom(self, monkeypatch):
|
||||
config = {
|
||||
"model": {
|
||||
"provider": "custom",
|
||||
"base_url": "http://local:8080/v1",
|
||||
"default": "my-local-model",
|
||||
}
|
||||
}
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "local-key")
|
||||
monkeypatch.setattr("hermes_cli.config.load_config", lambda: config)
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.load_config", lambda: config)
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = _resolve_forced_provider("main")
|
||||
assert model == "my-local-model"
|
||||
|
||||
def test_forced_main_uses_config_saved_custom_endpoint(self, monkeypatch):
|
||||
config = {
|
||||
"model": {
|
||||
"provider": "custom",
|
||||
"base_url": "http://local:8080/v1",
|
||||
"default": "my-local-model",
|
||||
}
|
||||
}
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "local-key")
|
||||
monkeypatch.setattr("hermes_cli.config.load_config", lambda: config)
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.load_config", lambda: config)
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client._read_codex_access_token", return_value=None), \
|
||||
patch("agent.auxiliary_client._resolve_api_key_provider", return_value=(None, None)), \
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = _resolve_forced_provider("main")
|
||||
assert client is not None
|
||||
assert model == "my-local-model"
|
||||
call_kwargs = mock_openai.call_args
|
||||
assert call_kwargs.kwargs["base_url"] == "http://local:8080/v1"
|
||||
|
||||
def test_forced_main_skips_openrouter_nous(self, monkeypatch):
|
||||
"""Even if OpenRouter key is set, 'main' skips it."""
|
||||
config = {
|
||||
"model": {
|
||||
"provider": "custom",
|
||||
"base_url": "http://local:8080/v1",
|
||||
"default": "my-local-model",
|
||||
}
|
||||
}
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "local-key")
|
||||
monkeypatch.setattr("hermes_cli.config.load_config", lambda: config)
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.load_config", lambda: config)
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = _resolve_forced_provider("main")
|
||||
# Should use custom endpoint, not OpenRouter
|
||||
assert model == "my-local-model"
|
||||
|
||||
def test_forced_main_falls_to_codex(self, codex_auth_dir, monkeypatch):
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client.OpenAI"):
|
||||
client, model = _resolve_forced_provider("main")
|
||||
from agent.auxiliary_client import CodexAuxiliaryClient
|
||||
assert isinstance(client, CodexAuxiliaryClient)
|
||||
assert model == "gpt-5.2-codex"
|
||||
|
||||
def test_forced_codex(self, codex_auth_dir, monkeypatch):
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client.OpenAI"):
|
||||
client, model = _resolve_forced_provider("codex")
|
||||
from agent.auxiliary_client import CodexAuxiliaryClient
|
||||
assert isinstance(client, CodexAuxiliaryClient)
|
||||
assert model == "gpt-5.2-codex"
|
||||
|
||||
def test_forced_codex_no_token(self, monkeypatch):
|
||||
with patch("agent.auxiliary_client._read_codex_access_token", return_value=None):
|
||||
client, model = _resolve_forced_provider("codex")
|
||||
assert client is None
|
||||
assert model is None
|
||||
|
||||
def test_forced_unknown_returns_none(self, monkeypatch):
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client._read_codex_access_token", return_value=None):
|
||||
client, model = _resolve_forced_provider("invalid-provider")
|
||||
assert client is None
|
||||
assert model is None
|
||||
|
||||
|
||||
class TestTaskSpecificOverrides:
|
||||
"""Integration tests for per-task provider routing via get_text_auxiliary_client(task=...)."""
|
||||
|
||||
|
||||
@@ -38,6 +38,16 @@ class TestShouldCompress:
|
||||
assert compressor.should_compress(prompt_tokens=50000) is False
|
||||
|
||||
|
||||
class TestShouldCompressPreflight:
|
||||
def test_short_messages(self, compressor):
|
||||
msgs = [{"role": "user", "content": "short"}]
|
||||
assert compressor.should_compress_preflight(msgs) is False
|
||||
|
||||
def test_long_messages(self, compressor):
|
||||
# Each message ~100k chars / 4 = 25k tokens, need >85k threshold
|
||||
msgs = [{"role": "user", "content": "x" * 400000}]
|
||||
assert compressor.should_compress_preflight(msgs) is True
|
||||
|
||||
|
||||
class TestUpdateFromResponse:
|
||||
def test_updates_fields(self, compressor):
|
||||
@@ -48,12 +58,27 @@ class TestUpdateFromResponse:
|
||||
})
|
||||
assert compressor.last_prompt_tokens == 5000
|
||||
assert compressor.last_completion_tokens == 1000
|
||||
assert compressor.last_total_tokens == 6000
|
||||
|
||||
def test_missing_fields_default_zero(self, compressor):
|
||||
compressor.update_from_response({})
|
||||
assert compressor.last_prompt_tokens == 0
|
||||
|
||||
|
||||
class TestGetStatus:
|
||||
def test_returns_expected_keys(self, compressor):
|
||||
status = compressor.get_status()
|
||||
assert "last_prompt_tokens" in status
|
||||
assert "threshold_tokens" in status
|
||||
assert "context_length" in status
|
||||
assert "usage_percent" in status
|
||||
assert "compression_count" in status
|
||||
|
||||
def test_usage_percent_calculation(self, compressor):
|
||||
compressor.last_prompt_tokens = 50000
|
||||
status = compressor.get_status()
|
||||
assert status["usage_percent"] == 50.0
|
||||
|
||||
|
||||
class TestCompress:
|
||||
def _make_messages(self, n):
|
||||
@@ -299,10 +324,7 @@ class TestCompressWithClient:
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100000):
|
||||
c = ContextCompressor(model="test", quiet_mode=True, protect_first_n=2, protect_last_n=2)
|
||||
|
||||
# Last head message (index 1) is "assistant" → summary should be "user".
|
||||
# With min_tail=3, tail = last 3 messages (indices 5-7).
|
||||
# head_last=assistant, tail_first=assistant → summary_role="user", no collision.
|
||||
# Need 8 messages: min_for_compress = 2+3+1 = 6, must have > 6.
|
||||
# Last head message (index 1) is "assistant" → summary should be "user"
|
||||
msgs = [
|
||||
{"role": "user", "content": "msg 0"},
|
||||
{"role": "assistant", "content": "msg 1"},
|
||||
@@ -310,8 +332,6 @@ class TestCompressWithClient:
|
||||
{"role": "assistant", "content": "msg 3"},
|
||||
{"role": "user", "content": "msg 4"},
|
||||
{"role": "assistant", "content": "msg 5"},
|
||||
{"role": "user", "content": "msg 6"},
|
||||
{"role": "assistant", "content": "msg 7"},
|
||||
]
|
||||
with patch("agent.context_compressor.call_llm", return_value=mock_response):
|
||||
result = c.compress(msgs)
|
||||
@@ -440,10 +460,8 @@ class TestCompressWithClient:
|
||||
c = ContextCompressor(model="test", quiet_mode=True, protect_first_n=2, protect_last_n=2)
|
||||
|
||||
# Head: [system, user] → last head = user
|
||||
# Tail: [assistant, user, assistant] → first tail = assistant
|
||||
# Tail: [assistant, user] → first tail = assistant
|
||||
# summary_role="assistant" collides with tail, "user" collides with head → merge
|
||||
# With min_tail=3, tail = last 3 messages (indices 5-7).
|
||||
# Need 8 messages: min_for_compress = 2+3+1 = 6, must have > 6.
|
||||
msgs = [
|
||||
{"role": "system", "content": "system prompt"},
|
||||
{"role": "user", "content": "msg 1"},
|
||||
@@ -452,7 +470,6 @@ class TestCompressWithClient:
|
||||
{"role": "assistant", "content": "msg 4"}, # compressed
|
||||
{"role": "assistant", "content": "msg 5"}, # tail start
|
||||
{"role": "user", "content": "msg 6"},
|
||||
{"role": "assistant", "content": "msg 7"},
|
||||
]
|
||||
with patch("agent.context_compressor.call_llm", return_value=mock_response):
|
||||
result = c.compress(msgs)
|
||||
@@ -464,7 +481,7 @@ class TestCompressWithClient:
|
||||
if r1 in ("user", "assistant") and r2 in ("user", "assistant"):
|
||||
assert r1 != r2, f"consecutive {r1} at indices {i-1},{i}"
|
||||
|
||||
# The summary should be merged into the first tail message (assistant at index 5)
|
||||
# The summary should be merged into the first tail message (assistant)
|
||||
first_tail = [m for m in result if "msg 5" in (m.get("content") or "")]
|
||||
assert len(first_tail) == 1
|
||||
assert "summary text" in first_tail[0]["content"]
|
||||
@@ -479,18 +496,14 @@ class TestCompressWithClient:
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100000):
|
||||
c = ContextCompressor(model="test", quiet_mode=True, protect_first_n=2, protect_last_n=2)
|
||||
|
||||
# Head=assistant, Tail=assistant → summary_role="user", no collision.
|
||||
# With min_tail=3, tail = last 3 messages (indices 5-7).
|
||||
# Need 8 messages: min_for_compress = 2+3+1 = 6, must have > 6.
|
||||
# Head=assistant, Tail=assistant → summary_role="user", no collision
|
||||
msgs = [
|
||||
{"role": "user", "content": "msg 0"},
|
||||
{"role": "assistant", "content": "msg 1"},
|
||||
{"role": "user", "content": "msg 2"},
|
||||
{"role": "assistant", "content": "msg 3"},
|
||||
{"role": "user", "content": "msg 4"},
|
||||
{"role": "assistant", "content": "msg 5"},
|
||||
{"role": "user", "content": "msg 6"},
|
||||
{"role": "assistant", "content": "msg 7"},
|
||||
{"role": "assistant", "content": "msg 4"},
|
||||
{"role": "user", "content": "msg 5"},
|
||||
]
|
||||
with patch("agent.context_compressor.call_llm", return_value=mock_response):
|
||||
result = c.compress(msgs)
|
||||
@@ -587,158 +600,3 @@ class TestSummaryTargetRatio:
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100_000):
|
||||
c = ContextCompressor(model="test", quiet_mode=True)
|
||||
assert c.protect_last_n == 20
|
||||
|
||||
|
||||
class TestTokenBudgetTailProtection:
|
||||
"""Tests for token-budget-based tail protection (PR #6240).
|
||||
|
||||
The core change: tail protection is now based on a token budget rather
|
||||
than a fixed message count. This prevents large tool outputs from
|
||||
blocking compaction.
|
||||
"""
|
||||
|
||||
@pytest.fixture()
|
||||
def budget_compressor(self):
|
||||
"""Compressor with known token budget for tail protection tests."""
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=200_000):
|
||||
c = ContextCompressor(
|
||||
model="test/model",
|
||||
threshold_percent=0.50, # 100K threshold
|
||||
protect_first_n=2,
|
||||
protect_last_n=20,
|
||||
quiet_mode=True,
|
||||
)
|
||||
return c
|
||||
|
||||
def test_large_tool_outputs_no_longer_block_compaction(self, budget_compressor):
|
||||
"""The motivating scenario: 20 messages with large tool outputs should
|
||||
NOT prevent compaction. With message-count tail protection they would
|
||||
all be protected, leaving nothing to summarize."""
|
||||
c = budget_compressor
|
||||
messages = [
|
||||
{"role": "user", "content": "Start task"},
|
||||
{"role": "assistant", "content": "On it"},
|
||||
]
|
||||
# Add 20 messages with large tool outputs (~5K chars each ≈ 1250 tokens)
|
||||
for i in range(10):
|
||||
messages.append({
|
||||
"role": "assistant", "content": None,
|
||||
"tool_calls": [{"function": {"name": f"tool_{i}", "arguments": "{}"}}],
|
||||
})
|
||||
messages.append({
|
||||
"role": "tool", "content": "x" * 5000,
|
||||
"tool_call_id": f"call_{i}",
|
||||
})
|
||||
# Add 3 recent small messages
|
||||
messages.append({"role": "user", "content": "What's the status?"})
|
||||
messages.append({"role": "assistant", "content": "Here's what I found..."})
|
||||
messages.append({"role": "user", "content": "Continue"})
|
||||
|
||||
# The tail cut should NOT protect all 20 tool messages
|
||||
head_end = c.protect_first_n
|
||||
cut = c._find_tail_cut_by_tokens(messages, head_end)
|
||||
tail_size = len(messages) - cut
|
||||
# With token budget, the tail should be much smaller than 20+
|
||||
assert tail_size < 20, f"Tail {tail_size} messages — large tool outputs are blocking compaction"
|
||||
# But at least 3 (hard minimum)
|
||||
assert tail_size >= 3
|
||||
|
||||
def test_min_tail_always_3_messages(self, budget_compressor):
|
||||
"""Even with a tiny token budget, at least 3 messages are protected."""
|
||||
c = budget_compressor
|
||||
# Override to a tiny budget
|
||||
c.tail_token_budget = 10
|
||||
messages = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi"},
|
||||
{"role": "user", "content": "do something"},
|
||||
{"role": "assistant", "content": "working on it"},
|
||||
{"role": "user", "content": "more work"},
|
||||
{"role": "assistant", "content": "done"},
|
||||
{"role": "user", "content": "thanks"},
|
||||
]
|
||||
head_end = 2
|
||||
cut = c._find_tail_cut_by_tokens(messages, head_end)
|
||||
tail_size = len(messages) - cut
|
||||
assert tail_size >= 3, f"Tail is only {tail_size} messages, min should be 3"
|
||||
|
||||
def test_soft_ceiling_allows_oversized_message(self, budget_compressor):
|
||||
"""The 1.5x soft ceiling allows an oversized message to be included
|
||||
rather than splitting it."""
|
||||
c = budget_compressor
|
||||
# Set a small budget — 500 tokens
|
||||
c.tail_token_budget = 500
|
||||
messages = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi"},
|
||||
{"role": "user", "content": "read the file"},
|
||||
# This message is ~600 tokens (> budget of 500, but < 1.5x = 750)
|
||||
{"role": "assistant", "content": "a" * 2400},
|
||||
{"role": "user", "content": "short"},
|
||||
{"role": "assistant", "content": "short reply"},
|
||||
{"role": "user", "content": "continue"},
|
||||
]
|
||||
head_end = 2
|
||||
cut = c._find_tail_cut_by_tokens(messages, head_end)
|
||||
# The oversized message at index 3 should NOT be the cut point
|
||||
# because 1.5x ceiling = 750 tokens and accumulated would be ~610
|
||||
# (short msgs + oversized msg) which is < 750
|
||||
tail_size = len(messages) - cut
|
||||
assert tail_size >= 3
|
||||
|
||||
def test_small_conversation_still_compresses(self, budget_compressor):
|
||||
"""With the new min of 8 messages (head=2 + 3 + 1 guard + 2 middle),
|
||||
a small but compressible conversation should still compress."""
|
||||
c = budget_compressor
|
||||
# 9 messages: head(2) + 4 middle + 3 tail = compressible
|
||||
messages = []
|
||||
for i in range(9):
|
||||
role = "user" if i % 2 == 0 else "assistant"
|
||||
messages.append({"role": role, "content": f"Message {i}"})
|
||||
|
||||
# Should not early-return (needs > protect_first_n + 3 + 1 = 6)
|
||||
# Mock the summary generation to avoid real API call
|
||||
with patch.object(c, "_generate_summary", return_value="Summary of conversation"):
|
||||
result = c.compress(messages, current_tokens=90_000)
|
||||
# Should have compressed (fewer messages than original)
|
||||
assert len(result) < len(messages)
|
||||
|
||||
def test_prune_with_token_budget(self, budget_compressor):
|
||||
"""_prune_old_tool_results with protect_tail_tokens respects the budget."""
|
||||
c = budget_compressor
|
||||
messages = [
|
||||
{"role": "user", "content": "start"},
|
||||
{"role": "assistant", "content": None,
|
||||
"tool_calls": [{"function": {"name": "read_file", "arguments": '{"path": "big.txt"}'}}]},
|
||||
{"role": "tool", "content": "x" * 10000, "tool_call_id": "c1"}, # ~2500 tokens
|
||||
{"role": "assistant", "content": None,
|
||||
"tool_calls": [{"function": {"name": "read_file", "arguments": '{"path": "small.txt"}'}}]},
|
||||
{"role": "tool", "content": "y" * 10000, "tool_call_id": "c2"}, # ~2500 tokens
|
||||
{"role": "user", "content": "short recent message"},
|
||||
{"role": "assistant", "content": "short reply"},
|
||||
]
|
||||
# With a 1000-token budget, only the last couple messages should be protected
|
||||
result, pruned = c._prune_old_tool_results(
|
||||
messages, protect_tail_count=2, protect_tail_tokens=1000,
|
||||
)
|
||||
# At least one old tool result should have been pruned
|
||||
assert pruned >= 1
|
||||
|
||||
def test_prune_without_token_budget_uses_message_count(self, budget_compressor):
|
||||
"""Without protect_tail_tokens, falls back to message-count behavior."""
|
||||
c = budget_compressor
|
||||
messages = [
|
||||
{"role": "user", "content": "start"},
|
||||
{"role": "assistant", "content": None,
|
||||
"tool_calls": [{"function": {"name": "tool", "arguments": "{}"}}]},
|
||||
{"role": "tool", "content": "x" * 5000, "tool_call_id": "c1"},
|
||||
{"role": "user", "content": "recent"},
|
||||
{"role": "assistant", "content": "reply"},
|
||||
]
|
||||
# protect_tail_count=3 means last 3 messages protected
|
||||
result, pruned = c._prune_old_tool_results(
|
||||
messages, protect_tail_count=3,
|
||||
)
|
||||
# Tool at index 2 is outside the protected tail (last 3 = indices 2,3,4)
|
||||
# so it might or might not be pruned depending on boundary
|
||||
assert isinstance(pruned, int)
|
||||
|
||||
@@ -214,42 +214,6 @@ def test_exhausted_entry_resets_after_ttl(tmp_path, monkeypatch):
|
||||
assert entry.last_status == "ok"
|
||||
|
||||
|
||||
def test_exhausted_402_entry_resets_after_one_hour(tmp_path, monkeypatch):
|
||||
"""402-exhausted credentials recover after 1 hour, not 24."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
_write_auth_store(
|
||||
tmp_path,
|
||||
{
|
||||
"version": 1,
|
||||
"credential_pool": {
|
||||
"openrouter": [
|
||||
{
|
||||
"id": "cred-1",
|
||||
"label": "primary",
|
||||
"auth_type": "api_key",
|
||||
"priority": 0,
|
||||
"source": "manual",
|
||||
"access_token": "***",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"last_status": "exhausted",
|
||||
"last_status_at": time.time() - 3700, # ~1h2m ago
|
||||
"last_error_code": 402,
|
||||
}
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
from agent.credential_pool import load_pool
|
||||
|
||||
pool = load_pool("openrouter")
|
||||
entry = pool.select()
|
||||
|
||||
assert entry is not None
|
||||
assert entry.id == "cred-1"
|
||||
assert entry.last_status == "ok"
|
||||
|
||||
|
||||
def test_explicit_reset_timestamp_overrides_default_429_ttl(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
_write_auth_store(
|
||||
|
||||
@@ -1,782 +0,0 @@
|
||||
"""Tests for agent.error_classifier — structured API error classification."""
|
||||
|
||||
import pytest
|
||||
from agent.error_classifier import (
|
||||
ClassifiedError,
|
||||
FailoverReason,
|
||||
classify_api_error,
|
||||
_extract_status_code,
|
||||
_extract_error_body,
|
||||
_extract_error_code,
|
||||
_classify_402,
|
||||
)
|
||||
|
||||
|
||||
# ── Helper: mock API errors ────────────────────────────────────────────
|
||||
|
||||
class MockAPIError(Exception):
|
||||
"""Simulates an OpenAI SDK APIStatusError."""
|
||||
def __init__(self, message, status_code=None, body=None):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
self.body = body or {}
|
||||
|
||||
|
||||
class MockTransportError(Exception):
|
||||
"""Simulates a transport-level error with a specific type name."""
|
||||
pass
|
||||
|
||||
|
||||
class ReadTimeout(MockTransportError):
|
||||
pass
|
||||
|
||||
|
||||
class ConnectError(MockTransportError):
|
||||
pass
|
||||
|
||||
|
||||
class RemoteProtocolError(MockTransportError):
|
||||
pass
|
||||
|
||||
|
||||
class ServerDisconnectedError(MockTransportError):
|
||||
pass
|
||||
|
||||
|
||||
# ── Test: FailoverReason enum ──────────────────────────────────────────
|
||||
|
||||
class TestFailoverReason:
|
||||
def test_all_reasons_have_string_values(self):
|
||||
for reason in FailoverReason:
|
||||
assert isinstance(reason.value, str)
|
||||
|
||||
def test_enum_members_exist(self):
|
||||
expected = {
|
||||
"auth", "auth_permanent", "billing", "rate_limit",
|
||||
"overloaded", "server_error", "timeout",
|
||||
"context_overflow", "payload_too_large",
|
||||
"model_not_found", "format_error",
|
||||
"thinking_signature", "long_context_tier", "unknown",
|
||||
}
|
||||
actual = {r.value for r in FailoverReason}
|
||||
assert expected == actual
|
||||
|
||||
|
||||
# ── Test: ClassifiedError ──────────────────────────────────────────────
|
||||
|
||||
class TestClassifiedError:
|
||||
def test_is_auth_property(self):
|
||||
e1 = ClassifiedError(reason=FailoverReason.auth)
|
||||
assert e1.is_auth is True
|
||||
|
||||
e2 = ClassifiedError(reason=FailoverReason.auth_permanent)
|
||||
assert e2.is_auth is True
|
||||
|
||||
e3 = ClassifiedError(reason=FailoverReason.billing)
|
||||
assert e3.is_auth is False
|
||||
|
||||
def test_is_transient_property(self):
|
||||
transient_reasons = [
|
||||
FailoverReason.rate_limit,
|
||||
FailoverReason.overloaded,
|
||||
FailoverReason.server_error,
|
||||
FailoverReason.timeout,
|
||||
FailoverReason.unknown,
|
||||
]
|
||||
for reason in transient_reasons:
|
||||
e = ClassifiedError(reason=reason)
|
||||
assert e.is_transient is True, f"{reason} should be transient"
|
||||
|
||||
non_transient = [
|
||||
FailoverReason.auth,
|
||||
FailoverReason.billing,
|
||||
FailoverReason.model_not_found,
|
||||
FailoverReason.format_error,
|
||||
]
|
||||
for reason in non_transient:
|
||||
e = ClassifiedError(reason=reason)
|
||||
assert e.is_transient is False, f"{reason} should NOT be transient"
|
||||
|
||||
def test_defaults(self):
|
||||
e = ClassifiedError(reason=FailoverReason.unknown)
|
||||
assert e.retryable is True
|
||||
assert e.should_compress is False
|
||||
assert e.should_rotate_credential is False
|
||||
assert e.should_fallback is False
|
||||
assert e.status_code is None
|
||||
assert e.message == ""
|
||||
|
||||
|
||||
# ── Test: Status code extraction ───────────────────────────────────────
|
||||
|
||||
class TestExtractStatusCode:
|
||||
def test_from_status_code_attr(self):
|
||||
e = MockAPIError("fail", status_code=429)
|
||||
assert _extract_status_code(e) == 429
|
||||
|
||||
def test_from_status_attr(self):
|
||||
class ErrWithStatus(Exception):
|
||||
status = 503
|
||||
assert _extract_status_code(ErrWithStatus()) == 503
|
||||
|
||||
def test_from_cause_chain(self):
|
||||
inner = MockAPIError("inner", status_code=401)
|
||||
outer = Exception("outer")
|
||||
outer.__cause__ = inner
|
||||
assert _extract_status_code(outer) == 401
|
||||
|
||||
def test_none_when_missing(self):
|
||||
assert _extract_status_code(Exception("generic")) is None
|
||||
|
||||
def test_rejects_non_http_status(self):
|
||||
"""Integers outside 100-599 on .status should be ignored."""
|
||||
class ErrWeirdStatus(Exception):
|
||||
status = 42
|
||||
assert _extract_status_code(ErrWeirdStatus()) is None
|
||||
|
||||
|
||||
# ── Test: Error body extraction ────────────────────────────────────────
|
||||
|
||||
class TestExtractErrorBody:
|
||||
def test_from_body_attr(self):
|
||||
e = MockAPIError("fail", body={"error": {"message": "bad"}})
|
||||
assert _extract_error_body(e) == {"error": {"message": "bad"}}
|
||||
|
||||
def test_empty_when_no_body(self):
|
||||
assert _extract_error_body(Exception("generic")) == {}
|
||||
|
||||
|
||||
# ── Test: Error code extraction ────────────────────────────────────────
|
||||
|
||||
class TestExtractErrorCode:
|
||||
def test_from_nested_error_code(self):
|
||||
body = {"error": {"code": "rate_limit_exceeded"}}
|
||||
assert _extract_error_code(body) == "rate_limit_exceeded"
|
||||
|
||||
def test_from_nested_error_type(self):
|
||||
body = {"error": {"type": "invalid_request_error"}}
|
||||
assert _extract_error_code(body) == "invalid_request_error"
|
||||
|
||||
def test_from_top_level_code(self):
|
||||
body = {"code": "model_not_found"}
|
||||
assert _extract_error_code(body) == "model_not_found"
|
||||
|
||||
def test_empty_when_no_code(self):
|
||||
assert _extract_error_code({}) == ""
|
||||
assert _extract_error_code({"error": {"message": "oops"}}) == ""
|
||||
|
||||
|
||||
# ── Test: 402 disambiguation ───────────────────────────────────────────
|
||||
|
||||
class TestClassify402:
|
||||
"""The critical 402 billing vs rate_limit disambiguation."""
|
||||
|
||||
def test_billing_exhaustion(self):
|
||||
"""Plain 402 = billing."""
|
||||
result = _classify_402(
|
||||
"payment required",
|
||||
lambda reason, **kw: ClassifiedError(reason=reason, **kw),
|
||||
)
|
||||
assert result.reason == FailoverReason.billing
|
||||
assert result.should_rotate_credential is True
|
||||
|
||||
def test_transient_usage_limit(self):
|
||||
"""402 with 'usage limit' + 'try again' = rate limit, not billing."""
|
||||
result = _classify_402(
|
||||
"usage limit exceeded. try again in 5 minutes",
|
||||
lambda reason, **kw: ClassifiedError(reason=reason, **kw),
|
||||
)
|
||||
assert result.reason == FailoverReason.rate_limit
|
||||
assert result.should_rotate_credential is True
|
||||
|
||||
def test_quota_with_retry(self):
|
||||
"""402 with 'quota' + 'retry' = rate limit."""
|
||||
result = _classify_402(
|
||||
"quota exceeded, please retry after the window resets",
|
||||
lambda reason, **kw: ClassifiedError(reason=reason, **kw),
|
||||
)
|
||||
assert result.reason == FailoverReason.rate_limit
|
||||
|
||||
def test_quota_without_retry(self):
|
||||
"""402 with just 'quota' but no transient signal = billing."""
|
||||
result = _classify_402(
|
||||
"quota exceeded",
|
||||
lambda reason, **kw: ClassifiedError(reason=reason, **kw),
|
||||
)
|
||||
assert result.reason == FailoverReason.billing
|
||||
|
||||
def test_insufficient_credits(self):
|
||||
result = _classify_402(
|
||||
"insufficient credits to complete request",
|
||||
lambda reason, **kw: ClassifiedError(reason=reason, **kw),
|
||||
)
|
||||
assert result.reason == FailoverReason.billing
|
||||
|
||||
|
||||
# ── Test: Full classification pipeline ─────────────────────────────────
|
||||
|
||||
class TestClassifyApiError:
|
||||
"""End-to-end classification tests."""
|
||||
|
||||
# ── Auth errors ──
|
||||
|
||||
def test_401_classified_as_auth(self):
|
||||
e = MockAPIError("Unauthorized", status_code=401)
|
||||
result = classify_api_error(e, provider="openrouter")
|
||||
assert result.reason == FailoverReason.auth
|
||||
assert result.should_rotate_credential is True
|
||||
# 401 is non-retryable on its own — credential rotation runs
|
||||
# before the retryability check in the agent loop.
|
||||
assert result.retryable is False
|
||||
assert result.should_fallback is True
|
||||
|
||||
def test_403_classified_as_auth(self):
|
||||
e = MockAPIError("Forbidden", status_code=403)
|
||||
result = classify_api_error(e, provider="anthropic")
|
||||
assert result.reason == FailoverReason.auth
|
||||
assert result.should_fallback is True
|
||||
|
||||
def test_403_key_limit_classified_as_billing(self):
|
||||
"""OpenRouter 403 'key limit exceeded' is billing, not auth."""
|
||||
e = MockAPIError("Key limit exceeded for this key", status_code=403)
|
||||
result = classify_api_error(e, provider="openrouter")
|
||||
assert result.reason == FailoverReason.billing
|
||||
assert result.should_rotate_credential is True
|
||||
assert result.should_fallback is True
|
||||
|
||||
def test_403_spending_limit_classified_as_billing(self):
|
||||
e = MockAPIError("spending limit reached", status_code=403)
|
||||
result = classify_api_error(e, provider="openrouter")
|
||||
assert result.reason == FailoverReason.billing
|
||||
|
||||
# ── Billing ──
|
||||
|
||||
def test_402_plain_billing(self):
|
||||
e = MockAPIError("Payment Required", status_code=402)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.billing
|
||||
assert result.retryable is False
|
||||
|
||||
def test_402_transient_usage_limit(self):
|
||||
e = MockAPIError("usage limit exceeded, try again later", status_code=402)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.rate_limit
|
||||
assert result.retryable is True
|
||||
|
||||
# ── Rate limit ──
|
||||
|
||||
def test_429_rate_limit(self):
|
||||
e = MockAPIError("Too Many Requests", status_code=429)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.rate_limit
|
||||
assert result.should_fallback is True
|
||||
|
||||
# ── Server errors ──
|
||||
|
||||
def test_500_server_error(self):
|
||||
e = MockAPIError("Internal Server Error", status_code=500)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.server_error
|
||||
assert result.retryable is True
|
||||
|
||||
def test_502_server_error(self):
|
||||
e = MockAPIError("Bad Gateway", status_code=502)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.server_error
|
||||
|
||||
def test_503_overloaded(self):
|
||||
e = MockAPIError("Service Unavailable", status_code=503)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.overloaded
|
||||
|
||||
def test_529_anthropic_overloaded(self):
|
||||
e = MockAPIError("Overloaded", status_code=529)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.overloaded
|
||||
|
||||
# ── Model not found ──
|
||||
|
||||
def test_404_model_not_found(self):
|
||||
e = MockAPIError("model not found", status_code=404)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.model_not_found
|
||||
assert result.should_fallback is True
|
||||
assert result.retryable is False
|
||||
|
||||
def test_404_generic(self):
|
||||
e = MockAPIError("Not Found", status_code=404)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.model_not_found
|
||||
|
||||
# ── Payload too large ──
|
||||
|
||||
def test_413_payload_too_large(self):
|
||||
e = MockAPIError("Request Entity Too Large", status_code=413)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.payload_too_large
|
||||
assert result.should_compress is True
|
||||
|
||||
# ── Context overflow ──
|
||||
|
||||
def test_400_context_length(self):
|
||||
e = MockAPIError("context length exceeded: 250000 > 200000", status_code=400)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.context_overflow
|
||||
assert result.should_compress is True
|
||||
|
||||
def test_400_too_many_tokens(self):
|
||||
e = MockAPIError("This model's maximum context is 128000 tokens, too many tokens", status_code=400)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.context_overflow
|
||||
|
||||
def test_400_prompt_too_long(self):
|
||||
e = MockAPIError("prompt is too long: 300000 tokens > 200000 maximum", status_code=400)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.context_overflow
|
||||
|
||||
def test_400_generic_large_session(self):
|
||||
"""Generic 400 with large session → context overflow heuristic."""
|
||||
e = MockAPIError(
|
||||
"Error",
|
||||
status_code=400,
|
||||
body={"error": {"message": "Error"}},
|
||||
)
|
||||
result = classify_api_error(e, approx_tokens=100000, context_length=200000)
|
||||
assert result.reason == FailoverReason.context_overflow
|
||||
|
||||
def test_400_generic_small_session_is_format_error(self):
|
||||
"""Generic 400 with small session → format error, not context overflow."""
|
||||
e = MockAPIError(
|
||||
"Error",
|
||||
status_code=400,
|
||||
body={"error": {"message": "Error"}},
|
||||
)
|
||||
result = classify_api_error(e, approx_tokens=1000, context_length=200000)
|
||||
assert result.reason == FailoverReason.format_error
|
||||
|
||||
# ── Server disconnect + large session ──
|
||||
|
||||
def test_disconnect_large_session_context_overflow(self):
|
||||
"""Server disconnect with large session → context overflow."""
|
||||
e = Exception("server disconnected without sending complete message")
|
||||
result = classify_api_error(e, approx_tokens=150000, context_length=200000)
|
||||
assert result.reason == FailoverReason.context_overflow
|
||||
assert result.should_compress is True
|
||||
|
||||
def test_disconnect_small_session_timeout(self):
|
||||
"""Server disconnect with small session → timeout."""
|
||||
e = Exception("server disconnected without sending complete message")
|
||||
result = classify_api_error(e, approx_tokens=5000, context_length=200000)
|
||||
assert result.reason == FailoverReason.timeout
|
||||
|
||||
# ── Provider-specific: Anthropic thinking signature ──
|
||||
|
||||
def test_anthropic_thinking_signature(self):
|
||||
e = MockAPIError(
|
||||
"thinking block has invalid signature",
|
||||
status_code=400,
|
||||
)
|
||||
result = classify_api_error(e, provider="anthropic")
|
||||
assert result.reason == FailoverReason.thinking_signature
|
||||
assert result.retryable is True
|
||||
|
||||
def test_non_anthropic_400_with_signature_not_classified_as_thinking(self):
|
||||
"""400 with 'signature' but from non-Anthropic → format error."""
|
||||
e = MockAPIError("invalid signature", status_code=400)
|
||||
result = classify_api_error(e, provider="openrouter", approx_tokens=0)
|
||||
# Without "thinking" in the message, it shouldn't be thinking_signature
|
||||
assert result.reason != FailoverReason.thinking_signature
|
||||
|
||||
# ── Provider-specific: Anthropic long-context tier ──
|
||||
|
||||
def test_anthropic_long_context_tier(self):
|
||||
e = MockAPIError(
|
||||
"Extra usage is required for long context requests over 200k tokens",
|
||||
status_code=429,
|
||||
)
|
||||
result = classify_api_error(e, provider="anthropic", model="claude-sonnet-4")
|
||||
assert result.reason == FailoverReason.long_context_tier
|
||||
assert result.should_compress is True
|
||||
|
||||
def test_normal_429_not_long_context(self):
|
||||
"""Normal 429 without 'extra usage' + 'long context' → rate_limit."""
|
||||
e = MockAPIError("Too Many Requests", status_code=429)
|
||||
result = classify_api_error(e, provider="anthropic")
|
||||
assert result.reason == FailoverReason.rate_limit
|
||||
|
||||
# ── Transport errors ──
|
||||
|
||||
def test_read_timeout(self):
|
||||
e = ReadTimeout("Read timed out")
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.timeout
|
||||
assert result.retryable is True
|
||||
|
||||
def test_connect_error(self):
|
||||
e = ConnectError("Connection refused")
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.timeout
|
||||
|
||||
def test_connection_error_builtin(self):
|
||||
e = ConnectionError("Connection reset by peer")
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.timeout
|
||||
|
||||
def test_timeout_error_builtin(self):
|
||||
e = TimeoutError("timed out")
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.timeout
|
||||
|
||||
# ── Error code classification ──
|
||||
|
||||
def test_error_code_resource_exhausted(self):
|
||||
e = MockAPIError(
|
||||
"Resource exhausted",
|
||||
body={"error": {"code": "resource_exhausted", "message": "Too many requests"}},
|
||||
)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.rate_limit
|
||||
|
||||
def test_error_code_model_not_found(self):
|
||||
e = MockAPIError(
|
||||
"Model not available",
|
||||
body={"error": {"code": "model_not_found"}},
|
||||
)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.model_not_found
|
||||
|
||||
def test_error_code_context_length_exceeded(self):
|
||||
e = MockAPIError(
|
||||
"Context too large",
|
||||
body={"error": {"code": "context_length_exceeded"}},
|
||||
)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.context_overflow
|
||||
|
||||
# ── Message-only patterns (no status code) ──
|
||||
|
||||
def test_message_billing_pattern(self):
|
||||
e = Exception("insufficient credits to complete this request")
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.billing
|
||||
|
||||
def test_message_rate_limit_pattern(self):
|
||||
e = Exception("rate limit reached for this model")
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.rate_limit
|
||||
|
||||
def test_message_auth_pattern(self):
|
||||
e = Exception("invalid api key provided")
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.auth
|
||||
|
||||
def test_message_model_not_found_pattern(self):
|
||||
e = Exception("gpt-99 is not a valid model")
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.model_not_found
|
||||
|
||||
def test_message_context_overflow_pattern(self):
|
||||
e = Exception("maximum context length exceeded")
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.context_overflow
|
||||
|
||||
# ── Unknown / fallback ──
|
||||
|
||||
def test_generic_exception_is_unknown(self):
|
||||
e = Exception("something weird happened")
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.unknown
|
||||
assert result.retryable is True
|
||||
|
||||
# ── Format error ──
|
||||
|
||||
def test_400_descriptive_format_error(self):
|
||||
"""400 with descriptive message (not context overflow) → format error."""
|
||||
e = MockAPIError(
|
||||
"Invalid value for parameter 'temperature': must be between 0 and 2",
|
||||
status_code=400,
|
||||
body={"error": {"message": "Invalid value for parameter 'temperature': must be between 0 and 2"}},
|
||||
)
|
||||
result = classify_api_error(e, approx_tokens=1000)
|
||||
assert result.reason == FailoverReason.format_error
|
||||
assert result.retryable is False
|
||||
|
||||
def test_422_format_error(self):
|
||||
e = MockAPIError("Unprocessable Entity", status_code=422)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.format_error
|
||||
assert result.retryable is False
|
||||
|
||||
def test_400_flat_body_descriptive_not_context_overflow(self):
|
||||
"""Responses API flat body with descriptive error + large session → format error.
|
||||
|
||||
The Codex Responses API returns errors in flat body format:
|
||||
{"message": "...", "type": "..."} without an "error" wrapper.
|
||||
A descriptive 400 must NOT be misclassified as context overflow
|
||||
just because the session is large.
|
||||
"""
|
||||
e = MockAPIError(
|
||||
"Invalid 'input[index].name': string does not match pattern.",
|
||||
status_code=400,
|
||||
body={"message": "Invalid 'input[index].name': string does not match pattern.",
|
||||
"type": "invalid_request_error"},
|
||||
)
|
||||
result = classify_api_error(e, approx_tokens=200000, context_length=400000, num_messages=500)
|
||||
assert result.reason == FailoverReason.format_error
|
||||
assert result.retryable is False
|
||||
|
||||
def test_400_flat_body_generic_large_session_still_context_overflow(self):
|
||||
"""Flat body with generic 'Error' message + large session → context overflow.
|
||||
|
||||
Regression: the flat-body fallback must not break the existing heuristic
|
||||
for genuinely generic errors from providers that use flat bodies.
|
||||
"""
|
||||
e = MockAPIError(
|
||||
"Error",
|
||||
status_code=400,
|
||||
body={"message": "Error"},
|
||||
)
|
||||
result = classify_api_error(e, approx_tokens=100000, context_length=200000)
|
||||
assert result.reason == FailoverReason.context_overflow
|
||||
|
||||
# ── Peer closed + large session ──
|
||||
|
||||
def test_peer_closed_large_session(self):
|
||||
e = Exception("peer closed connection without sending complete message")
|
||||
result = classify_api_error(e, approx_tokens=130000, context_length=200000)
|
||||
assert result.reason == FailoverReason.context_overflow
|
||||
|
||||
# ── Chinese error messages ──
|
||||
|
||||
def test_chinese_context_overflow(self):
|
||||
e = MockAPIError("超过最大长度限制", status_code=400)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.context_overflow
|
||||
|
||||
# ── Result metadata ──
|
||||
|
||||
def test_provider_and_model_in_result(self):
|
||||
e = MockAPIError("fail", status_code=500)
|
||||
result = classify_api_error(e, provider="openrouter", model="gpt-5")
|
||||
assert result.provider == "openrouter"
|
||||
assert result.model == "gpt-5"
|
||||
assert result.status_code == 500
|
||||
|
||||
def test_message_extracted(self):
|
||||
e = MockAPIError(
|
||||
"outer",
|
||||
status_code=500,
|
||||
body={"error": {"message": "Internal server error occurred"}},
|
||||
)
|
||||
result = classify_api_error(e)
|
||||
assert result.message == "Internal server error occurred"
|
||||
|
||||
|
||||
# ── Test: Adversarial / edge cases (from live testing) ─────────────────
|
||||
|
||||
class TestAdversarialEdgeCases:
|
||||
"""Edge cases discovered during live testing with real SDK objects."""
|
||||
|
||||
def test_empty_exception_message(self):
|
||||
result = classify_api_error(Exception(""))
|
||||
assert result.reason == FailoverReason.unknown
|
||||
assert result.retryable is True
|
||||
|
||||
def test_500_with_none_body(self):
|
||||
e = MockAPIError("fail", status_code=500, body=None)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.server_error
|
||||
|
||||
def test_non_dict_body(self):
|
||||
"""Some providers return strings instead of JSON."""
|
||||
class StringBodyError(Exception):
|
||||
status_code = 400
|
||||
body = "just a string"
|
||||
result = classify_api_error(StringBodyError("bad"))
|
||||
assert result.reason == FailoverReason.format_error
|
||||
|
||||
def test_list_body(self):
|
||||
class ListBodyError(Exception):
|
||||
status_code = 500
|
||||
body = [{"error": "something"}]
|
||||
result = classify_api_error(ListBodyError("server error"))
|
||||
assert result.reason == FailoverReason.server_error
|
||||
|
||||
def test_circular_cause_chain(self):
|
||||
"""Must not infinite-loop on circular __cause__."""
|
||||
e = Exception("circular")
|
||||
e.__cause__ = e
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.unknown
|
||||
|
||||
def test_three_level_cause_chain(self):
|
||||
inner = MockAPIError("inner", status_code=429)
|
||||
middle = Exception("middle")
|
||||
middle.__cause__ = inner
|
||||
outer = RuntimeError("outer")
|
||||
outer.__cause__ = middle
|
||||
result = classify_api_error(outer)
|
||||
assert result.status_code == 429
|
||||
assert result.reason == FailoverReason.rate_limit
|
||||
|
||||
def test_400_with_rate_limit_text(self):
|
||||
"""Some providers send rate limits as 400 instead of 429."""
|
||||
e = MockAPIError(
|
||||
"rate limit policy",
|
||||
status_code=400,
|
||||
body={"error": {"message": "rate limit exceeded on this model"}},
|
||||
)
|
||||
result = classify_api_error(e, provider="openrouter")
|
||||
assert result.reason == FailoverReason.rate_limit
|
||||
|
||||
def test_400_with_billing_text(self):
|
||||
"""Some providers send billing errors as 400."""
|
||||
e = MockAPIError(
|
||||
"billing",
|
||||
status_code=400,
|
||||
body={"error": {"message": "insufficient credits for this request"}},
|
||||
)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.billing
|
||||
|
||||
def test_200_with_error_body(self):
|
||||
"""200 status with error in body — should be unknown, not crash."""
|
||||
class WeirdSuccess(Exception):
|
||||
status_code = 200
|
||||
body = {"error": {"message": "loading"}}
|
||||
result = classify_api_error(WeirdSuccess("model loading"))
|
||||
assert result.reason == FailoverReason.unknown
|
||||
|
||||
def test_ollama_context_size_exceeded(self):
|
||||
e = MockAPIError(
|
||||
"Error",
|
||||
status_code=400,
|
||||
body={"error": {"message": "context size has been exceeded"}},
|
||||
)
|
||||
result = classify_api_error(e, provider="ollama")
|
||||
assert result.reason == FailoverReason.context_overflow
|
||||
|
||||
def test_connection_refused_error(self):
|
||||
e = ConnectionRefusedError("Connection refused: localhost:11434")
|
||||
result = classify_api_error(e, provider="ollama")
|
||||
assert result.reason == FailoverReason.timeout
|
||||
|
||||
def test_body_message_enrichment(self):
|
||||
"""Body message must be included in pattern matching even when
|
||||
str(error) doesn't contain it (OpenAI SDK APIStatusError)."""
|
||||
e = MockAPIError(
|
||||
"Usage limit", # str(e) = "usage limit"
|
||||
status_code=402,
|
||||
body={"error": {"message": "Usage limit reached, try again in 5 minutes"}},
|
||||
)
|
||||
result = classify_api_error(e)
|
||||
# "try again" is only in body, not in str(e)
|
||||
assert result.reason == FailoverReason.rate_limit
|
||||
|
||||
def test_disconnect_pattern_ordering(self):
|
||||
"""Disconnect + large session must beat generic transport catch."""
|
||||
class FakeRemoteProtocol(Exception):
|
||||
pass
|
||||
# Type name isn't in _TRANSPORT_ERROR_TYPES but message has disconnect pattern
|
||||
e = Exception("peer closed connection without sending complete message")
|
||||
result = classify_api_error(e, approx_tokens=150000, context_length=200000)
|
||||
assert result.reason == FailoverReason.context_overflow
|
||||
assert result.should_compress is True
|
||||
|
||||
def test_credit_balance_too_low(self):
|
||||
e = MockAPIError(
|
||||
"Credits low",
|
||||
status_code=402,
|
||||
body={"error": {"message": "Your credit balance is too low"}},
|
||||
)
|
||||
result = classify_api_error(e, provider="anthropic")
|
||||
assert result.reason == FailoverReason.billing
|
||||
|
||||
def test_deepseek_402_chinese(self):
|
||||
"""Chinese billing message should still match billing patterns."""
|
||||
# "余额不足" doesn't match English billing patterns, but 402 defaults to billing
|
||||
e = MockAPIError("余额不足", status_code=402)
|
||||
result = classify_api_error(e, provider="deepseek")
|
||||
assert result.reason == FailoverReason.billing
|
||||
|
||||
def test_openrouter_wrapped_context_overflow_in_metadata_raw(self):
|
||||
"""OpenRouter wraps provider errors in metadata.raw JSON string."""
|
||||
e = MockAPIError(
|
||||
"Provider returned error",
|
||||
status_code=400,
|
||||
body={
|
||||
"error": {
|
||||
"message": "Provider returned error",
|
||||
"code": 400,
|
||||
"metadata": {
|
||||
"raw": '{"error":{"message":"context length exceeded: 50000 > 32768"}}'
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
result = classify_api_error(e, provider="openrouter", approx_tokens=10000)
|
||||
assert result.reason == FailoverReason.context_overflow
|
||||
assert result.should_compress is True
|
||||
|
||||
def test_openrouter_wrapped_rate_limit_in_metadata_raw(self):
|
||||
e = MockAPIError(
|
||||
"Provider returned error",
|
||||
status_code=400,
|
||||
body={
|
||||
"error": {
|
||||
"message": "Provider returned error",
|
||||
"metadata": {
|
||||
"raw": '{"error":{"message":"Rate limit exceeded. Please retry after 30s."}}'
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
result = classify_api_error(e, provider="openrouter")
|
||||
assert result.reason == FailoverReason.rate_limit
|
||||
|
||||
def test_thinking_signature_via_openrouter(self):
|
||||
"""Thinking signature errors proxied through OpenRouter must be caught."""
|
||||
e = MockAPIError(
|
||||
"thinking block has invalid signature",
|
||||
status_code=400,
|
||||
)
|
||||
# provider is openrouter, not anthropic — old code missed this
|
||||
result = classify_api_error(e, provider="openrouter", model="anthropic/claude-sonnet-4")
|
||||
assert result.reason == FailoverReason.thinking_signature
|
||||
|
||||
def test_generic_400_large_by_message_count(self):
|
||||
"""Many small messages (>80) should trigger context overflow heuristic."""
|
||||
e = MockAPIError(
|
||||
"Error",
|
||||
status_code=400,
|
||||
body={"error": {"message": "Error"}},
|
||||
)
|
||||
# Low token count but high message count
|
||||
result = classify_api_error(
|
||||
e, approx_tokens=5000, context_length=200000, num_messages=100,
|
||||
)
|
||||
assert result.reason == FailoverReason.context_overflow
|
||||
|
||||
def test_disconnect_large_by_message_count(self):
|
||||
"""Server disconnect with 200+ messages should trigger context overflow."""
|
||||
e = Exception("server disconnected without sending complete message")
|
||||
result = classify_api_error(
|
||||
e, approx_tokens=5000, context_length=200000, num_messages=250,
|
||||
)
|
||||
assert result.reason == FailoverReason.context_overflow
|
||||
|
||||
def test_openrouter_wrapped_model_not_found_in_metadata_raw(self):
|
||||
e = MockAPIError(
|
||||
"Provider returned error",
|
||||
status_code=400,
|
||||
body={
|
||||
"error": {
|
||||
"message": "Provider returned error",
|
||||
"metadata": {
|
||||
"raw": '{"error":{"message":"The model gpt-99 does not exist"}}'
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
result = classify_api_error(e, provider="openrouter")
|
||||
assert result.reason == FailoverReason.model_not_found
|
||||
@@ -7,6 +7,7 @@ from pathlib import Path
|
||||
from hermes_state import SessionDB
|
||||
from agent.insights import (
|
||||
InsightsEngine,
|
||||
_get_pricing,
|
||||
_estimate_cost,
|
||||
_format_duration,
|
||||
_bar_chart,
|
||||
@@ -117,6 +118,45 @@ def populated_db(db):
|
||||
return db
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Pricing helpers
|
||||
# =========================================================================
|
||||
|
||||
class TestPricing:
|
||||
def test_provider_prefix_stripped(self):
|
||||
pricing = _get_pricing("anthropic/claude-sonnet-4-20250514")
|
||||
assert pricing["input"] == 3.00
|
||||
assert pricing["output"] == 15.00
|
||||
|
||||
def test_unknown_models_do_not_use_heuristics(self):
|
||||
pricing = _get_pricing("some-new-opus-model")
|
||||
assert pricing == _DEFAULT_PRICING
|
||||
pricing = _get_pricing("anthropic/claude-haiku-future")
|
||||
assert pricing == _DEFAULT_PRICING
|
||||
|
||||
def test_unknown_model_returns_zero_cost(self):
|
||||
"""Unknown/custom models should NOT have fabricated costs."""
|
||||
pricing = _get_pricing("totally-unknown-model-xyz")
|
||||
assert pricing == _DEFAULT_PRICING
|
||||
assert pricing["input"] == 0.0
|
||||
assert pricing["output"] == 0.0
|
||||
|
||||
def test_custom_endpoint_model_zero_cost(self):
|
||||
"""Self-hosted models should return zero cost."""
|
||||
for model in ["FP16_Hermes_4.5", "Hermes_4.5_1T_epoch2", "my-local-llama"]:
|
||||
pricing = _get_pricing(model)
|
||||
assert pricing["input"] == 0.0, f"{model} should have zero cost"
|
||||
assert pricing["output"] == 0.0, f"{model} should have zero cost"
|
||||
|
||||
def test_none_model(self):
|
||||
pricing = _get_pricing(None)
|
||||
assert pricing == _DEFAULT_PRICING
|
||||
|
||||
def test_empty_model(self):
|
||||
pricing = _get_pricing("")
|
||||
assert pricing == _DEFAULT_PRICING
|
||||
|
||||
|
||||
class TestHasKnownPricing:
|
||||
def test_known_commercial_model(self):
|
||||
assert _has_known_pricing("gpt-4o", provider="openai") is True
|
||||
|
||||
@@ -0,0 +1,299 @@
|
||||
"""End-to-end test: a SQLite-backed memory plugin exercising the full interface.
|
||||
|
||||
This proves a real plugin can register as a MemoryProvider and get wired
|
||||
into the agent loop via MemoryManager. Uses SQLite + FTS5 (stdlib, no
|
||||
external deps, no API keys).
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from agent.memory_provider import MemoryProvider
|
||||
from agent.memory_manager import MemoryManager
|
||||
from agent.builtin_memory_provider import BuiltinMemoryProvider
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SQLite FTS5 memory provider — a real, minimal plugin implementation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SQLiteMemoryProvider(MemoryProvider):
|
||||
"""Minimal SQLite + FTS5 memory provider for testing.
|
||||
|
||||
Demonstrates the full MemoryProvider interface with a real backend.
|
||||
No external dependencies — just stdlib sqlite3.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str = ":memory:"):
|
||||
self._db_path = db_path
|
||||
self._conn = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "sqlite_memory"
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return True # SQLite is always available
|
||||
|
||||
def initialize(self, session_id: str, **kwargs) -> None:
|
||||
self._conn = sqlite3.connect(self._db_path)
|
||||
self._conn.execute("PRAGMA journal_mode=WAL")
|
||||
self._conn.execute("""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS memories
|
||||
USING fts5(content, context, session_id)
|
||||
""")
|
||||
self._session_id = session_id
|
||||
|
||||
def system_prompt_block(self) -> str:
|
||||
if not self._conn:
|
||||
return ""
|
||||
count = self._conn.execute("SELECT COUNT(*) FROM memories").fetchone()[0]
|
||||
if count == 0:
|
||||
return ""
|
||||
return (
|
||||
f"# SQLite Memory Plugin\n"
|
||||
f"Active. {count} memories stored.\n"
|
||||
f"Use sqlite_recall to search, sqlite_retain to store."
|
||||
)
|
||||
|
||||
def prefetch(self, query: str, *, session_id: str = "") -> str:
|
||||
if not self._conn or not query:
|
||||
return ""
|
||||
# FTS5 search
|
||||
try:
|
||||
rows = self._conn.execute(
|
||||
"SELECT content FROM memories WHERE memories MATCH ? LIMIT 5",
|
||||
(query,)
|
||||
).fetchall()
|
||||
if not rows:
|
||||
return ""
|
||||
results = [row[0] for row in rows]
|
||||
return "## SQLite Memory\n" + "\n".join(f"- {r}" for r in results)
|
||||
except sqlite3.OperationalError:
|
||||
return ""
|
||||
|
||||
def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None:
|
||||
if not self._conn:
|
||||
return
|
||||
combined = f"User: {user_content}\nAssistant: {assistant_content}"
|
||||
self._conn.execute(
|
||||
"INSERT INTO memories (content, context, session_id) VALUES (?, ?, ?)",
|
||||
(combined, "conversation", self._session_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def get_tool_schemas(self):
|
||||
return [
|
||||
{
|
||||
"name": "sqlite_retain",
|
||||
"description": "Store a fact to SQLite memory.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {"type": "string", "description": "What to remember"},
|
||||
"context": {"type": "string", "description": "Category/context"},
|
||||
},
|
||||
"required": ["content"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "sqlite_recall",
|
||||
"description": "Search SQLite memory.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
def handle_tool_call(self, tool_name: str, args: dict, **kwargs) -> str:
|
||||
if tool_name == "sqlite_retain":
|
||||
content = args.get("content", "")
|
||||
context = args.get("context", "explicit")
|
||||
if not content:
|
||||
return json.dumps({"error": "content is required"})
|
||||
self._conn.execute(
|
||||
"INSERT INTO memories (content, context, session_id) VALUES (?, ?, ?)",
|
||||
(content, context, self._session_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
return json.dumps({"result": "Stored."})
|
||||
|
||||
elif tool_name == "sqlite_recall":
|
||||
query = args.get("query", "")
|
||||
if not query:
|
||||
return json.dumps({"error": "query is required"})
|
||||
try:
|
||||
rows = self._conn.execute(
|
||||
"SELECT content, context FROM memories WHERE memories MATCH ? LIMIT 10",
|
||||
(query,)
|
||||
).fetchall()
|
||||
results = [{"content": r[0], "context": r[1]} for r in rows]
|
||||
return json.dumps({"results": results})
|
||||
except sqlite3.OperationalError:
|
||||
return json.dumps({"results": []})
|
||||
|
||||
return json.dumps({"error": f"Unknown tool: {tool_name}"})
|
||||
|
||||
def on_memory_write(self, action, target, content):
|
||||
"""Mirror built-in memory writes to SQLite."""
|
||||
if action == "add" and self._conn:
|
||||
self._conn.execute(
|
||||
"INSERT INTO memories (content, context, session_id) VALUES (?, ?, ?)",
|
||||
(content, f"builtin_{target}", self._session_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def shutdown(self):
|
||||
if self._conn:
|
||||
self._conn.close()
|
||||
self._conn = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# End-to-end tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSQLiteMemoryPlugin:
|
||||
"""Full lifecycle test with the SQLite provider."""
|
||||
|
||||
def test_full_lifecycle(self):
|
||||
"""Exercise init → store → recall → sync → prefetch → shutdown."""
|
||||
mgr = MemoryManager()
|
||||
builtin = BuiltinMemoryProvider()
|
||||
sqlite_mem = SQLiteMemoryProvider()
|
||||
|
||||
mgr.add_provider(builtin)
|
||||
mgr.add_provider(sqlite_mem)
|
||||
|
||||
# Initialize
|
||||
mgr.initialize_all(session_id="test-session-1", platform="cli")
|
||||
assert sqlite_mem._conn is not None
|
||||
|
||||
# System prompt — empty at first
|
||||
prompt = mgr.build_system_prompt()
|
||||
assert "SQLite Memory Plugin" not in prompt
|
||||
|
||||
# Store via tool call
|
||||
result = json.loads(mgr.handle_tool_call(
|
||||
"sqlite_retain", {"content": "User prefers dark mode", "context": "preference"}
|
||||
))
|
||||
assert result["result"] == "Stored."
|
||||
|
||||
# System prompt now shows count
|
||||
prompt = mgr.build_system_prompt()
|
||||
assert "1 memories stored" in prompt
|
||||
|
||||
# Recall via tool call
|
||||
result = json.loads(mgr.handle_tool_call(
|
||||
"sqlite_recall", {"query": "dark mode"}
|
||||
))
|
||||
assert len(result["results"]) == 1
|
||||
assert "dark mode" in result["results"][0]["content"]
|
||||
|
||||
# Sync a turn (auto-stores conversation)
|
||||
mgr.sync_all("What's my theme?", "You prefer dark mode.")
|
||||
count = sqlite_mem._conn.execute("SELECT COUNT(*) FROM memories").fetchone()[0]
|
||||
assert count == 2 # 1 explicit + 1 synced
|
||||
|
||||
# Prefetch for next turn
|
||||
prefetched = mgr.prefetch_all("dark mode")
|
||||
assert "dark mode" in prefetched
|
||||
|
||||
# Memory bridge — mirroring builtin writes
|
||||
mgr.on_memory_write("add", "user", "Timezone: US Pacific")
|
||||
count = sqlite_mem._conn.execute("SELECT COUNT(*) FROM memories").fetchone()[0]
|
||||
assert count == 3
|
||||
|
||||
# Shutdown
|
||||
mgr.shutdown_all()
|
||||
assert sqlite_mem._conn is None
|
||||
|
||||
def test_tool_routing_with_builtin(self):
|
||||
"""Verify builtin + plugin tools coexist without conflict."""
|
||||
mgr = MemoryManager()
|
||||
builtin = BuiltinMemoryProvider()
|
||||
sqlite_mem = SQLiteMemoryProvider()
|
||||
mgr.add_provider(builtin)
|
||||
mgr.add_provider(sqlite_mem)
|
||||
mgr.initialize_all(session_id="test-2")
|
||||
|
||||
# Builtin has no tools
|
||||
assert len(builtin.get_tool_schemas()) == 0
|
||||
# SQLite has 2 tools
|
||||
schemas = mgr.get_all_tool_schemas()
|
||||
names = {s["name"] for s in schemas}
|
||||
assert names == {"sqlite_retain", "sqlite_recall"}
|
||||
|
||||
# Routing works
|
||||
assert mgr.has_tool("sqlite_retain")
|
||||
assert mgr.has_tool("sqlite_recall")
|
||||
assert not mgr.has_tool("memory") # builtin doesn't register this
|
||||
|
||||
def test_second_external_plugin_rejected(self):
|
||||
"""Only one external memory provider is allowed at a time."""
|
||||
mgr = MemoryManager()
|
||||
p1 = SQLiteMemoryProvider()
|
||||
p2 = SQLiteMemoryProvider()
|
||||
# Hack name for p2
|
||||
p2._name_override = "sqlite_memory_2"
|
||||
original_name = p2.__class__.name
|
||||
type(p2).name = property(lambda self: getattr(self, '_name_override', 'sqlite_memory'))
|
||||
|
||||
mgr.add_provider(p1)
|
||||
mgr.add_provider(p2) # should be rejected
|
||||
|
||||
# Only p1 was accepted
|
||||
assert len(mgr.providers) == 1
|
||||
assert mgr.provider_names == ["sqlite_memory"]
|
||||
|
||||
# Restore class
|
||||
type(p2).name = original_name
|
||||
mgr.shutdown_all()
|
||||
|
||||
def test_provider_failure_isolation(self):
|
||||
"""Failing external provider doesn't break builtin."""
|
||||
from agent.builtin_memory_provider import BuiltinMemoryProvider
|
||||
|
||||
mgr = MemoryManager()
|
||||
builtin = BuiltinMemoryProvider() # name="builtin", always accepted
|
||||
ext = SQLiteMemoryProvider()
|
||||
|
||||
mgr.add_provider(builtin)
|
||||
mgr.add_provider(ext)
|
||||
mgr.initialize_all(session_id="test-4")
|
||||
|
||||
# Break external provider's connection
|
||||
ext._conn.close()
|
||||
ext._conn = None
|
||||
|
||||
# Sync — external fails silently, builtin (no-op sync) succeeds
|
||||
mgr.sync_all("user", "assistant") # should not raise
|
||||
|
||||
mgr.shutdown_all()
|
||||
|
||||
def test_plugin_registration_flow(self):
|
||||
"""Simulate the full plugin load → agent init path."""
|
||||
# Simulate what AIAgent.__init__ does via plugins/memory/ discovery
|
||||
provider = SQLiteMemoryProvider()
|
||||
|
||||
mem_mgr = MemoryManager()
|
||||
mem_mgr.add_provider(BuiltinMemoryProvider())
|
||||
if provider.is_available():
|
||||
mem_mgr.add_provider(provider)
|
||||
mem_mgr.initialize_all(session_id="agent-session")
|
||||
|
||||
assert len(mem_mgr.providers) == 2
|
||||
assert mem_mgr.provider_names == ["builtin", "sqlite_memory"]
|
||||
assert provider._conn is not None # initialized = connection established
|
||||
|
||||
mem_mgr.shutdown_all()
|
||||
@@ -6,6 +6,8 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
from agent.memory_provider import MemoryProvider
|
||||
from agent.memory_manager import MemoryManager
|
||||
from agent.builtin_memory_provider import BuiltinMemoryProvider
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Concrete test provider
|
||||
@@ -116,7 +118,7 @@ class TestMemoryManager:
|
||||
def test_empty_manager(self):
|
||||
mgr = MemoryManager()
|
||||
assert mgr.providers == []
|
||||
assert [p.name for p in mgr.providers] == []
|
||||
assert mgr.provider_names == []
|
||||
assert mgr.get_all_tool_schemas() == []
|
||||
assert mgr.build_system_prompt() == ""
|
||||
assert mgr.prefetch_all("test") == ""
|
||||
@@ -126,7 +128,7 @@ class TestMemoryManager:
|
||||
p = FakeMemoryProvider("test1")
|
||||
mgr.add_provider(p)
|
||||
assert len(mgr.providers) == 1
|
||||
assert [p.name for p in mgr.providers] == ["test1"]
|
||||
assert mgr.provider_names == ["test1"]
|
||||
|
||||
def test_get_provider_by_name(self):
|
||||
mgr = MemoryManager()
|
||||
@@ -141,7 +143,7 @@ class TestMemoryManager:
|
||||
p2 = FakeMemoryProvider("external")
|
||||
mgr.add_provider(p1)
|
||||
mgr.add_provider(p2)
|
||||
assert [p.name for p in mgr.providers] == ["builtin", "external"]
|
||||
assert mgr.provider_names == ["builtin", "external"]
|
||||
|
||||
def test_second_external_rejected(self):
|
||||
"""Only one non-builtin provider is allowed."""
|
||||
@@ -152,7 +154,7 @@ class TestMemoryManager:
|
||||
mgr.add_provider(builtin)
|
||||
mgr.add_provider(ext1)
|
||||
mgr.add_provider(ext2) # should be rejected
|
||||
assert [p.name for p in mgr.providers] == ["builtin", "mem0"]
|
||||
assert mgr.provider_names == ["builtin", "mem0"]
|
||||
assert len(mgr.providers) == 2
|
||||
|
||||
def test_system_prompt_merges_blocks(self):
|
||||
@@ -319,6 +321,17 @@ class TestMemoryManager:
|
||||
mgr.on_pre_compress([{"role": "user", "content": "old"}])
|
||||
assert p.pre_compress_called
|
||||
|
||||
def test_on_memory_write_skips_builtin(self):
|
||||
"""on_memory_write should skip the builtin provider."""
|
||||
mgr = MemoryManager()
|
||||
builtin = BuiltinMemoryProvider()
|
||||
external = FakeMemoryProvider("external")
|
||||
mgr.add_provider(builtin)
|
||||
mgr.add_provider(external)
|
||||
|
||||
mgr.on_memory_write("add", "memory", "test fact")
|
||||
assert external.memory_writes == [("add", "memory", "test fact")]
|
||||
|
||||
def test_shutdown_all_reverse_order(self):
|
||||
mgr = MemoryManager()
|
||||
order = []
|
||||
@@ -372,6 +385,146 @@ class TestMemoryManager:
|
||||
assert result == "works fine"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BuiltinMemoryProvider tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuiltinMemoryProvider:
|
||||
def test_name(self):
|
||||
p = BuiltinMemoryProvider()
|
||||
assert p.name == "builtin"
|
||||
|
||||
def test_always_available(self):
|
||||
p = BuiltinMemoryProvider()
|
||||
assert p.is_available()
|
||||
|
||||
def test_no_tools(self):
|
||||
"""Builtin provider exposes no tools (memory tool is agent-level)."""
|
||||
p = BuiltinMemoryProvider()
|
||||
assert p.get_tool_schemas() == []
|
||||
|
||||
def test_system_prompt_with_store(self):
|
||||
store = MagicMock()
|
||||
store.format_for_system_prompt.side_effect = lambda t: f"BLOCK_{t}" if t == "memory" else f"BLOCK_{t}"
|
||||
|
||||
p = BuiltinMemoryProvider(
|
||||
memory_store=store,
|
||||
memory_enabled=True,
|
||||
user_profile_enabled=True,
|
||||
)
|
||||
block = p.system_prompt_block()
|
||||
assert "BLOCK_memory" in block
|
||||
assert "BLOCK_user" in block
|
||||
|
||||
def test_system_prompt_memory_disabled(self):
|
||||
store = MagicMock()
|
||||
store.format_for_system_prompt.return_value = "content"
|
||||
|
||||
p = BuiltinMemoryProvider(
|
||||
memory_store=store,
|
||||
memory_enabled=False,
|
||||
user_profile_enabled=False,
|
||||
)
|
||||
assert p.system_prompt_block() == ""
|
||||
|
||||
def test_system_prompt_no_store(self):
|
||||
p = BuiltinMemoryProvider(memory_store=None, memory_enabled=True)
|
||||
assert p.system_prompt_block() == ""
|
||||
|
||||
def test_prefetch_returns_empty(self):
|
||||
p = BuiltinMemoryProvider()
|
||||
assert p.prefetch("anything") == ""
|
||||
|
||||
def test_store_property(self):
|
||||
store = MagicMock()
|
||||
p = BuiltinMemoryProvider(memory_store=store)
|
||||
assert p.store is store
|
||||
|
||||
def test_initialize_loads_from_disk(self):
|
||||
store = MagicMock()
|
||||
p = BuiltinMemoryProvider(memory_store=store)
|
||||
p.initialize(session_id="test")
|
||||
store.load_from_disk.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Plugin registration tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSingleProviderGating:
|
||||
"""Only the configured provider should activate."""
|
||||
|
||||
def test_no_provider_configured_means_builtin_only(self):
|
||||
"""When memory.provider is empty, no plugin providers activate."""
|
||||
mgr = MemoryManager()
|
||||
builtin = BuiltinMemoryProvider()
|
||||
mgr.add_provider(builtin)
|
||||
|
||||
# Simulate what run_agent.py does when provider=""
|
||||
configured = ""
|
||||
available_plugins = [
|
||||
FakeMemoryProvider("holographic"),
|
||||
FakeMemoryProvider("mem0"),
|
||||
]
|
||||
# With empty config, no plugins should be added
|
||||
if configured:
|
||||
for p in available_plugins:
|
||||
if p.name == configured and p.is_available():
|
||||
mgr.add_provider(p)
|
||||
|
||||
assert mgr.provider_names == ["builtin"]
|
||||
|
||||
def test_configured_provider_activates(self):
|
||||
"""Only the named provider should be added."""
|
||||
mgr = MemoryManager()
|
||||
builtin = BuiltinMemoryProvider()
|
||||
mgr.add_provider(builtin)
|
||||
|
||||
configured = "holographic"
|
||||
p1 = FakeMemoryProvider("holographic")
|
||||
p2 = FakeMemoryProvider("mem0")
|
||||
p3 = FakeMemoryProvider("hindsight")
|
||||
|
||||
for p in [p1, p2, p3]:
|
||||
if p.name == configured and p.is_available():
|
||||
mgr.add_provider(p)
|
||||
|
||||
assert mgr.provider_names == ["builtin", "holographic"]
|
||||
assert p1.initialized is False # not initialized by the gating logic itself
|
||||
|
||||
def test_unavailable_provider_skipped(self):
|
||||
"""If the configured provider is unavailable, it should be skipped."""
|
||||
mgr = MemoryManager()
|
||||
builtin = BuiltinMemoryProvider()
|
||||
mgr.add_provider(builtin)
|
||||
|
||||
configured = "holographic"
|
||||
p1 = FakeMemoryProvider("holographic", available=False)
|
||||
|
||||
for p in [p1]:
|
||||
if p.name == configured and p.is_available():
|
||||
mgr.add_provider(p)
|
||||
|
||||
assert mgr.provider_names == ["builtin"]
|
||||
|
||||
def test_nonexistent_provider_results_in_builtin_only(self):
|
||||
"""If the configured name doesn't match any plugin, only builtin remains."""
|
||||
mgr = MemoryManager()
|
||||
builtin = BuiltinMemoryProvider()
|
||||
mgr.add_provider(builtin)
|
||||
|
||||
configured = "nonexistent"
|
||||
plugins = [FakeMemoryProvider("holographic"), FakeMemoryProvider("mem0")]
|
||||
|
||||
for p in plugins:
|
||||
if p.name == configured and p.is_available():
|
||||
mgr.add_provider(p)
|
||||
|
||||
assert mgr.provider_names == ["builtin"]
|
||||
|
||||
|
||||
class TestPluginMemoryDiscovery:
|
||||
"""Memory providers are discovered from plugins/memory/ directory."""
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from agent.prompt_builder import (
|
||||
_scan_context_content,
|
||||
_truncate_content,
|
||||
_parse_skill_file,
|
||||
_read_skill_conditions,
|
||||
_skill_should_show,
|
||||
_find_hermes_md,
|
||||
_find_git_root,
|
||||
@@ -774,6 +775,61 @@ class TestPromptBuilderConstants:
|
||||
# Conditional skill activation
|
||||
# =========================================================================
|
||||
|
||||
class TestReadSkillConditions:
|
||||
def test_no_conditions_returns_empty_lists(self, tmp_path):
|
||||
skill_file = tmp_path / "SKILL.md"
|
||||
skill_file.write_text("---\nname: test\ndescription: A skill\n---\n")
|
||||
conditions = _read_skill_conditions(skill_file)
|
||||
assert conditions["fallback_for_toolsets"] == []
|
||||
assert conditions["requires_toolsets"] == []
|
||||
assert conditions["fallback_for_tools"] == []
|
||||
assert conditions["requires_tools"] == []
|
||||
|
||||
def test_reads_fallback_for_toolsets(self, tmp_path):
|
||||
skill_file = tmp_path / "SKILL.md"
|
||||
skill_file.write_text(
|
||||
"---\nname: ddg\ndescription: DuckDuckGo\nmetadata:\n hermes:\n fallback_for_toolsets: [web]\n---\n"
|
||||
)
|
||||
conditions = _read_skill_conditions(skill_file)
|
||||
assert conditions["fallback_for_toolsets"] == ["web"]
|
||||
|
||||
def test_reads_requires_toolsets(self, tmp_path):
|
||||
skill_file = tmp_path / "SKILL.md"
|
||||
skill_file.write_text(
|
||||
"---\nname: openhue\ndescription: Hue lights\nmetadata:\n hermes:\n requires_toolsets: [terminal]\n---\n"
|
||||
)
|
||||
conditions = _read_skill_conditions(skill_file)
|
||||
assert conditions["requires_toolsets"] == ["terminal"]
|
||||
|
||||
def test_reads_multiple_conditions(self, tmp_path):
|
||||
skill_file = tmp_path / "SKILL.md"
|
||||
skill_file.write_text(
|
||||
"---\nname: test\ndescription: Test\nmetadata:\n hermes:\n fallback_for_toolsets: [browser]\n requires_tools: [terminal]\n---\n"
|
||||
)
|
||||
conditions = _read_skill_conditions(skill_file)
|
||||
assert conditions["fallback_for_toolsets"] == ["browser"]
|
||||
assert conditions["requires_tools"] == ["terminal"]
|
||||
|
||||
def test_missing_file_returns_empty(self, tmp_path):
|
||||
conditions = _read_skill_conditions(tmp_path / "missing.md")
|
||||
assert conditions == {}
|
||||
|
||||
def test_logs_condition_read_failures_and_returns_empty(self, tmp_path, monkeypatch, caplog):
|
||||
skill_file = tmp_path / "SKILL.md"
|
||||
skill_file.write_text("---\nname: broken\n---\n")
|
||||
|
||||
def boom(*args, **kwargs):
|
||||
raise OSError("read exploded")
|
||||
|
||||
monkeypatch.setattr(type(skill_file), "read_text", boom)
|
||||
with caplog.at_level(logging.DEBUG, logger="agent.prompt_builder"):
|
||||
conditions = _read_skill_conditions(skill_file)
|
||||
|
||||
assert conditions == {}
|
||||
assert "Failed to read skill conditions" in caplog.text
|
||||
assert str(skill_file) in caplog.text
|
||||
|
||||
|
||||
class TestSkillShouldShow:
|
||||
def test_no_filter_info_always_shows(self):
|
||||
assert _skill_should_show({}, None, None) is True
|
||||
|
||||
@@ -1,212 +0,0 @@
|
||||
"""Tests for agent.rate_limit_tracker — header parsing and formatting."""
|
||||
|
||||
import time
|
||||
import pytest
|
||||
from agent.rate_limit_tracker import (
|
||||
RateLimitBucket,
|
||||
RateLimitState,
|
||||
parse_rate_limit_headers,
|
||||
format_rate_limit_display,
|
||||
format_rate_limit_compact,
|
||||
_fmt_count,
|
||||
_fmt_seconds,
|
||||
_bar,
|
||||
)
|
||||
|
||||
|
||||
# ── Sample headers from Nous inference API ──────────────────────────────
|
||||
|
||||
NOUS_HEADERS = {
|
||||
"x-ratelimit-limit-requests": "800",
|
||||
"x-ratelimit-limit-requests-1h": "33600",
|
||||
"x-ratelimit-limit-tokens": "8000000",
|
||||
"x-ratelimit-limit-tokens-1h": "336000000",
|
||||
"x-ratelimit-remaining-requests": "795",
|
||||
"x-ratelimit-remaining-requests-1h": "33590",
|
||||
"x-ratelimit-remaining-tokens": "7999500",
|
||||
"x-ratelimit-remaining-tokens-1h": "335999000",
|
||||
"x-ratelimit-reset-requests": "45.5",
|
||||
"x-ratelimit-reset-requests-1h": "3500.0",
|
||||
"x-ratelimit-reset-tokens": "42.3",
|
||||
"x-ratelimit-reset-tokens-1h": "3490.0",
|
||||
}
|
||||
|
||||
|
||||
class TestParseHeaders:
|
||||
def test_basic_parsing(self):
|
||||
state = parse_rate_limit_headers(NOUS_HEADERS, provider="nous")
|
||||
assert state is not None
|
||||
assert state.provider == "nous"
|
||||
assert state.has_data
|
||||
|
||||
assert state.requests_min.limit == 800
|
||||
assert state.requests_min.remaining == 795
|
||||
assert state.requests_min.reset_seconds == 45.5
|
||||
|
||||
assert state.requests_hour.limit == 33600
|
||||
assert state.requests_hour.remaining == 33590
|
||||
|
||||
assert state.tokens_min.limit == 8000000
|
||||
assert state.tokens_min.remaining == 7999500
|
||||
|
||||
assert state.tokens_hour.limit == 336000000
|
||||
assert state.tokens_hour.remaining == 335999000
|
||||
assert state.tokens_hour.reset_seconds == 3490.0
|
||||
|
||||
def test_no_headers(self):
|
||||
state = parse_rate_limit_headers({})
|
||||
assert state is None
|
||||
|
||||
def test_partial_headers(self):
|
||||
headers = {
|
||||
"x-ratelimit-limit-requests": "100",
|
||||
"x-ratelimit-remaining-requests": "50",
|
||||
}
|
||||
state = parse_rate_limit_headers(headers)
|
||||
assert state is not None
|
||||
assert state.requests_min.limit == 100
|
||||
assert state.requests_min.remaining == 50
|
||||
# Missing fields default to 0
|
||||
assert state.tokens_min.limit == 0
|
||||
|
||||
def test_non_rate_limit_headers_ignored(self):
|
||||
headers = {
|
||||
"content-type": "application/json",
|
||||
"server": "nginx",
|
||||
}
|
||||
state = parse_rate_limit_headers(headers)
|
||||
assert state is None
|
||||
|
||||
def test_malformed_values(self):
|
||||
headers = {
|
||||
"x-ratelimit-limit-requests": "not-a-number",
|
||||
"x-ratelimit-remaining-requests": "",
|
||||
"x-ratelimit-reset-requests": "abc",
|
||||
}
|
||||
state = parse_rate_limit_headers(headers)
|
||||
assert state is not None
|
||||
assert state.requests_min.limit == 0
|
||||
assert state.requests_min.remaining == 0
|
||||
assert state.requests_min.reset_seconds == 0.0
|
||||
|
||||
|
||||
class TestBucket:
|
||||
def test_used(self):
|
||||
b = RateLimitBucket(limit=800, remaining=795, reset_seconds=45.0, captured_at=time.time())
|
||||
assert b.used == 5
|
||||
|
||||
def test_usage_pct(self):
|
||||
b = RateLimitBucket(limit=100, remaining=20, reset_seconds=30.0, captured_at=time.time())
|
||||
assert b.usage_pct == pytest.approx(80.0)
|
||||
|
||||
def test_usage_pct_zero_limit(self):
|
||||
b = RateLimitBucket(limit=0, remaining=0)
|
||||
assert b.usage_pct == 0.0
|
||||
|
||||
def test_remaining_seconds_now(self):
|
||||
now = time.time()
|
||||
b = RateLimitBucket(limit=800, remaining=795, reset_seconds=60.0, captured_at=now - 10)
|
||||
# ~50 seconds should remain
|
||||
assert 49 <= b.remaining_seconds_now <= 51
|
||||
|
||||
def test_remaining_seconds_expired(self):
|
||||
b = RateLimitBucket(limit=800, remaining=795, reset_seconds=30.0, captured_at=time.time() - 60)
|
||||
assert b.remaining_seconds_now == 0.0
|
||||
|
||||
|
||||
class TestFormatting:
|
||||
def test_fmt_count_millions(self):
|
||||
assert _fmt_count(8000000) == "8.0M"
|
||||
assert _fmt_count(336000000) == "336.0M"
|
||||
|
||||
def test_fmt_count_thousands(self):
|
||||
assert _fmt_count(33600) == "33.6K"
|
||||
assert _fmt_count(1500) == "1.5K"
|
||||
|
||||
def test_fmt_count_small(self):
|
||||
assert _fmt_count(800) == "800"
|
||||
assert _fmt_count(0) == "0"
|
||||
|
||||
def test_fmt_seconds_short(self):
|
||||
assert _fmt_seconds(45) == "45s"
|
||||
assert _fmt_seconds(0) == "0s"
|
||||
|
||||
def test_fmt_seconds_minutes(self):
|
||||
assert _fmt_seconds(125) == "2m 5s"
|
||||
assert _fmt_seconds(120) == "2m"
|
||||
|
||||
def test_fmt_seconds_hours(self):
|
||||
assert _fmt_seconds(3660) == "1h 1m"
|
||||
assert _fmt_seconds(3600) == "1h"
|
||||
|
||||
def test_bar(self):
|
||||
bar = _bar(50.0, width=10)
|
||||
assert bar == "[█████░░░░░]"
|
||||
assert _bar(0.0, width=10) == "[░░░░░░░░░░]"
|
||||
assert _bar(100.0, width=10) == "[██████████]"
|
||||
|
||||
def test_format_display_no_data(self):
|
||||
state = RateLimitState()
|
||||
result = format_rate_limit_display(state)
|
||||
assert "No rate limit data" in result
|
||||
|
||||
def test_format_display_with_data(self):
|
||||
state = parse_rate_limit_headers(NOUS_HEADERS, provider="nous")
|
||||
result = format_rate_limit_display(state)
|
||||
assert "Nous" in result
|
||||
assert "Requests/min" in result
|
||||
assert "Requests/hr" in result
|
||||
assert "Tokens/min" in result
|
||||
assert "Tokens/hr" in result
|
||||
assert "resets in" in result
|
||||
|
||||
def test_format_display_warning_on_high_usage(self):
|
||||
headers = {
|
||||
**NOUS_HEADERS,
|
||||
"x-ratelimit-remaining-requests": "50", # 750/800 used = 93.75%
|
||||
}
|
||||
state = parse_rate_limit_headers(headers)
|
||||
result = format_rate_limit_display(state)
|
||||
assert "⚠" in result
|
||||
|
||||
def test_format_compact(self):
|
||||
state = parse_rate_limit_headers(NOUS_HEADERS, provider="nous")
|
||||
result = format_rate_limit_compact(state)
|
||||
assert "RPM:" in result
|
||||
assert "RPH:" in result
|
||||
assert "TPM:" in result
|
||||
assert "TPH:" in result
|
||||
assert "resets" in result
|
||||
|
||||
def test_format_compact_no_data(self):
|
||||
state = RateLimitState()
|
||||
result = format_rate_limit_compact(state)
|
||||
assert "No rate limit data" in result
|
||||
|
||||
|
||||
class TestAgentIntegration:
|
||||
"""Test that AIAgent captures rate limit state correctly."""
|
||||
|
||||
def test_capture_rate_limits_from_headers(self):
|
||||
"""Simulate the header capture path without a real API call."""
|
||||
import sys
|
||||
import os
|
||||
# Use a mock httpx-like response
|
||||
class MockResponse:
|
||||
headers = NOUS_HEADERS
|
||||
|
||||
# Import AIAgent minimally
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
# Test the parsing directly
|
||||
state = parse_rate_limit_headers(MockResponse.headers, provider="nous")
|
||||
assert state is not None
|
||||
assert state.requests_min.limit == 800
|
||||
assert state.tokens_hour.limit == 336000000
|
||||
|
||||
def test_capture_rate_limits_none_response(self):
|
||||
"""_capture_rate_limits should handle None gracefully."""
|
||||
from agent.rate_limit_tracker import parse_rate_limit_headers
|
||||
# None should not crash
|
||||
result = parse_rate_limit_headers({})
|
||||
assert result is None
|
||||
@@ -3,7 +3,6 @@
|
||||
import os
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from agent.subdirectory_hints import SubdirectoryHintTracker
|
||||
|
||||
@@ -190,45 +189,3 @@ class TestSubdirectoryHintTracker:
|
||||
"terminal", {"command": "curl https://example.com/frontend/api"}
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestPermissionErrorHandling:
|
||||
"""Regression tests for PermissionError in filesystem checks (ref #6214)."""
|
||||
|
||||
def test_is_valid_subdir_permission_error(self, tmp_path):
|
||||
"""_is_valid_subdir should return False when is_dir() raises PermissionError."""
|
||||
tracker = SubdirectoryHintTracker(working_dir=str(tmp_path))
|
||||
restricted = tmp_path / "restricted"
|
||||
restricted.mkdir()
|
||||
with patch.object(Path, "is_dir", side_effect=PermissionError("Permission denied")):
|
||||
assert tracker._is_valid_subdir(restricted) is False
|
||||
|
||||
def test_load_hints_permission_error_on_is_file(self, tmp_path):
|
||||
"""_load_hints_for_directory should skip files when is_file() raises PermissionError."""
|
||||
tracker = SubdirectoryHintTracker(working_dir=str(tmp_path))
|
||||
restricted = tmp_path / "restricted"
|
||||
restricted.mkdir()
|
||||
original_is_file = Path.is_file
|
||||
def patched_is_file(self):
|
||||
if "restricted" in str(self):
|
||||
raise PermissionError("Permission denied")
|
||||
return original_is_file(self)
|
||||
with patch.object(Path, "is_file", patched_is_file):
|
||||
result = tracker._load_hints_for_directory(restricted)
|
||||
assert result is None
|
||||
|
||||
def test_check_tool_call_survives_inaccessible_path(self, project):
|
||||
"""Full check_tool_call should not crash when a path is inaccessible."""
|
||||
tracker = SubdirectoryHintTracker(working_dir=str(project))
|
||||
original_is_dir = Path.is_dir
|
||||
def patched_is_dir(self):
|
||||
if "backend" in str(self) and "src" not in str(self):
|
||||
raise PermissionError("Permission denied")
|
||||
return original_is_dir(self)
|
||||
with patch.object(Path, "is_dir", patched_is_dir):
|
||||
# Should not raise — gracefully skip the inaccessible directory
|
||||
result = tracker.check_tool_call(
|
||||
"read_file", {"path": str(project / "backend" / "src" / "main.py")}
|
||||
)
|
||||
# Result may be None (backend skipped) — the key point is no crash
|
||||
assert result is None or isinstance(result, str)
|
||||
|
||||
@@ -2,65 +2,22 @@ import queue
|
||||
import threading
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import cli as cli_module
|
||||
from cli import HermesCLI
|
||||
|
||||
|
||||
class _FakeBuffer:
|
||||
def __init__(self, text="", cursor_position=None):
|
||||
self.text = text
|
||||
self.cursor_position = len(text) if cursor_position is None else cursor_position
|
||||
|
||||
def reset(self, append_to_history=False):
|
||||
self.text = ""
|
||||
self.cursor_position = 0
|
||||
|
||||
|
||||
def _make_cli_stub():
|
||||
cli = HermesCLI.__new__(HermesCLI)
|
||||
cli._approval_state = None
|
||||
cli._approval_deadline = 0
|
||||
cli._approval_lock = threading.Lock()
|
||||
cli._sudo_state = None
|
||||
cli._sudo_deadline = 0
|
||||
cli._modal_input_snapshot = None
|
||||
cli._invalidate = MagicMock()
|
||||
cli._app = SimpleNamespace(invalidate=MagicMock(), current_buffer=_FakeBuffer())
|
||||
cli._app = SimpleNamespace(invalidate=MagicMock())
|
||||
return cli
|
||||
|
||||
|
||||
class TestCliApprovalUi:
|
||||
def test_sudo_prompt_restores_existing_draft_after_response(self):
|
||||
cli = _make_cli_stub()
|
||||
cli._app.current_buffer = _FakeBuffer("draft command", cursor_position=5)
|
||||
result = {}
|
||||
|
||||
def _run_callback():
|
||||
result["value"] = cli._sudo_password_callback()
|
||||
|
||||
with patch.object(cli_module, "_cprint"):
|
||||
thread = threading.Thread(target=_run_callback, daemon=True)
|
||||
thread.start()
|
||||
|
||||
deadline = time.time() + 2
|
||||
while cli._sudo_state is None and time.time() < deadline:
|
||||
time.sleep(0.01)
|
||||
|
||||
assert cli._sudo_state is not None
|
||||
assert cli._app.current_buffer.text == ""
|
||||
|
||||
cli._app.current_buffer.text = "secret"
|
||||
cli._app.current_buffer.cursor_position = len("secret")
|
||||
cli._sudo_state["response_queue"].put("secret")
|
||||
|
||||
thread.join(timeout=2)
|
||||
|
||||
assert result["value"] == "secret"
|
||||
assert cli._app.current_buffer.text == "draft command"
|
||||
assert cli._app.current_buffer.cursor_position == 5
|
||||
|
||||
def test_approval_callback_includes_view_for_long_commands(self):
|
||||
cli = _make_cli_stub()
|
||||
command = "sudo dd if=/tmp/githubcli-keyring.gpg of=/usr/share/keyrings/githubcli-archive-keyring.gpg bs=4M status=progress"
|
||||
|
||||
@@ -6,17 +6,6 @@ from unittest.mock import patch
|
||||
from cli import HermesCLI
|
||||
|
||||
|
||||
def _assert_chrome_debug_cmd(cmd, expected_chrome, expected_port):
|
||||
"""Verify the auto-launch command has all required flags."""
|
||||
assert cmd[0] == expected_chrome
|
||||
assert f"--remote-debugging-port={expected_port}" in cmd
|
||||
assert "--no-first-run" in cmd
|
||||
assert "--no-default-browser-check" in cmd
|
||||
user_data_args = [a for a in cmd if a.startswith("--user-data-dir=")]
|
||||
assert len(user_data_args) == 1, "Expected exactly one --user-data-dir flag"
|
||||
assert "chrome-debug" in user_data_args[0]
|
||||
|
||||
|
||||
class TestChromeDebugLaunch:
|
||||
def test_windows_launch_uses_browser_found_on_path(self):
|
||||
captured = {}
|
||||
@@ -31,7 +20,7 @@ class TestChromeDebugLaunch:
|
||||
patch("subprocess.Popen", side_effect=fake_popen):
|
||||
assert HermesCLI._try_launch_chrome_debug(9333, "Windows") is True
|
||||
|
||||
_assert_chrome_debug_cmd(captured["cmd"], r"C:\Chrome\chrome.exe", 9333)
|
||||
assert captured["cmd"] == [r"C:\Chrome\chrome.exe", "--remote-debugging-port=9333"]
|
||||
assert captured["kwargs"]["start_new_session"] is True
|
||||
|
||||
def test_windows_launch_falls_back_to_common_install_dirs(self, monkeypatch):
|
||||
@@ -54,4 +43,4 @@ class TestChromeDebugLaunch:
|
||||
patch("subprocess.Popen", side_effect=fake_popen):
|
||||
assert HermesCLI._try_launch_chrome_debug(9222, "Windows") is True
|
||||
|
||||
_assert_chrome_debug_cmd(captured["cmd"], installed, 9222)
|
||||
assert captured["cmd"] == [installed, "--remote-debugging-port=9222"]
|
||||
|
||||
@@ -41,7 +41,6 @@ def _attach_agent(
|
||||
session_completion_tokens=completion_tokens,
|
||||
session_total_tokens=total_tokens,
|
||||
session_api_calls=api_calls,
|
||||
get_rate_limit_state=lambda: None,
|
||||
context_compressor=SimpleNamespace(
|
||||
last_prompt_tokens=context_tokens,
|
||||
context_length=context_length,
|
||||
|
||||
@@ -619,14 +619,17 @@ class TestReasoningDeltasFiredFlag(unittest.TestCase):
|
||||
agent = AIAgent.__new__(AIAgent)
|
||||
agent.reasoning_callback = None
|
||||
agent.stream_delta_callback = None
|
||||
agent._reasoning_deltas_fired = False
|
||||
agent.verbose_logging = False
|
||||
return agent
|
||||
|
||||
def test_fire_reasoning_delta_calls_callback(self):
|
||||
def test_fire_reasoning_delta_sets_flag(self):
|
||||
agent = self._make_agent()
|
||||
captured = []
|
||||
agent.reasoning_callback = lambda t: captured.append(t)
|
||||
self.assertFalse(agent._reasoning_deltas_fired)
|
||||
agent._fire_reasoning_delta("thinking...")
|
||||
self.assertTrue(agent._reasoning_deltas_fired)
|
||||
self.assertEqual(captured, ["thinking..."])
|
||||
|
||||
def test_build_assistant_message_skips_callback_when_already_streamed(self):
|
||||
@@ -637,7 +640,8 @@ class TestReasoningDeltasFiredFlag(unittest.TestCase):
|
||||
agent.reasoning_callback = lambda t: captured.append(t)
|
||||
agent.stream_delta_callback = lambda t: None # streaming is active
|
||||
|
||||
# Simulate streaming having already fired reasoning
|
||||
# Simulate streaming having fired reasoning
|
||||
agent._reasoning_deltas_fired = True
|
||||
|
||||
msg = SimpleNamespace(
|
||||
content="I'll merge that.",
|
||||
@@ -661,8 +665,9 @@ class TestReasoningDeltasFiredFlag(unittest.TestCase):
|
||||
agent.reasoning_callback = lambda t: captured.append(t)
|
||||
agent.stream_delta_callback = lambda t: None # streaming active
|
||||
|
||||
# Reasoning came through content tags, not reasoning_content deltas.
|
||||
# Callback should not fire since streaming is active.
|
||||
# Even though _reasoning_deltas_fired is False (reasoning came through
|
||||
# content tags, not reasoning_content deltas), callback should not fire
|
||||
agent._reasoning_deltas_fired = False
|
||||
|
||||
msg = SimpleNamespace(
|
||||
content="I'll merge that.",
|
||||
@@ -684,6 +689,7 @@ class TestReasoningDeltasFiredFlag(unittest.TestCase):
|
||||
agent.reasoning_callback = lambda t: captured.append(t)
|
||||
# No streaming
|
||||
agent.stream_delta_callback = None
|
||||
agent._reasoning_deltas_fired = False
|
||||
|
||||
msg = SimpleNamespace(
|
||||
content="I'll merge that.",
|
||||
|
||||
@@ -38,8 +38,6 @@ def _isolate_hermes_home(tmp_path, monkeypatch):
|
||||
monkeypatch.delenv("HERMES_SESSION_CHAT_ID", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_CHAT_NAME", raising=False)
|
||||
monkeypatch.delenv("HERMES_GATEWAY_SESSION", raising=False)
|
||||
# Avoid making real calls during tests if this key is set in the env files
|
||||
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
||||
@@ -141,7 +141,7 @@ class TestBlockingGatewayApproval:
|
||||
def test_resolve_single_pops_oldest_fifo(self):
|
||||
"""resolve_gateway_approval without resolve_all resolves oldest first."""
|
||||
from tools.approval import (
|
||||
resolve_gateway_approval,
|
||||
resolve_gateway_approval, pending_approval_count,
|
||||
_ApprovalEntry, _gateway_queues,
|
||||
)
|
||||
session_key = "test-fifo"
|
||||
@@ -154,7 +154,7 @@ class TestBlockingGatewayApproval:
|
||||
assert e1.event.is_set()
|
||||
assert e1.result == "once"
|
||||
assert not e2.event.is_set()
|
||||
assert len(_gateway_queues[session_key]) == 1
|
||||
assert pending_approval_count(session_key) == 1
|
||||
|
||||
def test_unregister_signals_all_entries(self):
|
||||
"""unregister_gateway_notify signals all waiting entries to prevent hangs."""
|
||||
@@ -173,6 +173,35 @@ class TestBlockingGatewayApproval:
|
||||
assert e1.event.is_set()
|
||||
assert e2.event.is_set()
|
||||
|
||||
def test_clear_session_signals_all_entries(self):
|
||||
"""clear_session should unblock all waiting approval threads."""
|
||||
from tools.approval import (
|
||||
register_gateway_notify, clear_session,
|
||||
_ApprovalEntry, _gateway_queues,
|
||||
)
|
||||
session_key = "test-clear"
|
||||
register_gateway_notify(session_key, lambda d: None)
|
||||
|
||||
e1 = _ApprovalEntry({"command": "cmd1"})
|
||||
e2 = _ApprovalEntry({"command": "cmd2"})
|
||||
_gateway_queues[session_key] = [e1, e2]
|
||||
|
||||
clear_session(session_key)
|
||||
assert e1.event.is_set()
|
||||
assert e2.event.is_set()
|
||||
|
||||
def test_pending_approval_count(self):
|
||||
from tools.approval import (
|
||||
pending_approval_count, _ApprovalEntry, _gateway_queues,
|
||||
)
|
||||
session_key = "test-count"
|
||||
assert pending_approval_count(session_key) == 0
|
||||
_gateway_queues[session_key] = [
|
||||
_ApprovalEntry({"command": "a"}),
|
||||
_ApprovalEntry({"command": "b"}),
|
||||
]
|
||||
assert pending_approval_count(session_key) == 2
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# /approve command
|
||||
@@ -477,7 +506,7 @@ class TestBlockingApprovalE2E:
|
||||
from tools.approval import (
|
||||
register_gateway_notify, unregister_gateway_notify,
|
||||
resolve_gateway_approval, check_all_command_guards,
|
||||
_gateway_queues,
|
||||
pending_approval_count,
|
||||
)
|
||||
|
||||
session_key = "e2e-parallel"
|
||||
@@ -516,7 +545,7 @@ class TestBlockingApprovalE2E:
|
||||
time.sleep(0.05)
|
||||
|
||||
assert len(notified) == 3
|
||||
assert len(_gateway_queues.get(session_key, [])) == 3
|
||||
assert pending_approval_count(session_key) == 3
|
||||
|
||||
# Approve all at once
|
||||
count = resolve_gateway_approval(session_key, "session", resolve_all=True)
|
||||
|
||||
@@ -1,361 +0,0 @@
|
||||
"""Tests for the BlueBubbles iMessage gateway adapter."""
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
|
||||
|
||||
def _make_adapter(monkeypatch, **extra):
|
||||
monkeypatch.setenv("BLUEBUBBLES_SERVER_URL", "http://localhost:1234")
|
||||
monkeypatch.setenv("BLUEBUBBLES_PASSWORD", "secret")
|
||||
from gateway.platforms.bluebubbles import BlueBubblesAdapter
|
||||
|
||||
cfg = PlatformConfig(
|
||||
enabled=True,
|
||||
extra={
|
||||
"server_url": "http://localhost:1234",
|
||||
"password": "secret",
|
||||
**extra,
|
||||
},
|
||||
)
|
||||
return BlueBubblesAdapter(cfg)
|
||||
|
||||
|
||||
class TestBlueBubblesPlatformEnum:
|
||||
def test_bluebubbles_enum_exists(self):
|
||||
assert Platform.BLUEBUBBLES.value == "bluebubbles"
|
||||
|
||||
|
||||
class TestBlueBubblesConfigLoading:
|
||||
def test_apply_env_overrides_bluebubbles(self, monkeypatch):
|
||||
monkeypatch.setenv("BLUEBUBBLES_SERVER_URL", "http://localhost:1234")
|
||||
monkeypatch.setenv("BLUEBUBBLES_PASSWORD", "secret")
|
||||
monkeypatch.setenv("BLUEBUBBLES_WEBHOOK_PORT", "9999")
|
||||
from gateway.config import GatewayConfig, _apply_env_overrides
|
||||
|
||||
config = GatewayConfig()
|
||||
_apply_env_overrides(config)
|
||||
assert Platform.BLUEBUBBLES in config.platforms
|
||||
bc = config.platforms[Platform.BLUEBUBBLES]
|
||||
assert bc.enabled is True
|
||||
assert bc.extra["server_url"] == "http://localhost:1234"
|
||||
assert bc.extra["password"] == "secret"
|
||||
assert bc.extra["webhook_port"] == 9999
|
||||
|
||||
def test_connected_platforms_includes_bluebubbles(self, monkeypatch):
|
||||
monkeypatch.setenv("BLUEBUBBLES_SERVER_URL", "http://localhost:1234")
|
||||
monkeypatch.setenv("BLUEBUBBLES_PASSWORD", "secret")
|
||||
from gateway.config import GatewayConfig, _apply_env_overrides
|
||||
|
||||
config = GatewayConfig()
|
||||
_apply_env_overrides(config)
|
||||
assert Platform.BLUEBUBBLES in config.get_connected_platforms()
|
||||
|
||||
def test_home_channel_set_from_env(self, monkeypatch):
|
||||
monkeypatch.setenv("BLUEBUBBLES_SERVER_URL", "http://localhost:1234")
|
||||
monkeypatch.setenv("BLUEBUBBLES_PASSWORD", "secret")
|
||||
monkeypatch.setenv("BLUEBUBBLES_HOME_CHANNEL", "user@example.com")
|
||||
from gateway.config import GatewayConfig, _apply_env_overrides
|
||||
|
||||
config = GatewayConfig()
|
||||
_apply_env_overrides(config)
|
||||
hc = config.platforms[Platform.BLUEBUBBLES].home_channel
|
||||
assert hc is not None
|
||||
assert hc.chat_id == "user@example.com"
|
||||
|
||||
def test_not_connected_without_password(self, monkeypatch):
|
||||
monkeypatch.setenv("BLUEBUBBLES_SERVER_URL", "http://localhost:1234")
|
||||
monkeypatch.delenv("BLUEBUBBLES_PASSWORD", raising=False)
|
||||
from gateway.config import GatewayConfig, _apply_env_overrides
|
||||
|
||||
config = GatewayConfig()
|
||||
_apply_env_overrides(config)
|
||||
assert Platform.BLUEBUBBLES not in config.get_connected_platforms()
|
||||
|
||||
|
||||
class TestBlueBubblesHelpers:
|
||||
def test_check_requirements(self, monkeypatch):
|
||||
monkeypatch.setenv("BLUEBUBBLES_SERVER_URL", "http://localhost:1234")
|
||||
monkeypatch.setenv("BLUEBUBBLES_PASSWORD", "secret")
|
||||
from gateway.platforms.bluebubbles import check_bluebubbles_requirements
|
||||
|
||||
assert check_bluebubbles_requirements() is True
|
||||
|
||||
def test_format_message_strips_markdown(self, monkeypatch):
|
||||
adapter = _make_adapter(monkeypatch)
|
||||
assert adapter.format_message("**Hello** `world`") == "Hello world"
|
||||
|
||||
def test_strip_markdown_headers(self, monkeypatch):
|
||||
adapter = _make_adapter(monkeypatch)
|
||||
assert adapter.format_message("## Heading\ntext") == "Heading\ntext"
|
||||
|
||||
def test_strip_markdown_links(self, monkeypatch):
|
||||
adapter = _make_adapter(monkeypatch)
|
||||
assert adapter.format_message("[click here](http://example.com)") == "click here"
|
||||
|
||||
def test_init_normalizes_webhook_path(self, monkeypatch):
|
||||
adapter = _make_adapter(monkeypatch, webhook_path="bluebubbles-webhook")
|
||||
assert adapter.webhook_path == "/bluebubbles-webhook"
|
||||
|
||||
def test_init_preserves_leading_slash(self, monkeypatch):
|
||||
adapter = _make_adapter(monkeypatch, webhook_path="/my-hook")
|
||||
assert adapter.webhook_path == "/my-hook"
|
||||
|
||||
def test_server_url_normalized(self, monkeypatch):
|
||||
adapter = _make_adapter(monkeypatch, server_url="http://localhost:1234/")
|
||||
assert adapter.server_url == "http://localhost:1234"
|
||||
|
||||
def test_server_url_adds_scheme(self, monkeypatch):
|
||||
adapter = _make_adapter(monkeypatch, server_url="localhost:1234")
|
||||
assert adapter.server_url == "http://localhost:1234"
|
||||
|
||||
|
||||
class TestBlueBubblesWebhookParsing:
|
||||
def test_webhook_prefers_chat_guid_over_message_guid(self, monkeypatch):
|
||||
adapter = _make_adapter(monkeypatch)
|
||||
payload = {
|
||||
"guid": "MESSAGE-GUID",
|
||||
"chatGuid": "iMessage;-;user@example.com",
|
||||
"chatIdentifier": "user@example.com",
|
||||
}
|
||||
record = adapter._extract_payload_record(payload) or {}
|
||||
chat_guid = adapter._value(
|
||||
record.get("chatGuid"),
|
||||
payload.get("chatGuid"),
|
||||
record.get("chat_guid"),
|
||||
payload.get("chat_guid"),
|
||||
payload.get("guid"),
|
||||
)
|
||||
assert chat_guid == "iMessage;-;user@example.com"
|
||||
|
||||
def test_webhook_can_fall_back_to_sender_when_chat_fields_missing(self, monkeypatch):
|
||||
adapter = _make_adapter(monkeypatch)
|
||||
payload = {
|
||||
"data": {
|
||||
"guid": "MESSAGE-GUID",
|
||||
"text": "hello",
|
||||
"handle": {"address": "user@example.com"},
|
||||
"isFromMe": False,
|
||||
}
|
||||
}
|
||||
record = adapter._extract_payload_record(payload) or {}
|
||||
chat_guid = adapter._value(
|
||||
record.get("chatGuid"),
|
||||
payload.get("chatGuid"),
|
||||
record.get("chat_guid"),
|
||||
payload.get("chat_guid"),
|
||||
payload.get("guid"),
|
||||
)
|
||||
chat_identifier = adapter._value(
|
||||
record.get("chatIdentifier"),
|
||||
record.get("identifier"),
|
||||
payload.get("chatIdentifier"),
|
||||
payload.get("identifier"),
|
||||
)
|
||||
sender = (
|
||||
adapter._value(
|
||||
record.get("handle", {}).get("address")
|
||||
if isinstance(record.get("handle"), dict)
|
||||
else None,
|
||||
record.get("sender"),
|
||||
record.get("from"),
|
||||
record.get("address"),
|
||||
)
|
||||
or chat_identifier
|
||||
or chat_guid
|
||||
)
|
||||
if not (chat_guid or chat_identifier) and sender:
|
||||
chat_identifier = sender
|
||||
assert chat_identifier == "user@example.com"
|
||||
|
||||
def test_extract_payload_record_accepts_list_data(self, monkeypatch):
|
||||
adapter = _make_adapter(monkeypatch)
|
||||
payload = {
|
||||
"type": "new-message",
|
||||
"data": [
|
||||
{
|
||||
"text": "hello",
|
||||
"chatGuid": "iMessage;-;user@example.com",
|
||||
"chatIdentifier": "user@example.com",
|
||||
}
|
||||
],
|
||||
}
|
||||
record = adapter._extract_payload_record(payload)
|
||||
assert record == payload["data"][0]
|
||||
|
||||
def test_extract_payload_record_dict_data(self, monkeypatch):
|
||||
adapter = _make_adapter(monkeypatch)
|
||||
payload = {"data": {"text": "hello", "chatGuid": "iMessage;-;+1234"}}
|
||||
record = adapter._extract_payload_record(payload)
|
||||
assert record["text"] == "hello"
|
||||
|
||||
def test_extract_payload_record_fallback_to_message(self, monkeypatch):
|
||||
adapter = _make_adapter(monkeypatch)
|
||||
payload = {"message": {"text": "hello"}}
|
||||
record = adapter._extract_payload_record(payload)
|
||||
assert record["text"] == "hello"
|
||||
|
||||
|
||||
class TestBlueBubblesGuidResolution:
|
||||
def test_raw_guid_returned_as_is(self, monkeypatch):
|
||||
"""If target already contains ';' it's a raw GUID — return unchanged."""
|
||||
adapter = _make_adapter(monkeypatch)
|
||||
import asyncio
|
||||
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
adapter._resolve_chat_guid("iMessage;-;user@example.com")
|
||||
)
|
||||
assert result == "iMessage;-;user@example.com"
|
||||
|
||||
def test_empty_target_returns_none(self, monkeypatch):
|
||||
adapter = _make_adapter(monkeypatch)
|
||||
import asyncio
|
||||
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
adapter._resolve_chat_guid("")
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestBlueBubblesToolsetIntegration:
|
||||
def test_toolset_exists(self):
|
||||
from toolsets import TOOLSETS
|
||||
|
||||
assert "hermes-bluebubbles" in TOOLSETS
|
||||
|
||||
def test_toolset_in_gateway_composite(self):
|
||||
from toolsets import TOOLSETS
|
||||
|
||||
gateway = TOOLSETS["hermes-gateway"]
|
||||
assert "hermes-bluebubbles" in gateway["includes"]
|
||||
|
||||
|
||||
class TestBlueBubblesPromptHint:
|
||||
def test_platform_hint_exists(self):
|
||||
from agent.prompt_builder import PLATFORM_HINTS
|
||||
|
||||
assert "bluebubbles" in PLATFORM_HINTS
|
||||
hint = PLATFORM_HINTS["bluebubbles"]
|
||||
assert "iMessage" in hint
|
||||
assert "plain text" in hint
|
||||
|
||||
|
||||
class TestBlueBubblesAttachmentDownload:
|
||||
"""Verify _download_attachment routes to the correct cache helper."""
|
||||
|
||||
def test_download_image_uses_image_cache(self, monkeypatch):
|
||||
"""Image MIME routes to cache_image_from_bytes."""
|
||||
adapter = _make_adapter(monkeypatch)
|
||||
import asyncio
|
||||
import httpx
|
||||
|
||||
# Mock the HTTP client response
|
||||
class MockResponse:
|
||||
status_code = 200
|
||||
content = b"\x89PNG\r\n\x1a\n"
|
||||
|
||||
def raise_for_status(self):
|
||||
pass
|
||||
|
||||
async def mock_get(*args, **kwargs):
|
||||
return MockResponse()
|
||||
|
||||
adapter.client = type("MockClient", (), {"get": mock_get})()
|
||||
|
||||
cached_path = None
|
||||
|
||||
def mock_cache_image(data, ext):
|
||||
nonlocal cached_path
|
||||
cached_path = f"/tmp/test_image{ext}"
|
||||
return cached_path
|
||||
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.bluebubbles.cache_image_from_bytes",
|
||||
mock_cache_image,
|
||||
)
|
||||
|
||||
att_meta = {"mimeType": "image/png", "transferName": "photo.png"}
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
adapter._download_attachment("att-guid-123", att_meta)
|
||||
)
|
||||
assert result == "/tmp/test_image.png"
|
||||
|
||||
def test_download_audio_uses_audio_cache(self, monkeypatch):
|
||||
"""Audio MIME routes to cache_audio_from_bytes."""
|
||||
adapter = _make_adapter(monkeypatch)
|
||||
import asyncio
|
||||
|
||||
class MockResponse:
|
||||
status_code = 200
|
||||
content = b"fake-audio-data"
|
||||
|
||||
def raise_for_status(self):
|
||||
pass
|
||||
|
||||
async def mock_get(*args, **kwargs):
|
||||
return MockResponse()
|
||||
|
||||
adapter.client = type("MockClient", (), {"get": mock_get})()
|
||||
|
||||
cached_path = None
|
||||
|
||||
def mock_cache_audio(data, ext):
|
||||
nonlocal cached_path
|
||||
cached_path = f"/tmp/test_audio{ext}"
|
||||
return cached_path
|
||||
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.bluebubbles.cache_audio_from_bytes",
|
||||
mock_cache_audio,
|
||||
)
|
||||
|
||||
att_meta = {"mimeType": "audio/mpeg", "transferName": "voice.mp3"}
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
adapter._download_attachment("att-guid-456", att_meta)
|
||||
)
|
||||
assert result == "/tmp/test_audio.mp3"
|
||||
|
||||
def test_download_document_uses_document_cache(self, monkeypatch):
|
||||
"""Non-image/audio MIME routes to cache_document_from_bytes."""
|
||||
adapter = _make_adapter(monkeypatch)
|
||||
import asyncio
|
||||
|
||||
class MockResponse:
|
||||
status_code = 200
|
||||
content = b"fake-doc-data"
|
||||
|
||||
def raise_for_status(self):
|
||||
pass
|
||||
|
||||
async def mock_get(*args, **kwargs):
|
||||
return MockResponse()
|
||||
|
||||
adapter.client = type("MockClient", (), {"get": mock_get})()
|
||||
|
||||
cached_path = None
|
||||
|
||||
def mock_cache_doc(data, filename):
|
||||
nonlocal cached_path
|
||||
cached_path = f"/tmp/{filename}"
|
||||
return cached_path
|
||||
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.bluebubbles.cache_document_from_bytes",
|
||||
mock_cache_doc,
|
||||
)
|
||||
|
||||
att_meta = {"mimeType": "application/pdf", "transferName": "report.pdf"}
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
adapter._download_attachment("att-guid-789", att_meta)
|
||||
)
|
||||
assert result == "/tmp/report.pdf"
|
||||
|
||||
def test_download_returns_none_without_client(self, monkeypatch):
|
||||
"""No client → returns None gracefully."""
|
||||
adapter = _make_adapter(monkeypatch)
|
||||
adapter.client = None
|
||||
import asyncio
|
||||
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
adapter._download_attachment("att-guid", {"mimeType": "image/png"})
|
||||
)
|
||||
assert result is None
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Tests for the delivery routing module."""
|
||||
|
||||
from gateway.config import Platform, GatewayConfig, PlatformConfig, HomeChannel
|
||||
from gateway.delivery import DeliveryRouter, DeliveryTarget
|
||||
from gateway.delivery import DeliveryRouter, DeliveryTarget, parse_deliver_spec
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
@@ -41,6 +41,28 @@ class TestParseTargetPlatformChat:
|
||||
assert target.platform == Platform.LOCAL
|
||||
|
||||
|
||||
class TestParseDeliverSpec:
|
||||
def test_none_returns_default(self):
|
||||
result = parse_deliver_spec(None)
|
||||
assert result == "origin"
|
||||
|
||||
def test_empty_string_returns_default(self):
|
||||
result = parse_deliver_spec("")
|
||||
assert result == "origin"
|
||||
|
||||
def test_custom_default(self):
|
||||
result = parse_deliver_spec(None, default="local")
|
||||
assert result == "local"
|
||||
|
||||
def test_passthrough_string(self):
|
||||
result = parse_deliver_spec("telegram")
|
||||
assert result == "telegram"
|
||||
|
||||
def test_passthrough_list(self):
|
||||
result = parse_deliver_spec(["local", "telegram"])
|
||||
assert result == ["local", "telegram"]
|
||||
|
||||
|
||||
class TestTargetToStringRoundtrip:
|
||||
def test_origin_roundtrip(self):
|
||||
origin = SessionSource(platform=Platform.TELEGRAM, chat_id="111", thread_id="42")
|
||||
|
||||
@@ -56,7 +56,7 @@ class FakeTree:
|
||||
|
||||
|
||||
class FakeBot:
|
||||
def __init__(self, *, intents, proxy=None):
|
||||
def __init__(self, *, intents):
|
||||
self.intents = intents
|
||||
self.user = SimpleNamespace(id=999, name="Hermes")
|
||||
self._events = {}
|
||||
@@ -95,7 +95,7 @@ async def test_connect_only_requests_members_intent_when_needed(monkeypatch, all
|
||||
|
||||
created = {}
|
||||
|
||||
def fake_bot_factory(*, command_prefix, intents, proxy=None):
|
||||
def fake_bot_factory(*, command_prefix, intents):
|
||||
created["bot"] = FakeBot(intents=intents)
|
||||
return created["bot"]
|
||||
|
||||
@@ -124,7 +124,7 @@ async def test_connect_releases_token_lock_on_timeout(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
discord_platform.commands,
|
||||
"Bot",
|
||||
lambda **kwargs: FakeBot(intents=kwargs["intents"], proxy=kwargs.get("proxy")),
|
||||
lambda **kwargs: FakeBot(intents=kwargs["intents"]),
|
||||
)
|
||||
|
||||
async def fake_wait_for(awaitable, timeout):
|
||||
|
||||
@@ -209,31 +209,14 @@ class TestIncomingDocumentHandling:
|
||||
assert "[Content of readme.md]:" in event.text
|
||||
assert "# Title" in event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_content_injected(self, adapter):
|
||||
""".log file under 100KB should be treated as text/plain and injected."""
|
||||
file_content = b"BLE trace line 1\nBLE trace line 2"
|
||||
|
||||
with _mock_aiohttp_download(file_content):
|
||||
msg = make_message(
|
||||
attachments=[make_attachment(filename="btsnoop_hci.log", content_type="text/plain")],
|
||||
content="please inspect this",
|
||||
)
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert "[Content of btsnoop_hci.log]:" in event.text
|
||||
assert "BLE trace line 1" in event.text
|
||||
assert "please inspect this" in event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oversized_document_skipped(self, adapter):
|
||||
"""A document over 32MB should be skipped — media_urls stays empty."""
|
||||
"""A document over 20MB should be skipped — media_urls stays empty."""
|
||||
msg = make_message([
|
||||
make_attachment(
|
||||
filename="huge.pdf",
|
||||
content_type="application/pdf",
|
||||
size=33 * 1024 * 1024,
|
||||
size=25 * 1024 * 1024,
|
||||
)
|
||||
])
|
||||
await adapter._handle_message(msg)
|
||||
@@ -243,24 +226,6 @@ class TestIncomingDocumentHandling:
|
||||
# handler must still be called
|
||||
adapter.handle_message.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mid_sized_zip_under_32mb_is_cached(self, adapter):
|
||||
"""A 25MB .zip should be accepted now that Discord documents allow up to 32MB."""
|
||||
msg = make_message([
|
||||
make_attachment(
|
||||
filename="bugreport.zip",
|
||||
content_type="application/zip",
|
||||
size=25 * 1024 * 1024,
|
||||
)
|
||||
])
|
||||
|
||||
with _mock_aiohttp_download(b"PK\x03\x04test"):
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert len(event.media_urls) == 1
|
||||
assert event.media_types == ["application/zip"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zip_document_cached(self, adapter):
|
||||
"""A .zip file should be cached as a supported document."""
|
||||
|
||||
@@ -1,315 +0,0 @@
|
||||
"""Tests for staged inactivity timeout in gateway agent runs.
|
||||
|
||||
Tests cover:
|
||||
- Warning fires once when inactivity reaches gateway_timeout_warning threshold
|
||||
- Warning does not fire when gateway_timeout is 0 (unlimited)
|
||||
- Warning fires only once per run, not on every poll
|
||||
- Full timeout still fires at gateway_timeout threshold
|
||||
- Warning respects HERMES_AGENT_TIMEOUT_WARNING env var
|
||||
- Warning disabled when gateway_timeout_warning is 0
|
||||
"""
|
||||
|
||||
import concurrent.futures
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
|
||||
class FakeAgent:
|
||||
"""Mock agent with controllable activity summary for timeout tests."""
|
||||
|
||||
def __init__(self, idle_seconds=0.0, activity_desc="tool_call",
|
||||
current_tool=None, api_call_count=5, max_iterations=90):
|
||||
self._idle_seconds = idle_seconds
|
||||
self._activity_desc = activity_desc
|
||||
self._current_tool = current_tool
|
||||
self._api_call_count = api_call_count
|
||||
self._max_iterations = max_iterations
|
||||
self._interrupted = False
|
||||
self._interrupt_msg = None
|
||||
|
||||
def get_activity_summary(self):
|
||||
return {
|
||||
"last_activity_ts": time.time() - self._idle_seconds,
|
||||
"last_activity_desc": self._activity_desc,
|
||||
"seconds_since_activity": self._idle_seconds,
|
||||
"current_tool": self._current_tool,
|
||||
"api_call_count": self._api_call_count,
|
||||
"max_iterations": self._max_iterations,
|
||||
}
|
||||
|
||||
def interrupt(self, msg):
|
||||
self._interrupted = True
|
||||
self._interrupt_msg = msg
|
||||
|
||||
def run_conversation(self, prompt):
|
||||
return {"final_response": "Done", "messages": []}
|
||||
|
||||
|
||||
class SlowFakeAgent(FakeAgent):
|
||||
"""Agent that runs for a while, then goes idle."""
|
||||
|
||||
def __init__(self, run_duration=0.5, idle_after=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._run_duration = run_duration
|
||||
self._idle_after = idle_after
|
||||
self._start_time = None
|
||||
|
||||
def get_activity_summary(self):
|
||||
summary = super().get_activity_summary()
|
||||
if self._idle_after is not None and self._start_time:
|
||||
elapsed = time.time() - self._start_time
|
||||
if elapsed > self._idle_after:
|
||||
idle_time = elapsed - self._idle_after
|
||||
summary["seconds_since_activity"] = idle_time
|
||||
summary["last_activity_desc"] = "api_call_streaming"
|
||||
else:
|
||||
summary["seconds_since_activity"] = 0.0
|
||||
return summary
|
||||
|
||||
def run_conversation(self, prompt):
|
||||
self._start_time = time.time()
|
||||
time.sleep(self._run_duration)
|
||||
return {"final_response": "Completed after work", "messages": []}
|
||||
|
||||
|
||||
class TestStagedInactivityWarning:
|
||||
"""Test the staged inactivity warning before full timeout."""
|
||||
|
||||
def test_warning_fires_once_before_timeout(self):
|
||||
"""Warning fires when inactivity reaches warning threshold."""
|
||||
agent = SlowFakeAgent(
|
||||
run_duration=10.0,
|
||||
idle_after=0.1,
|
||||
activity_desc="api_call_streaming",
|
||||
)
|
||||
|
||||
_agent_timeout = 20.0
|
||||
_agent_warning = 5.0
|
||||
_POLL_INTERVAL = 0.1
|
||||
|
||||
pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
future = pool.submit(agent.run_conversation, "test prompt")
|
||||
_inactivity_timeout = False
|
||||
_warning_fired = False
|
||||
_warning_send_count = 0
|
||||
|
||||
while True:
|
||||
done, _ = concurrent.futures.wait({future}, timeout=_POLL_INTERVAL)
|
||||
if done:
|
||||
result = future.result()
|
||||
break
|
||||
_idle_secs = 0.0
|
||||
if hasattr(agent, "get_activity_summary"):
|
||||
try:
|
||||
_act = agent.get_activity_summary()
|
||||
_idle_secs = _act.get("seconds_since_activity", 0.0)
|
||||
except Exception:
|
||||
pass
|
||||
if (not _warning_fired and _agent_warning > 0
|
||||
and _idle_secs >= _agent_warning):
|
||||
_warning_fired = True
|
||||
_warning_send_count += 1
|
||||
if _idle_secs >= _agent_timeout:
|
||||
_inactivity_timeout = True
|
||||
break
|
||||
|
||||
pool.shutdown(wait=False, cancel_futures=True)
|
||||
|
||||
assert _warning_fired
|
||||
assert _warning_send_count == 1
|
||||
assert not _inactivity_timeout
|
||||
|
||||
def test_warning_disabled_when_zero(self):
|
||||
"""No warning fires when gateway_timeout_warning is 0."""
|
||||
agent = SlowFakeAgent(
|
||||
run_duration=5.0,
|
||||
idle_after=0.1,
|
||||
)
|
||||
|
||||
_agent_timeout = 20.0
|
||||
_agent_warning = 0.0
|
||||
_POLL_INTERVAL = 0.1
|
||||
|
||||
pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
future = pool.submit(agent.run_conversation, "test")
|
||||
_warning_fired = False
|
||||
|
||||
while True:
|
||||
done, _ = concurrent.futures.wait({future}, timeout=_POLL_INTERVAL)
|
||||
if done:
|
||||
future.result()
|
||||
break
|
||||
_idle_secs = 0.0
|
||||
if hasattr(agent, "get_activity_summary"):
|
||||
try:
|
||||
_act = agent.get_activity_summary()
|
||||
_idle_secs = _act.get("seconds_since_activity", 0.0)
|
||||
except Exception:
|
||||
pass
|
||||
if (not _warning_fired and _agent_warning > 0
|
||||
and _idle_secs >= _agent_warning):
|
||||
_warning_fired = True
|
||||
if _idle_secs >= _agent_timeout:
|
||||
break
|
||||
|
||||
pool.shutdown(wait=False, cancel_futures=True)
|
||||
assert not _warning_fired
|
||||
|
||||
def test_warning_fires_only_once(self):
|
||||
"""Warning fires exactly once even if agent remains idle."""
|
||||
agent = SlowFakeAgent(
|
||||
run_duration=10.0,
|
||||
idle_after=0.05,
|
||||
)
|
||||
|
||||
_agent_timeout = 20.0
|
||||
_agent_warning = 0.2
|
||||
_POLL_INTERVAL = 0.05
|
||||
|
||||
pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
future = pool.submit(agent.run_conversation, "test")
|
||||
_warning_count = 0
|
||||
|
||||
while True:
|
||||
done, _ = concurrent.futures.wait({future}, timeout=_POLL_INTERVAL)
|
||||
if done:
|
||||
future.result()
|
||||
break
|
||||
_idle_secs = 0.0
|
||||
if hasattr(agent, "get_activity_summary"):
|
||||
try:
|
||||
_act = agent.get_activity_summary()
|
||||
_idle_secs = _act.get("seconds_since_activity", 0.0)
|
||||
except Exception:
|
||||
pass
|
||||
if (not _warning_count and _agent_warning > 0
|
||||
and _idle_secs >= _agent_warning):
|
||||
_warning_count += 1
|
||||
if _idle_secs >= _agent_timeout:
|
||||
break
|
||||
|
||||
pool.shutdown(wait=False, cancel_futures=True)
|
||||
assert _warning_count == 1
|
||||
|
||||
def test_full_timeout_still_fires_after_warning(self):
|
||||
"""Full timeout fires even after warning was sent."""
|
||||
agent = SlowFakeAgent(
|
||||
run_duration=15.0,
|
||||
idle_after=0.1,
|
||||
activity_desc="waiting for provider response (streaming)",
|
||||
)
|
||||
|
||||
_agent_timeout = 1.0
|
||||
_agent_warning = 0.3
|
||||
_POLL_INTERVAL = 0.05
|
||||
|
||||
pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
future = pool.submit(agent.run_conversation, "test")
|
||||
_inactivity_timeout = False
|
||||
_warning_fired = False
|
||||
|
||||
while True:
|
||||
done, _ = concurrent.futures.wait({future}, timeout=_POLL_INTERVAL)
|
||||
if done:
|
||||
future.result()
|
||||
break
|
||||
_idle_secs = 0.0
|
||||
if hasattr(agent, "get_activity_summary"):
|
||||
try:
|
||||
_act = agent.get_activity_summary()
|
||||
_idle_secs = _act.get("seconds_since_activity", 0.0)
|
||||
except Exception:
|
||||
pass
|
||||
if (not _warning_fired and _agent_warning > 0
|
||||
and _idle_secs >= _agent_warning):
|
||||
_warning_fired = True
|
||||
if _idle_secs >= _agent_timeout:
|
||||
_inactivity_timeout = True
|
||||
break
|
||||
|
||||
pool.shutdown(wait=False, cancel_futures=True)
|
||||
assert _warning_fired
|
||||
assert _inactivity_timeout
|
||||
|
||||
def test_warning_env_var_respected(self, monkeypatch):
|
||||
"""HERMES_AGENT_TIMEOUT_WARNING env var is parsed correctly."""
|
||||
monkeypatch.setenv("HERMES_AGENT_TIMEOUT_WARNING", "600")
|
||||
_warning = float(os.getenv("HERMES_AGENT_TIMEOUT_WARNING", 900))
|
||||
assert _warning == 600.0
|
||||
|
||||
def test_warning_zero_means_disabled(self, monkeypatch):
|
||||
"""HERMES_AGENT_TIMEOUT_WARNING=0 disables the warning."""
|
||||
monkeypatch.setenv("HERMES_AGENT_TIMEOUT_WARNING", "0")
|
||||
_raw = float(os.getenv("HERMES_AGENT_TIMEOUT_WARNING", 900))
|
||||
_warning = _raw if _raw > 0 else None
|
||||
assert _warning is None
|
||||
|
||||
def test_unlimited_timeout_no_warning(self):
|
||||
"""When timeout is unlimited (0), no warning fires either."""
|
||||
agent = SlowFakeAgent(
|
||||
run_duration=0.5,
|
||||
idle_after=0.0,
|
||||
)
|
||||
|
||||
_agent_timeout = None
|
||||
_agent_warning = 5.0
|
||||
_POLL_INTERVAL = 0.05
|
||||
|
||||
pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
future = pool.submit(agent.run_conversation, "test")
|
||||
|
||||
result = future.result(timeout=2.0)
|
||||
pool.shutdown(wait=False)
|
||||
|
||||
assert result["final_response"] == "Completed after work"
|
||||
|
||||
|
||||
class TestWarningThresholdBelowTimeout:
|
||||
"""Test that warning threshold must be less than timeout threshold."""
|
||||
|
||||
def test_warning_at_half_timeout(self):
|
||||
"""Warning fires at half the timeout duration."""
|
||||
agent = SlowFakeAgent(
|
||||
run_duration=10.0,
|
||||
idle_after=0.1,
|
||||
activity_desc="receiving stream response",
|
||||
)
|
||||
|
||||
_agent_timeout = 2.0
|
||||
_agent_warning = 1.0
|
||||
_POLL_INTERVAL = 0.05
|
||||
|
||||
pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
future = pool.submit(agent.run_conversation, "test")
|
||||
_warning_fired = False
|
||||
_timeout_fired = False
|
||||
|
||||
while True:
|
||||
done, _ = concurrent.futures.wait({future}, timeout=_POLL_INTERVAL)
|
||||
if done:
|
||||
future.result()
|
||||
break
|
||||
_idle_secs = 0.0
|
||||
if hasattr(agent, "get_activity_summary"):
|
||||
try:
|
||||
_act = agent.get_activity_summary()
|
||||
_idle_secs = _act.get("seconds_since_activity", 0.0)
|
||||
except Exception:
|
||||
pass
|
||||
if (not _warning_fired and _agent_warning > 0
|
||||
and _idle_secs >= _agent_warning):
|
||||
_warning_fired = True
|
||||
if _idle_secs >= _agent_timeout:
|
||||
_timeout_fired = True
|
||||
break
|
||||
|
||||
pool.shutdown(wait=False, cancel_futures=True)
|
||||
assert _warning_fired
|
||||
assert _timeout_fired
|
||||
@@ -1,226 +0,0 @@
|
||||
"""Tests that internal synthetic events (e.g. background process completion)
|
||||
bypass user authorization and do not trigger DM pairing.
|
||||
|
||||
Regression test for the bug where ``_run_process_watcher`` with
|
||||
``notify_on_complete=True`` injected a ``MessageEvent`` without ``user_id``,
|
||||
causing ``_is_user_authorized`` to reject it and the gateway to send a
|
||||
pairing code to the chat.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.run import GatewayRunner
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _FakeRegistry:
|
||||
"""Return pre-canned sessions, then None once exhausted."""
|
||||
|
||||
def __init__(self, sessions):
|
||||
self._sessions = list(sessions)
|
||||
|
||||
def get(self, session_id):
|
||||
if self._sessions:
|
||||
return self._sessions.pop(0)
|
||||
return None
|
||||
|
||||
|
||||
def _build_runner(monkeypatch, tmp_path) -> GatewayRunner:
|
||||
"""Create a GatewayRunner with notifications set to 'all'."""
|
||||
(tmp_path / "config.yaml").write_text(
|
||||
"display:\n background_process_notifications: all\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
import gateway.run as gateway_run
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
|
||||
runner = GatewayRunner(GatewayConfig())
|
||||
adapter = SimpleNamespace(send=AsyncMock(), handle_message=AsyncMock())
|
||||
runner.adapters[Platform.DISCORD] = adapter
|
||||
return runner
|
||||
|
||||
|
||||
def _watcher_dict_with_notify():
|
||||
return {
|
||||
"session_id": "proc_test_internal",
|
||||
"check_interval": 0,
|
||||
"session_key": "agent:main:discord:dm:123",
|
||||
"platform": "discord",
|
||||
"chat_id": "123",
|
||||
"thread_id": "",
|
||||
"notify_on_complete": True,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notify_on_complete_sets_internal_flag(monkeypatch, tmp_path):
|
||||
"""Synthetic completion event must have internal=True."""
|
||||
import tools.process_registry as pr_module
|
||||
|
||||
sessions = [
|
||||
SimpleNamespace(
|
||||
output_buffer="done\n", exited=True, exit_code=0, command="echo test"
|
||||
),
|
||||
]
|
||||
monkeypatch.setattr(pr_module, "process_registry", _FakeRegistry(sessions))
|
||||
|
||||
async def _instant_sleep(*_a, **_kw):
|
||||
pass
|
||||
monkeypatch.setattr(asyncio, "sleep", _instant_sleep)
|
||||
|
||||
runner = _build_runner(monkeypatch, tmp_path)
|
||||
adapter = runner.adapters[Platform.DISCORD]
|
||||
|
||||
await runner._run_process_watcher(_watcher_dict_with_notify())
|
||||
|
||||
assert adapter.handle_message.await_count == 1
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert isinstance(event, MessageEvent)
|
||||
assert event.internal is True, "Synthetic completion event must be marked internal"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_internal_event_bypasses_authorization(monkeypatch, tmp_path):
|
||||
"""An internal event should skip _is_user_authorized entirely."""
|
||||
import gateway.run as gateway_run
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
(tmp_path / "config.yaml").write_text("", encoding="utf-8")
|
||||
|
||||
runner = GatewayRunner(GatewayConfig())
|
||||
|
||||
# Create an internal event with no user_id (simulates the bug scenario)
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="123",
|
||||
chat_type="dm",
|
||||
)
|
||||
event = MessageEvent(
|
||||
text="[SYSTEM: Background process completed]",
|
||||
source=source,
|
||||
internal=True,
|
||||
)
|
||||
|
||||
# Track if _is_user_authorized is called
|
||||
auth_called = False
|
||||
original_auth = GatewayRunner._is_user_authorized
|
||||
|
||||
def tracking_auth(self, src):
|
||||
nonlocal auth_called
|
||||
auth_called = True
|
||||
return original_auth(self, src)
|
||||
|
||||
monkeypatch.setattr(GatewayRunner, "_is_user_authorized", tracking_auth)
|
||||
|
||||
# _handle_message will proceed past auth check and eventually fail on
|
||||
# downstream logic. We just need to verify auth is skipped.
|
||||
try:
|
||||
await runner._handle_message(event)
|
||||
except Exception:
|
||||
pass # Expected — downstream code needs more setup
|
||||
|
||||
assert not auth_called, (
|
||||
"_is_user_authorized should NOT be called for internal events"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_internal_event_does_not_trigger_pairing(monkeypatch, tmp_path):
|
||||
"""An internal event with no user_id must not generate a pairing code."""
|
||||
import gateway.run as gateway_run
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
(tmp_path / "config.yaml").write_text("", encoding="utf-8")
|
||||
|
||||
runner = GatewayRunner(GatewayConfig())
|
||||
# Add adapter so pairing would have somewhere to send
|
||||
adapter = SimpleNamespace(send=AsyncMock())
|
||||
runner.adapters[Platform.DISCORD] = adapter
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="123",
|
||||
chat_type="dm", # DM would normally trigger pairing
|
||||
)
|
||||
event = MessageEvent(
|
||||
text="[SYSTEM: Background process completed]",
|
||||
source=source,
|
||||
internal=True,
|
||||
)
|
||||
|
||||
# Track pairing code generation
|
||||
generate_called = False
|
||||
original_generate = runner.pairing_store.generate_code
|
||||
|
||||
def tracking_generate(*args, **kwargs):
|
||||
nonlocal generate_called
|
||||
generate_called = True
|
||||
return original_generate(*args, **kwargs)
|
||||
|
||||
runner.pairing_store.generate_code = tracking_generate
|
||||
|
||||
try:
|
||||
await runner._handle_message(event)
|
||||
except Exception:
|
||||
pass # Expected — downstream code needs more setup
|
||||
|
||||
assert not generate_called, (
|
||||
"Pairing code should NOT be generated for internal events"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_internal_event_without_user_triggers_pairing(monkeypatch, tmp_path):
|
||||
"""Verify the normal (non-internal) path still triggers pairing for unknown users."""
|
||||
import gateway.run as gateway_run
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
(tmp_path / "config.yaml").write_text("", encoding="utf-8")
|
||||
|
||||
# Clear env vars that could let all users through (loaded by
|
||||
# module-level dotenv in gateway/run.py from the real ~/.hermes/.env).
|
||||
monkeypatch.delenv("DISCORD_ALLOW_ALL_USERS", raising=False)
|
||||
monkeypatch.delenv("DISCORD_ALLOWED_USERS", raising=False)
|
||||
monkeypatch.delenv("GATEWAY_ALLOW_ALL_USERS", raising=False)
|
||||
monkeypatch.delenv("GATEWAY_ALLOWED_USERS", raising=False)
|
||||
|
||||
runner = GatewayRunner(GatewayConfig())
|
||||
adapter = SimpleNamespace(send=AsyncMock())
|
||||
runner.adapters[Platform.DISCORD] = adapter
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="123",
|
||||
chat_type="dm",
|
||||
user_id="unknown_user_999",
|
||||
)
|
||||
# Normal event (not internal)
|
||||
event = MessageEvent(
|
||||
text="hello",
|
||||
source=source,
|
||||
internal=False,
|
||||
)
|
||||
|
||||
result = await runner._handle_message(event)
|
||||
|
||||
# Should return None (unauthorized) and send pairing message
|
||||
assert result is None
|
||||
assert adapter.send.await_count == 1
|
||||
sent_text = adapter.send.await_args.args[1]
|
||||
assert "don't recognize you" in sent_text
|
||||
@@ -38,11 +38,10 @@ def _make_timeout_error() -> httpx.TimeoutException:
|
||||
# cache_image_from_url (base.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@patch("tools.url_safety.is_safe_url", return_value=True)
|
||||
class TestCacheImageFromUrl:
|
||||
"""Tests for gateway.platforms.base.cache_image_from_url"""
|
||||
|
||||
def test_success_on_first_attempt(self, _mock_safe, tmp_path, monkeypatch):
|
||||
def test_success_on_first_attempt(self, tmp_path, monkeypatch):
|
||||
"""A clean 200 response caches the image and returns a path."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
|
||||
@@ -66,7 +65,7 @@ class TestCacheImageFromUrl:
|
||||
assert path.endswith(".jpg")
|
||||
mock_client.get.assert_called_once()
|
||||
|
||||
def test_retries_on_timeout_then_succeeds(self, _mock_safe, tmp_path, monkeypatch):
|
||||
def test_retries_on_timeout_then_succeeds(self, tmp_path, monkeypatch):
|
||||
"""A timeout on the first attempt is retried; second attempt succeeds."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
|
||||
@@ -96,7 +95,7 @@ class TestCacheImageFromUrl:
|
||||
assert mock_client.get.call_count == 2
|
||||
mock_sleep.assert_called_once()
|
||||
|
||||
def test_retries_on_429_then_succeeds(self, _mock_safe, tmp_path, monkeypatch):
|
||||
def test_retries_on_429_then_succeeds(self, tmp_path, monkeypatch):
|
||||
"""A 429 response on the first attempt is retried; second attempt succeeds."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
|
||||
@@ -123,7 +122,7 @@ class TestCacheImageFromUrl:
|
||||
assert path.endswith(".jpg")
|
||||
assert mock_client.get.call_count == 2
|
||||
|
||||
def test_raises_after_max_retries_exhausted(self, _mock_safe, tmp_path, monkeypatch):
|
||||
def test_raises_after_max_retries_exhausted(self, tmp_path, monkeypatch):
|
||||
"""Timeout on every attempt raises after all retries are consumed."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
|
||||
@@ -146,7 +145,7 @@ class TestCacheImageFromUrl:
|
||||
# 3 total calls: initial + 2 retries
|
||||
assert mock_client.get.call_count == 3
|
||||
|
||||
def test_non_retryable_4xx_raises_immediately(self, _mock_safe, tmp_path, monkeypatch):
|
||||
def test_non_retryable_4xx_raises_immediately(self, tmp_path, monkeypatch):
|
||||
"""A 404 (non-retryable) is raised immediately without any retry."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
|
||||
@@ -176,11 +175,10 @@ class TestCacheImageFromUrl:
|
||||
# cache_audio_from_url (base.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@patch("tools.url_safety.is_safe_url", return_value=True)
|
||||
class TestCacheAudioFromUrl:
|
||||
"""Tests for gateway.platforms.base.cache_audio_from_url"""
|
||||
|
||||
def test_success_on_first_attempt(self, _mock_safe, tmp_path, monkeypatch):
|
||||
def test_success_on_first_attempt(self, tmp_path, monkeypatch):
|
||||
"""A clean 200 response caches the audio and returns a path."""
|
||||
monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio")
|
||||
|
||||
@@ -204,7 +202,7 @@ class TestCacheAudioFromUrl:
|
||||
assert path.endswith(".ogg")
|
||||
mock_client.get.assert_called_once()
|
||||
|
||||
def test_retries_on_timeout_then_succeeds(self, _mock_safe, tmp_path, monkeypatch):
|
||||
def test_retries_on_timeout_then_succeeds(self, tmp_path, monkeypatch):
|
||||
"""A timeout on the first attempt is retried; second attempt succeeds."""
|
||||
monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio")
|
||||
|
||||
@@ -234,7 +232,7 @@ class TestCacheAudioFromUrl:
|
||||
assert mock_client.get.call_count == 2
|
||||
mock_sleep.assert_called_once()
|
||||
|
||||
def test_retries_on_429_then_succeeds(self, _mock_safe, tmp_path, monkeypatch):
|
||||
def test_retries_on_429_then_succeeds(self, tmp_path, monkeypatch):
|
||||
"""A 429 response on the first attempt is retried; second attempt succeeds."""
|
||||
monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio")
|
||||
|
||||
@@ -261,7 +259,7 @@ class TestCacheAudioFromUrl:
|
||||
assert path.endswith(".ogg")
|
||||
assert mock_client.get.call_count == 2
|
||||
|
||||
def test_retries_on_500_then_succeeds(self, _mock_safe, tmp_path, monkeypatch):
|
||||
def test_retries_on_500_then_succeeds(self, tmp_path, monkeypatch):
|
||||
"""A 500 response on the first attempt is retried; second attempt succeeds."""
|
||||
monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio")
|
||||
|
||||
@@ -288,7 +286,7 @@ class TestCacheAudioFromUrl:
|
||||
assert path.endswith(".ogg")
|
||||
assert mock_client.get.call_count == 2
|
||||
|
||||
def test_raises_after_max_retries_exhausted(self, _mock_safe, tmp_path, monkeypatch):
|
||||
def test_raises_after_max_retries_exhausted(self, tmp_path, monkeypatch):
|
||||
"""Timeout on every attempt raises after all retries are consumed."""
|
||||
monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio")
|
||||
|
||||
@@ -311,7 +309,7 @@ class TestCacheAudioFromUrl:
|
||||
# 3 total calls: initial + 2 retries
|
||||
assert mock_client.get.call_count == 3
|
||||
|
||||
def test_non_retryable_4xx_raises_immediately(self, _mock_safe, tmp_path, monkeypatch):
|
||||
def test_non_retryable_4xx_raises_immediately(self, tmp_path, monkeypatch):
|
||||
"""A 404 (non-retryable) is raised immediately without any retry."""
|
||||
monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio")
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user