Compare commits
80 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 46f7b38bb8 | |||
| 93f6f66872 | |||
| a418ddbd8b | |||
| 0d25e1c146 | |||
| 6391b46779 | |||
| d1d425e9d0 | |||
| 7cb06e3bb3 | |||
| 8275fa597a | |||
| 7856d304f2 | |||
| f3ec4b3a16 | |||
| 5082a9f66c | |||
| 0c30385be2 | |||
| 8b167af66b | |||
| 990030c26e | |||
| d2f85383e8 | |||
| 2dc5f9d2d3 | |||
| f61cc464f0 | |||
| 2276b72141 | |||
| dee592a0b1 | |||
| da448d4fce | |||
| aa398ad655 | |||
| 422f2866e6 | |||
| 722331a57d | |||
| 41e2d61b3f | |||
| 4da598b48a | |||
| 33ae403890 | |||
| 47e6ea84bb | |||
| 4bcb2f2d26 | |||
| 1c4d3216d3 | |||
| dedc4600dd | |||
| 8bc9b5a0b4 | |||
| 2546b7acea | |||
| 7b2700c9af | |||
| a4e1842f12 | |||
| e69526be79 | |||
| 180b14442f | |||
| 03446e06bb | |||
| df7be3d8ae | |||
| da8bab77fb | |||
| 9932366f3c | |||
| 029938fbed | |||
| 772cfb6c4e | |||
| 5d5d21556e | |||
| 9855190f23 | |||
| 50c35dcabe | |||
| 93fe4ead83 | |||
| a8b7db35b2 | |||
| 8548893d14 | |||
| c5688e7c8b | |||
| ba24f058ed | |||
| ef04de3e98 | |||
| fc6cb5b970 | |||
| 4b2a1a4337 | |||
| 2871ef1807 | |||
| 5cbb45d93e | |||
| ca0ae56ccb | |||
| 23b87c8ca8 | |||
| 92385679b6 | |||
| 82f364ffd1 | |||
| 31d0620663 | |||
| cf1d718823 | |||
| 302554b158 | |||
| d6c09ab94a | |||
| da528a8207 | |||
| 677f1227c3 | |||
| 4610551d74 | |||
| 498cb7a0fc | |||
| c10fea8d26 | |||
| cda64a5961 | |||
| 2a98098035 | |||
| 6c89306437 | |||
| 847d7cbea5 | |||
| a9c78d0eb0 | |||
| e7475b1582 | |||
| ac1f8fcccd | |||
| 56c34ac4f7 | |||
| 3ca7417c2a | |||
| cfa24532d3 | |||
| b24e5ee4b0 | |||
| 3b50821555 |
@@ -145,6 +145,10 @@
|
||||
# Only override here if you need to force a backend without touching config.yaml:
|
||||
# TERMINAL_ENV=local
|
||||
|
||||
# Override the container runtime binary (e.g. to use Podman instead of Docker).
|
||||
# Useful on systems where Docker's storage driver is broken or unavailable.
|
||||
# HERMES_DOCKER_BINARY=/usr/local/bin/podman
|
||||
|
||||
# Container images (for singularity/docker/modal backends)
|
||||
# TERMINAL_DOCKER_IMAGE=nikolaik/python-nodejs:python3.11-nodejs20
|
||||
# TERMINAL_SINGULARITY_IMAGE=docker://nikolaik/python-nodejs:python3.11-nodejs20
|
||||
|
||||
@@ -13,7 +13,7 @@ source venv/bin/activate # ALWAYS activate before running Python
|
||||
```
|
||||
hermes-agent/
|
||||
├── run_agent.py # AIAgent class — core conversation loop
|
||||
├── model_tools.py # Tool orchestration, _discover_tools(), handle_function_call()
|
||||
├── model_tools.py # Tool orchestration, discover_builtin_tools(), handle_function_call()
|
||||
├── toolsets.py # Toolset definitions, _HERMES_CORE_TOOLS list
|
||||
├── cli.py # HermesCLI class — interactive CLI orchestrator
|
||||
├── hermes_state.py # SessionDB — SQLite session store (FTS5 search)
|
||||
@@ -181,7 +181,7 @@ if canonical == "mycommand":
|
||||
|
||||
## Adding New Tools
|
||||
|
||||
Requires changes in **3 files**:
|
||||
Requires changes in **2 files**:
|
||||
|
||||
**1. Create `tools/your_tool.py`:**
|
||||
```python
|
||||
@@ -204,9 +204,9 @@ registry.register(
|
||||
)
|
||||
```
|
||||
|
||||
**2. Add import** in `model_tools.py` `_discover_tools()` list.
|
||||
**2. Add to `toolsets.py`** — either `_HERMES_CORE_TOOLS` (all platforms) or a new toolset.
|
||||
|
||||
**3. Add to `toolsets.py`** — either `_HERMES_CORE_TOOLS` (all platforms) or a new toolset.
|
||||
Auto-discovery: any `tools/*.py` file with a top-level `registry.register()` call is imported automatically — no manual import list to maintain.
|
||||
|
||||
The registry handles schema collection, dispatch, availability checking, and error wrapping. All handlers MUST return a JSON string.
|
||||
|
||||
|
||||
+40
-18
@@ -1835,9 +1835,15 @@ def auxiliary_max_tokens_param(value: int) -> dict:
|
||||
# Every auxiliary LLM consumer should use these instead of manually
|
||||
# constructing clients and calling .chat.completions.create().
|
||||
|
||||
# Client cache: (provider, async_mode, base_url, api_key) -> (client, default_model)
|
||||
# Client cache: (provider, async_mode, base_url, api_key, api_mode, runtime_key) -> (client, default_model, loop)
|
||||
# NOTE: loop identity is NOT part of the key. On async cache hits we check
|
||||
# whether the cached loop is the *current* loop; if not, the stale entry is
|
||||
# replaced in-place. This bounds cache growth to one entry per unique
|
||||
# provider config rather than one per (config × event-loop), which previously
|
||||
# caused unbounded fd accumulation in long-running gateway processes (#10200).
|
||||
_client_cache: Dict[tuple, tuple] = {}
|
||||
_client_cache_lock = threading.Lock()
|
||||
_CLIENT_CACHE_MAX_SIZE = 64 # safety belt — evict oldest when exceeded
|
||||
|
||||
|
||||
def neuter_async_httpx_del() -> None:
|
||||
@@ -1970,39 +1976,49 @@ def _get_cached_client(
|
||||
Async clients (AsyncOpenAI) use httpx.AsyncClient internally, which
|
||||
binds to the event loop that was current when the client was created.
|
||||
Using such a client on a *different* loop causes deadlocks or
|
||||
RuntimeError. To prevent cross-loop issues (especially in gateway
|
||||
mode where _run_async() may spawn fresh loops in worker threads), the
|
||||
cache key for async clients includes the current event loop's identity
|
||||
so each loop gets its own client instance.
|
||||
RuntimeError. To prevent cross-loop issues, the cache validates on
|
||||
every async hit that the cached loop is the *current, open* loop.
|
||||
If the loop changed (e.g. a new gateway worker-thread loop), the stale
|
||||
entry is replaced in-place rather than creating an additional entry.
|
||||
|
||||
This keeps cache size bounded to one entry per unique provider config,
|
||||
preventing the fd-exhaustion that previously occurred in long-running
|
||||
gateways where recycled worker threads created unbounded entries (#10200).
|
||||
"""
|
||||
# Include loop identity for async clients to prevent cross-loop reuse.
|
||||
# httpx.AsyncClient (inside AsyncOpenAI) is bound to the loop where it
|
||||
# was created — reusing it on a different loop causes deadlocks (#2681).
|
||||
loop_id = 0
|
||||
# Resolve the current event loop for async clients so we can validate
|
||||
# cached entries. Loop identity is NOT in the cache key — instead we
|
||||
# check at hit time whether the cached loop is still current and open.
|
||||
# This prevents unbounded cache growth from recycled worker-thread loops
|
||||
# while still guaranteeing we never reuse a client on the wrong loop
|
||||
# (which causes deadlocks, see #2681).
|
||||
current_loop = None
|
||||
if async_mode:
|
||||
try:
|
||||
import asyncio as _aio
|
||||
current_loop = _aio.get_event_loop()
|
||||
loop_id = id(current_loop)
|
||||
except RuntimeError:
|
||||
pass
|
||||
runtime = _normalize_main_runtime(main_runtime)
|
||||
runtime_key = tuple(runtime.get(field, "") for field in _MAIN_RUNTIME_FIELDS) if provider == "auto" else ()
|
||||
cache_key = (provider, async_mode, base_url or "", api_key or "", api_mode or "", loop_id, runtime_key)
|
||||
cache_key = (provider, async_mode, base_url or "", api_key or "", api_mode or "", runtime_key)
|
||||
with _client_cache_lock:
|
||||
if cache_key in _client_cache:
|
||||
cached_client, cached_default, cached_loop = _client_cache[cache_key]
|
||||
if async_mode:
|
||||
# A cached async client whose loop has been closed will raise
|
||||
# "Event loop is closed" when httpx tries to clean up its
|
||||
# transport. Discard the stale client and create a fresh one.
|
||||
if cached_loop is not None and cached_loop.is_closed():
|
||||
_force_close_async_httpx(cached_client)
|
||||
del _client_cache[cache_key]
|
||||
else:
|
||||
# Validate: the cached client must be bound to the CURRENT,
|
||||
# OPEN loop. If the loop changed or was closed, the httpx
|
||||
# transport inside is dead — force-close and replace.
|
||||
loop_ok = (
|
||||
cached_loop is not None
|
||||
and cached_loop is current_loop
|
||||
and not cached_loop.is_closed()
|
||||
)
|
||||
if loop_ok:
|
||||
effective = _compat_model(cached_client, model, cached_default)
|
||||
return cached_client, effective
|
||||
# Stale — evict and fall through to create a new client.
|
||||
_force_close_async_httpx(cached_client)
|
||||
del _client_cache[cache_key]
|
||||
else:
|
||||
effective = _compat_model(cached_client, model, cached_default)
|
||||
return cached_client, effective
|
||||
@@ -2022,6 +2038,12 @@ def _get_cached_client(
|
||||
bound_loop = current_loop
|
||||
with _client_cache_lock:
|
||||
if cache_key not in _client_cache:
|
||||
# Safety belt: if the cache has grown beyond the max, evict
|
||||
# the oldest entries (FIFO — dict preserves insertion order).
|
||||
while len(_client_cache) >= _CLIENT_CACHE_MAX_SIZE:
|
||||
evict_key, evict_entry = next(iter(_client_cache.items()))
|
||||
_force_close_async_httpx(evict_entry[0])
|
||||
del _client_cache[evict_key]
|
||||
_client_cache[cache_key] = (client, default_model, bound_loop)
|
||||
else:
|
||||
client, default_model, _ = _client_cache[cache_key]
|
||||
|
||||
+307
-36
@@ -17,7 +17,10 @@ Improvements over v2:
|
||||
- Richer tool call/result detail in summarizer input
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
@@ -57,6 +60,128 @@ _CHARS_PER_TOKEN = 4
|
||||
_SUMMARY_FAILURE_COOLDOWN_SECONDS = 600
|
||||
|
||||
|
||||
def _summarize_tool_result(tool_name: str, tool_args: str, tool_content: str) -> str:
|
||||
"""Create an informative 1-line summary of a tool call + result.
|
||||
|
||||
Used during the pre-compression pruning pass to replace large tool
|
||||
outputs with a short but useful description of what the tool did,
|
||||
rather than a generic placeholder that carries zero information.
|
||||
|
||||
Returns strings like::
|
||||
|
||||
[terminal] ran `npm test` -> exit 0, 47 lines output
|
||||
[read_file] read config.py from line 1 (1,200 chars)
|
||||
[search_files] content search for 'compress' in agent/ -> 12 matches
|
||||
"""
|
||||
try:
|
||||
args = json.loads(tool_args) if tool_args else {}
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
args = {}
|
||||
|
||||
content = tool_content or ""
|
||||
content_len = len(content)
|
||||
line_count = content.count("\n") + 1 if content.strip() else 0
|
||||
|
||||
if tool_name == "terminal":
|
||||
cmd = args.get("command", "")
|
||||
if len(cmd) > 80:
|
||||
cmd = cmd[:77] + "..."
|
||||
exit_match = re.search(r'"exit_code"\s*:\s*(-?\d+)', content)
|
||||
exit_code = exit_match.group(1) if exit_match else "?"
|
||||
return f"[terminal] ran `{cmd}` -> exit {exit_code}, {line_count} lines output"
|
||||
|
||||
if tool_name == "read_file":
|
||||
path = args.get("path", "?")
|
||||
offset = args.get("offset", 1)
|
||||
return f"[read_file] read {path} from line {offset} ({content_len:,} chars)"
|
||||
|
||||
if tool_name == "write_file":
|
||||
path = args.get("path", "?")
|
||||
written_lines = args.get("content", "").count("\n") + 1 if args.get("content") else "?"
|
||||
return f"[write_file] wrote to {path} ({written_lines} lines)"
|
||||
|
||||
if tool_name == "search_files":
|
||||
pattern = args.get("pattern", "?")
|
||||
path = args.get("path", ".")
|
||||
target = args.get("target", "content")
|
||||
match_count = re.search(r'"total_count"\s*:\s*(\d+)', content)
|
||||
count = match_count.group(1) if match_count else "?"
|
||||
return f"[search_files] {target} search for '{pattern}' in {path} -> {count} matches"
|
||||
|
||||
if tool_name == "patch":
|
||||
path = args.get("path", "?")
|
||||
mode = args.get("mode", "replace")
|
||||
return f"[patch] {mode} in {path} ({content_len:,} chars result)"
|
||||
|
||||
if tool_name in ("browser_navigate", "browser_click", "browser_snapshot",
|
||||
"browser_type", "browser_scroll", "browser_vision"):
|
||||
url = args.get("url", "")
|
||||
ref = args.get("ref", "")
|
||||
detail = f" {url}" if url else (f" ref={ref}" if ref else "")
|
||||
return f"[{tool_name}]{detail} ({content_len:,} chars)"
|
||||
|
||||
if tool_name == "web_search":
|
||||
query = args.get("query", "?")
|
||||
return f"[web_search] query='{query}' ({content_len:,} chars result)"
|
||||
|
||||
if tool_name == "web_extract":
|
||||
urls = args.get("urls", [])
|
||||
url_desc = urls[0] if isinstance(urls, list) and urls else "?"
|
||||
if isinstance(urls, list) and len(urls) > 1:
|
||||
url_desc += f" (+{len(urls) - 1} more)"
|
||||
return f"[web_extract] {url_desc} ({content_len:,} chars)"
|
||||
|
||||
if tool_name == "delegate_task":
|
||||
goal = args.get("goal", "")
|
||||
if len(goal) > 60:
|
||||
goal = goal[:57] + "..."
|
||||
return f"[delegate_task] '{goal}' ({content_len:,} chars result)"
|
||||
|
||||
if tool_name == "execute_code":
|
||||
code_preview = (args.get("code") or "")[:60].replace("\n", " ")
|
||||
if len(args.get("code", "")) > 60:
|
||||
code_preview += "..."
|
||||
return f"[execute_code] `{code_preview}` ({line_count} lines output)"
|
||||
|
||||
if tool_name in ("skill_view", "skills_list", "skill_manage"):
|
||||
name = args.get("name", "?")
|
||||
return f"[{tool_name}] name={name} ({content_len:,} chars)"
|
||||
|
||||
if tool_name == "vision_analyze":
|
||||
question = args.get("question", "")[:50]
|
||||
return f"[vision_analyze] '{question}' ({content_len:,} chars)"
|
||||
|
||||
if tool_name == "memory":
|
||||
action = args.get("action", "?")
|
||||
target = args.get("target", "?")
|
||||
return f"[memory] {action} on {target}"
|
||||
|
||||
if tool_name == "todo":
|
||||
return "[todo] updated task list"
|
||||
|
||||
if tool_name == "clarify":
|
||||
return "[clarify] asked user a question"
|
||||
|
||||
if tool_name == "text_to_speech":
|
||||
return f"[text_to_speech] generated audio ({content_len:,} chars)"
|
||||
|
||||
if tool_name == "cronjob":
|
||||
action = args.get("action", "?")
|
||||
return f"[cronjob] {action}"
|
||||
|
||||
if tool_name == "process":
|
||||
action = args.get("action", "?")
|
||||
sid = args.get("session_id", "?")
|
||||
return f"[process] {action} session={sid}"
|
||||
|
||||
# Generic fallback
|
||||
first_arg = ""
|
||||
for k, v in list(args.items())[:2]:
|
||||
sv = str(v)[:40]
|
||||
first_arg += f" {k}={sv}"
|
||||
return f"[{tool_name}]{first_arg} ({content_len:,} chars result)"
|
||||
|
||||
|
||||
class ContextCompressor(ContextEngine):
|
||||
"""Default context engine — compresses conversation context via lossy summarization.
|
||||
|
||||
@@ -78,6 +203,8 @@ class ContextCompressor(ContextEngine):
|
||||
self._context_probed = False
|
||||
self._context_probe_persistable = False
|
||||
self._previous_summary = None
|
||||
self._last_compression_savings_pct = 100.0
|
||||
self._ineffective_compression_count = 0
|
||||
|
||||
def update_model(
|
||||
self,
|
||||
@@ -167,6 +294,9 @@ class ContextCompressor(ContextEngine):
|
||||
|
||||
# Stores the previous compaction summary for iterative updates
|
||||
self._previous_summary: Optional[str] = None
|
||||
# Anti-thrashing: track whether last compression was effective
|
||||
self._last_compression_savings_pct: float = 100.0
|
||||
self._ineffective_compression_count: int = 0
|
||||
self._summary_failure_cooldown_until: float = 0.0
|
||||
|
||||
def update_from_response(self, usage: Dict[str, Any]):
|
||||
@@ -175,9 +305,26 @@ class ContextCompressor(ContextEngine):
|
||||
self.last_completion_tokens = usage.get("completion_tokens", 0)
|
||||
|
||||
def should_compress(self, prompt_tokens: int = None) -> bool:
|
||||
"""Check if context exceeds the compression threshold."""
|
||||
"""Check if context exceeds the compression threshold.
|
||||
|
||||
Includes anti-thrashing protection: if the last two compressions
|
||||
each saved less than 10%, skip compression to avoid infinite loops
|
||||
where each pass removes only 1-2 messages.
|
||||
"""
|
||||
tokens = prompt_tokens if prompt_tokens is not None else self.last_prompt_tokens
|
||||
return tokens >= self.threshold_tokens
|
||||
if tokens < self.threshold_tokens:
|
||||
return False
|
||||
# Anti-thrashing: back off if recent compressions were ineffective
|
||||
if self._ineffective_compression_count >= 2:
|
||||
if not self.quiet_mode:
|
||||
logger.warning(
|
||||
"Compression skipped — last %d compressions saved <10%% each. "
|
||||
"Consider /new to start a fresh session, or /compress <topic> "
|
||||
"for focused compression.",
|
||||
self._ineffective_compression_count,
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool output pruning (cheap pre-pass, no LLM call)
|
||||
@@ -187,7 +334,16 @@ class ContextCompressor(ContextEngine):
|
||||
self, messages: List[Dict[str, Any]], protect_tail_count: int,
|
||||
protect_tail_tokens: int | None = None,
|
||||
) -> tuple[List[Dict[str, Any]], int]:
|
||||
"""Replace old tool result contents with a short placeholder.
|
||||
"""Replace old tool result contents with informative 1-line summaries.
|
||||
|
||||
Instead of a generic placeholder, generates a summary like::
|
||||
|
||||
[terminal] ran `npm test` -> exit 0, 47 lines output
|
||||
[read_file] read config.py from line 1 (3,400 chars)
|
||||
|
||||
Also deduplicates identical tool results (e.g. reading the same file
|
||||
5x keeps only the newest full copy) and truncates large tool_call
|
||||
arguments in assistant messages outside the protected tail.
|
||||
|
||||
Walks backward from the end, protecting the most recent messages that
|
||||
fall within ``protect_tail_tokens`` (when provided) OR the last
|
||||
@@ -203,6 +359,22 @@ class ContextCompressor(ContextEngine):
|
||||
result = [m.copy() for m in messages]
|
||||
pruned = 0
|
||||
|
||||
# Build index: tool_call_id -> (tool_name, arguments_json)
|
||||
call_id_to_tool: Dict[str, tuple] = {}
|
||||
for msg in result:
|
||||
if msg.get("role") == "assistant":
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
if isinstance(tc, dict):
|
||||
cid = tc.get("id", "")
|
||||
fn = tc.get("function", {})
|
||||
call_id_to_tool[cid] = (fn.get("name", "unknown"), fn.get("arguments", ""))
|
||||
else:
|
||||
cid = getattr(tc, "id", "") or ""
|
||||
fn = getattr(tc, "function", None)
|
||||
name = getattr(fn, "name", "unknown") if fn else "unknown"
|
||||
args_str = getattr(fn, "arguments", "") if fn else ""
|
||||
call_id_to_tool[cid] = (name, args_str)
|
||||
|
||||
# Determine the prune boundary
|
||||
if protect_tail_tokens is not None and protect_tail_tokens > 0:
|
||||
# Token-budget approach: walk backward accumulating tokens
|
||||
@@ -211,7 +383,8 @@ class ContextCompressor(ContextEngine):
|
||||
min_protect = min(protect_tail_count, len(result) - 1)
|
||||
for i in range(len(result) - 1, -1, -1):
|
||||
msg = result[i]
|
||||
content_len = len(msg.get("content") or "")
|
||||
raw_content = msg.get("content") or ""
|
||||
content_len = sum(len(p.get("text", "")) for p in raw_content) if isinstance(raw_content, list) else len(raw_content)
|
||||
msg_tokens = content_len // _CHARS_PER_TOKEN + 10
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
if isinstance(tc, dict):
|
||||
@@ -226,18 +399,69 @@ class ContextCompressor(ContextEngine):
|
||||
else:
|
||||
prune_boundary = len(result) - protect_tail_count
|
||||
|
||||
# Pass 1: Deduplicate identical tool results.
|
||||
# When the same file is read multiple times, keep only the most recent
|
||||
# full copy and replace older duplicates with a back-reference.
|
||||
content_hashes: dict = {} # hash -> (index, tool_call_id)
|
||||
for i in range(len(result) - 1, -1, -1):
|
||||
msg = result[i]
|
||||
if msg.get("role") != "tool":
|
||||
continue
|
||||
content = msg.get("content") or ""
|
||||
# Skip multimodal content (list of content blocks)
|
||||
if isinstance(content, list):
|
||||
continue
|
||||
if len(content) < 200:
|
||||
continue
|
||||
h = hashlib.md5(content.encode("utf-8", errors="replace")).hexdigest()[:12]
|
||||
if h in content_hashes:
|
||||
# This is an older duplicate — replace with back-reference
|
||||
result[i] = {**msg, "content": "[Duplicate tool output — same content as a more recent call]"}
|
||||
pruned += 1
|
||||
else:
|
||||
content_hashes[h] = (i, msg.get("tool_call_id", "?"))
|
||||
|
||||
# Pass 2: Replace old tool results with informative summaries
|
||||
for i in range(prune_boundary):
|
||||
msg = result[i]
|
||||
if msg.get("role") != "tool":
|
||||
continue
|
||||
content = msg.get("content", "")
|
||||
# Skip multimodal content (list of content blocks)
|
||||
if isinstance(content, list):
|
||||
continue
|
||||
if not content or content == _PRUNED_TOOL_PLACEHOLDER:
|
||||
continue
|
||||
# Skip already-deduplicated or previously-summarized results
|
||||
if content.startswith("[Duplicate tool output"):
|
||||
continue
|
||||
# Only prune if the content is substantial (>200 chars)
|
||||
if len(content) > 200:
|
||||
result[i] = {**msg, "content": _PRUNED_TOOL_PLACEHOLDER}
|
||||
call_id = msg.get("tool_call_id", "")
|
||||
tool_name, tool_args = call_id_to_tool.get(call_id, ("unknown", ""))
|
||||
summary = _summarize_tool_result(tool_name, tool_args, content)
|
||||
result[i] = {**msg, "content": summary}
|
||||
pruned += 1
|
||||
|
||||
# Pass 3: Truncate large tool_call arguments in assistant messages
|
||||
# outside the protected tail. write_file with 50KB content, for
|
||||
# example, survives pruning entirely without this.
|
||||
for i in range(prune_boundary):
|
||||
msg = result[i]
|
||||
if msg.get("role") != "assistant" or not msg.get("tool_calls"):
|
||||
continue
|
||||
new_tcs = []
|
||||
modified = False
|
||||
for tc in msg["tool_calls"]:
|
||||
if isinstance(tc, dict):
|
||||
args = tc.get("function", {}).get("arguments", "")
|
||||
if len(args) > 500:
|
||||
tc = {**tc, "function": {**tc["function"], "arguments": args[:200] + "...[truncated]"}}
|
||||
modified = True
|
||||
new_tcs.append(tc)
|
||||
if modified:
|
||||
result[i] = {**msg, "tool_calls": new_tcs}
|
||||
|
||||
return result, pruned
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@@ -357,29 +581,37 @@ class ContextCompressor(ContextEngine):
|
||||
)
|
||||
|
||||
# Shared structured template (used by both paths).
|
||||
# Key changes vs v1:
|
||||
# - "Pending User Asks" section (from Claude Code) explicitly tracks
|
||||
# unanswered questions so the model knows what's resolved vs open
|
||||
# - "Remaining Work" replaces "Next Steps" to avoid reading as active
|
||||
# instructions
|
||||
# - "Resolved Questions" makes it clear which questions were already
|
||||
# answered (prevents model from re-answering them)
|
||||
_template_sections = f"""## Goal
|
||||
[What the user is trying to accomplish]
|
||||
|
||||
## Constraints & Preferences
|
||||
[User preferences, coding style, constraints, important decisions]
|
||||
|
||||
## Progress
|
||||
### Done
|
||||
[Completed work — include specific file paths, commands run, results obtained]
|
||||
### In Progress
|
||||
[Work currently underway]
|
||||
### Blocked
|
||||
[Any blockers or issues encountered]
|
||||
## Completed Actions
|
||||
[Numbered list of concrete actions taken — include tool used, target, and outcome.
|
||||
Format each as: N. ACTION target — outcome [tool: name]
|
||||
Example:
|
||||
1. READ config.py:45 — found `==` should be `!=` [tool: read_file]
|
||||
2. PATCH config.py:45 — changed `==` to `!=` [tool: patch]
|
||||
3. TEST `pytest tests/` — 3/50 failed: test_parse, test_validate, test_edge [tool: terminal]
|
||||
Be specific with file paths, commands, line numbers, and results.]
|
||||
|
||||
## Active State
|
||||
[Current working state — include:
|
||||
- Working directory and branch (if applicable)
|
||||
- Modified/created files with brief note on each
|
||||
- Test status (X/Y passing)
|
||||
- Any running processes or servers
|
||||
- Environment details that matter]
|
||||
|
||||
## In Progress
|
||||
[Work currently underway — what was being done when compaction fired]
|
||||
|
||||
## Blocked
|
||||
[Any blockers, errors, or issues not yet resolved. Include exact error messages.]
|
||||
|
||||
## Key Decisions
|
||||
[Important technical decisions and why they were made]
|
||||
[Important technical decisions and WHY they were made]
|
||||
|
||||
## Resolved Questions
|
||||
[Questions the user asked that were ALREADY answered — include the answer so the next assistant does not re-answer them]
|
||||
@@ -396,10 +628,7 @@ class ContextCompressor(ContextEngine):
|
||||
## Critical Context
|
||||
[Any specific values, error messages, configuration details, or data that would be lost without explicit preservation]
|
||||
|
||||
## Tools & Patterns
|
||||
[Which tools were used, how they were used effectively, and any tool-specific discoveries]
|
||||
|
||||
Target ~{summary_budget} tokens. Be specific — include file paths, command outputs, error messages, and concrete values rather than vague descriptions.
|
||||
Target ~{summary_budget} tokens. Be CONCRETE — include file paths, command outputs, error messages, line numbers, and specific values. Avoid vague descriptions like "made some changes" — say exactly what changed.
|
||||
|
||||
Write only the summary body. Do not include any preamble or prefix."""
|
||||
|
||||
@@ -415,7 +644,7 @@ PREVIOUS SUMMARY:
|
||||
NEW TURNS TO INCORPORATE:
|
||||
{content_to_summarize}
|
||||
|
||||
Update the summary using this exact structure. PRESERVE all existing information that is still relevant. ADD new progress. Move items from "In Progress" to "Done" when completed. Move answered questions to "Resolved Questions". Remove information only if it is clearly obsolete.
|
||||
Update the summary using this exact structure. PRESERVE all existing information that is still relevant. ADD new completed actions to the numbered list (continue numbering). Move items from "In Progress" to "Completed Actions" when done. Move answered questions to "Resolved Questions". Update "Active State" to reflect current state. Remove information only if it is clearly obsolete.
|
||||
|
||||
{_template_sections}"""
|
||||
else:
|
||||
@@ -450,7 +679,7 @@ The user has requested that this compaction PRIORITISE preserving all informatio
|
||||
"api_mode": self.api_mode,
|
||||
},
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": summary_budget * 2,
|
||||
"max_tokens": int(summary_budget * 1.3),
|
||||
# timeout resolved from auxiliary.compression.timeout config by call_llm
|
||||
}
|
||||
if self.summary_model:
|
||||
@@ -464,8 +693,10 @@ The user has requested that this compaction PRIORITISE preserving all informatio
|
||||
# Store for iterative updates on next compaction
|
||||
self._previous_summary = summary
|
||||
self._summary_failure_cooldown_until = 0.0
|
||||
self._summary_model_fallen_back = False
|
||||
return self._with_summary_prefix(summary)
|
||||
except RuntimeError:
|
||||
# No provider configured — long cooldown, unlikely to self-resolve
|
||||
self._summary_failure_cooldown_until = time.monotonic() + _SUMMARY_FAILURE_COOLDOWN_SECONDS
|
||||
logging.warning("Context compression: no provider available for "
|
||||
"summary. Middle turns will be dropped without summary "
|
||||
@@ -473,12 +704,42 @@ The user has requested that this compaction PRIORITISE preserving all informatio
|
||||
_SUMMARY_FAILURE_COOLDOWN_SECONDS)
|
||||
return None
|
||||
except Exception as e:
|
||||
self._summary_failure_cooldown_until = time.monotonic() + _SUMMARY_FAILURE_COOLDOWN_SECONDS
|
||||
# If the summary model is different from the main model and the
|
||||
# error looks permanent (model not found, 503, 404), fall back to
|
||||
# using the main model instead of entering cooldown that leaves
|
||||
# context growing unbounded. (#8620 sub-issue 4)
|
||||
_status = getattr(e, "status_code", None) or getattr(getattr(e, "response", None), "status_code", None)
|
||||
_err_str = str(e).lower()
|
||||
_is_model_not_found = (
|
||||
_status in (404, 503)
|
||||
or "model_not_found" in _err_str
|
||||
or "does not exist" in _err_str
|
||||
or "no available channel" in _err_str
|
||||
)
|
||||
if (
|
||||
_is_model_not_found
|
||||
and self.summary_model
|
||||
and self.summary_model != self.model
|
||||
and not getattr(self, "_summary_model_fallen_back", False)
|
||||
):
|
||||
self._summary_model_fallen_back = True
|
||||
logging.warning(
|
||||
"Summary model '%s' not available (%s). "
|
||||
"Falling back to main model '%s' for compression.",
|
||||
self.summary_model, e, self.model,
|
||||
)
|
||||
self.summary_model = "" # empty = use main model
|
||||
self._summary_failure_cooldown_until = 0.0 # no cooldown
|
||||
return self._generate_summary(messages, summary_budget) # retry immediately
|
||||
|
||||
# Transient errors (timeout, rate limit, network) — shorter cooldown
|
||||
_transient_cooldown = 60
|
||||
self._summary_failure_cooldown_until = time.monotonic() + _transient_cooldown
|
||||
logging.warning(
|
||||
"Failed to generate context summary: %s. "
|
||||
"Further summary attempts paused for %d seconds.",
|
||||
e,
|
||||
_SUMMARY_FAILURE_COOLDOWN_SECONDS,
|
||||
_transient_cooldown,
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -744,11 +1005,11 @@ The user has requested that this compaction PRIORITISE preserving all informatio
|
||||
compressed = []
|
||||
for i in range(compress_start):
|
||||
msg = messages[i].copy()
|
||||
if i == 0 and msg.get("role") == "system" and self.compression_count == 0:
|
||||
msg["content"] = (
|
||||
(msg.get("content") or "")
|
||||
+ "\n\n[Note: Some earlier conversation turns have been compacted into a handoff summary to preserve context space. The current session state may still reflect earlier work, so build on that summary and state rather than re-doing work.]"
|
||||
)
|
||||
if i == 0 and msg.get("role") == "system":
|
||||
existing = msg.get("content") or ""
|
||||
_compression_note = "[Note: Some earlier conversation turns have been compacted into a handoff summary to preserve context space. The current session state may still reflect earlier work, so build on that summary and state rather than re-doing work.]"
|
||||
if _compression_note not in existing:
|
||||
msg["content"] = existing + "\n\n" + _compression_note
|
||||
compressed.append(msg)
|
||||
|
||||
# If LLM summary failed, insert a static fallback so the model
|
||||
@@ -806,14 +1067,24 @@ The user has requested that this compaction PRIORITISE preserving all informatio
|
||||
|
||||
compressed = self._sanitize_tool_pairs(compressed)
|
||||
|
||||
new_estimate = estimate_messages_tokens_rough(compressed)
|
||||
saved_estimate = display_tokens - new_estimate
|
||||
|
||||
# Anti-thrashing: track compression effectiveness
|
||||
savings_pct = (saved_estimate / display_tokens * 100) if display_tokens > 0 else 0
|
||||
self._last_compression_savings_pct = savings_pct
|
||||
if savings_pct < 10:
|
||||
self._ineffective_compression_count += 1
|
||||
else:
|
||||
self._ineffective_compression_count = 0
|
||||
|
||||
if not self.quiet_mode:
|
||||
new_estimate = estimate_messages_tokens_rough(compressed)
|
||||
saved_estimate = display_tokens - new_estimate
|
||||
logger.info(
|
||||
"Compressed: %d -> %d messages (~%d tokens saved)",
|
||||
"Compressed: %d -> %d messages (~%d tokens saved, %.0f%%)",
|
||||
n_messages,
|
||||
len(compressed),
|
||||
saved_estimate,
|
||||
savings_pct,
|
||||
)
|
||||
logger.info("Compression #%d complete", self.compression_count)
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ _PROVIDER_PREFIXES: frozenset[str] = frozenset({
|
||||
"opencode", "zen", "go", "vercel", "kilo", "dashscope", "aliyun", "qwen",
|
||||
"mimo", "xiaomi-mimo",
|
||||
"arcee-ai", "arceeai",
|
||||
"xai", "x-ai", "x.ai", "grok",
|
||||
"qwen-portal",
|
||||
})
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from hermes_constants import display_hermes_home
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_skill_commands: Dict[str, Dict[str, Any]] = {}
|
||||
@@ -108,7 +110,7 @@ def _inject_skill_config(loaded_skill: dict[str, Any], parts: list[str]) -> None
|
||||
if not resolved:
|
||||
return
|
||||
|
||||
lines = ["", "[Skill config (from ~/.hermes/config.yaml):"]
|
||||
lines = ["", f"[Skill config (from {display_hermes_home()}/config.yaml):"]
|
||||
for key, value in resolved.items():
|
||||
display_val = str(value) if value else "(not set)"
|
||||
lines.append(f" {key} = {display_val}")
|
||||
|
||||
@@ -989,6 +989,7 @@ def _prune_orphaned_branches(repo_root: str) -> None:
|
||||
_ACCENT_ANSI_DEFAULT = "\033[1;38;2;255;215;0m" # True-color #FFD700 bold — fallback
|
||||
_BOLD = "\033[1m"
|
||||
_RST = "\033[0m"
|
||||
_STREAM_PAD = " " # 4-space indent for streamed response text (matches Panel padding)
|
||||
|
||||
|
||||
def _hex_to_ansi(hex_color: str, *, bold: bool = False) -> str:
|
||||
@@ -1712,9 +1713,9 @@ class HermesCLI:
|
||||
# Parse and validate toolsets
|
||||
self.enabled_toolsets = toolsets
|
||||
if toolsets and "all" not in toolsets and "*" not in toolsets:
|
||||
# Validate each toolset — MCP server names are added by
|
||||
# _get_platform_tools() but aren't registered in TOOLSETS yet
|
||||
# (that happens later in _sync_mcp_toolsets), so exclude them.
|
||||
# Validate each toolset — MCP server names are resolved via
|
||||
# live registry aliases (registered during discover_mcp_tools),
|
||||
# but discovery hasn't run yet at this point, so exclude them.
|
||||
mcp_names = set((CLI_CONFIG.get("mcp_servers") or {}).keys())
|
||||
invalid = [t for t in toolsets if not validate_toolset(t) and t not in mcp_names]
|
||||
if invalid:
|
||||
@@ -2580,7 +2581,7 @@ class HermesCLI:
|
||||
_tc = getattr(self, "_stream_text_ansi", "")
|
||||
while "\n" in self._stream_buf:
|
||||
line, self._stream_buf = self._stream_buf.split("\n", 1)
|
||||
_cprint(f"{_tc}{line}{_RST}" if _tc else line)
|
||||
_cprint(f"{_STREAM_PAD}{_tc}{line}{_RST}" if _tc else f"{_STREAM_PAD}{line}")
|
||||
|
||||
def _flush_stream(self) -> None:
|
||||
"""Emit any remaining partial line from the stream buffer and close the box."""
|
||||
@@ -2597,7 +2598,7 @@ class HermesCLI:
|
||||
|
||||
if self._stream_buf:
|
||||
_tc = getattr(self, "_stream_text_ansi", "")
|
||||
_cprint(f"{_tc}{self._stream_buf}{_RST}" if _tc else self._stream_buf)
|
||||
_cprint(f"{_STREAM_PAD}{_tc}{self._stream_buf}{_RST}" if _tc else f"{_STREAM_PAD}{self._stream_buf}")
|
||||
self._stream_buf = ""
|
||||
|
||||
# Close the response box
|
||||
@@ -4099,6 +4100,8 @@ class HermesCLI:
|
||||
self.agent.flush_memories(self.conversation_history)
|
||||
except (Exception, KeyboardInterrupt):
|
||||
pass
|
||||
# Trigger memory extraction on the old session before session_id rotates.
|
||||
self.agent.commit_memory_session(self.conversation_history)
|
||||
self._notify_session_boundary("on_session_finalize")
|
||||
elif self.agent:
|
||||
# First session or empty history — still finalize the old session
|
||||
@@ -4587,16 +4590,19 @@ class HermesCLI:
|
||||
self._close_model_picker()
|
||||
return
|
||||
provider_data = providers[selected]
|
||||
model_list = []
|
||||
try:
|
||||
from hermes_cli.models import provider_model_ids
|
||||
live = provider_model_ids(provider_data["slug"])
|
||||
if live:
|
||||
model_list = live
|
||||
except Exception:
|
||||
pass
|
||||
# Use the curated model list from list_authenticated_providers()
|
||||
# (same lists as `hermes model` and gateway pickers).
|
||||
# Only fall back to the live provider catalog when the curated
|
||||
# list is empty (e.g. user-defined endpoints with no curated list).
|
||||
model_list = provider_data.get("models", [])
|
||||
if not model_list:
|
||||
model_list = provider_data.get("models", [])
|
||||
try:
|
||||
from hermes_cli.models import provider_model_ids
|
||||
live = provider_model_ids(provider_data["slug"])
|
||||
if live:
|
||||
model_list = live
|
||||
except Exception:
|
||||
pass
|
||||
state["stage"] = "model"
|
||||
state["provider_data"] = provider_data
|
||||
state["model_list"] = model_list
|
||||
@@ -5761,7 +5767,7 @@ class HermesCLI:
|
||||
border_style=_resp_color,
|
||||
style=_resp_text,
|
||||
box=rich_box.HORIZONTALS,
|
||||
padding=(1, 2),
|
||||
padding=(1, 4),
|
||||
))
|
||||
else:
|
||||
_cprint(" (No response generated)")
|
||||
@@ -5885,7 +5891,7 @@ class HermesCLI:
|
||||
title_align="left",
|
||||
border_style=_resp_color,
|
||||
box=rich_box.HORIZONTALS,
|
||||
padding=(1, 2),
|
||||
padding=(1, 4),
|
||||
))
|
||||
else:
|
||||
_cprint(" 💬 /btw: (no response)")
|
||||
@@ -5952,7 +5958,7 @@ class HermesCLI:
|
||||
parts = cmd.strip().split(None, 1)
|
||||
sub = parts[1].lower().strip() if len(parts) > 1 else "status"
|
||||
|
||||
_DEFAULT_CDP = "http://localhost:9222"
|
||||
_DEFAULT_CDP = "http://127.0.0.1:9222"
|
||||
current = os.environ.get("BROWSER_CDP_URL", "").strip()
|
||||
|
||||
if sub.startswith("connect"):
|
||||
@@ -7648,7 +7654,7 @@ class HermesCLI:
|
||||
label = " ⚕ Hermes "
|
||||
fill = w - 2 - len(label)
|
||||
_cprint(f"\n{_ACCENT}╭─{label}{'─' * max(fill - 1, 0)}╮{_RST}")
|
||||
_cprint(sentence.rstrip())
|
||||
_cprint(f"{_STREAM_PAD}{sentence.rstrip()}")
|
||||
|
||||
tts_thread = threading.Thread(
|
||||
target=stream_tts_to_speaker,
|
||||
@@ -7879,7 +7885,7 @@ class HermesCLI:
|
||||
border_style=_resp_color,
|
||||
style=_resp_text,
|
||||
box=rich_box.HORIZONTALS,
|
||||
padding=(1, 2),
|
||||
padding=(1, 4),
|
||||
))
|
||||
|
||||
|
||||
|
||||
+9
-2
@@ -10,6 +10,7 @@ runs at a time if multiple processes overlap.
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -288,11 +289,13 @@ def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> Option
|
||||
|
||||
if wrap_response:
|
||||
task_name = job.get("name", job["id"])
|
||||
job_id = job.get("id", "")
|
||||
delivery_content = (
|
||||
f"Cronjob Response: {task_name}\n"
|
||||
f"(job_id: {job_id})\n"
|
||||
f"-------------\n\n"
|
||||
f"{content}\n\n"
|
||||
f"Note: The agent cannot see this message, and therefore cannot respond to it."
|
||||
f"To stop or manage this job, send me a new message (e.g. \"stop reminder {task_name}\")."
|
||||
)
|
||||
else:
|
||||
delivery_content = content
|
||||
@@ -768,7 +771,11 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
_cron_inactivity_limit = _cron_timeout if _cron_timeout > 0 else None
|
||||
_POLL_INTERVAL = 5.0
|
||||
_cron_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
_cron_future = _cron_pool.submit(agent.run_conversation, prompt)
|
||||
# Preserve scheduler-scoped ContextVar state (for example skill-declared
|
||||
# env passthrough registrations) when the cron run hops into the worker
|
||||
# thread used for inactivity timeout monitoring.
|
||||
_cron_context = contextvars.copy_context()
|
||||
_cron_future = _cron_pool.submit(_cron_context.run, agent.run_conversation, prompt)
|
||||
_inactivity_timeout = False
|
||||
try:
|
||||
if _cron_inactivity_limit is None:
|
||||
|
||||
Regular → Executable
+13
-6
@@ -1,13 +1,14 @@
|
||||
#!/bin/bash
|
||||
# Docker entrypoint: bootstrap config files into the mounted volume, then run hermes.
|
||||
# Docker/Podman entrypoint: bootstrap config files into the mounted volume, then run hermes.
|
||||
set -e
|
||||
|
||||
HERMES_HOME="/opt/data"
|
||||
HERMES_HOME="${HERMES_HOME:-/opt/data}"
|
||||
INSTALL_DIR="/opt/hermes"
|
||||
|
||||
# --- Privilege dropping via gosu ---
|
||||
# When started as root (the default), optionally remap the hermes user/group
|
||||
# to match host-side ownership, fix volume permissions, then re-exec as hermes.
|
||||
# When started as root (the default for Docker, or fakeroot in rootless Podman),
|
||||
# optionally remap the hermes user/group to match host-side ownership, fix volume
|
||||
# permissions, then re-exec as hermes.
|
||||
if [ "$(id -u)" = "0" ]; then
|
||||
if [ -n "$HERMES_UID" ] && [ "$HERMES_UID" != "$(id -u hermes)" ]; then
|
||||
echo "Changing hermes UID to $HERMES_UID"
|
||||
@@ -16,13 +17,19 @@ if [ "$(id -u)" = "0" ]; then
|
||||
|
||||
if [ -n "$HERMES_GID" ] && [ "$HERMES_GID" != "$(id -g hermes)" ]; then
|
||||
echo "Changing hermes GID to $HERMES_GID"
|
||||
groupmod -g "$HERMES_GID" hermes
|
||||
# -o allows non-unique GID (e.g. macOS GID 20 "staff" may already exist
|
||||
# as "dialout" in the Debian-based container image)
|
||||
groupmod -o -g "$HERMES_GID" hermes 2>/dev/null || true
|
||||
fi
|
||||
|
||||
actual_hermes_uid=$(id -u hermes)
|
||||
if [ "$(stat -c %u "$HERMES_HOME" 2>/dev/null)" != "$actual_hermes_uid" ]; then
|
||||
echo "$HERMES_HOME is not owned by $actual_hermes_uid, fixing"
|
||||
chown -R hermes:hermes "$HERMES_HOME"
|
||||
# In rootless Podman the container's "root" is mapped to an unprivileged
|
||||
# host UID — chown will fail. That's fine: the volume is already owned
|
||||
# by the mapped user on the host side.
|
||||
chown -R hermes:hermes "$HERMES_HOME" 2>/dev/null || \
|
||||
echo "Warning: chown failed (rootless container?) — continuing anyway"
|
||||
fi
|
||||
|
||||
echo "Dropping root privileges"
|
||||
|
||||
+1
-25
@@ -3,12 +3,11 @@ Event Hook System
|
||||
|
||||
A lightweight event-driven system that fires handlers at key lifecycle points.
|
||||
Hooks are discovered from ~/.hermes/hooks/ directories, each containing:
|
||||
- HOOK.yaml (metadata: name, description, events list, optional startup_readiness)
|
||||
- HOOK.yaml (metadata: name, description, events list)
|
||||
- handler.py (Python handler with async def handle(event_type, context))
|
||||
|
||||
Events:
|
||||
- gateway:startup -- Gateway process starts
|
||||
- gateway:shutdown -- Gateway process is shutting down
|
||||
- session:start -- New session created (first message of a new session)
|
||||
- session:end -- Session ends (user ran /new or /reset)
|
||||
- session:reset -- Session reset completed (new session entry created)
|
||||
@@ -32,26 +31,6 @@ from hermes_cli.config import get_hermes_home
|
||||
HOOKS_DIR = get_hermes_home() / "hooks"
|
||||
|
||||
|
||||
def _normalize_startup_readiness(hook_name: str, manifest: dict[str, Any]) -> Optional[dict[str, Any]]:
|
||||
"""Validate and normalize optional startup readiness metadata."""
|
||||
readiness = manifest.get("startup_readiness")
|
||||
if readiness is None:
|
||||
return None
|
||||
if not isinstance(readiness, dict):
|
||||
print(f"[hooks] Ignoring startup_readiness for {hook_name}: expected mapping", flush=True)
|
||||
return None
|
||||
|
||||
check_id = str(readiness.get("id", "")).strip()
|
||||
if not check_id:
|
||||
print(f"[hooks] Ignoring startup_readiness for {hook_name}: missing id", flush=True)
|
||||
return None
|
||||
|
||||
return {
|
||||
"id": check_id,
|
||||
"required": bool(readiness.get("required", True)),
|
||||
}
|
||||
|
||||
|
||||
class HookRegistry:
|
||||
"""
|
||||
Discovers, loads, and fires event hooks.
|
||||
@@ -83,7 +62,6 @@ class HookRegistry:
|
||||
"description": "Run ~/.hermes/BOOT.md on gateway startup",
|
||||
"events": ["gateway:startup"],
|
||||
"path": "(builtin)",
|
||||
"startup_readiness": None,
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"[hooks] Could not load built-in boot-md hook: {e}", flush=True)
|
||||
@@ -124,7 +102,6 @@ class HookRegistry:
|
||||
if not events:
|
||||
print(f"[hooks] Skipping {hook_name}: no events declared", flush=True)
|
||||
continue
|
||||
startup_readiness = _normalize_startup_readiness(hook_name, manifest)
|
||||
|
||||
# Dynamically load the handler module
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
@@ -151,7 +128,6 @@ class HookRegistry:
|
||||
"description": manifest.get("description", ""),
|
||||
"events": events,
|
||||
"path": str(hook_dir),
|
||||
"startup_readiness": startup_readiness,
|
||||
})
|
||||
|
||||
print(f"[hooks] Loaded hook '{hook_name}' for events: {events}", flush=True)
|
||||
|
||||
@@ -515,6 +515,8 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
session_id: Optional[str] = None,
|
||||
stream_delta_callback=None,
|
||||
tool_progress_callback=None,
|
||||
tool_start_callback=None,
|
||||
tool_complete_callback=None,
|
||||
) -> Any:
|
||||
"""
|
||||
Create an AIAgent instance using the gateway's runtime config.
|
||||
@@ -553,6 +555,8 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
platform="api_server",
|
||||
stream_delta_callback=stream_delta_callback,
|
||||
tool_progress_callback=tool_progress_callback,
|
||||
tool_start_callback=tool_start_callback,
|
||||
tool_complete_callback=tool_complete_callback,
|
||||
session_db=self._ensure_session_db(),
|
||||
fallback_model=fallback_model,
|
||||
)
|
||||
@@ -965,6 +969,427 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
|
||||
return response
|
||||
|
||||
async def _write_sse_responses(
|
||||
self,
|
||||
request: "web.Request",
|
||||
response_id: str,
|
||||
model: str,
|
||||
created_at: int,
|
||||
stream_q,
|
||||
agent_task,
|
||||
agent_ref,
|
||||
conversation_history: List[Dict[str, str]],
|
||||
user_message: str,
|
||||
instructions: Optional[str],
|
||||
conversation: Optional[str],
|
||||
store: bool,
|
||||
session_id: str,
|
||||
) -> "web.StreamResponse":
|
||||
"""Write an SSE stream for POST /v1/responses (OpenAI Responses API).
|
||||
|
||||
Emits spec-compliant event types as the agent runs:
|
||||
|
||||
- ``response.created`` — initial envelope (status=in_progress)
|
||||
- ``response.output_text.delta`` / ``response.output_text.done`` —
|
||||
streamed assistant text
|
||||
- ``response.output_item.added`` / ``response.output_item.done``
|
||||
with ``item.type == "function_call"`` — when the agent invokes a
|
||||
tool (both events fire; the ``done`` event carries the finalized
|
||||
``arguments`` string)
|
||||
- ``response.output_item.added`` with
|
||||
``item.type == "function_call_output"`` — tool result with
|
||||
``{call_id, output, status}``
|
||||
- ``response.completed`` — terminal event carrying the full
|
||||
response object with all output items + usage (same payload
|
||||
shape as the non-streaming path for parity)
|
||||
- ``response.failed`` — terminal event on agent error
|
||||
|
||||
If the client disconnects mid-stream, ``agent.interrupt()`` is
|
||||
called so the agent stops issuing upstream LLM calls, then the
|
||||
asyncio task is cancelled. When ``store=True`` the full response
|
||||
is persisted to the ResponseStore in a ``finally`` block so GET
|
||||
/v1/responses/{id} and ``previous_response_id`` chaining work the
|
||||
same as the batch path.
|
||||
"""
|
||||
import queue as _q
|
||||
|
||||
sse_headers = {
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no",
|
||||
}
|
||||
origin = request.headers.get("Origin", "")
|
||||
cors = self._cors_headers_for_origin(origin) if origin else None
|
||||
if cors:
|
||||
sse_headers.update(cors)
|
||||
if session_id:
|
||||
sse_headers["X-Hermes-Session-Id"] = session_id
|
||||
response = web.StreamResponse(status=200, headers=sse_headers)
|
||||
await response.prepare(request)
|
||||
|
||||
# State accumulated during the stream
|
||||
final_text_parts: List[str] = []
|
||||
# Track open function_call items by name so we can emit a matching
|
||||
# ``done`` event when the tool completes. Order preserved.
|
||||
pending_tool_calls: List[Dict[str, Any]] = []
|
||||
# Output items we've emitted so far (used to build the terminal
|
||||
# response.completed payload). Kept in the order they appeared.
|
||||
emitted_items: List[Dict[str, Any]] = []
|
||||
# Monotonic counter for output_index (spec requires it).
|
||||
output_index = 0
|
||||
# Monotonic counter for call_id generation if the agent doesn't
|
||||
# provide one (it doesn't, from tool_progress_callback).
|
||||
call_counter = 0
|
||||
# Canonical Responses SSE events include a monotonically increasing
|
||||
# sequence_number. Add it server-side for every emitted event so
|
||||
# clients that validate the OpenAI event schema can parse our stream.
|
||||
sequence_number = 0
|
||||
# Track the assistant message item id + content index for text
|
||||
# delta events — the spec ties deltas to a specific item.
|
||||
message_item_id = f"msg_{uuid.uuid4().hex[:24]}"
|
||||
message_output_index: Optional[int] = None
|
||||
message_opened = False
|
||||
|
||||
async def _write_event(event_type: str, data: Dict[str, Any]) -> None:
|
||||
nonlocal sequence_number
|
||||
if "sequence_number" not in data:
|
||||
data["sequence_number"] = sequence_number
|
||||
sequence_number += 1
|
||||
payload = f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
|
||||
await response.write(payload.encode())
|
||||
|
||||
def _envelope(status: str) -> Dict[str, Any]:
|
||||
env: Dict[str, Any] = {
|
||||
"id": response_id,
|
||||
"object": "response",
|
||||
"status": status,
|
||||
"created_at": created_at,
|
||||
"model": model,
|
||||
}
|
||||
return env
|
||||
|
||||
final_response_text = ""
|
||||
agent_error: Optional[str] = None
|
||||
usage: Dict[str, int] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
||||
|
||||
try:
|
||||
# response.created — initial envelope, status=in_progress
|
||||
created_env = _envelope("in_progress")
|
||||
created_env["output"] = []
|
||||
await _write_event("response.created", {
|
||||
"type": "response.created",
|
||||
"response": created_env,
|
||||
})
|
||||
last_activity = time.monotonic()
|
||||
|
||||
async def _open_message_item() -> None:
|
||||
"""Emit response.output_item.added for the assistant message
|
||||
the first time any text delta arrives."""
|
||||
nonlocal message_opened, message_output_index, output_index
|
||||
if message_opened:
|
||||
return
|
||||
message_opened = True
|
||||
message_output_index = output_index
|
||||
output_index += 1
|
||||
item = {
|
||||
"id": message_item_id,
|
||||
"type": "message",
|
||||
"status": "in_progress",
|
||||
"role": "assistant",
|
||||
"content": [],
|
||||
}
|
||||
await _write_event("response.output_item.added", {
|
||||
"type": "response.output_item.added",
|
||||
"output_index": message_output_index,
|
||||
"item": item,
|
||||
})
|
||||
|
||||
async def _emit_text_delta(delta_text: str) -> None:
|
||||
await _open_message_item()
|
||||
final_text_parts.append(delta_text)
|
||||
await _write_event("response.output_text.delta", {
|
||||
"type": "response.output_text.delta",
|
||||
"item_id": message_item_id,
|
||||
"output_index": message_output_index,
|
||||
"content_index": 0,
|
||||
"delta": delta_text,
|
||||
"logprobs": [],
|
||||
})
|
||||
|
||||
async def _emit_tool_started(payload: Dict[str, Any]) -> str:
|
||||
"""Emit response.output_item.added for a function_call.
|
||||
|
||||
Returns the call_id so the matching completion event can
|
||||
reference it. Prefer the real ``tool_call_id`` from the
|
||||
agent when available; fall back to a generated call id for
|
||||
safety in tests or older code paths.
|
||||
"""
|
||||
nonlocal output_index, call_counter
|
||||
call_counter += 1
|
||||
call_id = payload.get("tool_call_id") or f"call_{response_id[5:]}_{call_counter}"
|
||||
args = payload.get("arguments", {})
|
||||
if isinstance(args, dict):
|
||||
arguments_str = json.dumps(args)
|
||||
else:
|
||||
arguments_str = str(args)
|
||||
item = {
|
||||
"id": f"fc_{uuid.uuid4().hex[:24]}",
|
||||
"type": "function_call",
|
||||
"status": "in_progress",
|
||||
"name": payload.get("name", ""),
|
||||
"call_id": call_id,
|
||||
"arguments": arguments_str,
|
||||
}
|
||||
idx = output_index
|
||||
output_index += 1
|
||||
pending_tool_calls.append({
|
||||
"call_id": call_id,
|
||||
"name": payload.get("name", ""),
|
||||
"arguments": arguments_str,
|
||||
"item_id": item["id"],
|
||||
"output_index": idx,
|
||||
})
|
||||
emitted_items.append({
|
||||
"type": "function_call",
|
||||
"name": payload.get("name", ""),
|
||||
"arguments": arguments_str,
|
||||
"call_id": call_id,
|
||||
})
|
||||
await _write_event("response.output_item.added", {
|
||||
"type": "response.output_item.added",
|
||||
"output_index": idx,
|
||||
"item": item,
|
||||
})
|
||||
return call_id
|
||||
|
||||
async def _emit_tool_completed(payload: Dict[str, Any]) -> None:
|
||||
"""Emit response.output_item.done (function_call) followed
|
||||
by response.output_item.added (function_call_output)."""
|
||||
nonlocal output_index
|
||||
call_id = payload.get("tool_call_id")
|
||||
result = payload.get("result", "")
|
||||
pending = None
|
||||
if call_id:
|
||||
for i, p in enumerate(pending_tool_calls):
|
||||
if p["call_id"] == call_id:
|
||||
pending = pending_tool_calls.pop(i)
|
||||
break
|
||||
if pending is None:
|
||||
# Completion without a matching start — skip to avoid
|
||||
# emitting orphaned done events.
|
||||
return
|
||||
|
||||
# function_call done
|
||||
done_item = {
|
||||
"id": pending["item_id"],
|
||||
"type": "function_call",
|
||||
"status": "completed",
|
||||
"name": pending["name"],
|
||||
"call_id": pending["call_id"],
|
||||
"arguments": pending["arguments"],
|
||||
}
|
||||
await _write_event("response.output_item.done", {
|
||||
"type": "response.output_item.done",
|
||||
"output_index": pending["output_index"],
|
||||
"item": done_item,
|
||||
})
|
||||
|
||||
# function_call_output added (result)
|
||||
result_str = result if isinstance(result, str) else json.dumps(result)
|
||||
output_parts = [{"type": "input_text", "text": result_str}]
|
||||
output_item = {
|
||||
"id": f"fco_{uuid.uuid4().hex[:24]}",
|
||||
"type": "function_call_output",
|
||||
"call_id": pending["call_id"],
|
||||
"output": output_parts,
|
||||
"status": "completed",
|
||||
}
|
||||
idx = output_index
|
||||
output_index += 1
|
||||
emitted_items.append({
|
||||
"type": "function_call_output",
|
||||
"call_id": pending["call_id"],
|
||||
"output": output_parts,
|
||||
})
|
||||
await _write_event("response.output_item.added", {
|
||||
"type": "response.output_item.added",
|
||||
"output_index": idx,
|
||||
"item": output_item,
|
||||
})
|
||||
await _write_event("response.output_item.done", {
|
||||
"type": "response.output_item.done",
|
||||
"output_index": idx,
|
||||
"item": output_item,
|
||||
})
|
||||
|
||||
# Main drain loop — thread-safe queue fed by agent callbacks.
|
||||
async def _dispatch(it) -> None:
|
||||
"""Route a queue item to the correct SSE emitter.
|
||||
|
||||
Plain strings are text deltas. Tagged tuples with
|
||||
``__tool_started__`` / ``__tool_completed__`` prefixes
|
||||
are tool lifecycle events.
|
||||
"""
|
||||
if isinstance(it, tuple) and len(it) == 2 and isinstance(it[0], str):
|
||||
tag, payload = it
|
||||
if tag == "__tool_started__":
|
||||
await _emit_tool_started(payload)
|
||||
elif tag == "__tool_completed__":
|
||||
await _emit_tool_completed(payload)
|
||||
# Unknown tags are silently ignored (forward-compat).
|
||||
elif isinstance(it, str):
|
||||
await _emit_text_delta(it)
|
||||
# Other types (non-string, non-tuple) are silently dropped.
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
while True:
|
||||
try:
|
||||
item = await loop.run_in_executor(None, lambda: stream_q.get(timeout=0.5))
|
||||
except _q.Empty:
|
||||
if agent_task.done():
|
||||
# Drain remaining
|
||||
while True:
|
||||
try:
|
||||
item = stream_q.get_nowait()
|
||||
if item is None:
|
||||
break
|
||||
await _dispatch(item)
|
||||
last_activity = time.monotonic()
|
||||
except _q.Empty:
|
||||
break
|
||||
break
|
||||
if time.monotonic() - last_activity >= CHAT_COMPLETIONS_SSE_KEEPALIVE_SECONDS:
|
||||
await response.write(b": keepalive\n\n")
|
||||
last_activity = time.monotonic()
|
||||
continue
|
||||
|
||||
if item is None: # EOS sentinel
|
||||
break
|
||||
|
||||
await _dispatch(item)
|
||||
last_activity = time.monotonic()
|
||||
|
||||
# Pick up agent result + usage from the completed task
|
||||
try:
|
||||
result, agent_usage = await agent_task
|
||||
usage = agent_usage or usage
|
||||
# If the agent produced a final_response but no text
|
||||
# deltas were streamed (e.g. some providers only emit
|
||||
# the full response at the end), emit a single fallback
|
||||
# delta so Responses clients still receive a live text part.
|
||||
agent_final = result.get("final_response", "") if isinstance(result, dict) else ""
|
||||
if agent_final and not final_text_parts:
|
||||
await _emit_text_delta(agent_final)
|
||||
if agent_final and not final_response_text:
|
||||
final_response_text = agent_final
|
||||
if isinstance(result, dict) and result.get("error") and not final_response_text:
|
||||
agent_error = result["error"]
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("Error running agent for streaming responses: %s", e, exc_info=True)
|
||||
agent_error = str(e)
|
||||
|
||||
# Close the message item if it was opened
|
||||
final_response_text = "".join(final_text_parts) or final_response_text
|
||||
if message_opened:
|
||||
await _write_event("response.output_text.done", {
|
||||
"type": "response.output_text.done",
|
||||
"item_id": message_item_id,
|
||||
"output_index": message_output_index,
|
||||
"content_index": 0,
|
||||
"text": final_response_text,
|
||||
"logprobs": [],
|
||||
})
|
||||
msg_done_item = {
|
||||
"id": message_item_id,
|
||||
"type": "message",
|
||||
"status": "completed",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "output_text", "text": final_response_text}
|
||||
],
|
||||
}
|
||||
await _write_event("response.output_item.done", {
|
||||
"type": "response.output_item.done",
|
||||
"output_index": message_output_index,
|
||||
"item": msg_done_item,
|
||||
})
|
||||
|
||||
# Always append a final message item in the completed
|
||||
# response envelope so clients that only parse the terminal
|
||||
# payload still see the assistant text. This mirrors the
|
||||
# shape produced by _extract_output_items in the batch path.
|
||||
final_items: List[Dict[str, Any]] = list(emitted_items)
|
||||
final_items.append({
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "output_text", "text": final_response_text or (agent_error or "")}
|
||||
],
|
||||
})
|
||||
|
||||
if agent_error:
|
||||
failed_env = _envelope("failed")
|
||||
failed_env["output"] = final_items
|
||||
failed_env["error"] = {"message": agent_error, "type": "server_error"}
|
||||
failed_env["usage"] = {
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"total_tokens": usage.get("total_tokens", 0),
|
||||
}
|
||||
await _write_event("response.failed", {
|
||||
"type": "response.failed",
|
||||
"response": failed_env,
|
||||
})
|
||||
else:
|
||||
completed_env = _envelope("completed")
|
||||
completed_env["output"] = final_items
|
||||
completed_env["usage"] = {
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"total_tokens": usage.get("total_tokens", 0),
|
||||
}
|
||||
await _write_event("response.completed", {
|
||||
"type": "response.completed",
|
||||
"response": completed_env,
|
||||
})
|
||||
|
||||
# Persist for future chaining / GET retrieval, mirroring
|
||||
# the batch path behavior.
|
||||
if store:
|
||||
full_history = list(conversation_history)
|
||||
full_history.append({"role": "user", "content": user_message})
|
||||
if isinstance(result, dict) and result.get("messages"):
|
||||
full_history.extend(result["messages"])
|
||||
else:
|
||||
full_history.append({"role": "assistant", "content": final_response_text})
|
||||
self._response_store.put(response_id, {
|
||||
"response": completed_env,
|
||||
"conversation_history": full_history,
|
||||
"instructions": instructions,
|
||||
"session_id": session_id,
|
||||
})
|
||||
if conversation:
|
||||
self._response_store.set_conversation(conversation, response_id)
|
||||
|
||||
except (ConnectionResetError, ConnectionAbortedError, BrokenPipeError, OSError):
|
||||
# Client disconnected — interrupt the agent so it stops
|
||||
# making upstream LLM calls, then cancel the task.
|
||||
agent = agent_ref[0] if agent_ref else None
|
||||
if agent is not None:
|
||||
try:
|
||||
agent.interrupt("SSE client disconnected")
|
||||
except Exception:
|
||||
pass
|
||||
if not agent_task.done():
|
||||
agent_task.cancel()
|
||||
try:
|
||||
await agent_task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
logger.info("SSE client disconnected; interrupted agent task %s", response_id)
|
||||
|
||||
return response
|
||||
|
||||
async def _handle_responses(self, request: "web.Request") -> "web.Response":
|
||||
"""POST /v1/responses — OpenAI Responses API format."""
|
||||
auth_err = self._check_auth(request)
|
||||
@@ -1035,11 +1460,13 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
if previous_response_id:
|
||||
logger.debug("Both conversation_history and previous_response_id provided; using conversation_history")
|
||||
|
||||
stored_session_id = None
|
||||
if not conversation_history and previous_response_id:
|
||||
stored = self._response_store.get(previous_response_id)
|
||||
if stored is None:
|
||||
return web.json_response(_openai_error(f"Previous response not found: {previous_response_id}"), status=404)
|
||||
conversation_history = list(stored.get("conversation_history", []))
|
||||
stored_session_id = stored.get("session_id")
|
||||
# If no instructions provided, carry forward from previous
|
||||
if instructions is None:
|
||||
instructions = stored.get("instructions")
|
||||
@@ -1057,8 +1484,83 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
if body.get("truncation") == "auto" and len(conversation_history) > 100:
|
||||
conversation_history = conversation_history[-100:]
|
||||
|
||||
# Run the agent (with Idempotency-Key support)
|
||||
session_id = str(uuid.uuid4())
|
||||
# Reuse session from previous_response_id chain so the dashboard
|
||||
# groups the entire conversation under one session entry.
|
||||
session_id = stored_session_id or str(uuid.uuid4())
|
||||
|
||||
stream = bool(body.get("stream", False))
|
||||
if stream:
|
||||
# Streaming branch — emit OpenAI Responses SSE events as the
|
||||
# agent runs so frontends can render text deltas and tool
|
||||
# calls in real time. See _write_sse_responses for details.
|
||||
import queue as _q
|
||||
_stream_q: _q.Queue = _q.Queue()
|
||||
|
||||
def _on_delta(delta):
|
||||
# None from the agent is a CLI box-close signal, not EOS.
|
||||
# Forwarding would kill the SSE stream prematurely; the
|
||||
# SSE writer detects completion via agent_task.done().
|
||||
if delta is not None:
|
||||
_stream_q.put(delta)
|
||||
|
||||
def _on_tool_progress(event_type, name, preview, args, **kwargs):
|
||||
"""Queue non-start tool progress events if needed in future.
|
||||
|
||||
The structured Responses stream uses ``tool_start_callback``
|
||||
and ``tool_complete_callback`` for exact call-id correlation,
|
||||
so progress events are currently ignored here.
|
||||
"""
|
||||
return
|
||||
|
||||
def _on_tool_start(tool_call_id, function_name, function_args):
|
||||
"""Queue a started tool for live function_call streaming."""
|
||||
_stream_q.put(("__tool_started__", {
|
||||
"tool_call_id": tool_call_id,
|
||||
"name": function_name,
|
||||
"arguments": function_args or {},
|
||||
}))
|
||||
|
||||
def _on_tool_complete(tool_call_id, function_name, function_args, function_result):
|
||||
"""Queue a completed tool result for live function_call_output streaming."""
|
||||
_stream_q.put(("__tool_completed__", {
|
||||
"tool_call_id": tool_call_id,
|
||||
"name": function_name,
|
||||
"arguments": function_args or {},
|
||||
"result": function_result,
|
||||
}))
|
||||
|
||||
agent_ref = [None]
|
||||
agent_task = asyncio.ensure_future(self._run_agent(
|
||||
user_message=user_message,
|
||||
conversation_history=conversation_history,
|
||||
ephemeral_system_prompt=instructions,
|
||||
session_id=session_id,
|
||||
stream_delta_callback=_on_delta,
|
||||
tool_progress_callback=_on_tool_progress,
|
||||
tool_start_callback=_on_tool_start,
|
||||
tool_complete_callback=_on_tool_complete,
|
||||
agent_ref=agent_ref,
|
||||
))
|
||||
|
||||
response_id = f"resp_{uuid.uuid4().hex[:28]}"
|
||||
model_name = body.get("model", self._model_name)
|
||||
created_at = int(time.time())
|
||||
|
||||
return await self._write_sse_responses(
|
||||
request=request,
|
||||
response_id=response_id,
|
||||
model=model_name,
|
||||
created_at=created_at,
|
||||
stream_q=_stream_q,
|
||||
agent_task=agent_task,
|
||||
agent_ref=agent_ref,
|
||||
conversation_history=conversation_history,
|
||||
user_message=user_message,
|
||||
instructions=instructions,
|
||||
conversation=conversation,
|
||||
store=store,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
async def _compute_response():
|
||||
return await self._run_agent(
|
||||
@@ -1133,6 +1635,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
"response": response_data,
|
||||
"conversation_history": full_history,
|
||||
"instructions": instructions,
|
||||
"session_id": session_id,
|
||||
})
|
||||
# Update conversation mapping so the next request with the same
|
||||
# conversation name automatically chains to this response
|
||||
@@ -1486,6 +1989,8 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
session_id: Optional[str] = None,
|
||||
stream_delta_callback=None,
|
||||
tool_progress_callback=None,
|
||||
tool_start_callback=None,
|
||||
tool_complete_callback=None,
|
||||
agent_ref: Optional[list] = None,
|
||||
) -> tuple:
|
||||
"""
|
||||
@@ -1507,6 +2012,8 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
session_id=session_id,
|
||||
stream_delta_callback=stream_delta_callback,
|
||||
tool_progress_callback=tool_progress_callback,
|
||||
tool_start_callback=tool_start_callback,
|
||||
tool_complete_callback=tool_complete_callback,
|
||||
)
|
||||
if agent_ref is not None:
|
||||
agent_ref[0] = agent
|
||||
@@ -1643,10 +2150,12 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
if previous_response_id:
|
||||
logger.debug("Both conversation_history and previous_response_id provided; using conversation_history")
|
||||
|
||||
stored_session_id = None
|
||||
if not conversation_history and previous_response_id:
|
||||
stored = self._response_store.get(previous_response_id)
|
||||
if stored:
|
||||
conversation_history = list(stored.get("conversation_history", []))
|
||||
stored_session_id = stored.get("session_id")
|
||||
if instructions is None:
|
||||
instructions = stored.get("instructions")
|
||||
|
||||
@@ -1665,7 +2174,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
)
|
||||
conversation_history.append({"role": msg["role"], "content": str(content)})
|
||||
|
||||
session_id = body.get("session_id") or run_id
|
||||
session_id = body.get("session_id") or stored_session_id or run_id
|
||||
ephemeral_system_prompt = instructions
|
||||
|
||||
async def _run_and_close():
|
||||
|
||||
@@ -1624,6 +1624,21 @@ class BasePlatformAdapter(ABC):
|
||||
# streaming already delivered the text (already_sent=True) or
|
||||
# when the message was queued behind an active agent. Log at
|
||||
# DEBUG to avoid noisy warnings for expected behavior.
|
||||
#
|
||||
# Suppress stale response when the session was interrupted by a
|
||||
# new message that hasn't been consumed yet. The pending message
|
||||
# is processed by the pending-message handler below (#8221/#2483).
|
||||
if (
|
||||
response
|
||||
and interrupt_event.is_set()
|
||||
and session_key in self._pending_messages
|
||||
):
|
||||
logger.info(
|
||||
"[%s] Suppressing stale response for interrupted session %s",
|
||||
self.name,
|
||||
session_key,
|
||||
)
|
||||
response = None
|
||||
if not response:
|
||||
logger.debug("[%s] Handler returned empty/None response for %s", self.name, event.source.chat_id)
|
||||
if response:
|
||||
|
||||
@@ -1379,6 +1379,68 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
)
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to)
|
||||
|
||||
async def send_animation(
|
||||
self,
|
||||
chat_id: str,
|
||||
animation_url: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
"""Send an animated GIF natively as a Discord file attachment."""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
if not is_safe_url(animation_url):
|
||||
logger.warning("[%s] Blocked unsafe animation URL during Discord send_animation", self.name)
|
||||
return await super().send_animation(chat_id, animation_url, caption, reply_to, metadata=metadata)
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
# Download the GIF and send as a Discord file attachment
|
||||
# (Discord renders .gif attachments as auto-playing animations inline)
|
||||
from gateway.platforms.base import resolve_proxy_url, proxy_kwargs_for_aiohttp
|
||||
_proxy = resolve_proxy_url(platform_env_var="DISCORD_PROXY")
|
||||
_sess_kw, _req_kw = proxy_kwargs_for_aiohttp(_proxy)
|
||||
async with aiohttp.ClientSession(**_sess_kw) as session:
|
||||
async with session.get(animation_url, timeout=aiohttp.ClientTimeout(total=30), **_req_kw) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"Failed to download animation: HTTP {resp.status}")
|
||||
|
||||
animation_data = await resp.read()
|
||||
|
||||
import io
|
||||
file = discord.File(io.BytesIO(animation_data), filename="animation.gif")
|
||||
|
||||
msg = await channel.send(
|
||||
content=caption if caption else None,
|
||||
file=file,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.id))
|
||||
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"[%s] aiohttp not installed, falling back to URL. Run: pip install aiohttp",
|
||||
self.name,
|
||||
exc_info=True,
|
||||
)
|
||||
return await super().send_animation(chat_id, animation_url, caption, reply_to, metadata=metadata)
|
||||
except Exception as e: # pragma: no cover - defensive logging
|
||||
logger.error(
|
||||
"[%s] Failed to send animation attachment, falling back to URL: %s",
|
||||
self.name,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
return await super().send_animation(chat_id, animation_url, caption, reply_to, metadata=metadata)
|
||||
|
||||
async def send_video(
|
||||
self,
|
||||
chat_id: str,
|
||||
@@ -1696,6 +1758,10 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
async def slash_update(interaction: discord.Interaction):
|
||||
await self._run_simple_slash(interaction, "/update", "Update initiated~")
|
||||
|
||||
@tree.command(name="restart", description="Gracefully restart the Hermes gateway")
|
||||
async def slash_restart(interaction: discord.Interaction):
|
||||
await self._run_simple_slash(interaction, "/restart", "Restart requested~")
|
||||
|
||||
@tree.command(name="approve", description="Approve a pending dangerous command")
|
||||
@discord.app_commands.describe(scope="Optional: 'all', 'session', 'always', 'all session', 'all always'")
|
||||
async def slash_approve(interaction: discord.Interaction, scope: str = ""):
|
||||
|
||||
@@ -729,6 +729,14 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def stop_typing(self, chat_id: str) -> None:
|
||||
"""Stop the Matrix typing indicator."""
|
||||
if self._client:
|
||||
try:
|
||||
await self._client.set_typing(RoomID(chat_id), timeout=0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def edit_message(
|
||||
self, chat_id: str, message_id: str, content: str
|
||||
) -> SendResult:
|
||||
|
||||
+390
-80
@@ -482,6 +482,27 @@ def _resolve_hermes_bin() -> Optional[list[str]]:
|
||||
return None
|
||||
|
||||
|
||||
def _parse_session_key(session_key: str) -> "dict | None":
|
||||
"""Parse a session key into its component parts.
|
||||
|
||||
Session keys follow the format
|
||||
``agent:main:{platform}:{chat_type}:{chat_id}[:{thread_id}[:{user_id}]]``.
|
||||
Returns a dict with ``platform``, ``chat_type``, ``chat_id``, and
|
||||
optionally ``thread_id`` keys, or None if the key doesn't match.
|
||||
"""
|
||||
parts = session_key.split(":")
|
||||
if len(parts) >= 5 and parts[0] == "agent" and parts[1] == "main":
|
||||
result = {
|
||||
"platform": parts[2],
|
||||
"chat_type": parts[3],
|
||||
"chat_id": parts[4],
|
||||
}
|
||||
if len(parts) > 5:
|
||||
result["thread_id"] = parts[5]
|
||||
return result
|
||||
return None
|
||||
|
||||
|
||||
def _format_gateway_process_notification(evt: dict) -> "str | None":
|
||||
"""Format a watch pattern event from completion_queue into a [SYSTEM:] message."""
|
||||
evt_type = evt.get("type", "completion")
|
||||
@@ -573,6 +594,7 @@ class GatewayRunner:
|
||||
self._running_agents: Dict[str, Any] = {}
|
||||
self._running_agents_ts: Dict[str, float] = {} # start timestamp per session
|
||||
self._pending_messages: Dict[str, str] = {} # Queued messages during interrupt
|
||||
self._busy_ack_ts: Dict[str, float] = {} # last busy-ack timestamp per session (debounce)
|
||||
|
||||
# Cache AIAgent instances per session to preserve prompt caching.
|
||||
# Without this, a new AIAgent is created per message, rebuilding the
|
||||
@@ -1329,26 +1351,100 @@ class GatewayRunner:
|
||||
merge_pending_message_event(adapter._pending_messages, session_key, event)
|
||||
|
||||
async def _handle_active_session_busy_message(self, event: MessageEvent, session_key: str) -> bool:
|
||||
if not self._draining:
|
||||
return False
|
||||
# --- Draining case (gateway restarting/stopping) ---
|
||||
if self._draining:
|
||||
adapter = self.adapters.get(event.source.platform)
|
||||
if not adapter:
|
||||
return True
|
||||
|
||||
thread_meta = {"thread_id": event.source.thread_id} if event.source.thread_id else None
|
||||
if self._queue_during_drain_enabled():
|
||||
self._queue_or_replace_pending_event(session_key, event)
|
||||
message = f"⏳ Gateway {self._status_action_gerund()} — queued for the next turn after it comes back."
|
||||
else:
|
||||
message = f"⏳ Gateway is {self._status_action_gerund()} and is not accepting another turn right now."
|
||||
|
||||
await adapter._send_with_retry(
|
||||
chat_id=event.source.chat_id,
|
||||
content=message,
|
||||
reply_to=event.message_id,
|
||||
metadata=thread_meta,
|
||||
)
|
||||
return True
|
||||
|
||||
# --- Normal busy case (agent actively running a task) ---
|
||||
# The user sent a message while the agent is working. Interrupt the
|
||||
# agent immediately so it stops the current tool-calling loop and
|
||||
# processes the new message. The pending message is stored in the
|
||||
# adapter so the base adapter picks it up once the interrupted run
|
||||
# returns. A brief ack tells the user what's happening (debounced
|
||||
# to avoid spam when they fire multiple messages quickly).
|
||||
|
||||
adapter = self.adapters.get(event.source.platform)
|
||||
if not adapter:
|
||||
return True
|
||||
return False # let default path handle it
|
||||
|
||||
# Store the message so it's processed as the next turn after the
|
||||
# interrupt causes the current run to exit.
|
||||
from gateway.platforms.base import merge_pending_message_event
|
||||
merge_pending_message_event(adapter._pending_messages, session_key, event)
|
||||
|
||||
# Interrupt the running agent — this aborts in-flight tool calls and
|
||||
# causes the agent loop to exit at the next check point.
|
||||
running_agent = self._running_agents.get(session_key)
|
||||
if running_agent and running_agent is not _AGENT_PENDING_SENTINEL:
|
||||
try:
|
||||
running_agent.interrupt(event.text)
|
||||
except Exception:
|
||||
pass # don't let interrupt failure block the ack
|
||||
|
||||
# Debounce: only send an acknowledgment once every 30 seconds per session
|
||||
# to avoid spamming the user when they send multiple messages quickly
|
||||
_BUSY_ACK_COOLDOWN = 30
|
||||
now = time.time()
|
||||
last_ack = self._busy_ack_ts.get(session_key, 0)
|
||||
if now - last_ack < _BUSY_ACK_COOLDOWN:
|
||||
return True # interrupt sent, ack already delivered recently
|
||||
|
||||
self._busy_ack_ts[session_key] = now
|
||||
|
||||
# Build a status-rich acknowledgment
|
||||
status_parts = []
|
||||
if running_agent and running_agent is not _AGENT_PENDING_SENTINEL:
|
||||
try:
|
||||
summary = running_agent.get_activity_summary()
|
||||
iteration = summary.get("api_call_count", 0)
|
||||
max_iter = summary.get("max_iterations", 0)
|
||||
current_tool = summary.get("current_tool")
|
||||
start_ts = self._running_agents_ts.get(session_key, 0)
|
||||
if start_ts:
|
||||
elapsed_min = int((now - start_ts) / 60)
|
||||
if elapsed_min > 0:
|
||||
status_parts.append(f"{elapsed_min} min elapsed")
|
||||
if max_iter:
|
||||
status_parts.append(f"iteration {iteration}/{max_iter}")
|
||||
if current_tool:
|
||||
status_parts.append(f"running: {current_tool}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
status_detail = f" ({', '.join(status_parts)})" if status_parts else ""
|
||||
message = (
|
||||
f"⚡ Interrupting current task{status_detail}. "
|
||||
f"I'll respond to your message shortly."
|
||||
)
|
||||
|
||||
thread_meta = {"thread_id": event.source.thread_id} if event.source.thread_id else None
|
||||
if self._queue_during_drain_enabled():
|
||||
self._queue_or_replace_pending_event(session_key, event)
|
||||
message = f"⏳ Gateway {self._status_action_gerund()} — queued for the next turn after it comes back."
|
||||
else:
|
||||
message = f"⏳ Gateway is {self._status_action_gerund()} and is not accepting another turn right now."
|
||||
try:
|
||||
await adapter._send_with_retry(
|
||||
chat_id=event.source.chat_id,
|
||||
content=message,
|
||||
reply_to=event.message_id,
|
||||
metadata=thread_meta,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to send busy-ack: %s", e)
|
||||
|
||||
await adapter._send_with_retry(
|
||||
chat_id=event.source.chat_id,
|
||||
content=message,
|
||||
reply_to=event.message_id,
|
||||
metadata=thread_meta,
|
||||
)
|
||||
return True
|
||||
|
||||
async def _drain_active_agents(self, timeout: float) -> tuple[Dict[str, Any], bool]:
|
||||
@@ -1405,7 +1501,7 @@ class GatewayRunner:
|
||||
action = "restarting" if self._restart_requested else "shutting down"
|
||||
hint = (
|
||||
"Your current task will be interrupted. "
|
||||
"Use /retry after restart to continue."
|
||||
"Send any message after restart to resume where it left off."
|
||||
if self._restart_requested
|
||||
else "Your current task will be interrupted."
|
||||
)
|
||||
@@ -1414,12 +1510,11 @@ class GatewayRunner:
|
||||
notified: set = set()
|
||||
for session_key in active:
|
||||
# Parse platform + chat_id from the session key.
|
||||
# Format: agent:main:{platform}:{chat_type}:{chat_id}[:{extra}...]
|
||||
parts = session_key.split(":")
|
||||
if len(parts) < 5:
|
||||
_parsed = _parse_session_key(session_key)
|
||||
if not _parsed:
|
||||
continue
|
||||
platform_str = parts[2]
|
||||
chat_id = parts[4]
|
||||
platform_str = _parsed["platform"]
|
||||
chat_id = _parsed["chat_id"]
|
||||
|
||||
# Deduplicate: one notification per chat, even if multiple
|
||||
# sessions (different users/threads) share the same chat.
|
||||
@@ -1435,7 +1530,7 @@ class GatewayRunner:
|
||||
|
||||
# Include thread_id if present so the message lands in the
|
||||
# correct forum topic / thread.
|
||||
thread_id = parts[5] if len(parts) > 5 else None
|
||||
thread_id = _parsed.get("thread_id")
|
||||
metadata = {"thread_id": thread_id} if thread_id else None
|
||||
|
||||
await adapter.send(chat_id, msg, metadata=metadata)
|
||||
@@ -1475,6 +1570,106 @@ class GatewayRunner:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
_STUCK_LOOP_THRESHOLD = 3 # restarts while active before auto-suspend
|
||||
_STUCK_LOOP_FILE = ".restart_failure_counts"
|
||||
|
||||
def _increment_restart_failure_counts(self, active_session_keys: set) -> None:
|
||||
"""Increment restart-failure counters for sessions active at shutdown.
|
||||
|
||||
Persists to a JSON file so counters survive across restarts.
|
||||
Sessions NOT in active_session_keys are removed (they completed
|
||||
successfully, so the loop is broken).
|
||||
"""
|
||||
import json
|
||||
|
||||
path = _hermes_home / self._STUCK_LOOP_FILE
|
||||
try:
|
||||
counts = json.loads(path.read_text()) if path.exists() else {}
|
||||
except Exception:
|
||||
counts = {}
|
||||
|
||||
# Increment active sessions, remove inactive ones (loop broken)
|
||||
new_counts = {}
|
||||
for key in active_session_keys:
|
||||
new_counts[key] = counts.get(key, 0) + 1
|
||||
# Keep any entries that are still above 0 even if not active now
|
||||
# (they might become active again next restart)
|
||||
|
||||
try:
|
||||
path.write_text(json.dumps(new_counts))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _suspend_stuck_loop_sessions(self) -> int:
|
||||
"""Suspend sessions that have been active across too many restarts.
|
||||
|
||||
Returns the number of sessions suspended. Called on gateway startup
|
||||
AFTER suspend_recently_active() to catch the stuck-loop pattern:
|
||||
session loads → agent gets stuck → gateway restarts → repeat.
|
||||
"""
|
||||
import json
|
||||
|
||||
path = _hermes_home / self._STUCK_LOOP_FILE
|
||||
if not path.exists():
|
||||
return 0
|
||||
|
||||
try:
|
||||
counts = json.loads(path.read_text())
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
suspended = 0
|
||||
stuck_keys = [k for k, v in counts.items() if v >= self._STUCK_LOOP_THRESHOLD]
|
||||
|
||||
for session_key in stuck_keys:
|
||||
try:
|
||||
entry = self.session_store._entries.get(session_key)
|
||||
if entry and not entry.suspended:
|
||||
entry.suspended = True
|
||||
suspended += 1
|
||||
logger.warning(
|
||||
"Auto-suspended stuck session %s (active across %d "
|
||||
"consecutive restarts — likely a stuck loop)",
|
||||
session_key[:30], counts[session_key],
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if suspended:
|
||||
try:
|
||||
self.session_store._save()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Clear the file — counters start fresh after suspension
|
||||
try:
|
||||
path.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return suspended
|
||||
|
||||
def _clear_restart_failure_count(self, session_key: str) -> None:
|
||||
"""Clear the restart-failure counter for a session that completed OK.
|
||||
|
||||
Called after a successful agent turn to signal the loop is broken.
|
||||
"""
|
||||
import json
|
||||
|
||||
path = _hermes_home / self._STUCK_LOOP_FILE
|
||||
if not path.exists():
|
||||
return
|
||||
try:
|
||||
counts = json.loads(path.read_text())
|
||||
if session_key in counts:
|
||||
del counts[session_key]
|
||||
if counts:
|
||||
path.write_text(json.dumps(counts))
|
||||
else:
|
||||
path.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _launch_detached_restart_command(self) -> None:
|
||||
import shutil
|
||||
import subprocess
|
||||
@@ -1540,7 +1735,7 @@ class GatewayRunner:
|
||||
pass
|
||||
try:
|
||||
from gateway.status import write_runtime_status
|
||||
write_runtime_status(gateway_state="starting", exit_reason=None, startup_checks={})
|
||||
write_runtime_status(gateway_state="starting", exit_reason=None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -1582,23 +1777,8 @@ class GatewayRunner:
|
||||
"or configure platform allowlists (e.g., TELEGRAM_ALLOWED_USERS=your_id)."
|
||||
)
|
||||
|
||||
# Discover plugins before hooks so plugin-owned hook bundles can
|
||||
# participate in this same startup cycle.
|
||||
try:
|
||||
from hermes_cli.plugins import discover_plugins
|
||||
|
||||
discover_plugins()
|
||||
except Exception as e:
|
||||
logger.warning("Plugin discovery during gateway startup failed: %s", e)
|
||||
|
||||
# Discover and load event hooks
|
||||
self.hooks.discover_and_load()
|
||||
try:
|
||||
from gateway.status import reset_startup_checks
|
||||
|
||||
reset_startup_checks(self.hooks.loaded_hooks)
|
||||
except Exception as e:
|
||||
logger.warning("Startup readiness initialization failed: %s", e)
|
||||
|
||||
# Recover background processes from checkpoint (crash recovery)
|
||||
try:
|
||||
@@ -1633,6 +1813,17 @@ class GatewayRunner:
|
||||
except Exception as e:
|
||||
logger.warning("Session suspension on startup failed: %s", e)
|
||||
|
||||
# Stuck-loop detection (#7536): if a session has been active across
|
||||
# 3+ consecutive restarts, it's probably stuck in a loop (the same
|
||||
# history keeps causing the agent to hang). Auto-suspend it so the
|
||||
# user gets a clean slate on the next message.
|
||||
try:
|
||||
stuck = self._suspend_stuck_loop_sessions()
|
||||
if stuck:
|
||||
logger.warning("Auto-suspended %d stuck-loop session(s)", stuck)
|
||||
except Exception as e:
|
||||
logger.debug("Stuck-loop detection failed: %s", e)
|
||||
|
||||
connected_count = 0
|
||||
enabled_platform_count = 0
|
||||
startup_nonretryable_errors: list[str] = []
|
||||
@@ -2119,11 +2310,6 @@ class GatewayRunner:
|
||||
logger.error("Failed to launch detached gateway restart: %s", e)
|
||||
|
||||
self._finalize_shutdown_agents(active_agents)
|
||||
await self.hooks.emit("gateway:shutdown", {
|
||||
"restart": self._restart_requested,
|
||||
"service_restart": self._restart_via_service,
|
||||
"detached_restart": self._restart_detached,
|
||||
})
|
||||
|
||||
for platform, adapter in list(self.adapters.items()):
|
||||
try:
|
||||
@@ -2146,6 +2332,8 @@ class GatewayRunner:
|
||||
self._running_agents.clear()
|
||||
self._pending_messages.clear()
|
||||
self._pending_approvals.clear()
|
||||
if hasattr(self, '_busy_ack_ts'):
|
||||
self._busy_ack_ts.clear()
|
||||
self._shutdown_event.set()
|
||||
|
||||
# Global cleanup: kill any remaining tool subprocesses not tied
|
||||
@@ -2189,6 +2377,14 @@ class GatewayRunner:
|
||||
"active sessions."
|
||||
)
|
||||
|
||||
# Track sessions that were active at shutdown for stuck-loop
|
||||
# detection (#7536). On each restart, the counter increments
|
||||
# for sessions that were running. If a session hits the
|
||||
# threshold (3 consecutive restarts while active), the next
|
||||
# startup auto-suspends it — breaking the loop.
|
||||
if active_agents:
|
||||
self._increment_restart_failure_counts(set(active_agents.keys()))
|
||||
|
||||
if self._restart_requested and self._restart_via_service:
|
||||
self._exit_code = GATEWAY_SERVICE_RESTART_EXIT_CODE
|
||||
self._exit_reason = self._exit_reason or "Gateway restart requested"
|
||||
@@ -2622,6 +2818,7 @@ class GatewayRunner:
|
||||
)
|
||||
del self._running_agents[_quick_key]
|
||||
self._running_agents_ts.pop(_quick_key, None)
|
||||
self._busy_ack_ts.pop(_quick_key, None)
|
||||
|
||||
if _quick_key in self._running_agents:
|
||||
if event.get_command() == "status":
|
||||
@@ -3687,6 +3884,12 @@ class GatewayRunner:
|
||||
_response_time, _api_calls, _resp_len,
|
||||
)
|
||||
|
||||
# Successful turn — clear any stuck-loop counter for this session.
|
||||
# This ensures the counter only accumulates across CONSECUTIVE
|
||||
# restarts where the session was active (never completed).
|
||||
if session_key:
|
||||
self._clear_restart_failure_count(session_key)
|
||||
|
||||
# Surface error details when the agent failed silently (final_response=None)
|
||||
if not response and agent_result.get("failed"):
|
||||
error_detail = agent_result.get("error", "unknown error")
|
||||
@@ -3775,7 +3978,7 @@ class GatewayRunner:
|
||||
synth_text = _format_gateway_process_notification(evt)
|
||||
if synth_text:
|
||||
try:
|
||||
await self._inject_watch_notification(synth_text, event)
|
||||
await self._inject_watch_notification(synth_text, evt)
|
||||
except Exception as e2:
|
||||
logger.error("Watch notification injection error: %s", e2)
|
||||
except Exception as e:
|
||||
@@ -3793,14 +3996,11 @@ class GatewayRunner:
|
||||
# intermediate reasoning) so sessions can be resumed with full context
|
||||
# and transcripts are useful for debugging and training data.
|
||||
#
|
||||
# IMPORTANT: When the agent failed before producing any response
|
||||
# (e.g. context-overflow 400), do NOT persist the user's message.
|
||||
# IMPORTANT: When the agent failed (e.g. context-overflow 400,
|
||||
# compression exhausted), do NOT persist the user's message.
|
||||
# Persisting it would make the session even larger, causing the
|
||||
# same failure on the next attempt — an infinite loop. (#1630)
|
||||
agent_failed_early = (
|
||||
agent_result.get("failed")
|
||||
and not agent_result.get("final_response")
|
||||
)
|
||||
# same failure on the next attempt — an infinite loop. (#1630, #9893)
|
||||
agent_failed_early = bool(agent_result.get("failed"))
|
||||
if agent_failed_early:
|
||||
logger.info(
|
||||
"Skipping transcript persistence for failed request in "
|
||||
@@ -3808,6 +4008,24 @@ class GatewayRunner:
|
||||
session_entry.session_id,
|
||||
)
|
||||
|
||||
# When compression is exhausted, the session is permanently too
|
||||
# large to process. Auto-reset it so the next message starts
|
||||
# fresh instead of replaying the same oversized context in an
|
||||
# infinite fail loop. (#9893)
|
||||
if agent_result.get("compression_exhausted") and session_entry and session_key:
|
||||
logger.info(
|
||||
"Auto-resetting session %s after compression exhaustion.",
|
||||
session_entry.session_id,
|
||||
)
|
||||
self.session_store.reset_session(session_key)
|
||||
self._evict_cached_agent(session_key)
|
||||
self._session_model_overrides.pop(session_key, None)
|
||||
response = (response or "") + (
|
||||
"\n\n🔄 Session auto-reset — the conversation exceeded the "
|
||||
"maximum context size and could not be compressed further. "
|
||||
"Your next message will start a fresh session."
|
||||
)
|
||||
|
||||
ts = datetime.now().isoformat()
|
||||
|
||||
# If this is a fresh session (no history), write the full tool
|
||||
@@ -3915,6 +4133,8 @@ class GatewayRunner:
|
||||
_hist_len = len(history) if 'history' in locals() else 0
|
||||
if status_code == 401:
|
||||
status_hint = " Check your API key or run `claude /login` to refresh OAuth credentials."
|
||||
elif status_code == 402:
|
||||
status_hint = " Your API balance or quota is exhausted. Check your provider dashboard."
|
||||
elif status_code == 429:
|
||||
# Check if this is a plan usage limit (resets on a schedule) vs a transient rate limit
|
||||
_err_body = getattr(e, "response", None)
|
||||
@@ -7252,14 +7472,75 @@ class GatewayRunner:
|
||||
return prefix
|
||||
return user_text
|
||||
|
||||
async def _inject_watch_notification(self, synth_text: str, original_event) -> None:
|
||||
def _build_process_event_source(self, evt: dict):
|
||||
"""Resolve the canonical source for a synthetic background-process event.
|
||||
|
||||
Prefer the persisted session-store origin for the event's session key.
|
||||
Falling back to the currently active foreground event is what causes
|
||||
cross-topic bleed, so don't do that.
|
||||
"""
|
||||
from gateway.session import SessionSource
|
||||
|
||||
session_key = str(evt.get("session_key") or "").strip()
|
||||
derived_platform = ""
|
||||
derived_chat_type = ""
|
||||
derived_chat_id = ""
|
||||
|
||||
if session_key:
|
||||
try:
|
||||
self.session_store._ensure_loaded()
|
||||
entry = self.session_store._entries.get(session_key)
|
||||
if entry and getattr(entry, "origin", None):
|
||||
return entry.origin
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"Synthetic process-event session-store lookup failed for %s: %s",
|
||||
session_key,
|
||||
exc,
|
||||
)
|
||||
|
||||
_parsed = _parse_session_key(session_key)
|
||||
if _parsed:
|
||||
derived_platform = _parsed["platform"]
|
||||
derived_chat_type = _parsed["chat_type"]
|
||||
derived_chat_id = _parsed["chat_id"]
|
||||
|
||||
platform_name = str(evt.get("platform") or derived_platform or "").strip().lower()
|
||||
chat_type = str(evt.get("chat_type") or derived_chat_type or "").strip().lower()
|
||||
chat_id = str(evt.get("chat_id") or derived_chat_id or "").strip()
|
||||
if not platform_name or not chat_type or not chat_id:
|
||||
return None
|
||||
|
||||
try:
|
||||
platform = Platform(platform_name)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Synthetic process event has invalid platform metadata: %r",
|
||||
platform_name,
|
||||
)
|
||||
return None
|
||||
|
||||
return SessionSource(
|
||||
platform=platform,
|
||||
chat_id=chat_id,
|
||||
chat_type=chat_type,
|
||||
thread_id=str(evt.get("thread_id") or "").strip() or None,
|
||||
user_id=str(evt.get("user_id") or "").strip() or None,
|
||||
user_name=str(evt.get("user_name") or "").strip() or None,
|
||||
)
|
||||
|
||||
async def _inject_watch_notification(self, synth_text: str, evt: dict) -> None:
|
||||
"""Inject a watch-pattern notification as a synthetic message event.
|
||||
|
||||
Uses the source from the original user event to route the notification
|
||||
back to the correct chat/adapter.
|
||||
Routing must come from the queued watch event itself, not from whatever
|
||||
foreground message happened to be active when the queue was drained.
|
||||
"""
|
||||
source = getattr(original_event, "source", None)
|
||||
source = self._build_process_event_source(evt)
|
||||
if not source:
|
||||
logger.warning(
|
||||
"Dropping watch notification with no routing metadata for process %s",
|
||||
evt.get("session_id", "unknown"),
|
||||
)
|
||||
return
|
||||
platform_name = source.platform.value if hasattr(source.platform, "value") else str(source.platform)
|
||||
adapter = None
|
||||
@@ -7277,7 +7558,12 @@ class GatewayRunner:
|
||||
source=source,
|
||||
internal=True,
|
||||
)
|
||||
logger.info("Watch pattern notification — injecting for %s", platform_name)
|
||||
logger.info(
|
||||
"Watch pattern notification — injecting for %s chat=%s thread=%s",
|
||||
platform_name,
|
||||
source.chat_id,
|
||||
source.thread_id,
|
||||
)
|
||||
await adapter.handle_message(synth_event)
|
||||
except Exception as e:
|
||||
logger.error("Watch notification injection error: %s", e)
|
||||
@@ -7347,33 +7633,42 @@ class GatewayRunner:
|
||||
f"Command: {session.command}\n"
|
||||
f"Output:\n{_out}]"
|
||||
)
|
||||
source = self._build_process_event_source({
|
||||
"session_id": session_id,
|
||||
"session_key": session_key,
|
||||
"platform": platform_name,
|
||||
"chat_id": chat_id,
|
||||
"thread_id": thread_id,
|
||||
"user_id": user_id,
|
||||
"user_name": user_name,
|
||||
})
|
||||
if not source:
|
||||
logger.warning(
|
||||
"Dropping completion notification with no routing metadata for process %s",
|
||||
session_id,
|
||||
)
|
||||
break
|
||||
|
||||
adapter = None
|
||||
for p, a in self.adapters.items():
|
||||
if p.value == platform_name:
|
||||
if p == source.platform:
|
||||
adapter = a
|
||||
break
|
||||
if adapter and chat_id:
|
||||
if adapter and source.chat_id:
|
||||
try:
|
||||
from gateway.platforms.base import MessageEvent, MessageType
|
||||
from gateway.session import SessionSource
|
||||
from gateway.config import Platform
|
||||
_platform_enum = Platform(platform_name)
|
||||
_source = SessionSource(
|
||||
platform=_platform_enum,
|
||||
chat_id=chat_id,
|
||||
thread_id=thread_id or None,
|
||||
user_id=user_id or None,
|
||||
user_name=user_name or None,
|
||||
)
|
||||
synth_event = MessageEvent(
|
||||
text=synth_text,
|
||||
message_type=MessageType.TEXT,
|
||||
source=_source,
|
||||
source=source,
|
||||
internal=True,
|
||||
)
|
||||
logger.info(
|
||||
"Process %s finished — injecting agent notification for session %s",
|
||||
session_id, session_key,
|
||||
"Process %s finished — injecting agent notification for session %s chat=%s thread=%s",
|
||||
session_id,
|
||||
session_key,
|
||||
source.chat_id,
|
||||
source.thread_id,
|
||||
)
|
||||
await adapter.handle_message(synth_event)
|
||||
except Exception as e:
|
||||
@@ -8258,6 +8553,12 @@ class GatewayRunner:
|
||||
cached = _cache.get(session_key)
|
||||
if cached and cached[1] == _sig:
|
||||
agent = cached[0]
|
||||
# Reset activity timestamp so the inactivity timeout
|
||||
# handler doesn't see stale idle time from the previous
|
||||
# turn and immediately kill this agent. (#9051)
|
||||
agent._last_activity_ts = time.time()
|
||||
agent._last_activity_desc = "starting new turn (cached)"
|
||||
agent._api_call_count = 0
|
||||
logger.debug("Reusing cached agent for session %s", session_key)
|
||||
|
||||
if agent is None:
|
||||
@@ -8470,6 +8771,21 @@ class GatewayRunner:
|
||||
if _msn:
|
||||
message = _msn + "\n\n" + message
|
||||
|
||||
# Auto-continue: if the loaded history ends with a tool result,
|
||||
# the previous agent turn was interrupted mid-work (gateway
|
||||
# restart, crash, SIGTERM). Prepend a system note so the model
|
||||
# finishes processing the pending tool results before addressing
|
||||
# the user's new message. (#4493)
|
||||
if agent_history and agent_history[-1].get("role") == "tool":
|
||||
message = (
|
||||
"[System note: Your previous turn was interrupted before you could "
|
||||
"process the last tool result(s). The conversation history contains "
|
||||
"tool outputs you haven't responded to yet. Please finish processing "
|
||||
"those results and summarize what was accomplished, then address the "
|
||||
"user's new message below.]\n\n"
|
||||
+ message
|
||||
)
|
||||
|
||||
_approval_session_key = session_key or ""
|
||||
_approval_session_token = set_current_session_key(_approval_session_key)
|
||||
register_gateway_notify(_approval_session_key, _approval_notify_sync)
|
||||
@@ -8504,6 +8820,8 @@ class GatewayRunner:
|
||||
"final_response": error_msg,
|
||||
"messages": result.get("messages", []),
|
||||
"api_calls": result.get("api_calls", 0),
|
||||
"failed": result.get("failed", False),
|
||||
"compression_exhausted": result.get("compression_exhausted", False),
|
||||
"tools": tools_holder[0] or [],
|
||||
"history_offset": len(agent_history),
|
||||
"last_prompt_tokens": _last_prompt_toks,
|
||||
@@ -9008,15 +9326,11 @@ class GatewayRunner:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug("Stream consumer wait before queued message failed: %s", e)
|
||||
_response_previewed = bool(result.get("response_previewed"))
|
||||
_already_streamed = bool(
|
||||
_sc
|
||||
and (
|
||||
getattr(_sc, "final_response_sent", False)
|
||||
or (
|
||||
_response_previewed
|
||||
and getattr(_sc, "already_sent", False)
|
||||
)
|
||||
or getattr(_sc, "already_sent", False)
|
||||
)
|
||||
)
|
||||
first_response = result.get("final_response", "")
|
||||
@@ -9100,13 +9414,9 @@ class GatewayRunner:
|
||||
# them even if streaming had sent earlier partial output.
|
||||
_sc = stream_consumer_holder[0]
|
||||
if _sc and isinstance(response, dict) and not response.get("failed"):
|
||||
_response_previewed = bool(response.get("response_previewed"))
|
||||
if (
|
||||
getattr(_sc, "final_response_sent", False)
|
||||
or (
|
||||
_response_previewed
|
||||
and getattr(_sc, "already_sent", False)
|
||||
)
|
||||
or getattr(_sc, "already_sent", False)
|
||||
):
|
||||
response["already_sent"] = True
|
||||
|
||||
|
||||
+1
-135
@@ -27,7 +27,6 @@ _RUNTIME_STATUS_FILE = "gateway_state.json"
|
||||
_LOCKS_DIRNAME = "gateway-locks"
|
||||
_IS_WINDOWS = sys.platform == "win32"
|
||||
_UNSET = object()
|
||||
_VALID_STARTUP_CHECK_STATES = {"pending", "ready", "failed"}
|
||||
|
||||
|
||||
def _get_pid_path() -> Path:
|
||||
@@ -163,39 +162,11 @@ def _build_runtime_status_record() -> dict[str, Any]:
|
||||
"restart_requested": False,
|
||||
"active_agents": 0,
|
||||
"platforms": {},
|
||||
"startup_checks": {},
|
||||
"updated_at": _utc_now_iso(),
|
||||
})
|
||||
return payload
|
||||
|
||||
|
||||
def _normalize_startup_check_entries(
|
||||
startup_checks: Optional[dict[str, Any]],
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Normalize persisted startup readiness entries."""
|
||||
if not isinstance(startup_checks, dict):
|
||||
return {}
|
||||
|
||||
now = _utc_now_iso()
|
||||
normalized: dict[str, dict[str, Any]] = {}
|
||||
for raw_id, raw_payload in startup_checks.items():
|
||||
check_id = str(raw_id).strip()
|
||||
if not check_id:
|
||||
continue
|
||||
payload = raw_payload if isinstance(raw_payload, dict) else {}
|
||||
state = str(payload.get("state", "pending")).strip().lower()
|
||||
if state not in _VALID_STARTUP_CHECK_STATES:
|
||||
state = "pending"
|
||||
normalized[check_id] = {
|
||||
"state": state,
|
||||
"required": bool(payload.get("required", True)),
|
||||
"source": payload.get("source"),
|
||||
"detail": payload.get("detail"),
|
||||
"updated_at": payload.get("updated_at") or now,
|
||||
}
|
||||
return normalized
|
||||
|
||||
|
||||
def _read_json_file(path: Path) -> Optional[dict[str, Any]]:
|
||||
if not path.exists():
|
||||
return None
|
||||
@@ -252,7 +223,6 @@ def write_runtime_status(
|
||||
exit_reason: Any = _UNSET,
|
||||
restart_requested: Any = _UNSET,
|
||||
active_agents: Any = _UNSET,
|
||||
startup_checks: Any = _UNSET,
|
||||
platform: Any = _UNSET,
|
||||
platform_state: Any = _UNSET,
|
||||
error_code: Any = _UNSET,
|
||||
@@ -275,8 +245,6 @@ def write_runtime_status(
|
||||
payload["restart_requested"] = bool(restart_requested)
|
||||
if active_agents is not _UNSET:
|
||||
payload["active_agents"] = max(0, int(active_agents))
|
||||
if startup_checks is not _UNSET:
|
||||
payload["startup_checks"] = _normalize_startup_check_entries(startup_checks)
|
||||
|
||||
if platform is not _UNSET:
|
||||
platform_payload = payload["platforms"].get(platform, {})
|
||||
@@ -294,109 +262,7 @@ def write_runtime_status(
|
||||
|
||||
def read_runtime_status() -> Optional[dict[str, Any]]:
|
||||
"""Read the persisted gateway runtime health/status information."""
|
||||
payload = _read_json_file(_get_runtime_status_path())
|
||||
if payload is None:
|
||||
return None
|
||||
payload.setdefault("platforms", {})
|
||||
payload["startup_checks"] = _normalize_startup_check_entries(payload.get("startup_checks"))
|
||||
return payload
|
||||
|
||||
|
||||
def reset_startup_checks(checks: Optional[list[dict[str, Any]]] = None) -> dict[str, dict[str, Any]]:
|
||||
"""Replace persisted startup readiness checks for the current run."""
|
||||
normalized: dict[str, dict[str, Any]] = {}
|
||||
now = _utc_now_iso()
|
||||
|
||||
for hook in checks or []:
|
||||
if not isinstance(hook, dict):
|
||||
continue
|
||||
readiness = hook.get("startup_readiness")
|
||||
if not isinstance(readiness, dict):
|
||||
continue
|
||||
check_id = str(readiness.get("id", "")).strip()
|
||||
if not check_id:
|
||||
continue
|
||||
normalized[check_id] = {
|
||||
"state": "pending",
|
||||
"required": bool(readiness.get("required", True)),
|
||||
"source": hook.get("name"),
|
||||
"detail": None,
|
||||
"updated_at": now,
|
||||
}
|
||||
|
||||
write_runtime_status(startup_checks=normalized)
|
||||
return normalized
|
||||
|
||||
|
||||
def update_startup_check(
|
||||
check_id: str,
|
||||
state: str,
|
||||
*,
|
||||
detail: Any = _UNSET,
|
||||
required: Any = _UNSET,
|
||||
source: Any = _UNSET,
|
||||
) -> dict[str, Any]:
|
||||
"""Update a single startup readiness check in the runtime status file."""
|
||||
normalized_id = str(check_id).strip()
|
||||
if not normalized_id:
|
||||
raise ValueError("startup readiness check id is required")
|
||||
|
||||
normalized_state = str(state).strip().lower()
|
||||
if normalized_state not in _VALID_STARTUP_CHECK_STATES:
|
||||
raise ValueError(f"invalid startup readiness state: {state}")
|
||||
|
||||
path = _get_runtime_status_path()
|
||||
payload = _read_json_file(path) or _build_runtime_status_record()
|
||||
checks = _normalize_startup_check_entries(payload.get("startup_checks"))
|
||||
existing = checks.get(normalized_id, {})
|
||||
now = _utc_now_iso()
|
||||
|
||||
checks[normalized_id] = {
|
||||
"state": normalized_state,
|
||||
"required": bool(existing.get("required", True) if required is _UNSET else required),
|
||||
"source": existing.get("source") if source is _UNSET else source,
|
||||
"detail": existing.get("detail") if detail is _UNSET else detail,
|
||||
"updated_at": now,
|
||||
}
|
||||
|
||||
payload["startup_checks"] = checks
|
||||
payload.setdefault("platforms", {})
|
||||
payload.setdefault("kind", _GATEWAY_KIND)
|
||||
payload["pid"] = os.getpid()
|
||||
payload["start_time"] = _get_process_start_time(os.getpid())
|
||||
payload["updated_at"] = now
|
||||
_write_json_file(path, payload)
|
||||
return checks[normalized_id]
|
||||
|
||||
|
||||
def mark_startup_check_pending(
|
||||
check_id: str,
|
||||
*,
|
||||
detail: Any = _UNSET,
|
||||
required: Any = _UNSET,
|
||||
source: Any = _UNSET,
|
||||
) -> dict[str, Any]:
|
||||
return update_startup_check(check_id, "pending", detail=detail, required=required, source=source)
|
||||
|
||||
|
||||
def mark_startup_check_ready(
|
||||
check_id: str,
|
||||
*,
|
||||
detail: Any = _UNSET,
|
||||
required: Any = _UNSET,
|
||||
source: Any = _UNSET,
|
||||
) -> dict[str, Any]:
|
||||
return update_startup_check(check_id, "ready", detail=detail, required=required, source=source)
|
||||
|
||||
|
||||
def mark_startup_check_failed(
|
||||
check_id: str,
|
||||
*,
|
||||
detail: Any = _UNSET,
|
||||
required: Any = _UNSET,
|
||||
source: Any = _UNSET,
|
||||
) -> dict[str, Any]:
|
||||
return update_startup_check(check_id, "failed", detail=detail, required=required, source=source)
|
||||
return _read_json_file(_get_runtime_status_path())
|
||||
|
||||
|
||||
def remove_pid_file() -> None:
|
||||
|
||||
@@ -844,8 +844,7 @@ class SlashCommandCompleter(Completer):
|
||||
return None
|
||||
return word
|
||||
|
||||
@staticmethod
|
||||
def _context_completions(word: str, limit: int = 30):
|
||||
def _context_completions(self, word: str, limit: int = 30):
|
||||
"""Yield Claude Code-style @ context completions.
|
||||
|
||||
Bare ``@`` or ``@partial`` shows static references and matching
|
||||
|
||||
@@ -2766,6 +2766,47 @@ def sanitize_env_file() -> int:
|
||||
return fixes
|
||||
|
||||
|
||||
def _check_non_ascii_credential(key: str, value: str) -> str:
|
||||
"""Warn and strip non-ASCII characters from credential values.
|
||||
|
||||
API keys and tokens must be pure ASCII — they are sent as HTTP header
|
||||
values which httpx/httpcore encode as ASCII. Non-ASCII characters
|
||||
(commonly introduced by copy-pasting from rich-text editors or PDFs
|
||||
that substitute lookalike Unicode glyphs for ASCII letters) cause
|
||||
``UnicodeEncodeError: 'ascii' codec can't encode character`` at
|
||||
request time.
|
||||
|
||||
Returns the sanitized (ASCII-only) value. Prints a warning if any
|
||||
non-ASCII characters were found and removed.
|
||||
"""
|
||||
try:
|
||||
value.encode("ascii")
|
||||
return value # all ASCII — nothing to do
|
||||
except UnicodeEncodeError:
|
||||
pass
|
||||
|
||||
# Build a readable list of the offending characters
|
||||
bad_chars: list[str] = []
|
||||
for i, ch in enumerate(value):
|
||||
if ord(ch) > 127:
|
||||
bad_chars.append(f" position {i}: {ch!r} (U+{ord(ch):04X})")
|
||||
sanitized = value.encode("ascii", errors="ignore").decode("ascii")
|
||||
|
||||
import sys
|
||||
print(
|
||||
f"\n Warning: {key} contains non-ASCII characters that will break API requests.\n"
|
||||
f" This usually happens when copy-pasting from a PDF, rich-text editor,\n"
|
||||
f" or web page that substitutes lookalike Unicode glyphs for ASCII letters.\n"
|
||||
f"\n"
|
||||
+ "\n".join(f" {line}" for line in bad_chars[:5])
|
||||
+ ("\n ... and more" if len(bad_chars) > 5 else "")
|
||||
+ f"\n\n The non-ASCII characters have been stripped automatically.\n"
|
||||
f" If authentication fails, re-copy the key from the provider's dashboard.\n",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return sanitized
|
||||
|
||||
|
||||
def save_env_value(key: str, value: str):
|
||||
"""Save or update a value in ~/.hermes/.env."""
|
||||
if is_managed():
|
||||
@@ -2774,6 +2815,8 @@ def save_env_value(key: str, value: str):
|
||||
if not _ENV_VAR_NAME_RE.match(key):
|
||||
raise ValueError(f"Invalid environment variable name: {key!r}")
|
||||
value = value.replace("\n", "").replace("\r", "")
|
||||
# API keys / tokens must be ASCII — strip non-ASCII with a warning.
|
||||
value = _check_non_ascii_credential(key, value)
|
||||
ensure_hermes_home()
|
||||
env_path = get_env_path()
|
||||
|
||||
|
||||
+82
-1
@@ -8,6 +8,7 @@ import os
|
||||
import sys
|
||||
import subprocess
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from hermes_cli.config import get_project_root, get_hermes_home, get_env_path
|
||||
from hermes_constants import display_hermes_home
|
||||
@@ -513,7 +514,87 @@ def run_doctor(args):
|
||||
pass
|
||||
|
||||
_check_gateway_service_linger(issues)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Check: Command installation (hermes bin symlink)
|
||||
# =========================================================================
|
||||
if sys.platform != "win32":
|
||||
print()
|
||||
print(color("◆ Command Installation", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
# Determine the venv entry point location
|
||||
_venv_bin = None
|
||||
for _venv_name in ("venv", ".venv"):
|
||||
_candidate = PROJECT_ROOT / _venv_name / "bin" / "hermes"
|
||||
if _candidate.exists():
|
||||
_venv_bin = _candidate
|
||||
break
|
||||
|
||||
# Determine the expected command link directory (mirrors install.sh logic)
|
||||
_prefix = os.environ.get("PREFIX", "")
|
||||
_is_termux_env = bool(os.environ.get("TERMUX_VERSION")) or "com.termux/files/usr" in _prefix
|
||||
if _is_termux_env and _prefix:
|
||||
_cmd_link_dir = Path(_prefix) / "bin"
|
||||
_cmd_link_display = "$PREFIX/bin"
|
||||
else:
|
||||
_cmd_link_dir = Path.home() / ".local" / "bin"
|
||||
_cmd_link_display = "~/.local/bin"
|
||||
_cmd_link = _cmd_link_dir / "hermes"
|
||||
|
||||
if _venv_bin is None:
|
||||
check_warn(
|
||||
"Venv entry point not found",
|
||||
"(hermes not in venv/bin/ or .venv/bin/ — reinstall with pip install -e '.[all]')"
|
||||
)
|
||||
manual_issues.append(
|
||||
f"Reinstall entry point: cd {PROJECT_ROOT} && source venv/bin/activate && pip install -e '.[all]'"
|
||||
)
|
||||
else:
|
||||
check_ok(f"Venv entry point exists ({_venv_bin.relative_to(PROJECT_ROOT)})")
|
||||
|
||||
# Check the symlink at the command link location
|
||||
if _cmd_link.is_symlink():
|
||||
_target = _cmd_link.resolve()
|
||||
_expected = _venv_bin.resolve()
|
||||
if _target == _expected:
|
||||
check_ok(f"{_cmd_link_display}/hermes → correct target")
|
||||
else:
|
||||
check_warn(
|
||||
f"{_cmd_link_display}/hermes points to wrong target",
|
||||
f"(→ {_target}, expected → {_expected})"
|
||||
)
|
||||
if should_fix:
|
||||
_cmd_link.unlink()
|
||||
_cmd_link.symlink_to(_venv_bin)
|
||||
check_ok(f"Fixed symlink: {_cmd_link_display}/hermes → {_venv_bin}")
|
||||
fixed_count += 1
|
||||
else:
|
||||
issues.append(f"Broken symlink at {_cmd_link_display}/hermes — run 'hermes doctor --fix'")
|
||||
elif _cmd_link.exists():
|
||||
# It's a regular file, not a symlink — possibly a wrapper script
|
||||
check_ok(f"{_cmd_link_display}/hermes exists (non-symlink)")
|
||||
else:
|
||||
check_fail(
|
||||
f"{_cmd_link_display}/hermes not found",
|
||||
"(hermes command may not work outside the venv)"
|
||||
)
|
||||
if should_fix:
|
||||
_cmd_link_dir.mkdir(parents=True, exist_ok=True)
|
||||
_cmd_link.symlink_to(_venv_bin)
|
||||
check_ok(f"Created symlink: {_cmd_link_display}/hermes → {_venv_bin}")
|
||||
fixed_count += 1
|
||||
|
||||
# Check if the link dir is on PATH
|
||||
_path_dirs = os.environ.get("PATH", "").split(os.pathsep)
|
||||
if str(_cmd_link_dir) not in _path_dirs:
|
||||
check_warn(
|
||||
f"{_cmd_link_display} is not on your PATH",
|
||||
"(add it to your shell config: export PATH=\"$HOME/.local/bin:$PATH\")"
|
||||
)
|
||||
manual_issues.append(f"Add {_cmd_link_display} to your PATH")
|
||||
else:
|
||||
issues.append(f"Missing {_cmd_link_display}/hermes symlink — run 'hermes doctor --fix'")
|
||||
|
||||
# =========================================================================
|
||||
# Check: External tools
|
||||
# =========================================================================
|
||||
|
||||
@@ -8,11 +8,40 @@ from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
# Env var name suffixes that indicate credential values. These are the
|
||||
# only env vars whose values we sanitize on load — we must not silently
|
||||
# alter arbitrary user env vars, but credentials are known to require
|
||||
# pure ASCII (they become HTTP header values).
|
||||
_CREDENTIAL_SUFFIXES = ("_API_KEY", "_TOKEN", "_SECRET", "_KEY")
|
||||
|
||||
|
||||
def _sanitize_loaded_credentials() -> None:
|
||||
"""Strip non-ASCII characters from credential env vars in os.environ.
|
||||
|
||||
Called after dotenv loads so the rest of the codebase never sees
|
||||
non-ASCII API keys. Only touches env vars whose names end with
|
||||
known credential suffixes (``_API_KEY``, ``_TOKEN``, etc.).
|
||||
"""
|
||||
for key, value in list(os.environ.items()):
|
||||
if not any(key.endswith(suffix) for suffix in _CREDENTIAL_SUFFIXES):
|
||||
continue
|
||||
try:
|
||||
value.encode("ascii")
|
||||
except UnicodeEncodeError:
|
||||
os.environ[key] = value.encode("ascii", errors="ignore").decode("ascii")
|
||||
|
||||
|
||||
def _load_dotenv_with_fallback(path: Path, *, override: bool) -> None:
|
||||
try:
|
||||
load_dotenv(dotenv_path=path, override=override, encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
load_dotenv(dotenv_path=path, override=override, encoding="latin-1")
|
||||
# Strip non-ASCII characters from credential env vars that were just
|
||||
# loaded. API keys must be pure ASCII since they're sent as HTTP
|
||||
# header values (httpx encodes headers as ASCII). Non-ASCII chars
|
||||
# typically come from copy-pasting keys from PDFs or rich-text editors
|
||||
# that substitute Unicode lookalike glyphs (e.g. ʋ U+028B for v).
|
||||
_sanitize_loaded_credentials()
|
||||
|
||||
|
||||
def _sanitize_env_file_if_needed(path: Path) -> None:
|
||||
|
||||
+110
-126
@@ -10,7 +10,6 @@ import shutil
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
||||
@@ -38,10 +37,6 @@ from hermes_cli.setup import (
|
||||
from hermes_cli.colors import Colors, color
|
||||
|
||||
|
||||
_SERVICE_READINESS_TIMEOUT = 30.0
|
||||
_SERVICE_READINESS_POLL_INTERVAL = 0.2
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Process Management (for manual gateway runs)
|
||||
# =============================================================================
|
||||
@@ -720,7 +715,9 @@ def _detect_venv_dir() -> Path | None:
|
||||
"""Detect the active virtualenv directory.
|
||||
|
||||
Checks ``sys.prefix`` first (works regardless of the directory name),
|
||||
then falls back to probing common directory names under PROJECT_ROOT.
|
||||
then ``VIRTUAL_ENV`` env var (covers uv-managed environments where
|
||||
sys.prefix == sys.base_prefix), then falls back to probing common
|
||||
directory names under PROJECT_ROOT.
|
||||
Returns ``None`` when no virtualenv can be found.
|
||||
"""
|
||||
# If we're running inside a virtualenv, sys.prefix points to it.
|
||||
@@ -729,6 +726,15 @@ def _detect_venv_dir() -> Path | None:
|
||||
if venv.is_dir():
|
||||
return venv
|
||||
|
||||
# uv and some other tools set VIRTUAL_ENV without changing sys.prefix.
|
||||
# This catches `uv run` where sys.prefix == sys.base_prefix but the
|
||||
# environment IS a venv. (#8620)
|
||||
_virtual_env = os.environ.get("VIRTUAL_ENV")
|
||||
if _virtual_env:
|
||||
venv = Path(_virtual_env)
|
||||
if venv.is_dir():
|
||||
return venv
|
||||
|
||||
# Fallback: check common virtualenv directory names under the project root.
|
||||
for candidate in (".venv", "venv"):
|
||||
venv = PROJECT_ROOT / candidate
|
||||
@@ -1105,123 +1111,12 @@ def systemd_uninstall(system: bool = False):
|
||||
print(f"✓ {_service_scope_label(system).capitalize()} service uninstalled")
|
||||
|
||||
|
||||
def _describe_startup_check(check_id: str, check: dict) -> str:
|
||||
source = check.get("source")
|
||||
detail = check.get("detail")
|
||||
label = f"{check_id} ({source})" if source and source != check_id else check_id
|
||||
return f"{label}: {detail}" if detail else label
|
||||
|
||||
|
||||
def _classify_startup_checks(state: dict | None) -> tuple[list[str], list[str], list[str]]:
|
||||
checks = (state or {}).get("startup_checks") or {}
|
||||
pending_required: list[str] = []
|
||||
failed_required: list[str] = []
|
||||
optional_warnings: list[str] = []
|
||||
|
||||
if not isinstance(checks, dict):
|
||||
return pending_required, failed_required, optional_warnings
|
||||
|
||||
for check_id, raw_check in checks.items():
|
||||
check = raw_check if isinstance(raw_check, dict) else {}
|
||||
label = _describe_startup_check(str(check_id), check)
|
||||
check_state = str(check.get("state", "pending")).strip().lower()
|
||||
required = bool(check.get("required", True))
|
||||
|
||||
if check_state == "ready":
|
||||
continue
|
||||
if required:
|
||||
if check_state == "failed":
|
||||
failed_required.append(label)
|
||||
else:
|
||||
pending_required.append(label)
|
||||
else:
|
||||
prefix = "failed" if check_state == "failed" else "pending"
|
||||
optional_warnings.append(f"{prefix}: {label}")
|
||||
|
||||
return pending_required, failed_required, optional_warnings
|
||||
|
||||
|
||||
def _wait_for_service_readiness(
|
||||
*,
|
||||
action: str,
|
||||
previous_pid: int | None = None,
|
||||
timeout: float = _SERVICE_READINESS_TIMEOUT,
|
||||
poll_interval: float = _SERVICE_READINESS_POLL_INTERVAL,
|
||||
) -> list[str]:
|
||||
from gateway.status import get_running_pid, read_runtime_status
|
||||
|
||||
deadline = time.monotonic() + timeout
|
||||
last_pending: list[str] = []
|
||||
|
||||
while time.monotonic() < deadline:
|
||||
live_pid = get_running_pid()
|
||||
if live_pid is None or (previous_pid is not None and live_pid == previous_pid):
|
||||
time.sleep(poll_interval)
|
||||
continue
|
||||
|
||||
runtime = read_runtime_status() or {}
|
||||
try:
|
||||
runtime_pid = int(runtime.get("pid"))
|
||||
except (TypeError, ValueError):
|
||||
runtime_pid = None
|
||||
if runtime_pid != live_pid:
|
||||
time.sleep(poll_interval)
|
||||
continue
|
||||
|
||||
gateway_state = runtime.get("gateway_state")
|
||||
pending_required, failed_required, optional_warnings = _classify_startup_checks(runtime)
|
||||
last_pending = pending_required
|
||||
|
||||
if gateway_state == "startup_failed":
|
||||
reason = runtime.get("exit_reason") or f"gateway {action} failed during startup"
|
||||
raise RuntimeError(reason)
|
||||
if failed_required:
|
||||
raise RuntimeError(
|
||||
"required startup checks failed: " + "; ".join(failed_required)
|
||||
)
|
||||
if gateway_state == "running" and not pending_required:
|
||||
return optional_warnings
|
||||
|
||||
time.sleep(poll_interval)
|
||||
|
||||
if last_pending:
|
||||
raise RuntimeError(
|
||||
"timed out waiting for required startup checks: " + "; ".join(last_pending)
|
||||
)
|
||||
if previous_pid is not None:
|
||||
raise RuntimeError(
|
||||
f"timed out waiting for gateway {action}; previous process is still active or no new runtime became ready"
|
||||
)
|
||||
raise RuntimeError(f"timed out waiting for gateway {action} readiness")
|
||||
|
||||
|
||||
def _await_service_ready_or_exit(
|
||||
*,
|
||||
action: str,
|
||||
previous_pid: int | None = None,
|
||||
timeout: float = _SERVICE_READINESS_TIMEOUT,
|
||||
) -> None:
|
||||
try:
|
||||
optional_warnings = _wait_for_service_readiness(
|
||||
action=action,
|
||||
previous_pid=previous_pid,
|
||||
timeout=timeout,
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
print_error(f" Gateway {action} did not become ready: {exc}")
|
||||
raise SystemExit(1) from exc
|
||||
|
||||
for warning in optional_warnings:
|
||||
print_warning(f" Optional startup check {warning}")
|
||||
|
||||
|
||||
def systemd_start(system: bool = False):
|
||||
system = _select_systemd_scope(system)
|
||||
if system:
|
||||
_require_root_for_system_service("start")
|
||||
refresh_systemd_unit_if_needed(system=system)
|
||||
_run_systemctl(["start", get_service_name()], system=system, check=True, timeout=30)
|
||||
_await_service_ready_or_exit(action="start")
|
||||
print(f"✓ {_service_scope_label(system).capitalize()} service started")
|
||||
|
||||
|
||||
@@ -1244,11 +1139,64 @@ def systemd_restart(system: bool = False):
|
||||
|
||||
pid = get_running_pid()
|
||||
if pid is not None and _request_gateway_self_restart(pid):
|
||||
_await_service_ready_or_exit(action="restart", previous_pid=pid)
|
||||
print(f"✓ {_service_scope_label(system).capitalize()} service restarted")
|
||||
# SIGUSR1 sent — the gateway will drain active agents, exit with
|
||||
# code 75, and systemd will restart it after RestartSec (30s).
|
||||
# Wait for the old process to die and the new one to become active
|
||||
# so the CLI doesn't return while the service is still restarting.
|
||||
import time
|
||||
scope_label = _service_scope_label(system).capitalize()
|
||||
svc = get_service_name()
|
||||
scope_cmd = _systemctl_cmd(system)
|
||||
|
||||
# Phase 1: wait for old process to exit (drain + shutdown)
|
||||
print(f"⏳ {scope_label} service draining active work...")
|
||||
deadline = time.time() + 90
|
||||
while time.time() < deadline:
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
time.sleep(1)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
break # old process is gone
|
||||
else:
|
||||
print(f"⚠ Old process (PID {pid}) still alive after 90s")
|
||||
|
||||
# Phase 2: wait for systemd to start the new process
|
||||
print(f"⏳ Waiting for {svc} to restart...")
|
||||
deadline = time.time() + 60
|
||||
while time.time() < deadline:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
scope_cmd + ["is-active", svc],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
)
|
||||
if result.stdout.strip() == "active":
|
||||
# Verify it's a NEW process, not the old one somehow
|
||||
new_pid = get_running_pid()
|
||||
if new_pid and new_pid != pid:
|
||||
print(f"✓ {scope_label} service restarted (PID {new_pid})")
|
||||
return
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
pass
|
||||
time.sleep(2)
|
||||
|
||||
# Timed out — check final state
|
||||
try:
|
||||
result = subprocess.run(
|
||||
scope_cmd + ["is-active", svc],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
)
|
||||
if result.stdout.strip() == "active":
|
||||
print(f"✓ {scope_label} service restarted")
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
print(
|
||||
f"⚠ {scope_label} service did not become active within 60s.\n"
|
||||
f" Check status: {'sudo ' if system else ''}hermes gateway status\n"
|
||||
f" Check logs: journalctl {'--user ' if not system else ''}-u {svc} --since '2 min ago'"
|
||||
)
|
||||
return
|
||||
_run_systemctl(["reload-or-restart", get_service_name()], system=system, check=True, timeout=90)
|
||||
_await_service_ready_or_exit(action="restart", previous_pid=pid)
|
||||
print(f"✓ {_service_scope_label(system).capitalize()} service restarted")
|
||||
|
||||
|
||||
@@ -1507,7 +1455,6 @@ def launchd_start():
|
||||
plist_path.write_text(generate_launchd_plist(), encoding="utf-8")
|
||||
subprocess.run(["launchctl", "bootstrap", _launchd_domain(), str(plist_path)], check=True, timeout=30)
|
||||
subprocess.run(["launchctl", "kickstart", f"{_launchd_domain()}/{label}"], check=True, timeout=30)
|
||||
_await_service_ready_or_exit(action="start")
|
||||
print("✓ Service started")
|
||||
return
|
||||
|
||||
@@ -1520,7 +1467,6 @@ def launchd_start():
|
||||
print("↻ launchd job was unloaded; reloading service definition")
|
||||
subprocess.run(["launchctl", "bootstrap", _launchd_domain(), str(plist_path)], check=True, timeout=30)
|
||||
subprocess.run(["launchctl", "kickstart", f"{_launchd_domain()}/{label}"], check=True, timeout=30)
|
||||
_await_service_ready_or_exit(action="start")
|
||||
print("✓ Service started")
|
||||
|
||||
def launchd_stop():
|
||||
@@ -1591,8 +1537,7 @@ def launchd_restart():
|
||||
try:
|
||||
pid = get_running_pid()
|
||||
if pid is not None and _request_gateway_self_restart(pid):
|
||||
_await_service_ready_or_exit(action="restart", previous_pid=pid)
|
||||
print("✓ Service restarted")
|
||||
print("✓ Service restart requested")
|
||||
return
|
||||
if pid is not None:
|
||||
try:
|
||||
@@ -1604,7 +1549,6 @@ def launchd_restart():
|
||||
if not exited:
|
||||
print(f"⚠ Gateway drain timed out after {drain_timeout:.0f}s — forcing launchd restart")
|
||||
subprocess.run(["launchctl", "kickstart", "-k", target], check=True, timeout=90)
|
||||
_await_service_ready_or_exit(action="restart", previous_pid=pid)
|
||||
print("✓ Service restarted")
|
||||
except subprocess.CalledProcessError as e:
|
||||
if e.returncode not in (3, 113):
|
||||
@@ -1614,7 +1558,6 @@ def launchd_restart():
|
||||
plist_path = get_launchd_plist_path()
|
||||
subprocess.run(["launchctl", "bootstrap", _launchd_domain(), str(plist_path)], check=True, timeout=30)
|
||||
subprocess.run(["launchctl", "kickstart", target], check=True, timeout=30)
|
||||
_await_service_ready_or_exit(action="restart", previous_pid=pid)
|
||||
print("✓ Service restarted")
|
||||
|
||||
def launchd_status(deep: bool = False):
|
||||
@@ -2987,6 +2930,15 @@ def gateway_command(args):
|
||||
|
||||
elif subcmd == "start":
|
||||
system = getattr(args, 'system', False)
|
||||
start_all = getattr(args, 'all', False)
|
||||
|
||||
if start_all:
|
||||
# Kill all stale gateway processes across all profiles before starting
|
||||
killed = kill_gateway_processes(all_profiles=True)
|
||||
if killed:
|
||||
print(f"✓ Killed {killed} stale gateway process(es) across all profiles")
|
||||
_wait_for_gateway_exit(timeout=10.0, force_after=5.0)
|
||||
|
||||
if is_termux():
|
||||
print("Gateway service start is not supported on Termux because there is no system service manager.")
|
||||
print("Run manually: hermes gateway")
|
||||
@@ -3072,7 +3024,39 @@ def gateway_command(args):
|
||||
# Try service first, fall back to killing and restarting
|
||||
service_available = False
|
||||
system = getattr(args, 'system', False)
|
||||
restart_all = getattr(args, 'all', False)
|
||||
service_configured = False
|
||||
|
||||
if restart_all:
|
||||
# --all: stop every gateway process across all profiles, then start fresh
|
||||
service_stopped = False
|
||||
if supports_systemd_services() and (get_systemd_unit_path(system=False).exists() or get_systemd_unit_path(system=True).exists()):
|
||||
try:
|
||||
systemd_stop(system=system)
|
||||
service_stopped = True
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
elif is_macos() and get_launchd_plist_path().exists():
|
||||
try:
|
||||
launchd_stop()
|
||||
service_stopped = True
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
killed = kill_gateway_processes(all_profiles=True)
|
||||
total = killed + (1 if service_stopped else 0)
|
||||
if total:
|
||||
print(f"✓ Stopped {total} gateway process(es) across all profiles")
|
||||
_wait_for_gateway_exit(timeout=10.0, force_after=5.0)
|
||||
|
||||
# Start the current profile's service fresh
|
||||
print("Starting gateway...")
|
||||
if supports_systemd_services() and (get_systemd_unit_path(system=False).exists() or get_systemd_unit_path(system=True).exists()):
|
||||
systemd_start(system=system)
|
||||
elif is_macos() and get_launchd_plist_path().exists():
|
||||
launchd_start()
|
||||
else:
|
||||
run_gateway(verbose=0)
|
||||
return
|
||||
|
||||
if supports_systemd_services() and (get_systemd_unit_path(system=False).exists() or get_systemd_unit_path(system=True).exists()):
|
||||
service_configured = True
|
||||
|
||||
+33
-1
@@ -4749,6 +4749,7 @@ For more help on a command:
|
||||
# gateway start
|
||||
gateway_start = gateway_subparsers.add_parser("start", help="Start the installed systemd/launchd background service")
|
||||
gateway_start.add_argument("--system", action="store_true", help="Target the Linux system-level gateway service")
|
||||
gateway_start.add_argument("--all", action="store_true", help="Kill ALL stale gateway processes across all profiles before starting")
|
||||
|
||||
# gateway stop
|
||||
gateway_stop = gateway_subparsers.add_parser("stop", help="Stop gateway service")
|
||||
@@ -4758,6 +4759,7 @@ For more help on a command:
|
||||
# gateway restart
|
||||
gateway_restart = gateway_subparsers.add_parser("restart", help="Restart gateway service")
|
||||
gateway_restart.add_argument("--system", action="store_true", help="Target the Linux system-level gateway service")
|
||||
gateway_restart.add_argument("--all", action="store_true", help="Kill ALL gateway processes across all profiles before restarting")
|
||||
|
||||
# gateway status
|
||||
gateway_status = gateway_subparsers.add_parser("status", help="Show gateway status")
|
||||
@@ -6044,7 +6046,37 @@ Examples:
|
||||
sys.exit(1)
|
||||
|
||||
_processed_argv = _coalesce_session_name_args(sys.argv[1:])
|
||||
args = parser.parse_args(_processed_argv)
|
||||
|
||||
# ── Defensive subparser routing (bpo-9338 workaround) ───────────
|
||||
# On some Python versions (notably <3.11), argparse fails to route
|
||||
# subcommand tokens when the parent parser has nargs='?' optional
|
||||
# arguments (--continue). The symptom: "unrecognized arguments: model"
|
||||
# even though 'model' is a registered subcommand.
|
||||
#
|
||||
# Fix: when argv contains a token matching a known subcommand, set
|
||||
# subparsers.required=True to force deterministic routing. If that
|
||||
# fails (e.g. 'hermes -c model' where 'model' is consumed as the
|
||||
# session name for --continue), fall back to the default behaviour.
|
||||
import io as _io
|
||||
_known_cmds = set(subparsers.choices.keys()) if hasattr(subparsers, "choices") else set()
|
||||
_has_cmd_token = any(t in _known_cmds for t in _processed_argv if not t.startswith("-"))
|
||||
|
||||
if _has_cmd_token:
|
||||
subparsers.required = True
|
||||
_saved_stderr = sys.stderr
|
||||
try:
|
||||
sys.stderr = _io.StringIO()
|
||||
args = parser.parse_args(_processed_argv)
|
||||
sys.stderr = _saved_stderr
|
||||
except SystemExit:
|
||||
sys.stderr = _saved_stderr
|
||||
# Subcommand name was consumed as a flag value (e.g. -c model).
|
||||
# Fall back to optional subparsers so argparse handles it normally.
|
||||
subparsers.required = False
|
||||
args = parser.parse_args(_processed_argv)
|
||||
else:
|
||||
subparsers.required = False
|
||||
args = parser.parse_args(_processed_argv)
|
||||
|
||||
# Handle --version flag
|
||||
if args.version:
|
||||
|
||||
+38
-16
@@ -63,6 +63,7 @@ CONFIGURABLE_TOOLSETS = [
|
||||
("clarify", "❓ Clarifying Questions", "clarify"),
|
||||
("delegation", "👥 Task Delegation", "delegate_task"),
|
||||
("cronjob", "⏰ Cron Jobs", "create/list/update/pause/resume/run, with optional attached skills"),
|
||||
("messaging", "📨 Cross-Platform Messaging", "send_message"),
|
||||
("rl", "🧪 RL Training", "Tinker-Atropos training tools"),
|
||||
("homeassistant", "🏠 Home Assistant", "smart home device control"),
|
||||
]
|
||||
@@ -121,6 +122,7 @@ TOOL_CATEGORIES = {
|
||||
"providers": [
|
||||
{
|
||||
"name": "Nous Subscription",
|
||||
"badge": "subscription",
|
||||
"tag": "Managed OpenAI TTS billed to your subscription",
|
||||
"env_vars": [],
|
||||
"tts_provider": "openai",
|
||||
@@ -130,13 +132,15 @@ TOOL_CATEGORIES = {
|
||||
},
|
||||
{
|
||||
"name": "Microsoft Edge TTS",
|
||||
"tag": "Free - no API key needed",
|
||||
"badge": "★ recommended · free",
|
||||
"tag": "Good quality, no API key needed",
|
||||
"env_vars": [],
|
||||
"tts_provider": "edge",
|
||||
},
|
||||
{
|
||||
"name": "OpenAI TTS",
|
||||
"tag": "Premium - high quality voices",
|
||||
"badge": "paid",
|
||||
"tag": "High quality voices",
|
||||
"env_vars": [
|
||||
{"key": "VOICE_TOOLS_OPENAI_KEY", "prompt": "OpenAI API key", "url": "https://platform.openai.com/api-keys"},
|
||||
],
|
||||
@@ -144,7 +148,8 @@ TOOL_CATEGORIES = {
|
||||
},
|
||||
{
|
||||
"name": "ElevenLabs",
|
||||
"tag": "Premium - most natural voices",
|
||||
"badge": "paid",
|
||||
"tag": "Most natural voices",
|
||||
"env_vars": [
|
||||
{"key": "ELEVENLABS_API_KEY", "prompt": "ElevenLabs API key", "url": "https://elevenlabs.io/app/settings/api-keys"},
|
||||
],
|
||||
@@ -152,7 +157,8 @@ TOOL_CATEGORIES = {
|
||||
},
|
||||
{
|
||||
"name": "Mistral (Voxtral TTS)",
|
||||
"tag": "Multilingual, native Opus, needs MISTRAL_API_KEY",
|
||||
"badge": "paid",
|
||||
"tag": "Multilingual, native Opus",
|
||||
"env_vars": [
|
||||
{"key": "MISTRAL_API_KEY", "prompt": "Mistral API key", "url": "https://console.mistral.ai/"},
|
||||
],
|
||||
@@ -168,6 +174,7 @@ TOOL_CATEGORIES = {
|
||||
"providers": [
|
||||
{
|
||||
"name": "Nous Subscription",
|
||||
"badge": "subscription",
|
||||
"tag": "Managed Firecrawl billed to your subscription",
|
||||
"web_backend": "firecrawl",
|
||||
"env_vars": [],
|
||||
@@ -177,7 +184,8 @@ TOOL_CATEGORIES = {
|
||||
},
|
||||
{
|
||||
"name": "Firecrawl Cloud",
|
||||
"tag": "Hosted service - search, extract, and crawl",
|
||||
"badge": "★ recommended",
|
||||
"tag": "Full-featured search, extract, and crawl",
|
||||
"web_backend": "firecrawl",
|
||||
"env_vars": [
|
||||
{"key": "FIRECRAWL_API_KEY", "prompt": "Firecrawl API key", "url": "https://firecrawl.dev"},
|
||||
@@ -185,7 +193,8 @@ TOOL_CATEGORIES = {
|
||||
},
|
||||
{
|
||||
"name": "Exa",
|
||||
"tag": "AI-native search and contents",
|
||||
"badge": "paid",
|
||||
"tag": "Neural search with semantic understanding",
|
||||
"web_backend": "exa",
|
||||
"env_vars": [
|
||||
{"key": "EXA_API_KEY", "prompt": "Exa API key", "url": "https://exa.ai"},
|
||||
@@ -193,7 +202,8 @@ TOOL_CATEGORIES = {
|
||||
},
|
||||
{
|
||||
"name": "Parallel",
|
||||
"tag": "AI-native search and extract",
|
||||
"badge": "paid",
|
||||
"tag": "AI-powered search and extract",
|
||||
"web_backend": "parallel",
|
||||
"env_vars": [
|
||||
{"key": "PARALLEL_API_KEY", "prompt": "Parallel API key", "url": "https://parallel.ai"},
|
||||
@@ -201,7 +211,8 @@ TOOL_CATEGORIES = {
|
||||
},
|
||||
{
|
||||
"name": "Tavily",
|
||||
"tag": "AI-native search, extract, and crawl",
|
||||
"badge": "free tier",
|
||||
"tag": "Search, extract, and crawl — 1000 free searches/mo",
|
||||
"web_backend": "tavily",
|
||||
"env_vars": [
|
||||
{"key": "TAVILY_API_KEY", "prompt": "Tavily API key", "url": "https://app.tavily.com/home"},
|
||||
@@ -209,7 +220,8 @@ TOOL_CATEGORIES = {
|
||||
},
|
||||
{
|
||||
"name": "Firecrawl Self-Hosted",
|
||||
"tag": "Free - run your own instance",
|
||||
"badge": "free · self-hosted",
|
||||
"tag": "Run your own Firecrawl instance (Docker)",
|
||||
"web_backend": "firecrawl",
|
||||
"env_vars": [
|
||||
{"key": "FIRECRAWL_API_URL", "prompt": "Your Firecrawl instance URL (e.g., http://localhost:3002)"},
|
||||
@@ -223,6 +235,7 @@ TOOL_CATEGORIES = {
|
||||
"providers": [
|
||||
{
|
||||
"name": "Nous Subscription",
|
||||
"badge": "subscription",
|
||||
"tag": "Managed FAL image generation billed to your subscription",
|
||||
"env_vars": [],
|
||||
"requires_nous_auth": True,
|
||||
@@ -231,6 +244,7 @@ TOOL_CATEGORIES = {
|
||||
},
|
||||
{
|
||||
"name": "FAL.ai",
|
||||
"badge": "paid",
|
||||
"tag": "FLUX 2 Pro with auto-upscaling",
|
||||
"env_vars": [
|
||||
{"key": "FAL_KEY", "prompt": "FAL API key", "url": "https://fal.ai/dashboard/keys"},
|
||||
@@ -244,6 +258,7 @@ TOOL_CATEGORIES = {
|
||||
"providers": [
|
||||
{
|
||||
"name": "Nous Subscription (Browser Use cloud)",
|
||||
"badge": "subscription",
|
||||
"tag": "Managed Browser Use billed to your subscription",
|
||||
"env_vars": [],
|
||||
"browser_provider": "browser-use",
|
||||
@@ -254,14 +269,16 @@ TOOL_CATEGORIES = {
|
||||
},
|
||||
{
|
||||
"name": "Local Browser",
|
||||
"tag": "Free headless Chromium (no API key needed)",
|
||||
"badge": "★ recommended · free",
|
||||
"tag": "Headless Chromium, no API key needed",
|
||||
"env_vars": [],
|
||||
"browser_provider": "local",
|
||||
"post_setup": "agent_browser",
|
||||
},
|
||||
{
|
||||
"name": "Browserbase",
|
||||
"tag": "Cloud browser with stealth & proxies",
|
||||
"badge": "paid",
|
||||
"tag": "Cloud browser with stealth and proxies",
|
||||
"env_vars": [
|
||||
{"key": "BROWSERBASE_API_KEY", "prompt": "Browserbase API key", "url": "https://browserbase.com"},
|
||||
{"key": "BROWSERBASE_PROJECT_ID", "prompt": "Browserbase project ID"},
|
||||
@@ -271,6 +288,7 @@ TOOL_CATEGORIES = {
|
||||
},
|
||||
{
|
||||
"name": "Browser Use",
|
||||
"badge": "paid",
|
||||
"tag": "Cloud browser with remote execution",
|
||||
"env_vars": [
|
||||
{"key": "BROWSER_USE_API_KEY", "prompt": "Browser Use API key", "url": "https://browser-use.com"},
|
||||
@@ -280,6 +298,7 @@ TOOL_CATEGORIES = {
|
||||
},
|
||||
{
|
||||
"name": "Firecrawl",
|
||||
"badge": "paid",
|
||||
"tag": "Cloud browser with remote execution",
|
||||
"env_vars": [
|
||||
{"key": "FIRECRAWL_API_KEY", "prompt": "Firecrawl API key", "url": "https://firecrawl.dev"},
|
||||
@@ -289,7 +308,8 @@ TOOL_CATEGORIES = {
|
||||
},
|
||||
{
|
||||
"name": "Camofox",
|
||||
"tag": "Local anti-detection browser (Firefox/Camoufox)",
|
||||
"badge": "free · local",
|
||||
"tag": "Anti-detection browser (Firefox/Camoufox)",
|
||||
"env_vars": [
|
||||
{"key": "CAMOFOX_URL", "prompt": "Camofox server URL", "default": "http://localhost:9377",
|
||||
"url": "https://github.com/jo-inc/camofox-browser"},
|
||||
@@ -838,7 +858,8 @@ def _configure_tool_category(ts_key: str, cat: dict, config: dict):
|
||||
# Plain text labels only (no ANSI codes in menu items)
|
||||
provider_choices = []
|
||||
for p in providers:
|
||||
tag = f" ({p['tag']})" if p.get("tag") else ""
|
||||
badge = f" [{p['badge']}]" if p.get("badge") else ""
|
||||
tag = f" — {p['tag']}" if p.get("tag") else ""
|
||||
configured = ""
|
||||
env_vars = p.get("env_vars", [])
|
||||
if not env_vars or all(get_env_value(v["key"]) for v in env_vars):
|
||||
@@ -848,7 +869,7 @@ def _configure_tool_category(ts_key: str, cat: dict, config: dict):
|
||||
configured = ""
|
||||
else:
|
||||
configured = " [configured]"
|
||||
provider_choices.append(f"{p['name']}{tag}{configured}")
|
||||
provider_choices.append(f"{p['name']}{badge}{tag}{configured}")
|
||||
|
||||
# Add skip option
|
||||
provider_choices.append("Skip — keep defaults / configure later")
|
||||
@@ -1104,7 +1125,8 @@ def _configure_tool_category_for_reconfig(ts_key: str, cat: dict, config: dict):
|
||||
|
||||
provider_choices = []
|
||||
for p in providers:
|
||||
tag = f" ({p['tag']})" if p.get("tag") else ""
|
||||
badge = f" [{p['badge']}]" if p.get("badge") else ""
|
||||
tag = f" — {p['tag']}" if p.get("tag") else ""
|
||||
configured = ""
|
||||
env_vars = p.get("env_vars", [])
|
||||
if not env_vars or all(get_env_value(v["key"]) for v in env_vars):
|
||||
@@ -1114,7 +1136,7 @@ def _configure_tool_category_for_reconfig(ts_key: str, cat: dict, config: dict):
|
||||
configured = ""
|
||||
else:
|
||||
configured = " [configured]"
|
||||
provider_choices.append(f"{p['name']}{tag}{configured}")
|
||||
provider_choices.append(f"{p['name']}{badge}{tag}{configured}")
|
||||
|
||||
default_idx = _detect_active_provider_index(providers, config)
|
||||
|
||||
|
||||
@@ -358,6 +358,7 @@ def _add_rotating_handler(
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
handler = _ManagedRotatingFileHandler(
|
||||
str(path), maxBytes=max_bytes, backupCount=backup_count,
|
||||
encoding="utf-8",
|
||||
)
|
||||
handler.setLevel(level)
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
+2
-40
@@ -26,7 +26,7 @@ import logging
|
||||
import threading
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
|
||||
from tools.registry import registry
|
||||
from tools.registry import discover_builtin_tools, registry
|
||||
from toolsets import resolve_toolset, validate_toolset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -129,45 +129,7 @@ def _run_async(coro):
|
||||
# Tool Discovery (importing each module triggers its registry.register calls)
|
||||
# =============================================================================
|
||||
|
||||
def _discover_tools():
|
||||
"""Import all tool modules to trigger their registry.register() calls.
|
||||
|
||||
Wrapped in a function so import errors in optional tools (e.g., fal_client
|
||||
not installed) don't prevent the rest from loading.
|
||||
"""
|
||||
_modules = [
|
||||
"tools.web_tools",
|
||||
"tools.terminal_tool",
|
||||
"tools.file_tools",
|
||||
"tools.vision_tools",
|
||||
"tools.mixture_of_agents_tool",
|
||||
"tools.image_generation_tool",
|
||||
"tools.skills_tool",
|
||||
"tools.skill_manager_tool",
|
||||
"tools.browser_tool",
|
||||
"tools.cronjob_tools",
|
||||
"tools.rl_training_tool",
|
||||
"tools.tts_tool",
|
||||
"tools.todo_tool",
|
||||
"tools.memory_tool",
|
||||
"tools.session_search_tool",
|
||||
"tools.clarify_tool",
|
||||
"tools.code_execution_tool",
|
||||
"tools.delegate_tool",
|
||||
"tools.process_registry",
|
||||
"tools.send_message_tool",
|
||||
# "tools.honcho_tools", # Removed — Honcho is now a memory provider plugin
|
||||
"tools.homeassistant_tool",
|
||||
]
|
||||
import importlib
|
||||
for mod_name in _modules:
|
||||
try:
|
||||
importlib.import_module(mod_name)
|
||||
except Exception as e:
|
||||
logger.warning("Could not import tool module %s: %s", mod_name, e)
|
||||
|
||||
|
||||
_discover_tools()
|
||||
discover_builtin_tools()
|
||||
|
||||
# MCP tool discovery (external MCP servers from config)
|
||||
try:
|
||||
|
||||
@@ -10,8 +10,9 @@ lifecycle instead of read-only search endpoints.
|
||||
Config via environment variables (profile-scoped via each profile's .env):
|
||||
OPENVIKING_ENDPOINT — Server URL (default: http://127.0.0.1:1933)
|
||||
OPENVIKING_API_KEY — API key (required for authenticated servers)
|
||||
OPENVIKING_ACCOUNT — Tenant account (default: root)
|
||||
OPENVIKING_ACCOUNT — Tenant account (default: default)
|
||||
OPENVIKING_USER — Tenant user (default: default)
|
||||
OPENVIKING_AGENT — Tenant agent (default: hermes)
|
||||
|
||||
Capabilities:
|
||||
- Automatic memory extraction on session commit (6 categories)
|
||||
@@ -80,11 +81,12 @@ class _VikingClient:
|
||||
"""Thin HTTP client for the OpenViking REST API."""
|
||||
|
||||
def __init__(self, endpoint: str, api_key: str = "",
|
||||
account: str = "", user: str = ""):
|
||||
account: str = "", user: str = "", agent: str = ""):
|
||||
self._endpoint = endpoint.rstrip("/")
|
||||
self._api_key = api_key
|
||||
self._account = account or os.environ.get("OPENVIKING_ACCOUNT", "root")
|
||||
self._account = account or os.environ.get("OPENVIKING_ACCOUNT", "default")
|
||||
self._user = user or os.environ.get("OPENVIKING_USER", "default")
|
||||
self._agent = agent or os.environ.get("OPENVIKING_AGENT", "hermes")
|
||||
self._httpx = _get_httpx()
|
||||
if self._httpx is None:
|
||||
raise ImportError("httpx is required for OpenViking: pip install httpx")
|
||||
@@ -94,6 +96,7 @@ class _VikingClient:
|
||||
"Content-Type": "application/json",
|
||||
"X-OpenViking-Account": self._account,
|
||||
"X-OpenViking-User": self._user,
|
||||
"X-OpenViking-Agent": self._agent,
|
||||
}
|
||||
if self._api_key:
|
||||
h["X-API-Key"] = self._api_key
|
||||
@@ -282,20 +285,44 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
||||
},
|
||||
{
|
||||
"key": "api_key",
|
||||
"description": "OpenViking API key",
|
||||
"description": "OpenViking API key (leave blank for local dev mode)",
|
||||
"secret": True,
|
||||
"env_var": "OPENVIKING_API_KEY",
|
||||
},
|
||||
{
|
||||
"key": "account",
|
||||
"description": "OpenViking tenant account ID ([default], used when local mode, OPENVIKING_API_KEY is empty)",
|
||||
"default": "default",
|
||||
"env_var": "OPENVIKING_ACCOUNT",
|
||||
},
|
||||
{
|
||||
"key": "user",
|
||||
"description": "OpenViking user ID within the account ([default], used when local mode, OPENVIKING_API_KEY is empty)",
|
||||
"default": "default",
|
||||
"env_var": "OPENVIKING_USER",
|
||||
},
|
||||
{
|
||||
"key": "agent",
|
||||
"description": "OpenViking agent ID within the account ([hermes], useful in multi-agent mode)",
|
||||
"default": "hermes",
|
||||
"env_var": "OPENVIKING_AGENT",
|
||||
},
|
||||
]
|
||||
|
||||
def initialize(self, session_id: str, **kwargs) -> None:
|
||||
self._endpoint = os.environ.get("OPENVIKING_ENDPOINT", _DEFAULT_ENDPOINT)
|
||||
self._api_key = os.environ.get("OPENVIKING_API_KEY", "")
|
||||
self._account = os.environ.get("OPENVIKING_ACCOUNT", "default")
|
||||
self._user = os.environ.get("OPENVIKING_USER", "default")
|
||||
self._agent = os.environ.get("OPENVIKING_AGENT", "hermes")
|
||||
self._session_id = session_id
|
||||
self._turn_count = 0
|
||||
|
||||
try:
|
||||
self._client = _VikingClient(self._endpoint, self._api_key)
|
||||
self._client = _VikingClient(
|
||||
self._endpoint, self._api_key,
|
||||
account=self._account, user=self._user, agent=self._agent,
|
||||
)
|
||||
if not self._client.health():
|
||||
logger.warning("OpenViking server at %s is not reachable", self._endpoint)
|
||||
self._client = None
|
||||
@@ -325,7 +352,8 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
||||
"(abstract/overview/full), viking_browse to explore.\n"
|
||||
"Use viking_remember to store facts, viking_add_resource to index URLs/docs."
|
||||
)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.warning("OpenViking system_prompt_block failed: %s", e)
|
||||
return (
|
||||
"# OpenViking Knowledge Base\n"
|
||||
f"Active. Endpoint: {self._endpoint}\n"
|
||||
@@ -351,7 +379,10 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
||||
|
||||
def _run():
|
||||
try:
|
||||
client = _VikingClient(self._endpoint, self._api_key)
|
||||
client = _VikingClient(
|
||||
self._endpoint, self._api_key,
|
||||
account=self._account, user=self._user, agent=self._agent,
|
||||
)
|
||||
resp = client.post("/api/v1/search/find", {
|
||||
"query": query,
|
||||
"top_k": 5,
|
||||
@@ -386,7 +417,10 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
||||
|
||||
def _sync():
|
||||
try:
|
||||
client = _VikingClient(self._endpoint, self._api_key)
|
||||
client = _VikingClient(
|
||||
self._endpoint, self._api_key,
|
||||
account=self._account, user=self._user, agent=self._agent,
|
||||
)
|
||||
sid = self._session_id
|
||||
|
||||
# Add user message
|
||||
@@ -442,7 +476,10 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
||||
|
||||
def _write():
|
||||
try:
|
||||
client = _VikingClient(self._endpoint, self._api_key)
|
||||
client = _VikingClient(
|
||||
self._endpoint, self._api_key,
|
||||
account=self._account, user=self._user, agent=self._agent,
|
||||
)
|
||||
# Add as a user message with memory context so the commit
|
||||
# picks it up as an explicit memory during extraction
|
||||
client.post(f"/api/v1/sessions/{self._session_id}/messages", {
|
||||
|
||||
+299
-34
@@ -754,6 +754,7 @@ class AIAgent:
|
||||
self._interrupt_requested = False
|
||||
self._interrupt_message = None # Optional message that triggered interrupt
|
||||
self._execution_thread_id: int | None = None # Set at run_conversation() start
|
||||
self._interrupt_thread_signal_pending = False
|
||||
self._client_lock = threading.RLock()
|
||||
|
||||
# Subagent delegation state
|
||||
@@ -1268,6 +1269,19 @@ class AIAgent:
|
||||
try:
|
||||
_config_context_length = int(_config_context_length)
|
||||
except (TypeError, ValueError):
|
||||
logger.warning(
|
||||
"Invalid model.context_length in config.yaml: %r — "
|
||||
"must be a plain integer (e.g. 256000, not '256K'). "
|
||||
"Falling back to auto-detection.",
|
||||
_config_context_length,
|
||||
)
|
||||
import sys
|
||||
print(
|
||||
f"\n⚠ Invalid model.context_length in config.yaml: {_config_context_length!r}\n"
|
||||
f" Must be a plain integer (e.g. 256000, not '256K').\n"
|
||||
f" Falling back to auto-detected context window.\n",
|
||||
file=sys.stderr,
|
||||
)
|
||||
_config_context_length = None
|
||||
|
||||
# Store for reuse in switch_model (so config override persists across model switches)
|
||||
@@ -1296,7 +1310,20 @@ class AIAgent:
|
||||
try:
|
||||
_config_context_length = int(_cp_ctx)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
logger.warning(
|
||||
"Invalid context_length for model %r in "
|
||||
"custom_providers: %r — must be a plain "
|
||||
"integer (e.g. 256000, not '256K'). "
|
||||
"Falling back to auto-detection.",
|
||||
self.model, _cp_ctx,
|
||||
)
|
||||
import sys
|
||||
print(
|
||||
f"\n⚠ Invalid context_length for model {self.model!r} in custom_providers: {_cp_ctx!r}\n"
|
||||
f" Must be a plain integer (e.g. 256000, not '256K').\n"
|
||||
f" Falling back to auto-detected context window.\n",
|
||||
file=sys.stderr,
|
||||
)
|
||||
break
|
||||
|
||||
# Select context engine: config-driven (like memory providers).
|
||||
@@ -2923,7 +2950,15 @@ class AIAgent:
|
||||
# Signal all tools to abort any in-flight operations immediately.
|
||||
# Scope the interrupt to this agent's execution thread so other
|
||||
# agents running in the same process (gateway) are not affected.
|
||||
_set_interrupt(True, self._execution_thread_id)
|
||||
if self._execution_thread_id is not None:
|
||||
_set_interrupt(True, self._execution_thread_id)
|
||||
self._interrupt_thread_signal_pending = False
|
||||
else:
|
||||
# The interrupt arrived before run_conversation() finished
|
||||
# binding the agent to its execution thread. Defer the tool-level
|
||||
# interrupt signal until startup completes instead of targeting
|
||||
# the caller thread by mistake.
|
||||
self._interrupt_thread_signal_pending = True
|
||||
# Propagate interrupt to any running child agents (subagent delegation)
|
||||
with self._active_children_lock:
|
||||
children_copy = list(self._active_children)
|
||||
@@ -2939,7 +2974,9 @@ class AIAgent:
|
||||
"""Clear any pending interrupt request and the per-thread tool interrupt signal."""
|
||||
self._interrupt_requested = False
|
||||
self._interrupt_message = None
|
||||
_set_interrupt(False, self._execution_thread_id)
|
||||
self._interrupt_thread_signal_pending = False
|
||||
if self._execution_thread_id is not None:
|
||||
_set_interrupt(False, self._execution_thread_id)
|
||||
|
||||
def _touch_activity(self, desc: str) -> None:
|
||||
"""Update the last-activity timestamp and description (thread-safe)."""
|
||||
@@ -3014,6 +3051,18 @@ class AIAgent:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def commit_memory_session(self, messages: list = None) -> None:
|
||||
"""Trigger end-of-session extraction without tearing providers down.
|
||||
Called when session_id rotates (e.g. /new, context compression);
|
||||
providers keep their state and continue running under the old
|
||||
session_id — they just flush pending extraction now."""
|
||||
if not self._memory_manager:
|
||||
return
|
||||
try:
|
||||
self._memory_manager.on_session_end(messages or [])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
"""Release all resources held by this agent instance.
|
||||
|
||||
@@ -3563,7 +3612,12 @@ class AIAgent:
|
||||
item_id = ri.get("id")
|
||||
if item_id and item_id in seen_item_ids:
|
||||
continue
|
||||
items.append(ri)
|
||||
# Strip the "id" field — with store=False the
|
||||
# Responses API cannot look up items by ID and
|
||||
# returns 404. The encrypted_content blob is
|
||||
# self-contained for reasoning chain continuity.
|
||||
replay_item = {k: v for k, v in ri.items() if k != "id"}
|
||||
items.append(replay_item)
|
||||
if item_id:
|
||||
seen_item_ids.add(item_id)
|
||||
has_codex_reasoning = True
|
||||
@@ -3704,8 +3758,10 @@ class AIAgent:
|
||||
continue
|
||||
seen_ids.add(item_id)
|
||||
reasoning_item = {"type": "reasoning", "encrypted_content": encrypted}
|
||||
if isinstance(item_id, str) and item_id:
|
||||
reasoning_item["id"] = item_id
|
||||
# Do NOT include the "id" in the outgoing item — with
|
||||
# store=False (our default) the API tries to resolve the
|
||||
# id server-side and returns 404. The id is still used
|
||||
# above for local deduplication via seen_ids.
|
||||
summary = item.get("summary")
|
||||
if isinstance(summary, list):
|
||||
reasoning_item["summary"] = summary
|
||||
@@ -5477,9 +5533,27 @@ class AIAgent:
|
||||
|
||||
t = threading.Thread(target=_call, daemon=True)
|
||||
t.start()
|
||||
_last_heartbeat = time.time()
|
||||
_HEARTBEAT_INTERVAL = 30.0 # seconds between gateway activity touches
|
||||
while t.is_alive():
|
||||
t.join(timeout=0.3)
|
||||
|
||||
# Periodic heartbeat: touch the agent's activity tracker so the
|
||||
# gateway's inactivity monitor knows we're alive while waiting
|
||||
# for stream chunks. Without this, long thinking pauses (e.g.
|
||||
# reasoning models) or slow prefill on local providers (Ollama)
|
||||
# trigger false inactivity timeouts. The _call thread touches
|
||||
# activity on each chunk, but the gap between API call start
|
||||
# and first chunk can exceed the gateway timeout — especially
|
||||
# when the stale-stream timeout is disabled (local providers).
|
||||
_hb_now = time.time()
|
||||
if _hb_now - _last_heartbeat >= _HEARTBEAT_INTERVAL:
|
||||
_last_heartbeat = _hb_now
|
||||
_waiting_secs = int(_hb_now - last_chunk_time["t"])
|
||||
self._touch_activity(
|
||||
f"waiting for stream response ({_waiting_secs}s, no chunks yet)"
|
||||
)
|
||||
|
||||
# Detect stale streams: connections kept alive by SSE pings
|
||||
# but delivering no real chunks. Kill the client so the
|
||||
# inner retry loop can start a fresh connection.
|
||||
@@ -6793,6 +6867,8 @@ class AIAgent:
|
||||
try:
|
||||
# Propagate title to the new session with auto-numbering
|
||||
old_title = self._session_db.get_session_title(self.session_id)
|
||||
# Trigger memory extraction on the old session before it rotates.
|
||||
self.commit_memory_session(messages)
|
||||
self._session_db.end_session(self.session_id, "compression")
|
||||
old_session_id = self.session_id
|
||||
self.session_id = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}"
|
||||
@@ -6975,6 +7051,31 @@ class AIAgent:
|
||||
skip_pre_tool_call_hook=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _wrap_verbose(label: str, text: str, indent: str = " ") -> str:
|
||||
"""Word-wrap verbose tool output to fit the terminal width.
|
||||
|
||||
Splits *text* on existing newlines and wraps each line individually,
|
||||
preserving intentional line breaks (e.g. pretty-printed JSON).
|
||||
Returns a ready-to-print string with *label* on the first line and
|
||||
continuation lines indented.
|
||||
"""
|
||||
import shutil as _shutil
|
||||
import textwrap as _tw
|
||||
cols = _shutil.get_terminal_size((120, 24)).columns
|
||||
wrap_width = max(40, cols - len(indent))
|
||||
out_lines: list[str] = []
|
||||
for raw_line in text.split("\n"):
|
||||
if len(raw_line) <= wrap_width:
|
||||
out_lines.append(raw_line)
|
||||
else:
|
||||
wrapped = _tw.wrap(raw_line, width=wrap_width,
|
||||
break_long_words=True,
|
||||
break_on_hyphens=False)
|
||||
out_lines.extend(wrapped or [raw_line])
|
||||
body = ("\n" + indent).join(out_lines)
|
||||
return f"{indent}{label}{body}"
|
||||
|
||||
def _execute_tool_calls_concurrent(self, assistant_message, messages: list, effective_task_id: str, api_call_count: int = 0) -> None:
|
||||
"""Execute multiple tool calls concurrently using a thread pool.
|
||||
|
||||
@@ -7045,7 +7146,7 @@ class AIAgent:
|
||||
args_str = json.dumps(args, ensure_ascii=False)
|
||||
if self.verbose_logging:
|
||||
print(f" 📞 Tool {i}: {name}({list(args.keys())})")
|
||||
print(f" Args: {args_str}")
|
||||
print(self._wrap_verbose("Args: ", json.dumps(args, indent=2, ensure_ascii=False)))
|
||||
else:
|
||||
args_preview = args_str[:self.log_prefix_chars] + "..." if len(args_str) > self.log_prefix_chars else args_str
|
||||
print(f" 📞 Tool {i}: {name}({list(args.keys())}) - {args_preview}")
|
||||
@@ -7069,8 +7170,22 @@ class AIAgent:
|
||||
# Each slot holds (function_name, function_args, function_result, duration, error_flag)
|
||||
results = [None] * num_tools
|
||||
|
||||
# Touch activity before launching workers so the gateway knows
|
||||
# we're executing tools (not stuck).
|
||||
self._current_tool = tool_names_str
|
||||
self._touch_activity(f"executing {num_tools} tools concurrently: {tool_names_str}")
|
||||
|
||||
def _run_tool(index, tool_call, function_name, function_args):
|
||||
"""Worker function executed in a thread."""
|
||||
# Set the activity callback on THIS worker thread so
|
||||
# _wait_for_process (terminal commands) can fire heartbeats.
|
||||
# The callback is thread-local; the main thread's callback
|
||||
# is invisible to worker threads.
|
||||
try:
|
||||
from tools.environments.base import set_activity_callback
|
||||
set_activity_callback(self._touch_activity)
|
||||
except Exception:
|
||||
pass
|
||||
start = time.time()
|
||||
try:
|
||||
result = self._invoke_tool(function_name, function_args, effective_task_id, tool_call.id)
|
||||
@@ -7100,8 +7215,26 @@ class AIAgent:
|
||||
f = executor.submit(_run_tool, i, tc, name, args)
|
||||
futures.append(f)
|
||||
|
||||
# Wait for all to complete (exceptions are captured inside _run_tool)
|
||||
concurrent.futures.wait(futures)
|
||||
# Wait for all to complete with periodic heartbeats so the
|
||||
# gateway's inactivity monitor doesn't kill us during long
|
||||
# concurrent tool batches.
|
||||
_conc_start = time.time()
|
||||
while True:
|
||||
done, not_done = concurrent.futures.wait(
|
||||
futures, timeout=30.0,
|
||||
)
|
||||
if not not_done:
|
||||
break
|
||||
_conc_elapsed = int(time.time() - _conc_start)
|
||||
_still_running = [
|
||||
parsed_calls[futures.index(f)][1]
|
||||
for f in not_done
|
||||
if f in futures
|
||||
]
|
||||
self._touch_activity(
|
||||
f"concurrent tools running ({_conc_elapsed}s, "
|
||||
f"{len(not_done)} remaining: {', '.join(_still_running[:3])})"
|
||||
)
|
||||
finally:
|
||||
if spinner:
|
||||
# Build a summary message for the spinner stop
|
||||
@@ -7143,7 +7276,7 @@ class AIAgent:
|
||||
elif not self.quiet_mode:
|
||||
if self.verbose_logging:
|
||||
print(f" ✅ Tool {i+1} completed in {tool_duration:.2f}s")
|
||||
print(f" Result: {function_result}")
|
||||
print(self._wrap_verbose("Result: ", function_result))
|
||||
else:
|
||||
response_preview = function_result[:self.log_prefix_chars] + "..." if len(function_result) > self.log_prefix_chars else function_result
|
||||
print(f" ✅ Tool {i+1} completed in {tool_duration:.2f}s - {response_preview}")
|
||||
@@ -7236,7 +7369,7 @@ class AIAgent:
|
||||
args_str = json.dumps(function_args, ensure_ascii=False)
|
||||
if self.verbose_logging:
|
||||
print(f" 📞 Tool {i}: {function_name}({list(function_args.keys())})")
|
||||
print(f" Args: {args_str}")
|
||||
print(self._wrap_verbose("Args: ", json.dumps(function_args, indent=2, ensure_ascii=False)))
|
||||
else:
|
||||
args_preview = args_str[:self.log_prefix_chars] + "..." if len(args_str) > self.log_prefix_chars else args_str
|
||||
print(f" 📞 Tool {i}: {function_name}({list(function_args.keys())}) - {args_preview}")
|
||||
@@ -7333,6 +7466,16 @@ class AIAgent:
|
||||
old_text=function_args.get("old_text"),
|
||||
store=self._memory_store,
|
||||
)
|
||||
# Bridge: notify external memory provider of built-in memory writes
|
||||
if self._memory_manager and function_args.get("action") in ("add", "replace"):
|
||||
try:
|
||||
self._memory_manager.on_memory_write(
|
||||
function_args.get("action", ""),
|
||||
target,
|
||||
function_args.get("content", ""),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
tool_duration = time.time() - tool_start_time
|
||||
if self._should_emit_quiet_tool_messages():
|
||||
self._vprint(f" {_get_cute_tool_message_impl('memory', function_args, tool_duration, result=function_result)}")
|
||||
@@ -7524,7 +7667,7 @@ class AIAgent:
|
||||
if not self.quiet_mode:
|
||||
if self.verbose_logging:
|
||||
print(f" ✅ Tool {i} completed in {tool_duration:.2f}s")
|
||||
print(f" Result: {function_result}")
|
||||
print(self._wrap_verbose("Result: ", function_result))
|
||||
else:
|
||||
response_preview = function_result[:self.log_prefix_chars] + "..." if len(function_result) > self.log_prefix_chars else function_result
|
||||
print(f" ✅ Tool {i} completed in {tool_duration:.2f}s - {response_preview}")
|
||||
@@ -7807,7 +7950,9 @@ class AIAgent:
|
||||
self._incomplete_scratchpad_retries = 0
|
||||
self._codex_incomplete_retries = 0
|
||||
self._thinking_prefill_retries = 0
|
||||
self._post_tool_empty_retried = False
|
||||
self._last_content_with_tools = None
|
||||
self._last_content_tools_all_housekeeping = False
|
||||
self._mute_post_response = False
|
||||
self._unicode_sanitization_passes = 0
|
||||
|
||||
@@ -7987,6 +8132,16 @@ class AIAgent:
|
||||
# skipping them because conversation_history is still the
|
||||
# pre-compression length.
|
||||
conversation_history = None
|
||||
# Fix: reset retry counters after compression so the model
|
||||
# gets a fresh budget on the compressed context. Without
|
||||
# this, pre-compression retries carry over and the model
|
||||
# hits "(empty)" immediately after compression-induced
|
||||
# context loss.
|
||||
self._empty_content_retries = 0
|
||||
self._thinking_prefill_retries = 0
|
||||
self._last_content_with_tools = None
|
||||
self._last_content_tools_all_housekeeping = False
|
||||
self._mute_post_response = False
|
||||
# Re-estimate after compression
|
||||
_preflight_tokens = estimate_request_tokens_rough(
|
||||
messages,
|
||||
@@ -8045,11 +8200,19 @@ class AIAgent:
|
||||
|
||||
# Record the execution thread so interrupt()/clear_interrupt() can
|
||||
# scope the tool-level interrupt signal to THIS agent's thread only.
|
||||
# Must be set before clear_interrupt() which uses it.
|
||||
# Must be set before any thread-scoped interrupt syncing.
|
||||
self._execution_thread_id = threading.current_thread().ident
|
||||
|
||||
# Clear any stale interrupt state at start
|
||||
self.clear_interrupt()
|
||||
# Always clear stale per-thread state from a previous turn. If an
|
||||
# interrupt arrived before startup finished, preserve it and bind it
|
||||
# to this execution thread now instead of dropping it on the floor.
|
||||
_set_interrupt(False, self._execution_thread_id)
|
||||
if self._interrupt_requested:
|
||||
_set_interrupt(True, self._execution_thread_id)
|
||||
self._interrupt_thread_signal_pending = False
|
||||
else:
|
||||
self._interrupt_message = None
|
||||
self._interrupt_thread_signal_pending = False
|
||||
|
||||
# External memory provider: prefetch once before the tool loop.
|
||||
# Reuse the cached result on every iteration to avoid re-calling
|
||||
@@ -8962,12 +9125,40 @@ class AIAgent:
|
||||
if isinstance(_default_headers, dict):
|
||||
_headers_sanitized = _sanitize_structure_non_ascii(_default_headers)
|
||||
|
||||
# Sanitize the API key — non-ASCII characters in
|
||||
# credentials (e.g. ʋ instead of v from a bad
|
||||
# copy-paste) cause httpx to fail when encoding
|
||||
# the Authorization header as ASCII. This is the
|
||||
# most common cause of persistent UnicodeEncodeError
|
||||
# that survives message/tool sanitization (#6843).
|
||||
_credential_sanitized = False
|
||||
_raw_key = getattr(self, "api_key", None) or ""
|
||||
if _raw_key:
|
||||
_clean_key = _strip_non_ascii(_raw_key)
|
||||
if _clean_key != _raw_key:
|
||||
self.api_key = _clean_key
|
||||
if isinstance(getattr(self, "_client_kwargs", None), dict):
|
||||
self._client_kwargs["api_key"] = _clean_key
|
||||
# Also update the live client — it holds its
|
||||
# own copy of api_key which auth_headers reads
|
||||
# dynamically on every request.
|
||||
if getattr(self, "client", None) is not None and hasattr(self.client, "api_key"):
|
||||
self.client.api_key = _clean_key
|
||||
_credential_sanitized = True
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ API key contained non-ASCII characters "
|
||||
f"(bad copy-paste?) — stripped them. If auth fails, "
|
||||
f"re-copy the key from your provider's dashboard.",
|
||||
force=True,
|
||||
)
|
||||
|
||||
if (
|
||||
_messages_sanitized
|
||||
or _prefill_sanitized
|
||||
or _tools_sanitized
|
||||
or _system_sanitized
|
||||
or _headers_sanitized
|
||||
or _credential_sanitized
|
||||
):
|
||||
self._unicode_sanitization_passes += 1
|
||||
self._vprint(
|
||||
@@ -9255,7 +9446,9 @@ class AIAgent:
|
||||
"completed": False,
|
||||
"api_calls": api_call_count,
|
||||
"error": f"Request payload too large: max compression attempts ({max_compression_attempts}) reached.",
|
||||
"partial": True
|
||||
"partial": True,
|
||||
"failed": True,
|
||||
"compression_exhausted": True,
|
||||
}
|
||||
self._emit_status(f"⚠️ Request payload too large (413) — compression attempt {compression_attempts}/{max_compression_attempts}...")
|
||||
|
||||
@@ -9284,7 +9477,9 @@ class AIAgent:
|
||||
"completed": False,
|
||||
"api_calls": api_call_count,
|
||||
"error": "Request payload too large (413). Cannot compress further.",
|
||||
"partial": True
|
||||
"partial": True,
|
||||
"failed": True,
|
||||
"compression_exhausted": True,
|
||||
}
|
||||
|
||||
# Check for context-length errors BEFORE generic 4xx handler.
|
||||
@@ -9335,7 +9530,9 @@ class AIAgent:
|
||||
"completed": False,
|
||||
"api_calls": api_call_count,
|
||||
"error": f"Context length exceeded: max compression attempts ({max_compression_attempts}) reached.",
|
||||
"partial": True
|
||||
"partial": True,
|
||||
"failed": True,
|
||||
"compression_exhausted": True,
|
||||
}
|
||||
restart_with_compressed_messages = True
|
||||
break
|
||||
@@ -9385,7 +9582,9 @@ class AIAgent:
|
||||
"completed": False,
|
||||
"api_calls": api_call_count,
|
||||
"error": f"Context length exceeded: max compression attempts ({max_compression_attempts}) reached.",
|
||||
"partial": True
|
||||
"partial": True,
|
||||
"failed": True,
|
||||
"compression_exhausted": True,
|
||||
}
|
||||
self._emit_status(f"🗜️ Context too large (~{approx_tokens:,} tokens) — compressing ({compression_attempts}/{max_compression_attempts})...")
|
||||
|
||||
@@ -9416,7 +9615,9 @@ class AIAgent:
|
||||
"completed": False,
|
||||
"api_calls": api_call_count,
|
||||
"error": f"Context length exceeded ({approx_tokens:,} tokens). Cannot compress further.",
|
||||
"partial": True
|
||||
"partial": True,
|
||||
"failed": True,
|
||||
"compression_exhausted": True,
|
||||
}
|
||||
|
||||
# Check for non-retryable client errors. The classifier
|
||||
@@ -10010,6 +10211,7 @@ class AIAgent:
|
||||
tc.function.name in _HOUSEKEEPING_TOOLS
|
||||
for tc in assistant_message.tool_calls
|
||||
)
|
||||
self._last_content_tools_all_housekeeping = _all_housekeeping
|
||||
if _all_housekeeping and self._has_stream_consumers():
|
||||
self._mute_post_response = True
|
||||
elif self.quiet_mode:
|
||||
@@ -10038,6 +10240,10 @@ class AIAgent:
|
||||
if _had_prefill:
|
||||
self._thinking_prefill_retries = 0
|
||||
self._empty_content_retries = 0
|
||||
# Successful tool execution — reset the post-tool nudge
|
||||
# flag so it can fire again if the model goes empty on
|
||||
# a LATER tool round.
|
||||
self._post_tool_empty_retried = False
|
||||
|
||||
messages.append(assistant_msg)
|
||||
self._emit_interim_assistant_message(assistant_msg)
|
||||
@@ -10154,6 +10360,13 @@ class AIAgent:
|
||||
# No tool calls - this is the final response
|
||||
final_response = assistant_message.content or ""
|
||||
|
||||
# Fix: unmute output when entering the no-tool-call branch
|
||||
# so the user can see empty-response warnings and recovery
|
||||
# status messages. _mute_post_response was set during a
|
||||
# prior housekeeping tool turn and should not silence the
|
||||
# final response path.
|
||||
self._mute_post_response = False
|
||||
|
||||
# Check if response only has think block with no actual content after it
|
||||
if not self._has_content_after_think_block(final_response):
|
||||
# ── Partial stream recovery ─────────────────────
|
||||
@@ -10181,30 +10394,82 @@ class AIAgent:
|
||||
break
|
||||
|
||||
# If the previous turn already delivered real content alongside
|
||||
# tool calls (e.g. "You're welcome!" + memory save), the model
|
||||
# has nothing more to say. Use the earlier content immediately
|
||||
# instead of wasting API calls on retries that won't help.
|
||||
# HOUSEKEEPING tool calls (e.g. "You're welcome!" + memory save),
|
||||
# the model has nothing more to say. Use the earlier content
|
||||
# immediately instead of wasting API calls on retries.
|
||||
# NOTE: Only use this shortcut when ALL tools in that turn were
|
||||
# housekeeping (memory, todo, etc.). When substantive tools
|
||||
# were called (terminal, search_files, etc.), the content was
|
||||
# likely mid-task narration ("I'll scan the directory...") and
|
||||
# the empty follow-up means the model choked — let the
|
||||
# post-tool nudge below handle that instead of exiting early.
|
||||
fallback = getattr(self, '_last_content_with_tools', None)
|
||||
if fallback:
|
||||
if fallback and getattr(self, '_last_content_tools_all_housekeeping', False):
|
||||
_turn_exit_reason = "fallback_prior_turn_content"
|
||||
logger.info("Empty follow-up after tool calls — using prior turn content as final response")
|
||||
self._emit_status("↻ Empty response after tool calls — using earlier content as final answer")
|
||||
self._last_content_with_tools = None
|
||||
self._last_content_tools_all_housekeeping = False
|
||||
self._empty_content_retries = 0
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
msg = messages[i]
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
tool_names = []
|
||||
for tc in msg["tool_calls"]:
|
||||
if not tc or not isinstance(tc, dict): continue
|
||||
fn = tc.get("function", {})
|
||||
tool_names.append(fn.get("name", "unknown"))
|
||||
msg["content"] = f"Calling the {', '.join(tool_names)} tool{'s' if len(tool_names) > 1 else ''}..."
|
||||
break
|
||||
# Do NOT modify the assistant message content — the
|
||||
# old code injected "Calling the X tools..." which
|
||||
# poisoned the conversation history. Just use the
|
||||
# fallback text as the final response and break.
|
||||
final_response = self._strip_think_blocks(fallback).strip()
|
||||
self._response_was_previewed = True
|
||||
break
|
||||
|
||||
# ── Post-tool-call empty response nudge ───────────
|
||||
# The model returned empty after executing tool calls.
|
||||
# This covers two cases:
|
||||
# (a) No prior-turn content at all — model went silent
|
||||
# (b) Prior turn had content + SUBSTANTIVE tools (the
|
||||
# fallback above was skipped because the content
|
||||
# was mid-task narration, not a final answer)
|
||||
# Instead of giving up, nudge the model to continue by
|
||||
# appending a user-level hint. This is the #9400 case:
|
||||
# weaker models (mimo-v2-pro, GLM-5, etc.) sometimes
|
||||
# return empty after tool results instead of continuing
|
||||
# to the next step. One retry with a nudge usually
|
||||
# fixes it.
|
||||
_prior_was_tool = any(
|
||||
m.get("role") == "tool"
|
||||
for m in messages[-5:] # check recent messages
|
||||
)
|
||||
if (
|
||||
_prior_was_tool
|
||||
and not getattr(self, "_post_tool_empty_retried", False)
|
||||
):
|
||||
self._post_tool_empty_retried = True
|
||||
# Clear stale narration so it doesn't resurface
|
||||
# on a later empty response after the nudge.
|
||||
self._last_content_with_tools = None
|
||||
self._last_content_tools_all_housekeeping = False
|
||||
logger.info(
|
||||
"Empty response after tool calls — nudging model "
|
||||
"to continue processing"
|
||||
)
|
||||
self._emit_status(
|
||||
"⚠️ Model returned empty after tool calls — "
|
||||
"nudging to continue"
|
||||
)
|
||||
# Append the empty assistant message first so the
|
||||
# message sequence stays valid:
|
||||
# tool(result) → assistant("(empty)") → user(nudge)
|
||||
# Without this, we'd have tool → user which most
|
||||
# APIs reject as an invalid sequence.
|
||||
assistant_msg["content"] = "(empty)"
|
||||
messages.append(assistant_msg)
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": (
|
||||
"You just executed tool calls but returned an "
|
||||
"empty response. Please process the tool "
|
||||
"results above and continue with the task."
|
||||
),
|
||||
})
|
||||
continue
|
||||
|
||||
# ── Thinking-only prefill continuation ──────────
|
||||
# The model produced structured reasoning (via API
|
||||
# fields) but no visible text content. Rather than
|
||||
|
||||
@@ -95,7 +95,9 @@ AUTHOR_MAP = {
|
||||
"vincentcharlebois@gmail.com": "vincentcharlebois",
|
||||
"aryan@synvoid.com": "aryansingh",
|
||||
"johnsonblake1@gmail.com": "blakejohnson",
|
||||
"greer.guthrie@gmail.com": "g-guthrie",
|
||||
"kennyx102@gmail.com": "bobashopcashier",
|
||||
"shokatalishaikh95@gmail.com": "areu01or00",
|
||||
"bryan@intertwinesys.com": "bryanyoung",
|
||||
"christo.mitov@gmail.com": "christomitov",
|
||||
"hermes@nousresearch.com": "NousResearch",
|
||||
@@ -115,6 +117,8 @@ AUTHOR_MAP = {
|
||||
"m@statecraft.systems": "mbierling",
|
||||
"balyan.sid@gmail.com": "balyansid",
|
||||
"oluwadareab12@gmail.com": "bennytimz",
|
||||
"simon@simonmarcus.org": "simon-marcus",
|
||||
"1243352777@qq.com": "zons-zhaozhy",
|
||||
# ── bulk addition: 75 emails resolved via API, PR salvage bodies, noreply
|
||||
# crossref, and GH contributor list matching (April 2026 audit) ──
|
||||
"1115117931@qq.com": "aaronagent",
|
||||
@@ -193,6 +197,8 @@ AUTHOR_MAP = {
|
||||
"zhouboli@gmail.com": "zhouboli",
|
||||
"zqiao@microsoft.com": "tomqiaozc",
|
||||
"zzn+pa@zzn.im": "xinbenlv",
|
||||
"zaynjarvis@gmail.com": "ZaynJarvis",
|
||||
"zhiheng.liu@bytedance.com": "ZaynJarvis",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -650,9 +650,9 @@ registry.register(
|
||||
)
|
||||
```
|
||||
|
||||
**2. Add import** in `model_tools.py` → `_discover_tools()` list.
|
||||
**2. Add to `toolsets.py`** → `_HERMES_CORE_TOOLS` list.
|
||||
|
||||
**3. Add to `toolsets.py`** → `_HERMES_CORE_TOOLS` list.
|
||||
Auto-discovery: any `tools/*.py` file with a top-level `registry.register()` call is imported automatically — no manual list needed.
|
||||
|
||||
All handlers must return JSON strings. Use `get_hermes_home()` for paths, never hardcode `~/.hermes`.
|
||||
|
||||
|
||||
@@ -98,7 +98,7 @@ def find_nearby(lat: float, lon: float, types: list[str], radius: int = 1500, li
|
||||
# Get coordinates (nodes have lat/lon directly, ways/relations use center)
|
||||
plat = el.get("lat") or (el.get("center", {}) or {}).get("lat")
|
||||
plon = el.get("lon") or (el.get("center", {}) or {}).get("lon")
|
||||
if not plat or not plon:
|
||||
if plat is None or plon is None:
|
||||
continue
|
||||
|
||||
dist = haversine(lat, lon, plat, plon)
|
||||
|
||||
@@ -1,35 +1,19 @@
|
||||
---
|
||||
name: google-workspace
|
||||
description: Gmail, Calendar, Drive, Contacts, Sheets, and Docs integration via gws CLI (googleworkspace/cli). Uses OAuth2 with automatic token refresh via bridge script. Requires gws binary.
|
||||
version: 2.0.0
|
||||
description: Gmail, Calendar, Drive, Contacts, Sheets, and Docs integration for Hermes. Uses Hermes-managed OAuth2 setup, prefers the Google Workspace CLI (`gws`) when available for broader API coverage, and falls back to the Python client libraries otherwise.
|
||||
version: 1.0.0
|
||||
author: Nous Research
|
||||
license: MIT
|
||||
required_credential_files:
|
||||
- path: google_token.json
|
||||
description: Google OAuth2 token (created by setup script)
|
||||
- path: google_client_secret.json
|
||||
description: Google OAuth2 client credentials (downloaded from Google Cloud Console)
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Google, Gmail, Calendar, Drive, Sheets, Docs, Contacts, Email, OAuth, gws]
|
||||
tags: [Google, Gmail, Calendar, Drive, Sheets, Docs, Contacts, Email, OAuth]
|
||||
homepage: https://github.com/NousResearch/hermes-agent
|
||||
related_skills: [himalaya]
|
||||
---
|
||||
|
||||
# Google Workspace
|
||||
|
||||
Gmail, Calendar, Drive, Contacts, Sheets, and Docs — powered by `gws` (Google's official Rust CLI). The skill provides a backward-compatible Python wrapper that handles OAuth token refresh and delegates to `gws`.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
google_api.py → gws_bridge.py → gws CLI
|
||||
(argparse compat) (token refresh) (Google APIs)
|
||||
```
|
||||
|
||||
- `setup.py` handles OAuth2 (headless-compatible, works on CLI/Telegram/Discord)
|
||||
- `gws_bridge.py` refreshes the Hermes token and injects it into `gws` via `GOOGLE_WORKSPACE_CLI_TOKEN`
|
||||
- `google_api.py` provides the same CLI interface as v1 but delegates to `gws`
|
||||
Gmail, Calendar, Drive, Contacts, Sheets, and Docs — through Hermes-managed OAuth and a thin CLI wrapper. When `gws` is installed, the skill uses it as the execution backend for broader Google Workspace coverage; otherwise it falls back to the bundled Python client implementation.
|
||||
|
||||
## References
|
||||
|
||||
@@ -38,22 +22,7 @@ google_api.py → gws_bridge.py → gws CLI
|
||||
## Scripts
|
||||
|
||||
- `scripts/setup.py` — OAuth2 setup (run once to authorize)
|
||||
- `scripts/gws_bridge.py` — Token refresh bridge to gws CLI
|
||||
- `scripts/google_api.py` — Backward-compatible API wrapper (delegates to gws)
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Install `gws`:
|
||||
|
||||
```bash
|
||||
cargo install google-workspace-cli
|
||||
# or via npm (recommended, downloads prebuilt binary):
|
||||
npm install -g @googleworkspace/cli
|
||||
# or via Homebrew:
|
||||
brew install googleworkspace-cli
|
||||
```
|
||||
|
||||
Verify: `gws --version`
|
||||
- `scripts/google_api.py` — compatibility wrapper CLI. It prefers `gws` for operations when available, while preserving Hermes' existing JSON output contract.
|
||||
|
||||
## First-Time Setup
|
||||
|
||||
@@ -63,13 +32,7 @@ on CLI, Telegram, Discord, or any platform.
|
||||
Define a shorthand first:
|
||||
|
||||
```bash
|
||||
HERMES_HOME="${HERMES_HOME:-$HOME/.hermes}"
|
||||
GWORKSPACE_SKILL_DIR="$HERMES_HOME/skills/productivity/google-workspace"
|
||||
PYTHON_BIN="${HERMES_PYTHON:-python3}"
|
||||
if [ -x "$HERMES_HOME/hermes-agent/venv/bin/python" ]; then
|
||||
PYTHON_BIN="$HERMES_HOME/hermes-agent/venv/bin/python"
|
||||
fi
|
||||
GSETUP="$PYTHON_BIN $GWORKSPACE_SKILL_DIR/scripts/setup.py"
|
||||
GSETUP="python ~/.hermes/skills/productivity/google-workspace/scripts/setup.py"
|
||||
```
|
||||
|
||||
### Step 0: Check if already set up
|
||||
@@ -82,88 +45,166 @@ If it prints `AUTHENTICATED`, skip to Usage — setup is already done.
|
||||
|
||||
### Step 1: Triage — ask the user what they need
|
||||
|
||||
Before starting OAuth setup, ask the user TWO questions:
|
||||
|
||||
**Question 1: "What Google services do you need? Just email, or also
|
||||
Calendar/Drive/Sheets/Docs?"**
|
||||
|
||||
- **Email only** → Use the `himalaya` skill instead — simpler setup.
|
||||
- **Calendar, Drive, Sheets, Docs (or email + these)** → Continue below.
|
||||
- **Email only** → They don't need this skill at all. Use the `himalaya` skill
|
||||
instead — it works with a Gmail App Password (Settings → Security → App
|
||||
Passwords) and takes 2 minutes to set up. No Google Cloud project needed.
|
||||
Load the himalaya skill and follow its setup instructions.
|
||||
|
||||
**Partial scopes**: Users can authorize only a subset of services. The setup
|
||||
script accepts partial scopes and warns about missing ones.
|
||||
- **Email + Calendar** → Continue with this skill, but use
|
||||
`--services email,calendar` during auth so the consent screen only asks for
|
||||
the scopes they actually need.
|
||||
|
||||
**Question 2: "Does your Google account use Advanced Protection?"**
|
||||
- **Calendar/Drive/Sheets/Docs only** → Continue with this skill and use a
|
||||
narrower `--services` set like `calendar,drive,sheets,docs`.
|
||||
|
||||
- **No / Not sure** → Normal setup.
|
||||
- **Yes** → Workspace admin must add the OAuth client ID to allowed apps first.
|
||||
- **Full Workspace access** → Continue with this skill and use the default
|
||||
`all` service set.
|
||||
|
||||
**Question 2: "Does your Google account use Advanced Protection (hardware
|
||||
security keys required to sign in)? If you're not sure, you probably don't
|
||||
— it's something you would have explicitly enrolled in."**
|
||||
|
||||
- **No / Not sure** → Normal setup. Continue below.
|
||||
- **Yes** → Their Workspace admin must add the OAuth client ID to the org's
|
||||
allowed apps list before Step 4 will work. Let them know upfront.
|
||||
|
||||
### Step 2: Create OAuth credentials (one-time, ~5 minutes)
|
||||
|
||||
Tell the user:
|
||||
|
||||
> 1. Go to https://console.cloud.google.com/apis/credentials
|
||||
> 2. Create a project (or use an existing one)
|
||||
> 3. Enable the APIs you need (Gmail, Calendar, Drive, Sheets, Docs, People)
|
||||
> 4. Credentials → Create Credentials → OAuth 2.0 Client ID → Desktop app
|
||||
> 5. Download JSON and tell me the file path
|
||||
> You need a Google Cloud OAuth client. This is a one-time setup:
|
||||
>
|
||||
> 1. Create or select a project:
|
||||
> https://console.cloud.google.com/projectselector2/home/dashboard
|
||||
> 2. Enable the required APIs from the API Library:
|
||||
> https://console.cloud.google.com/apis/library
|
||||
> Enable: Gmail API, Google Calendar API, Google Drive API,
|
||||
> Google Sheets API, Google Docs API, People API
|
||||
> 3. Create the OAuth client here:
|
||||
> https://console.cloud.google.com/apis/credentials
|
||||
> Credentials → Create Credentials → OAuth 2.0 Client ID
|
||||
> 4. Application type: "Desktop app" → Create
|
||||
> 5. If the app is still in Testing, add the user's Google account as a test user here:
|
||||
> https://console.cloud.google.com/auth/audience
|
||||
> Audience → Test users → Add users
|
||||
> 6. Download the JSON file and tell me the file path
|
||||
>
|
||||
> Important Hermes CLI note: if the file path starts with `/`, do NOT send only the bare path as its own message in the CLI, because it can be mistaken for a slash command. Send it in a sentence instead, like:
|
||||
> `The JSON file path is: /home/user/Downloads/client_secret_....json`
|
||||
|
||||
Once they provide the path:
|
||||
|
||||
```bash
|
||||
$GSETUP --client-secret /path/to/client_secret.json
|
||||
```
|
||||
|
||||
If they paste the raw client ID / client secret values instead of a file path,
|
||||
write a valid Desktop OAuth JSON file for them yourself, save it somewhere
|
||||
explicit (for example `~/Downloads/hermes-google-client-secret.json`), then run
|
||||
`--client-secret` against that file.
|
||||
|
||||
### Step 3: Get authorization URL
|
||||
|
||||
Use the service set chosen in Step 1. Examples:
|
||||
|
||||
```bash
|
||||
$GSETUP --auth-url
|
||||
$GSETUP --auth-url --services email,calendar --format json
|
||||
$GSETUP --auth-url --services calendar,drive,sheets,docs --format json
|
||||
$GSETUP --auth-url --services all --format json
|
||||
```
|
||||
|
||||
Send the URL to the user. After authorizing, they paste back the redirect URL or code.
|
||||
This returns JSON with an `auth_url` field and also saves the exact URL to
|
||||
`~/.hermes/google_oauth_last_url.txt`.
|
||||
|
||||
Agent rules for this step:
|
||||
- Extract the `auth_url` field and send that exact URL to the user as a single line.
|
||||
- Tell the user that the browser will likely fail on `http://localhost:1` after approval, and that this is expected.
|
||||
- Tell them to copy the ENTIRE redirected URL from the browser address bar.
|
||||
- If the user gets `Error 403: access_denied`, send them directly to `https://console.cloud.google.com/auth/audience` to add themselves as a test user.
|
||||
|
||||
### Step 4: Exchange the code
|
||||
|
||||
The user will paste back either a URL like `http://localhost:1/?code=4/0A...&scope=...`
|
||||
or just the code string. Either works. The `--auth-url` step stores a temporary
|
||||
pending OAuth session locally so `--auth-code` can complete the PKCE exchange
|
||||
later, even on headless systems:
|
||||
|
||||
```bash
|
||||
$GSETUP --auth-code "THE_URL_OR_CODE_THE_USER_PASTED"
|
||||
$GSETUP --auth-code "THE_URL_OR_CODE_THE_USER_PASTED" --format json
|
||||
```
|
||||
|
||||
If `--auth-code` fails because the code expired, was already used, or came from
|
||||
an older browser tab, it now returns a fresh `fresh_auth_url`. In that case,
|
||||
immediately send the new URL to the user and have them retry with the newest
|
||||
browser redirect only.
|
||||
|
||||
### Step 5: Verify
|
||||
|
||||
```bash
|
||||
$GSETUP --check
|
||||
```
|
||||
|
||||
Should print `AUTHENTICATED`. Token refreshes automatically from now on.
|
||||
Should print `AUTHENTICATED`. Setup is complete — token refreshes automatically from now on.
|
||||
|
||||
### Notes
|
||||
|
||||
- Token is stored at `~/.hermes/google_token.json` and auto-refreshes.
|
||||
- Pending OAuth session state/verifier are stored temporarily at `~/.hermes/google_oauth_pending.json` until exchange completes.
|
||||
- If `gws` is installed, `google_api.py` points it at the same `~/.hermes/google_token.json` credentials file. Users do not need to run a separate `gws auth login` flow.
|
||||
- To revoke: `$GSETUP --revoke`
|
||||
|
||||
## Usage
|
||||
|
||||
All commands go through the API script:
|
||||
All commands go through the API script. Set `GAPI` as a shorthand:
|
||||
|
||||
```bash
|
||||
HERMES_HOME="${HERMES_HOME:-$HOME/.hermes}"
|
||||
GWORKSPACE_SKILL_DIR="$HERMES_HOME/skills/productivity/google-workspace"
|
||||
PYTHON_BIN="${HERMES_PYTHON:-python3}"
|
||||
if [ -x "$HERMES_HOME/hermes-agent/venv/bin/python" ]; then
|
||||
PYTHON_BIN="$HERMES_HOME/hermes-agent/venv/bin/python"
|
||||
fi
|
||||
GAPI="$PYTHON_BIN $GWORKSPACE_SKILL_DIR/scripts/google_api.py"
|
||||
GAPI="python ~/.hermes/skills/productivity/google-workspace/scripts/google_api.py"
|
||||
```
|
||||
|
||||
### Gmail
|
||||
|
||||
```bash
|
||||
# Search (returns JSON array with id, from, subject, date, snippet)
|
||||
$GAPI gmail search "is:unread" --max 10
|
||||
$GAPI gmail search "from:boss@company.com newer_than:1d"
|
||||
$GAPI gmail search "has:attachment filename:pdf newer_than:7d"
|
||||
|
||||
# Read full message (returns JSON with body text)
|
||||
$GAPI gmail get MESSAGE_ID
|
||||
|
||||
# Send
|
||||
$GAPI gmail send --to user@example.com --subject "Hello" --body "Message text"
|
||||
$GAPI gmail send --to user@example.com --subject "Report" --body "<h1>Q4</h1>" --html
|
||||
$GAPI gmail send --to user@example.com --subject "Report" --body "<h1>Q4</h1><p>Details...</p>" --html
|
||||
$GAPI gmail send --to user@example.com --subject "Hello" --from '"Research Agent" <user@example.com>' --body "Message text"
|
||||
|
||||
# Reply (automatically threads and sets In-Reply-To)
|
||||
$GAPI gmail reply MESSAGE_ID --body "Thanks, that works for me."
|
||||
$GAPI gmail reply MESSAGE_ID --from '"Support Bot" <user@example.com>' --body "Thanks"
|
||||
|
||||
# Labels
|
||||
$GAPI gmail labels
|
||||
$GAPI gmail modify MESSAGE_ID --add-labels LABEL_ID
|
||||
$GAPI gmail modify MESSAGE_ID --remove-labels UNREAD
|
||||
```
|
||||
|
||||
### Calendar
|
||||
|
||||
```bash
|
||||
# List events (defaults to next 7 days)
|
||||
$GAPI calendar list
|
||||
$GAPI calendar create --summary "Standup" --start 2026-03-01T10:00:00+01:00 --end 2026-03-01T10:30:00+01:00
|
||||
$GAPI calendar create --summary "Review" --start ... --end ... --attendees "alice@co.com,bob@co.com"
|
||||
$GAPI calendar list --start 2026-03-01T00:00:00Z --end 2026-03-07T23:59:59Z
|
||||
|
||||
# Create event (ISO 8601 with timezone required)
|
||||
$GAPI calendar create --summary "Team Standup" --start 2026-03-01T10:00:00-06:00 --end 2026-03-01T10:30:00-06:00
|
||||
$GAPI calendar create --summary "Lunch" --start 2026-03-01T12:00:00Z --end 2026-03-01T13:00:00Z --location "Cafe"
|
||||
$GAPI calendar create --summary "Review" --start 2026-03-01T14:00:00Z --end 2026-03-01T15:00:00Z --attendees "alice@co.com,bob@co.com"
|
||||
|
||||
# Delete event
|
||||
$GAPI calendar delete EVENT_ID
|
||||
```
|
||||
|
||||
@@ -183,8 +224,13 @@ $GAPI contacts list --max 20
|
||||
### Sheets
|
||||
|
||||
```bash
|
||||
# Read
|
||||
$GAPI sheets get SHEET_ID "Sheet1!A1:D10"
|
||||
|
||||
# Write
|
||||
$GAPI sheets update SHEET_ID "Sheet1!A1:B2" --values '[["Name","Score"],["Alice","95"]]'
|
||||
|
||||
# Append rows
|
||||
$GAPI sheets append SHEET_ID "Sheet1!A:C" --values '[["new","row","data"]]'
|
||||
```
|
||||
|
||||
@@ -194,52 +240,37 @@ $GAPI sheets append SHEET_ID "Sheet1!A:C" --values '[["new","row","data"]]'
|
||||
$GAPI docs get DOC_ID
|
||||
```
|
||||
|
||||
### Direct gws access (advanced)
|
||||
|
||||
For operations not covered by the wrapper, use `gws_bridge.py` directly:
|
||||
|
||||
```bash
|
||||
GBRIDGE="$PYTHON_BIN $GWORKSPACE_SKILL_DIR/scripts/gws_bridge.py"
|
||||
$GBRIDGE calendar +agenda --today --format table
|
||||
$GBRIDGE gmail +triage --labels --format json
|
||||
$GBRIDGE drive +upload ./report.pdf
|
||||
$GBRIDGE sheets +read --spreadsheet SHEET_ID --range "Sheet1!A1:D10"
|
||||
```
|
||||
|
||||
## Output Format
|
||||
|
||||
All commands return JSON via `gws --format json`. Key output shapes:
|
||||
All commands return JSON. Parse with `jq` or read directly. Key fields:
|
||||
|
||||
- **Gmail search/triage**: Array of message summaries (sender, subject, date, snippet)
|
||||
- **Gmail get/read**: Message object with headers and body text
|
||||
- **Gmail send/reply**: Confirmation with message ID
|
||||
- **Calendar list/agenda**: Array of event objects (summary, start, end, location)
|
||||
- **Calendar create**: Confirmation with event ID and htmlLink
|
||||
- **Drive search**: Array of file objects (id, name, mimeType, webViewLink)
|
||||
- **Sheets get/read**: 2D array of cell values
|
||||
- **Docs get**: Full document JSON (use `body.content` for text extraction)
|
||||
- **Contacts list**: Array of person objects with names, emails, phones
|
||||
|
||||
Parse output with `jq` or read JSON directly.
|
||||
- **Gmail search**: `[{id, threadId, from, to, subject, date, snippet, labels}]`
|
||||
- **Gmail get**: `{id, threadId, from, to, subject, date, labels, body}`
|
||||
- **Gmail send/reply**: `{status: "sent", id, threadId}`
|
||||
- **Calendar list**: `[{id, summary, start, end, location, description, htmlLink}]`
|
||||
- **Calendar create**: `{status: "created", id, summary, htmlLink}`
|
||||
- **Drive search**: `[{id, name, mimeType, modifiedTime, webViewLink}]`
|
||||
- **Contacts list**: `[{name, emails: [...], phones: [...]}]`
|
||||
- **Sheets get**: `[[cell, cell, ...], ...]`
|
||||
|
||||
## Rules
|
||||
|
||||
1. **Never send email or create/delete events without confirming with the user first.**
|
||||
2. **Check auth before first use** — run `setup.py --check`.
|
||||
3. **Use the Gmail search syntax reference** for complex queries.
|
||||
4. **Calendar times must include timezone** — ISO 8601 with offset or UTC.
|
||||
5. **Respect rate limits** — avoid rapid-fire sequential API calls.
|
||||
1. **Never send email or create/delete events without confirming with the user first.** Show the draft content and ask for approval.
|
||||
2. **Check auth before first use** — run `setup.py --check`. If it fails, guide the user through setup.
|
||||
3. **Use the Gmail search syntax reference** for complex queries — load it with `skill_view("google-workspace", file_path="references/gmail-search-syntax.md")`.
|
||||
4. **Calendar times must include timezone** — always use ISO 8601 with offset (e.g., `2026-03-01T10:00:00-06:00`) or UTC (`Z`).
|
||||
5. **Respect rate limits** — avoid rapid-fire sequential API calls. Batch reads when possible.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
| Problem | Fix |
|
||||
|---------|-----|
|
||||
| `NOT_AUTHENTICATED` | Run setup Steps 2-5 |
|
||||
| `REFRESH_FAILED` | Token revoked — redo Steps 3-5 |
|
||||
| `gws: command not found` | Install: `npm install -g @googleworkspace/cli` |
|
||||
| `HttpError 403` | Missing scope — `$GSETUP --revoke` then redo Steps 3-5 |
|
||||
| `HttpError 403: Access Not Configured` | Enable API in Google Cloud Console |
|
||||
| Advanced Protection blocks auth | Admin must allowlist the OAuth client ID |
|
||||
| `NOT_AUTHENTICATED` | Run setup Steps 2-5 above |
|
||||
| `REFRESH_FAILED` | Token revoked or expired — redo Steps 3-5 |
|
||||
| `HttpError 403: Insufficient Permission` | Missing API scope — `$GSETUP --revoke` then redo Steps 3-5 |
|
||||
| `HttpError 403: Access Not Configured` | API not enabled — user needs to enable it in Google Cloud Console |
|
||||
| `ModuleNotFoundError` | Run `$GSETUP --install-deps` |
|
||||
| Advanced Protection blocks auth | Workspace admin must allowlist the OAuth client ID |
|
||||
|
||||
## Revoking Access
|
||||
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Google Workspace API CLI for Hermes Agent.
|
||||
|
||||
Thin wrapper that delegates to gws (googleworkspace/cli) via gws_bridge.py.
|
||||
Maintains the same CLI interface for backward compatibility with Hermes skills.
|
||||
Uses the Google Workspace CLI (`gws`) when available, but preserves the
|
||||
existing Hermes-facing JSON contract and falls back to the Python client
|
||||
libraries if `gws` is not installed.
|
||||
|
||||
Usage:
|
||||
python google_api.py gmail search "is:unread" [--max 10]
|
||||
python google_api.py gmail get MESSAGE_ID
|
||||
python google_api.py gmail send --to user@example.com --subject "Hi" --body "Hello"
|
||||
python google_api.py gmail reply MESSAGE_ID --body "Thanks"
|
||||
python google_api.py calendar list [--start DATE] [--end DATE] [--calendar primary]
|
||||
python google_api.py calendar list [--from DATE] [--to DATE] [--calendar primary]
|
||||
python google_api.py calendar create --summary "Meeting" --start DATETIME --end DATETIME
|
||||
python google_api.py calendar delete EVENT_ID
|
||||
python google_api.py drive search "budget report" [--max 10]
|
||||
python google_api.py contacts list [--max 20]
|
||||
python google_api.py sheets get SHEET_ID RANGE
|
||||
@@ -21,47 +21,396 @@ Usage:
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from email.mime.text import MIMEText
|
||||
from pathlib import Path
|
||||
|
||||
BRIDGE = Path(__file__).parent / "gws_bridge.py"
|
||||
PYTHON = sys.executable
|
||||
HERMES_HOME = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
TOKEN_PATH = HERMES_HOME / "google_token.json"
|
||||
CLIENT_SECRET_PATH = HERMES_HOME / "google_client_secret.json"
|
||||
|
||||
SCOPES = [
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
"https://www.googleapis.com/auth/gmail.send",
|
||||
"https://www.googleapis.com/auth/gmail.modify",
|
||||
"https://www.googleapis.com/auth/calendar",
|
||||
"https://www.googleapis.com/auth/drive.readonly",
|
||||
"https://www.googleapis.com/auth/contacts.readonly",
|
||||
"https://www.googleapis.com/auth/spreadsheets",
|
||||
"https://www.googleapis.com/auth/documents.readonly",
|
||||
]
|
||||
|
||||
|
||||
def gws(*args: str) -> None:
|
||||
"""Call gws via the bridge and exit with its return code."""
|
||||
def _ensure_authenticated():
|
||||
if not TOKEN_PATH.exists():
|
||||
print("Not authenticated. Run the setup script first:", file=sys.stderr)
|
||||
print(f" python {Path(__file__).parent / 'setup.py'}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def _stored_token_scopes() -> list[str]:
|
||||
try:
|
||||
data = json.loads(TOKEN_PATH.read_text())
|
||||
except Exception:
|
||||
return list(SCOPES)
|
||||
scopes = data.get("scopes")
|
||||
if isinstance(scopes, list) and scopes:
|
||||
return scopes
|
||||
return list(SCOPES)
|
||||
|
||||
|
||||
def _gws_binary() -> str | None:
|
||||
override = os.getenv("HERMES_GWS_BIN")
|
||||
if override:
|
||||
return override
|
||||
return shutil.which("gws")
|
||||
|
||||
|
||||
def _gws_env() -> dict[str, str]:
|
||||
env = os.environ.copy()
|
||||
env["GOOGLE_WORKSPACE_CLI_CREDENTIALS_FILE"] = str(TOKEN_PATH)
|
||||
return env
|
||||
|
||||
|
||||
def _run_gws(parts: list[str], *, params: dict | None = None, body: dict | None = None):
|
||||
binary = _gws_binary()
|
||||
if not binary:
|
||||
raise RuntimeError("gws not installed")
|
||||
|
||||
_ensure_authenticated()
|
||||
|
||||
cmd = [binary, *parts]
|
||||
if params is not None:
|
||||
cmd.extend(["--params", json.dumps(params)])
|
||||
if body is not None:
|
||||
cmd.extend(["--json", json.dumps(body)])
|
||||
|
||||
result = subprocess.run(
|
||||
[PYTHON, str(BRIDGE)] + list(args),
|
||||
env={**os.environ, "HERMES_HOME": os.environ.get("HERMES_HOME", str(Path.home() / ".hermes"))},
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
env=_gws_env(),
|
||||
)
|
||||
sys.exit(result.returncode)
|
||||
if result.returncode != 0:
|
||||
err = result.stderr.strip() or result.stdout.strip() or "Unknown gws error"
|
||||
print(err, file=sys.stderr)
|
||||
sys.exit(result.returncode or 1)
|
||||
|
||||
stdout = result.stdout.strip()
|
||||
if not stdout:
|
||||
return {}
|
||||
|
||||
try:
|
||||
return json.loads(stdout)
|
||||
except json.JSONDecodeError:
|
||||
print("ERROR: Unexpected non-JSON output from gws:", file=sys.stderr)
|
||||
print(stdout, file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# -- Gmail --
|
||||
def _headers_dict(msg: dict) -> dict[str, str]:
|
||||
return {h["name"]: h["value"] for h in msg.get("payload", {}).get("headers", [])}
|
||||
|
||||
|
||||
def _extract_message_body(msg: dict) -> str:
|
||||
body = ""
|
||||
payload = msg.get("payload", {})
|
||||
if payload.get("body", {}).get("data"):
|
||||
body = base64.urlsafe_b64decode(payload["body"]["data"]).decode("utf-8", errors="replace")
|
||||
elif payload.get("parts"):
|
||||
for part in payload["parts"]:
|
||||
if part.get("mimeType") == "text/plain" and part.get("body", {}).get("data"):
|
||||
body = base64.urlsafe_b64decode(part["body"]["data"]).decode("utf-8", errors="replace")
|
||||
break
|
||||
if not body:
|
||||
for part in payload["parts"]:
|
||||
if part.get("mimeType") == "text/html" and part.get("body", {}).get("data"):
|
||||
body = base64.urlsafe_b64decode(part["body"]["data"]).decode("utf-8", errors="replace")
|
||||
break
|
||||
return body
|
||||
|
||||
|
||||
def _extract_doc_text(doc: dict) -> str:
|
||||
text_parts = []
|
||||
for element in doc.get("body", {}).get("content", []):
|
||||
paragraph = element.get("paragraph", {})
|
||||
for pe in paragraph.get("elements", []):
|
||||
text_run = pe.get("textRun", {})
|
||||
if text_run.get("content"):
|
||||
text_parts.append(text_run["content"])
|
||||
return "".join(text_parts)
|
||||
|
||||
|
||||
def _datetime_with_timezone(value: str) -> str:
|
||||
if not value:
|
||||
return value
|
||||
if "T" not in value:
|
||||
return value
|
||||
if value.endswith("Z"):
|
||||
return value
|
||||
tail = value[10:]
|
||||
if "+" in tail or "-" in tail:
|
||||
return value
|
||||
return value + "Z"
|
||||
|
||||
|
||||
def get_credentials():
|
||||
"""Load and refresh credentials from token file."""
|
||||
_ensure_authenticated()
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from google.auth.transport.requests import Request
|
||||
|
||||
creds = Credentials.from_authorized_user_file(str(TOKEN_PATH), _stored_token_scopes())
|
||||
if creds.expired and creds.refresh_token:
|
||||
creds.refresh(Request())
|
||||
TOKEN_PATH.write_text(creds.to_json())
|
||||
if not creds.valid:
|
||||
print("Token is invalid. Re-run setup.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
return creds
|
||||
|
||||
|
||||
def build_service(api, version):
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
return build(api, version, credentials=get_credentials())
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Gmail
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def gmail_search(args):
|
||||
cmd = ["gmail", "+triage", "--query", args.query, "--max", str(args.max), "--format", "json"]
|
||||
gws(*cmd)
|
||||
if _gws_binary():
|
||||
results = _run_gws(
|
||||
["gmail", "users", "messages", "list"],
|
||||
params={"userId": "me", "q": args.query, "maxResults": args.max},
|
||||
)
|
||||
messages = results.get("messages", [])
|
||||
output = []
|
||||
for msg_meta in messages:
|
||||
msg = _run_gws(
|
||||
["gmail", "users", "messages", "get"],
|
||||
params={
|
||||
"userId": "me",
|
||||
"id": msg_meta["id"],
|
||||
"format": "metadata",
|
||||
"metadataHeaders": ["From", "To", "Subject", "Date"],
|
||||
},
|
||||
)
|
||||
headers = _headers_dict(msg)
|
||||
output.append(
|
||||
{
|
||||
"id": msg["id"],
|
||||
"threadId": msg["threadId"],
|
||||
"from": headers.get("From", ""),
|
||||
"to": headers.get("To", ""),
|
||||
"subject": headers.get("Subject", ""),
|
||||
"date": headers.get("Date", ""),
|
||||
"snippet": msg.get("snippet", ""),
|
||||
"labels": msg.get("labelIds", []),
|
||||
}
|
||||
)
|
||||
print(json.dumps(output, indent=2, ensure_ascii=False))
|
||||
return
|
||||
|
||||
service = build_service("gmail", "v1")
|
||||
results = service.users().messages().list(
|
||||
userId="me", q=args.query, maxResults=args.max
|
||||
).execute()
|
||||
messages = results.get("messages", [])
|
||||
if not messages:
|
||||
print("No messages found.")
|
||||
return
|
||||
|
||||
output = []
|
||||
for msg_meta in messages:
|
||||
msg = service.users().messages().get(
|
||||
userId="me", id=msg_meta["id"], format="metadata",
|
||||
metadataHeaders=["From", "To", "Subject", "Date"],
|
||||
).execute()
|
||||
headers = _headers_dict(msg)
|
||||
output.append({
|
||||
"id": msg["id"],
|
||||
"threadId": msg["threadId"],
|
||||
"from": headers.get("From", ""),
|
||||
"to": headers.get("To", ""),
|
||||
"subject": headers.get("Subject", ""),
|
||||
"date": headers.get("Date", ""),
|
||||
"snippet": msg.get("snippet", ""),
|
||||
"labels": msg.get("labelIds", []),
|
||||
})
|
||||
print(json.dumps(output, indent=2, ensure_ascii=False))
|
||||
|
||||
|
||||
|
||||
def gmail_get(args):
|
||||
gws("gmail", "+read", "--id", args.message_id, "--headers", "--format", "json")
|
||||
if _gws_binary():
|
||||
msg = _run_gws(
|
||||
["gmail", "users", "messages", "get"],
|
||||
params={"userId": "me", "id": args.message_id, "format": "full"},
|
||||
)
|
||||
headers = _headers_dict(msg)
|
||||
result = {
|
||||
"id": msg["id"],
|
||||
"threadId": msg["threadId"],
|
||||
"from": headers.get("From", ""),
|
||||
"to": headers.get("To", ""),
|
||||
"subject": headers.get("Subject", ""),
|
||||
"date": headers.get("Date", ""),
|
||||
"labels": msg.get("labelIds", []),
|
||||
"body": _extract_message_body(msg),
|
||||
}
|
||||
print(json.dumps(result, indent=2, ensure_ascii=False))
|
||||
return
|
||||
|
||||
service = build_service("gmail", "v1")
|
||||
msg = service.users().messages().get(
|
||||
userId="me", id=args.message_id, format="full"
|
||||
).execute()
|
||||
|
||||
headers = _headers_dict(msg)
|
||||
result = {
|
||||
"id": msg["id"],
|
||||
"threadId": msg["threadId"],
|
||||
"from": headers.get("From", ""),
|
||||
"to": headers.get("To", ""),
|
||||
"subject": headers.get("Subject", ""),
|
||||
"date": headers.get("Date", ""),
|
||||
"labels": msg.get("labelIds", []),
|
||||
"body": _extract_message_body(msg),
|
||||
}
|
||||
print(json.dumps(result, indent=2, ensure_ascii=False))
|
||||
|
||||
|
||||
|
||||
def gmail_send(args):
|
||||
cmd = ["gmail", "+send", "--to", args.to, "--subject", args.subject, "--body", args.body, "--format", "json"]
|
||||
if _gws_binary():
|
||||
message = MIMEText(args.body, "html" if args.html else "plain")
|
||||
message["to"] = args.to
|
||||
message["subject"] = args.subject
|
||||
if args.cc:
|
||||
message["cc"] = args.cc
|
||||
if args.from_header:
|
||||
message["from"] = args.from_header
|
||||
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
body = {"raw": raw}
|
||||
if args.thread_id:
|
||||
body["threadId"] = args.thread_id
|
||||
|
||||
result = _run_gws(
|
||||
["gmail", "users", "messages", "send"],
|
||||
params={"userId": "me"},
|
||||
body=body,
|
||||
)
|
||||
print(json.dumps({"status": "sent", "id": result["id"], "threadId": result.get("threadId", "")}, indent=2))
|
||||
return
|
||||
|
||||
service = build_service("gmail", "v1")
|
||||
message = MIMEText(args.body, "html" if args.html else "plain")
|
||||
message["to"] = args.to
|
||||
message["subject"] = args.subject
|
||||
if args.cc:
|
||||
cmd += ["--cc", args.cc]
|
||||
if args.html:
|
||||
cmd.append("--html")
|
||||
gws(*cmd)
|
||||
message["cc"] = args.cc
|
||||
if args.from_header:
|
||||
message["from"] = args.from_header
|
||||
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
body = {"raw": raw}
|
||||
|
||||
if args.thread_id:
|
||||
body["threadId"] = args.thread_id
|
||||
|
||||
result = service.users().messages().send(userId="me", body=body).execute()
|
||||
print(json.dumps({"status": "sent", "id": result["id"], "threadId": result.get("threadId", "")}, indent=2))
|
||||
|
||||
|
||||
|
||||
def gmail_reply(args):
|
||||
gws("gmail", "+reply", "--message-id", args.message_id, "--body", args.body, "--format", "json")
|
||||
if _gws_binary():
|
||||
original = _run_gws(
|
||||
["gmail", "users", "messages", "get"],
|
||||
params={
|
||||
"userId": "me",
|
||||
"id": args.message_id,
|
||||
"format": "metadata",
|
||||
"metadataHeaders": ["From", "Subject", "Message-ID"],
|
||||
},
|
||||
)
|
||||
headers = _headers_dict(original)
|
||||
|
||||
subject = headers.get("Subject", "")
|
||||
if not subject.startswith("Re:"):
|
||||
subject = f"Re: {subject}"
|
||||
|
||||
message = MIMEText(args.body)
|
||||
message["to"] = headers.get("From", "")
|
||||
message["subject"] = subject
|
||||
if args.from_header:
|
||||
message["from"] = args.from_header
|
||||
if headers.get("Message-ID"):
|
||||
message["In-Reply-To"] = headers["Message-ID"]
|
||||
message["References"] = headers["Message-ID"]
|
||||
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
result = _run_gws(
|
||||
["gmail", "users", "messages", "send"],
|
||||
params={"userId": "me"},
|
||||
body={"raw": raw, "threadId": original["threadId"]},
|
||||
)
|
||||
print(json.dumps({"status": "sent", "id": result["id"], "threadId": result.get("threadId", "")}, indent=2))
|
||||
return
|
||||
|
||||
service = build_service("gmail", "v1")
|
||||
original = service.users().messages().get(
|
||||
userId="me", id=args.message_id, format="metadata",
|
||||
metadataHeaders=["From", "Subject", "Message-ID"],
|
||||
).execute()
|
||||
headers = _headers_dict(original)
|
||||
|
||||
subject = headers.get("Subject", "")
|
||||
if not subject.startswith("Re:"):
|
||||
subject = f"Re: {subject}"
|
||||
|
||||
message = MIMEText(args.body)
|
||||
message["to"] = headers.get("From", "")
|
||||
message["subject"] = subject
|
||||
if args.from_header:
|
||||
message["from"] = args.from_header
|
||||
if headers.get("Message-ID"):
|
||||
message["In-Reply-To"] = headers["Message-ID"]
|
||||
message["References"] = headers["Message-ID"]
|
||||
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
body = {"raw": raw, "threadId": original["threadId"]}
|
||||
|
||||
result = service.users().messages().send(userId="me", body=body).execute()
|
||||
print(json.dumps({"status": "sent", "id": result["id"], "threadId": result.get("threadId", "")}, indent=2))
|
||||
|
||||
|
||||
|
||||
def gmail_labels(args):
|
||||
gws("gmail", "users", "labels", "list", "--params", json.dumps({"userId": "me"}), "--format", "json")
|
||||
if _gws_binary():
|
||||
results = _run_gws(["gmail", "users", "labels", "list"], params={"userId": "me"})
|
||||
labels = [{"id": l["id"], "name": l["name"], "type": l.get("type", "")} for l in results.get("labels", [])]
|
||||
print(json.dumps(labels, indent=2))
|
||||
return
|
||||
|
||||
service = build_service("gmail", "v1")
|
||||
results = service.users().labels().list(userId="me").execute()
|
||||
labels = [{"id": l["id"], "name": l["name"], "type": l.get("type", "")} for l in results.get("labels", [])]
|
||||
print(json.dumps(labels, indent=2))
|
||||
|
||||
|
||||
|
||||
def gmail_modify(args):
|
||||
body = {}
|
||||
@@ -69,145 +418,310 @@ def gmail_modify(args):
|
||||
body["addLabelIds"] = args.add_labels.split(",")
|
||||
if args.remove_labels:
|
||||
body["removeLabelIds"] = args.remove_labels.split(",")
|
||||
gws(
|
||||
"gmail", "users", "messages", "modify",
|
||||
"--params", json.dumps({"userId": "me", "id": args.message_id}),
|
||||
"--json", json.dumps(body),
|
||||
"--format", "json",
|
||||
)
|
||||
|
||||
if _gws_binary():
|
||||
result = _run_gws(
|
||||
["gmail", "users", "messages", "modify"],
|
||||
params={"userId": "me", "id": args.message_id},
|
||||
body=body,
|
||||
)
|
||||
print(json.dumps({"id": result["id"], "labels": result.get("labelIds", [])}, indent=2))
|
||||
return
|
||||
|
||||
service = build_service("gmail", "v1")
|
||||
result = service.users().messages().modify(userId="me", id=args.message_id, body=body).execute()
|
||||
print(json.dumps({"id": result["id"], "labels": result.get("labelIds", [])}, indent=2))
|
||||
|
||||
|
||||
# -- Calendar --
|
||||
# =========================================================================
|
||||
# Calendar
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def calendar_list(args):
|
||||
if args.start or args.end:
|
||||
# Specific date range — use raw Calendar API for precise timeMin/timeMax
|
||||
from datetime import datetime, timedelta, timezone as tz
|
||||
now = datetime.now(tz.utc)
|
||||
time_min = args.start or now.isoformat()
|
||||
time_max = args.end or (now + timedelta(days=7)).isoformat()
|
||||
gws(
|
||||
"calendar", "events", "list",
|
||||
"--params", json.dumps({
|
||||
now = datetime.now(timezone.utc)
|
||||
time_min = _datetime_with_timezone(args.start or now.isoformat())
|
||||
time_max = _datetime_with_timezone(args.end or (now + timedelta(days=7)).isoformat())
|
||||
|
||||
if _gws_binary():
|
||||
results = _run_gws(
|
||||
["calendar", "events", "list"],
|
||||
params={
|
||||
"calendarId": args.calendar,
|
||||
"timeMin": time_min,
|
||||
"timeMax": time_max,
|
||||
"maxResults": args.max,
|
||||
"singleEvents": True,
|
||||
"orderBy": "startTime",
|
||||
}),
|
||||
"--format", "json",
|
||||
},
|
||||
)
|
||||
else:
|
||||
# No date range — use +agenda helper (defaults to 7 days)
|
||||
cmd = ["calendar", "+agenda", "--days", "7", "--format", "json"]
|
||||
if args.calendar != "primary":
|
||||
cmd += ["--calendar", args.calendar]
|
||||
gws(*cmd)
|
||||
events = []
|
||||
for e in results.get("items", []):
|
||||
events.append({
|
||||
"id": e["id"],
|
||||
"summary": e.get("summary", "(no title)"),
|
||||
"start": e.get("start", {}).get("dateTime", e.get("start", {}).get("date", "")),
|
||||
"end": e.get("end", {}).get("dateTime", e.get("end", {}).get("date", "")),
|
||||
"location": e.get("location", ""),
|
||||
"description": e.get("description", ""),
|
||||
"status": e.get("status", ""),
|
||||
"htmlLink": e.get("htmlLink", ""),
|
||||
})
|
||||
print(json.dumps(events, indent=2, ensure_ascii=False))
|
||||
return
|
||||
|
||||
service = build_service("calendar", "v3")
|
||||
results = service.events().list(
|
||||
calendarId=args.calendar, timeMin=time_min, timeMax=time_max,
|
||||
maxResults=args.max, singleEvents=True, orderBy="startTime",
|
||||
).execute()
|
||||
|
||||
events = []
|
||||
for e in results.get("items", []):
|
||||
events.append({
|
||||
"id": e["id"],
|
||||
"summary": e.get("summary", "(no title)"),
|
||||
"start": e.get("start", {}).get("dateTime", e.get("start", {}).get("date", "")),
|
||||
"end": e.get("end", {}).get("dateTime", e.get("end", {}).get("date", "")),
|
||||
"location": e.get("location", ""),
|
||||
"description": e.get("description", ""),
|
||||
"status": e.get("status", ""),
|
||||
"htmlLink": e.get("htmlLink", ""),
|
||||
})
|
||||
print(json.dumps(events, indent=2, ensure_ascii=False))
|
||||
|
||||
|
||||
|
||||
def calendar_create(args):
|
||||
cmd = [
|
||||
"calendar", "+insert",
|
||||
"--summary", args.summary,
|
||||
"--start", args.start,
|
||||
"--end", args.end,
|
||||
"--format", "json",
|
||||
]
|
||||
event = {
|
||||
"summary": args.summary,
|
||||
"start": {"dateTime": args.start},
|
||||
"end": {"dateTime": args.end},
|
||||
}
|
||||
if args.location:
|
||||
cmd += ["--location", args.location]
|
||||
event["location"] = args.location
|
||||
if args.description:
|
||||
cmd += ["--description", args.description]
|
||||
event["description"] = args.description
|
||||
if args.attendees:
|
||||
for email in args.attendees.split(","):
|
||||
cmd += ["--attendee", email.strip()]
|
||||
if args.calendar != "primary":
|
||||
cmd += ["--calendar", args.calendar]
|
||||
gws(*cmd)
|
||||
event["attendees"] = [{"email": e.strip()} for e in args.attendees.split(",") if e.strip()]
|
||||
|
||||
if _gws_binary():
|
||||
result = _run_gws(
|
||||
["calendar", "events", "insert"],
|
||||
params={"calendarId": args.calendar},
|
||||
body=event,
|
||||
)
|
||||
print(json.dumps({
|
||||
"status": "created",
|
||||
"id": result["id"],
|
||||
"summary": result.get("summary", ""),
|
||||
"htmlLink": result.get("htmlLink", ""),
|
||||
}, indent=2))
|
||||
return
|
||||
|
||||
service = build_service("calendar", "v3")
|
||||
result = service.events().insert(calendarId=args.calendar, body=event).execute()
|
||||
print(json.dumps({
|
||||
"status": "created",
|
||||
"id": result["id"],
|
||||
"summary": result.get("summary", ""),
|
||||
"htmlLink": result.get("htmlLink", ""),
|
||||
}, indent=2))
|
||||
|
||||
|
||||
|
||||
def calendar_delete(args):
|
||||
gws(
|
||||
"calendar", "events", "delete",
|
||||
"--params", json.dumps({"calendarId": args.calendar, "eventId": args.event_id}),
|
||||
"--format", "json",
|
||||
)
|
||||
if _gws_binary():
|
||||
_run_gws(["calendar", "events", "delete"], params={"calendarId": args.calendar, "eventId": args.event_id})
|
||||
print(json.dumps({"status": "deleted", "eventId": args.event_id}))
|
||||
return
|
||||
|
||||
service = build_service("calendar", "v3")
|
||||
service.events().delete(calendarId=args.calendar, eventId=args.event_id).execute()
|
||||
print(json.dumps({"status": "deleted", "eventId": args.event_id}))
|
||||
|
||||
|
||||
# -- Drive --
|
||||
# =========================================================================
|
||||
# Drive
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def drive_search(args):
|
||||
query = args.query if args.raw_query else f"fullText contains '{args.query}'"
|
||||
gws(
|
||||
"drive", "files", "list",
|
||||
"--params", json.dumps({
|
||||
"q": query,
|
||||
"pageSize": args.max,
|
||||
"fields": "files(id,name,mimeType,modifiedTime,webViewLink)",
|
||||
}),
|
||||
"--format", "json",
|
||||
)
|
||||
if _gws_binary():
|
||||
results = _run_gws(
|
||||
["drive", "files", "list"],
|
||||
params={
|
||||
"q": query,
|
||||
"pageSize": args.max,
|
||||
"fields": "files(id, name, mimeType, modifiedTime, webViewLink)",
|
||||
},
|
||||
)
|
||||
print(json.dumps(results.get("files", []), indent=2, ensure_ascii=False))
|
||||
return
|
||||
|
||||
service = build_service("drive", "v3")
|
||||
results = service.files().list(
|
||||
q=query, pageSize=args.max, fields="files(id, name, mimeType, modifiedTime, webViewLink)",
|
||||
).execute()
|
||||
files = results.get("files", [])
|
||||
print(json.dumps(files, indent=2, ensure_ascii=False))
|
||||
|
||||
|
||||
# -- Contacts --
|
||||
# =========================================================================
|
||||
# Contacts
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def contacts_list(args):
|
||||
gws(
|
||||
"people", "people", "connections", "list",
|
||||
"--params", json.dumps({
|
||||
"resourceName": "people/me",
|
||||
"pageSize": args.max,
|
||||
"personFields": "names,emailAddresses,phoneNumbers",
|
||||
}),
|
||||
"--format", "json",
|
||||
)
|
||||
if _gws_binary():
|
||||
results = _run_gws(
|
||||
["people", "people", "connections", "list"],
|
||||
params={
|
||||
"resourceName": "people/me",
|
||||
"pageSize": args.max,
|
||||
"personFields": "names,emailAddresses,phoneNumbers",
|
||||
},
|
||||
)
|
||||
contacts = []
|
||||
for person in results.get("connections", []):
|
||||
names = person.get("names", [{}])
|
||||
emails = person.get("emailAddresses", [])
|
||||
phones = person.get("phoneNumbers", [])
|
||||
contacts.append({
|
||||
"name": names[0].get("displayName", "") if names else "",
|
||||
"emails": [e.get("value", "") for e in emails],
|
||||
"phones": [p.get("value", "") for p in phones],
|
||||
})
|
||||
print(json.dumps(contacts, indent=2, ensure_ascii=False))
|
||||
return
|
||||
|
||||
service = build_service("people", "v1")
|
||||
results = service.people().connections().list(
|
||||
resourceName="people/me",
|
||||
pageSize=args.max,
|
||||
personFields="names,emailAddresses,phoneNumbers",
|
||||
).execute()
|
||||
contacts = []
|
||||
for person in results.get("connections", []):
|
||||
names = person.get("names", [{}])
|
||||
emails = person.get("emailAddresses", [])
|
||||
phones = person.get("phoneNumbers", [])
|
||||
contacts.append({
|
||||
"name": names[0].get("displayName", "") if names else "",
|
||||
"emails": [e.get("value", "") for e in emails],
|
||||
"phones": [p.get("value", "") for p in phones],
|
||||
})
|
||||
print(json.dumps(contacts, indent=2, ensure_ascii=False))
|
||||
|
||||
|
||||
# -- Sheets --
|
||||
# =========================================================================
|
||||
# Sheets
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def sheets_get(args):
|
||||
gws(
|
||||
"sheets", "+read",
|
||||
"--spreadsheet", args.sheet_id,
|
||||
"--range", args.range,
|
||||
"--format", "json",
|
||||
)
|
||||
if _gws_binary():
|
||||
result = _run_gws(
|
||||
["sheets", "spreadsheets", "values", "get"],
|
||||
params={"spreadsheetId": args.sheet_id, "range": args.range},
|
||||
)
|
||||
print(json.dumps(result.get("values", []), indent=2, ensure_ascii=False))
|
||||
return
|
||||
|
||||
service = build_service("sheets", "v4")
|
||||
result = service.spreadsheets().values().get(
|
||||
spreadsheetId=args.sheet_id, range=args.range,
|
||||
).execute()
|
||||
print(json.dumps(result.get("values", []), indent=2, ensure_ascii=False))
|
||||
|
||||
|
||||
|
||||
def sheets_update(args):
|
||||
values = json.loads(args.values)
|
||||
gws(
|
||||
"sheets", "spreadsheets", "values", "update",
|
||||
"--params", json.dumps({
|
||||
"spreadsheetId": args.sheet_id,
|
||||
"range": args.range,
|
||||
"valueInputOption": "USER_ENTERED",
|
||||
}),
|
||||
"--json", json.dumps({"values": values}),
|
||||
"--format", "json",
|
||||
)
|
||||
body = {"values": values}
|
||||
|
||||
if _gws_binary():
|
||||
result = _run_gws(
|
||||
["sheets", "spreadsheets", "values", "update"],
|
||||
params={
|
||||
"spreadsheetId": args.sheet_id,
|
||||
"range": args.range,
|
||||
"valueInputOption": "USER_ENTERED",
|
||||
},
|
||||
body=body,
|
||||
)
|
||||
print(json.dumps({"updatedCells": result.get("updatedCells", 0), "updatedRange": result.get("updatedRange", "")}, indent=2))
|
||||
return
|
||||
|
||||
service = build_service("sheets", "v4")
|
||||
result = service.spreadsheets().values().update(
|
||||
spreadsheetId=args.sheet_id, range=args.range,
|
||||
valueInputOption="USER_ENTERED", body=body,
|
||||
).execute()
|
||||
print(json.dumps({"updatedCells": result.get("updatedCells", 0), "updatedRange": result.get("updatedRange", "")}, indent=2))
|
||||
|
||||
|
||||
|
||||
def sheets_append(args):
|
||||
values = json.loads(args.values)
|
||||
gws(
|
||||
"sheets", "+append",
|
||||
"--spreadsheet", args.sheet_id,
|
||||
"--json-values", json.dumps(values),
|
||||
"--format", "json",
|
||||
)
|
||||
body = {"values": values}
|
||||
|
||||
if _gws_binary():
|
||||
result = _run_gws(
|
||||
["sheets", "spreadsheets", "values", "append"],
|
||||
params={
|
||||
"spreadsheetId": args.sheet_id,
|
||||
"range": args.range,
|
||||
"valueInputOption": "USER_ENTERED",
|
||||
"insertDataOption": "INSERT_ROWS",
|
||||
},
|
||||
body=body,
|
||||
)
|
||||
print(json.dumps({"updatedCells": result.get("updates", {}).get("updatedCells", 0)}, indent=2))
|
||||
return
|
||||
|
||||
service = build_service("sheets", "v4")
|
||||
result = service.spreadsheets().values().append(
|
||||
spreadsheetId=args.sheet_id, range=args.range,
|
||||
valueInputOption="USER_ENTERED", insertDataOption="INSERT_ROWS", body=body,
|
||||
).execute()
|
||||
print(json.dumps({"updatedCells": result.get("updates", {}).get("updatedCells", 0)}, indent=2))
|
||||
|
||||
|
||||
# -- Docs --
|
||||
# =========================================================================
|
||||
# Docs
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def docs_get(args):
|
||||
gws(
|
||||
"docs", "documents", "get",
|
||||
"--params", json.dumps({"documentId": args.doc_id}),
|
||||
"--format", "json",
|
||||
)
|
||||
if _gws_binary():
|
||||
doc = _run_gws(["docs", "documents", "get"], params={"documentId": args.doc_id})
|
||||
result = {
|
||||
"title": doc.get("title", ""),
|
||||
"documentId": doc.get("documentId", ""),
|
||||
"body": _extract_doc_text(doc),
|
||||
}
|
||||
print(json.dumps(result, indent=2, ensure_ascii=False))
|
||||
return
|
||||
|
||||
service = build_service("docs", "v1")
|
||||
doc = service.documents().get(documentId=args.doc_id).execute()
|
||||
result = {
|
||||
"title": doc.get("title", ""),
|
||||
"documentId": doc.get("documentId", ""),
|
||||
"body": _extract_doc_text(doc),
|
||||
}
|
||||
print(json.dumps(result, indent=2, ensure_ascii=False))
|
||||
|
||||
|
||||
# -- CLI parser (backward-compatible interface) --
|
||||
# =========================================================================
|
||||
# CLI parser
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Google Workspace API for Hermes Agent (gws backend)")
|
||||
parser = argparse.ArgumentParser(description="Google Workspace API for Hermes Agent")
|
||||
sub = parser.add_subparsers(dest="service", required=True)
|
||||
|
||||
# --- Gmail ---
|
||||
@@ -228,13 +742,15 @@ def main():
|
||||
p.add_argument("--subject", required=True)
|
||||
p.add_argument("--body", required=True)
|
||||
p.add_argument("--cc", default="")
|
||||
p.add_argument("--from", dest="from_header", default="", help="Custom From header (e.g. '\"Agent Name\" <user@example.com>')")
|
||||
p.add_argument("--html", action="store_true", help="Send body as HTML")
|
||||
p.add_argument("--thread-id", default="", help="Thread ID (unused with gws, kept for compat)")
|
||||
p.add_argument("--thread-id", default="", help="Thread ID for threading")
|
||||
p.set_defaults(func=gmail_send)
|
||||
|
||||
p = gmail_sub.add_parser("reply")
|
||||
p.add_argument("message_id", help="Message ID to reply to")
|
||||
p.add_argument("--body", required=True)
|
||||
p.add_argument("--from", dest="from_header", default="", help="Custom From header (e.g. '\"Agent Name\" <user@example.com>')")
|
||||
p.set_defaults(func=gmail_reply)
|
||||
|
||||
p = gmail_sub.add_parser("labels")
|
||||
|
||||
@@ -25,6 +25,13 @@ def refresh_token(token_data: dict) -> dict:
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
|
||||
required_keys = ["client_id", "client_secret", "refresh_token", "token_uri"]
|
||||
missing = [k for k in required_keys if k not in token_data]
|
||||
if missing:
|
||||
print(f"ERROR: google_token.json is missing required fields: {', '.join(missing)}", file=sys.stderr)
|
||||
print("Please re-authenticate by running the Google Workspace setup script.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
params = urllib.parse.urlencode({
|
||||
"client_id": token_data["client_id"],
|
||||
"client_secret": token_data["client_secret"],
|
||||
|
||||
@@ -695,3 +695,102 @@ class TestMemoryContextFencing:
|
||||
fence_end = combined.index("</memory-context>")
|
||||
assert "Alice" in combined[fence_start:fence_end]
|
||||
assert combined.index("weather") < fence_start
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AIAgent.commit_memory_session — routes to MemoryManager.on_session_end
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _CommitRecorder(FakeMemoryProvider):
|
||||
"""Provider that records on_session_end calls for assertions."""
|
||||
|
||||
def __init__(self, name="recorder"):
|
||||
super().__init__(name)
|
||||
self.end_calls = []
|
||||
|
||||
def on_session_end(self, messages):
|
||||
self.end_calls.append(list(messages or []))
|
||||
|
||||
|
||||
class TestCommitMemorySessionRouting:
|
||||
def test_on_session_end_fans_out(self):
|
||||
mgr = MemoryManager()
|
||||
builtin = _CommitRecorder("builtin")
|
||||
external = _CommitRecorder("openviking")
|
||||
mgr.add_provider(builtin)
|
||||
mgr.add_provider(external)
|
||||
|
||||
msgs = [{"role": "user", "content": "hi"}]
|
||||
mgr.on_session_end(msgs)
|
||||
|
||||
assert builtin.end_calls == [msgs]
|
||||
assert external.end_calls == [msgs]
|
||||
|
||||
def test_on_session_end_tolerates_failure(self):
|
||||
mgr = MemoryManager()
|
||||
builtin = FakeMemoryProvider("builtin")
|
||||
bad = _CommitRecorder("bad-provider")
|
||||
bad.on_session_end = lambda m: (_ for _ in ()).throw(RuntimeError("boom"))
|
||||
mgr.add_provider(builtin)
|
||||
mgr.add_provider(bad)
|
||||
|
||||
mgr.on_session_end([]) # must not raise
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# on_memory_write bridge — must fire from both concurrent AND sequential paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOnMemoryWriteBridge:
|
||||
"""Verify that MemoryManager.on_memory_write is called when built-in
|
||||
memory writes happen. This is a regression test for #10174 where the
|
||||
sequential tool execution path (_execute_tool_calls_sequential) was
|
||||
missing the bridge call, so single memory tool calls never notified
|
||||
external memory providers.
|
||||
"""
|
||||
|
||||
def test_on_memory_write_add(self):
|
||||
"""on_memory_write fires for 'add' actions."""
|
||||
mgr = MemoryManager()
|
||||
p = FakeMemoryProvider("ext")
|
||||
mgr.add_provider(p)
|
||||
|
||||
mgr.on_memory_write("add", "memory", "new fact")
|
||||
assert p.memory_writes == [("add", "memory", "new fact")]
|
||||
|
||||
def test_on_memory_write_replace(self):
|
||||
"""on_memory_write fires for 'replace' actions."""
|
||||
mgr = MemoryManager()
|
||||
p = FakeMemoryProvider("ext")
|
||||
mgr.add_provider(p)
|
||||
|
||||
mgr.on_memory_write("replace", "user", "updated pref")
|
||||
assert p.memory_writes == [("replace", "user", "updated pref")]
|
||||
|
||||
def test_on_memory_write_remove_not_bridged(self):
|
||||
"""The bridge intentionally skips 'remove' — only add/replace notify."""
|
||||
# This tests the contract that run_agent.py checks:
|
||||
# function_args.get("action") in ("add", "replace")
|
||||
mgr = MemoryManager()
|
||||
p = FakeMemoryProvider("ext")
|
||||
mgr.add_provider(p)
|
||||
|
||||
# Manager itself doesn't filter — run_agent.py does.
|
||||
# But providers should handle remove gracefully.
|
||||
mgr.on_memory_write("remove", "memory", "old fact")
|
||||
assert p.memory_writes == [("remove", "memory", "old fact")]
|
||||
|
||||
def test_on_memory_write_tolerates_provider_failure(self):
|
||||
"""If a provider's on_memory_write raises, others still get notified."""
|
||||
mgr = MemoryManager()
|
||||
bad = FakeMemoryProvider("builtin")
|
||||
bad.on_memory_write = MagicMock(side_effect=RuntimeError("boom"))
|
||||
good = FakeMemoryProvider("good")
|
||||
mgr.add_provider(bad)
|
||||
mgr.add_provider(good)
|
||||
|
||||
mgr.on_memory_write("add", "user", "test")
|
||||
# Good provider still received the call despite bad provider crashing
|
||||
assert good.memory_writes == [("add", "user", "test")]
|
||||
|
||||
@@ -8,6 +8,8 @@ from unittest.mock import AsyncMock, patch, MagicMock
|
||||
import pytest
|
||||
|
||||
from cron.scheduler import _resolve_origin, _resolve_delivery_target, _deliver_result, _send_media_via_adapter, run_job, SILENT_MARKER, _build_job_prompt
|
||||
from tools.env_passthrough import clear_env_passthrough
|
||||
from tools.credential_files import clear_credential_files
|
||||
|
||||
|
||||
class TestResolveOrigin:
|
||||
@@ -233,9 +235,10 @@ class TestDeliverResultWrapping:
|
||||
send_mock.assert_called_once()
|
||||
sent_content = send_mock.call_args.kwargs.get("content") or send_mock.call_args[0][-1]
|
||||
assert "Cronjob Response: daily-report" in sent_content
|
||||
assert "(job_id: test-job)" in sent_content
|
||||
assert "-------------" in sent_content
|
||||
assert "Here is today's summary." in sent_content
|
||||
assert "The agent cannot see this message" in sent_content
|
||||
assert "To stop or manage this job" in sent_content
|
||||
|
||||
def test_delivery_uses_job_id_when_no_name(self):
|
||||
"""When a job has no name, the wrapper should fall back to job id."""
|
||||
@@ -876,6 +879,117 @@ class TestRunJobPerJobOverrides:
|
||||
|
||||
|
||||
class TestRunJobSkillBacked:
|
||||
def test_run_job_preserves_skill_env_passthrough_into_worker_thread(self, tmp_path):
|
||||
job = {
|
||||
"id": "skill-env-job",
|
||||
"name": "skill env test",
|
||||
"prompt": "Use the skill.",
|
||||
"skill": "notion",
|
||||
}
|
||||
|
||||
fake_db = MagicMock()
|
||||
|
||||
def _skill_view(name):
|
||||
assert name == "notion"
|
||||
from tools.env_passthrough import register_env_passthrough
|
||||
|
||||
register_env_passthrough(["NOTION_API_KEY"])
|
||||
return json.dumps({"success": True, "content": "# notion\nUse Notion."})
|
||||
|
||||
def _run_conversation(prompt):
|
||||
from tools.env_passthrough import get_all_passthrough
|
||||
|
||||
assert "NOTION_API_KEY" in get_all_passthrough()
|
||||
return {"final_response": "ok"}
|
||||
|
||||
with patch("cron.scheduler._hermes_home", tmp_path), \
|
||||
patch("cron.scheduler._resolve_origin", return_value=None), \
|
||||
patch("dotenv.load_dotenv"), \
|
||||
patch("hermes_state.SessionDB", return_value=fake_db), \
|
||||
patch(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
return_value={
|
||||
"api_key": "***",
|
||||
"base_url": "https://example.invalid/v1",
|
||||
"provider": "openrouter",
|
||||
"api_mode": "chat_completions",
|
||||
},
|
||||
), \
|
||||
patch("tools.skills_tool.skill_view", side_effect=_skill_view), \
|
||||
patch("run_agent.AIAgent") as mock_agent_cls:
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.run_conversation.side_effect = _run_conversation
|
||||
mock_agent_cls.return_value = mock_agent
|
||||
|
||||
try:
|
||||
success, output, final_response, error = run_job(job)
|
||||
finally:
|
||||
clear_env_passthrough()
|
||||
|
||||
assert success is True
|
||||
assert error is None
|
||||
assert final_response == "ok"
|
||||
|
||||
def test_run_job_preserves_credential_file_passthrough_into_worker_thread(self, tmp_path):
|
||||
"""copy_context() also propagates credential_files ContextVar."""
|
||||
job = {
|
||||
"id": "cred-env-job",
|
||||
"name": "cred file test",
|
||||
"prompt": "Use the skill.",
|
||||
"skill": "google-workspace",
|
||||
}
|
||||
|
||||
fake_db = MagicMock()
|
||||
|
||||
# Create a credential file so register_credential_file succeeds
|
||||
cred_dir = tmp_path / "credentials"
|
||||
cred_dir.mkdir()
|
||||
(cred_dir / "google_token.json").write_text('{"token": "t"}')
|
||||
|
||||
def _skill_view(name):
|
||||
assert name == "google-workspace"
|
||||
from tools.credential_files import register_credential_file
|
||||
|
||||
register_credential_file("credentials/google_token.json")
|
||||
return json.dumps({"success": True, "content": "# google-workspace\nUse Google."})
|
||||
|
||||
def _run_conversation(prompt):
|
||||
from tools.credential_files import _get_registered
|
||||
|
||||
registered = _get_registered()
|
||||
assert registered, "credential files must be visible in worker thread"
|
||||
assert any("google_token.json" in v for v in registered.values())
|
||||
return {"final_response": "ok"}
|
||||
|
||||
with patch("cron.scheduler._hermes_home", tmp_path), \
|
||||
patch("cron.scheduler._resolve_origin", return_value=None), \
|
||||
patch("tools.credential_files._resolve_hermes_home", return_value=tmp_path), \
|
||||
patch("dotenv.load_dotenv"), \
|
||||
patch("hermes_state.SessionDB", return_value=fake_db), \
|
||||
patch(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
return_value={
|
||||
"api_key": "***",
|
||||
"base_url": "https://example.invalid/v1",
|
||||
"provider": "openrouter",
|
||||
"api_mode": "chat_completions",
|
||||
},
|
||||
), \
|
||||
patch("tools.skills_tool.skill_view", side_effect=_skill_view), \
|
||||
patch("run_agent.AIAgent") as mock_agent_cls:
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.run_conversation.side_effect = _run_conversation
|
||||
mock_agent_cls.return_value = mock_agent
|
||||
|
||||
try:
|
||||
success, output, final_response, error = run_job(job)
|
||||
finally:
|
||||
clear_credential_files()
|
||||
|
||||
assert success is True
|
||||
assert error is None
|
||||
assert final_response == "ok"
|
||||
|
||||
def test_run_job_loads_skill_and_disables_recursive_cron_tools(self, tmp_path):
|
||||
job = {
|
||||
"id": "skill-job",
|
||||
|
||||
@@ -1016,6 +1016,47 @@ class TestResponsesEndpoint:
|
||||
assert len(call_kwargs["conversation_history"]) > 0
|
||||
assert call_kwargs["user_message"] == "Now add 1 more"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_previous_response_id_preserves_session(self, adapter):
|
||||
"""Chained responses via previous_response_id reuse the same session_id."""
|
||||
mock_result = {
|
||||
"final_response": "ok",
|
||||
"messages": [{"role": "assistant", "content": "ok"}],
|
||||
"api_calls": 1,
|
||||
}
|
||||
usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
# First request — establishes a session
|
||||
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = (mock_result, usage)
|
||||
resp1 = await cli.post(
|
||||
"/v1/responses",
|
||||
json={"model": "hermes-agent", "input": "Hello"},
|
||||
)
|
||||
assert resp1.status == 200
|
||||
first_session_id = mock_run.call_args.kwargs["session_id"]
|
||||
data1 = await resp1.json()
|
||||
response_id = data1["id"]
|
||||
|
||||
# Second request — chains from the first
|
||||
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = (mock_result, usage)
|
||||
resp2 = await cli.post(
|
||||
"/v1/responses",
|
||||
json={
|
||||
"model": "hermes-agent",
|
||||
"input": "Follow up",
|
||||
"previous_response_id": response_id,
|
||||
},
|
||||
)
|
||||
assert resp2.status == 200
|
||||
second_session_id = mock_run.call_args.kwargs["session_id"]
|
||||
|
||||
# Session must be the same across the chain
|
||||
assert first_session_id == second_session_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_previous_response_id_returns_404(self, adapter):
|
||||
app = _create_app(adapter)
|
||||
@@ -1115,6 +1156,134 @@ class TestResponsesEndpoint:
|
||||
assert resp.status == 400
|
||||
|
||||
|
||||
class TestResponsesStreaming:
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_true_returns_responses_sse(self, adapter):
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
async def _mock_run_agent(**kwargs):
|
||||
cb = kwargs.get("stream_delta_callback")
|
||||
if cb:
|
||||
cb("Hello")
|
||||
cb(" world")
|
||||
return (
|
||||
{"final_response": "Hello world", "messages": [], "api_calls": 1},
|
||||
{"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
||||
)
|
||||
|
||||
with patch.object(adapter, "_run_agent", side_effect=_mock_run_agent):
|
||||
resp = await cli.post(
|
||||
"/v1/responses",
|
||||
json={"model": "hermes-agent", "input": "hi", "stream": True},
|
||||
)
|
||||
assert resp.status == 200
|
||||
assert "text/event-stream" in resp.headers.get("Content-Type", "")
|
||||
body = await resp.text()
|
||||
assert "event: response.created" in body
|
||||
assert "event: response.output_text.delta" in body
|
||||
assert "event: response.output_text.done" in body
|
||||
assert "event: response.completed" in body
|
||||
assert '"sequence_number":' in body
|
||||
assert '"logprobs": []' in body
|
||||
assert "Hello" in body
|
||||
assert " world" in body
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_emits_function_call_and_output_items(self, adapter):
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
async def _mock_run_agent(**kwargs):
|
||||
start_cb = kwargs.get("tool_start_callback")
|
||||
complete_cb = kwargs.get("tool_complete_callback")
|
||||
text_cb = kwargs.get("stream_delta_callback")
|
||||
if start_cb:
|
||||
start_cb("call_123", "read_file", {"path": "/tmp/test.txt"})
|
||||
if complete_cb:
|
||||
complete_cb("call_123", "read_file", {"path": "/tmp/test.txt"}, '{"content":"hello"}')
|
||||
if text_cb:
|
||||
text_cb("Done.")
|
||||
return (
|
||||
{
|
||||
"final_response": "Done.",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_123",
|
||||
"function": {
|
||||
"name": "read_file",
|
||||
"arguments": '{"path":"/tmp/test.txt"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"content": '{"content":"hello"}',
|
||||
},
|
||||
],
|
||||
"api_calls": 1,
|
||||
},
|
||||
{"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
||||
)
|
||||
|
||||
with patch.object(adapter, "_run_agent", side_effect=_mock_run_agent):
|
||||
resp = await cli.post(
|
||||
"/v1/responses",
|
||||
json={"model": "hermes-agent", "input": "read the file", "stream": True},
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.text()
|
||||
assert "event: response.output_item.added" in body
|
||||
assert "event: response.output_item.done" in body
|
||||
assert body.count("event: response.output_item.done") >= 2
|
||||
assert '"type": "function_call"' in body
|
||||
assert '"type": "function_call_output"' in body
|
||||
assert '"call_id": "call_123"' in body
|
||||
assert '"name": "read_file"' in body
|
||||
assert '"output": [{"type": "input_text", "text": "{\\"content\\":\\"hello\\"}"}]' in body
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streamed_response_is_stored_for_get(self, adapter):
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
async def _mock_run_agent(**kwargs):
|
||||
cb = kwargs.get("stream_delta_callback")
|
||||
if cb:
|
||||
cb("Stored response")
|
||||
return (
|
||||
{"final_response": "Stored response", "messages": [], "api_calls": 1},
|
||||
{"input_tokens": 1, "output_tokens": 2, "total_tokens": 3},
|
||||
)
|
||||
|
||||
with patch.object(adapter, "_run_agent", side_effect=_mock_run_agent):
|
||||
resp = await cli.post(
|
||||
"/v1/responses",
|
||||
json={"model": "hermes-agent", "input": "store this", "stream": True},
|
||||
)
|
||||
body = await resp.text()
|
||||
response_id = None
|
||||
for line in body.splitlines():
|
||||
if line.startswith("data: "):
|
||||
try:
|
||||
payload = json.loads(line[len("data: "):])
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if payload.get("type") == "response.completed":
|
||||
response_id = payload["response"]["id"]
|
||||
break
|
||||
assert response_id
|
||||
|
||||
get_resp = await cli.get(f"/v1/responses/{response_id}")
|
||||
assert get_resp.status == 200
|
||||
data = await get_resp.json()
|
||||
assert data["id"] == response_id
|
||||
assert data["status"] == "completed"
|
||||
assert data["output"][-1]["content"][0]["text"] == "Stored response"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth on endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
"""Tests for the auto-continue feature (#4493).
|
||||
|
||||
When the gateway restarts mid-agent-work, the session transcript ends on a
|
||||
tool result that the agent never processed. The auto-continue logic detects
|
||||
this and prepends a system note to the next user message so the model
|
||||
finishes the interrupted work before addressing the new input.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _simulate_auto_continue(agent_history: list, user_message: str) -> str:
|
||||
"""Reproduce the auto-continue injection logic from _run_agent().
|
||||
|
||||
This mirrors the exact code in gateway/run.py so we can test the
|
||||
detection and message transformation without spinning up a full
|
||||
gateway runner.
|
||||
"""
|
||||
message = user_message
|
||||
if agent_history and agent_history[-1].get("role") == "tool":
|
||||
message = (
|
||||
"[System note: Your previous turn was interrupted before you could "
|
||||
"process the last tool result(s). The conversation history contains "
|
||||
"tool outputs you haven't responded to yet. Please finish processing "
|
||||
"those results and summarize what was accomplished, then address the "
|
||||
"user's new message below.]\n\n"
|
||||
+ message
|
||||
)
|
||||
return message
|
||||
|
||||
|
||||
class TestAutoDetection:
|
||||
"""Test that trailing tool results are correctly detected."""
|
||||
|
||||
def test_trailing_tool_result_triggers_note(self):
|
||||
history = [
|
||||
{"role": "user", "content": "deploy the app"},
|
||||
{"role": "assistant", "content": None, "tool_calls": [
|
||||
{"id": "call_1", "function": {"name": "terminal", "arguments": "{}"}}
|
||||
]},
|
||||
{"role": "tool", "tool_call_id": "call_1", "content": "deployed successfully"},
|
||||
]
|
||||
result = _simulate_auto_continue(history, "what happened?")
|
||||
assert "[System note:" in result
|
||||
assert "interrupted" in result
|
||||
assert "what happened?" in result
|
||||
|
||||
def test_trailing_assistant_message_no_note(self):
|
||||
history = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
result = _simulate_auto_continue(history, "how are you?")
|
||||
assert "[System note:" not in result
|
||||
assert result == "how are you?"
|
||||
|
||||
def test_empty_history_no_note(self):
|
||||
result = _simulate_auto_continue([], "hello")
|
||||
assert result == "hello"
|
||||
|
||||
def test_trailing_user_message_no_note(self):
|
||||
"""Shouldn't happen in practice, but ensure no false positive."""
|
||||
history = [
|
||||
{"role": "user", "content": "hello"},
|
||||
]
|
||||
result = _simulate_auto_continue(history, "hello again")
|
||||
assert result == "hello again"
|
||||
|
||||
def test_multiple_tool_results_still_triggers(self):
|
||||
"""Multiple tool calls in a row — last one is still role=tool."""
|
||||
history = [
|
||||
{"role": "user", "content": "search and read"},
|
||||
{"role": "assistant", "content": None, "tool_calls": [
|
||||
{"id": "call_1", "function": {"name": "search", "arguments": "{}"}},
|
||||
{"id": "call_2", "function": {"name": "read", "arguments": "{}"}},
|
||||
]},
|
||||
{"role": "tool", "tool_call_id": "call_1", "content": "found it"},
|
||||
{"role": "tool", "tool_call_id": "call_2", "content": "file content here"},
|
||||
]
|
||||
result = _simulate_auto_continue(history, "continue")
|
||||
assert "[System note:" in result
|
||||
|
||||
def test_original_message_preserved_after_note(self):
|
||||
"""The user's actual message must appear after the system note."""
|
||||
history = [
|
||||
{"role": "assistant", "content": None, "tool_calls": [
|
||||
{"id": "c1", "function": {"name": "t", "arguments": "{}"}}
|
||||
]},
|
||||
{"role": "tool", "tool_call_id": "c1", "content": "done"},
|
||||
]
|
||||
result = _simulate_auto_continue(history, "now do X")
|
||||
# System note comes first, then user's message
|
||||
note_end = result.index("]\n\n")
|
||||
user_msg_start = result.index("now do X")
|
||||
assert user_msg_start > note_end
|
||||
@@ -14,7 +14,7 @@ from unittest.mock import AsyncMock, patch
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform
|
||||
from gateway.run import GatewayRunner
|
||||
from gateway.run import GatewayRunner, _parse_session_key
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -45,7 +45,7 @@ def _build_runner(monkeypatch, tmp_path, mode: str) -> GatewayRunner:
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
|
||||
runner = GatewayRunner(GatewayConfig())
|
||||
adapter = SimpleNamespace(send=AsyncMock())
|
||||
adapter = SimpleNamespace(send=AsyncMock(), handle_message=AsyncMock())
|
||||
runner.adapters[Platform.TELEGRAM] = adapter
|
||||
return runner
|
||||
|
||||
@@ -243,3 +243,162 @@ async def test_no_thread_id_sends_no_metadata(monkeypatch, tmp_path):
|
||||
assert adapter.send.await_count == 1
|
||||
_, kwargs = adapter.send.call_args
|
||||
assert kwargs["metadata"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inject_watch_notification_routes_from_session_store_origin(monkeypatch, tmp_path):
|
||||
from gateway.session import SessionSource
|
||||
|
||||
runner = _build_runner(monkeypatch, tmp_path, "all")
|
||||
adapter = runner.adapters[Platform.TELEGRAM]
|
||||
runner.session_store._entries["agent:main:telegram:group:-100:42"] = SimpleNamespace(
|
||||
origin=SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="-100",
|
||||
chat_type="group",
|
||||
thread_id="42",
|
||||
user_id="123",
|
||||
user_name="Emiliyan",
|
||||
)
|
||||
)
|
||||
|
||||
evt = {
|
||||
"session_id": "proc_watch",
|
||||
"session_key": "agent:main:telegram:group:-100:42",
|
||||
}
|
||||
|
||||
await runner._inject_watch_notification("[SYSTEM: Background process matched]", evt)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
synth_event = adapter.handle_message.await_args.args[0]
|
||||
assert synth_event.internal is True
|
||||
assert synth_event.source.platform == Platform.TELEGRAM
|
||||
assert synth_event.source.chat_id == "-100"
|
||||
assert synth_event.source.chat_type == "group"
|
||||
assert synth_event.source.thread_id == "42"
|
||||
assert synth_event.source.user_id == "123"
|
||||
assert synth_event.source.user_name == "Emiliyan"
|
||||
|
||||
|
||||
def test_build_process_event_source_falls_back_to_session_key_chat_type(monkeypatch, tmp_path):
|
||||
runner = _build_runner(monkeypatch, tmp_path, "all")
|
||||
|
||||
evt = {
|
||||
"session_id": "proc_watch",
|
||||
"session_key": "agent:main:telegram:group:-100:42",
|
||||
"platform": "telegram",
|
||||
"chat_id": "-100",
|
||||
"thread_id": "42",
|
||||
"user_id": "123",
|
||||
"user_name": "Emiliyan",
|
||||
}
|
||||
|
||||
source = runner._build_process_event_source(evt)
|
||||
|
||||
assert source is not None
|
||||
assert source.platform == Platform.TELEGRAM
|
||||
assert source.chat_id == "-100"
|
||||
assert source.chat_type == "group"
|
||||
assert source.thread_id == "42"
|
||||
assert source.user_id == "123"
|
||||
assert source.user_name == "Emiliyan"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inject_watch_notification_ignores_foreground_event_source(monkeypatch, tmp_path):
|
||||
"""Negative test: watch notification must NOT route to the foreground thread."""
|
||||
from gateway.session import SessionSource
|
||||
|
||||
runner = _build_runner(monkeypatch, tmp_path, "all")
|
||||
adapter = runner.adapters[Platform.TELEGRAM]
|
||||
|
||||
# Session store has the process's original thread (thread 42)
|
||||
runner.session_store._entries["agent:main:telegram:group:-100:42"] = SimpleNamespace(
|
||||
origin=SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="-100",
|
||||
chat_type="group",
|
||||
thread_id="42",
|
||||
user_id="proc_owner",
|
||||
user_name="alice",
|
||||
)
|
||||
)
|
||||
|
||||
# The evt dict carries the correct session_key — NOT a foreground event
|
||||
evt = {
|
||||
"session_id": "proc_cross_thread",
|
||||
"session_key": "agent:main:telegram:group:-100:42",
|
||||
}
|
||||
|
||||
await runner._inject_watch_notification("[SYSTEM: watch match]", evt)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
synth_event = adapter.handle_message.await_args.args[0]
|
||||
# Must route to thread 42 (process origin), NOT some other thread
|
||||
assert synth_event.source.thread_id == "42"
|
||||
assert synth_event.source.user_id == "proc_owner"
|
||||
|
||||
|
||||
def test_build_process_event_source_returns_none_for_empty_evt(monkeypatch, tmp_path):
|
||||
"""Missing session_key and no platform metadata → None (drop notification)."""
|
||||
runner = _build_runner(monkeypatch, tmp_path, "all")
|
||||
|
||||
source = runner._build_process_event_source({"session_id": "proc_orphan"})
|
||||
assert source is None
|
||||
|
||||
|
||||
def test_build_process_event_source_returns_none_for_invalid_platform(monkeypatch, tmp_path):
|
||||
"""Invalid platform string → None."""
|
||||
runner = _build_runner(monkeypatch, tmp_path, "all")
|
||||
|
||||
evt = {
|
||||
"session_id": "proc_bad",
|
||||
"platform": "not_a_real_platform",
|
||||
"chat_type": "dm",
|
||||
"chat_id": "123",
|
||||
}
|
||||
source = runner._build_process_event_source(evt)
|
||||
assert source is None
|
||||
|
||||
|
||||
def test_build_process_event_source_returns_none_for_short_session_key(monkeypatch, tmp_path):
|
||||
"""Session key with <5 parts doesn't parse, falls through to empty metadata → None."""
|
||||
runner = _build_runner(monkeypatch, tmp_path, "all")
|
||||
|
||||
evt = {
|
||||
"session_id": "proc_short",
|
||||
"session_key": "agent:main:telegram", # Too few parts
|
||||
}
|
||||
source = runner._build_process_event_source(evt)
|
||||
assert source is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_session_key helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_parse_session_key_valid():
|
||||
result = _parse_session_key("agent:main:telegram:group:-100")
|
||||
assert result == {"platform": "telegram", "chat_type": "group", "chat_id": "-100"}
|
||||
|
||||
|
||||
def test_parse_session_key_with_extra_parts():
|
||||
"""Thread ID (6th part) is extracted; further parts are ignored."""
|
||||
result = _parse_session_key("agent:main:discord:group:chan123:thread456")
|
||||
assert result == {"platform": "discord", "chat_type": "group", "chat_id": "chan123", "thread_id": "thread456"}
|
||||
|
||||
|
||||
def test_parse_session_key_with_user_id_part():
|
||||
"""7th part (user_id) is ignored — only up to thread_id is extracted."""
|
||||
result = _parse_session_key("agent:main:telegram:group:chat1:thread42:user99")
|
||||
assert result == {"platform": "telegram", "chat_type": "group", "chat_id": "chat1", "thread_id": "thread42"}
|
||||
|
||||
|
||||
def test_parse_session_key_too_short():
|
||||
assert _parse_session_key("agent:main:telegram") is None
|
||||
assert _parse_session_key("") is None
|
||||
|
||||
|
||||
def test_parse_session_key_wrong_prefix():
|
||||
assert _parse_session_key("cron:main:telegram:dm:123") is None
|
||||
assert _parse_session_key("agent:cron:telegram:dm:123") is None
|
||||
|
||||
@@ -0,0 +1,293 @@
|
||||
"""Tests for busy-session acknowledgment when user sends messages during active agent runs.
|
||||
|
||||
Verifies that users get an immediate status response instead of total silence
|
||||
when the agent is working on a task. See PR fix for the @Lonely__MH report.
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Minimal stubs so we can import gateway code without heavy deps
|
||||
# ---------------------------------------------------------------------------
|
||||
import sys, types
|
||||
|
||||
_tg = types.ModuleType("telegram")
|
||||
_tg.constants = types.ModuleType("telegram.constants")
|
||||
_ct = MagicMock()
|
||||
_ct.SUPERGROUP = "supergroup"
|
||||
_ct.GROUP = "group"
|
||||
_ct.PRIVATE = "private"
|
||||
_tg.constants.ChatType = _ct
|
||||
sys.modules.setdefault("telegram", _tg)
|
||||
sys.modules.setdefault("telegram.constants", _tg.constants)
|
||||
sys.modules.setdefault("telegram.ext", types.ModuleType("telegram.ext"))
|
||||
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SessionSource,
|
||||
build_session_key,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_event(text="hello", chat_id="123", platform_val="telegram"):
|
||||
"""Build a minimal MessageEvent."""
|
||||
source = SessionSource(
|
||||
platform=MagicMock(value=platform_val),
|
||||
chat_id=chat_id,
|
||||
chat_type="private",
|
||||
user_id="user1",
|
||||
)
|
||||
evt = MessageEvent(
|
||||
text=text,
|
||||
message_type=MessageType.TEXT,
|
||||
source=source,
|
||||
message_id="msg1",
|
||||
)
|
||||
return evt
|
||||
|
||||
|
||||
def _make_runner():
|
||||
"""Build a minimal GatewayRunner-like object for testing."""
|
||||
from gateway.run import GatewayRunner, _AGENT_PENDING_SENTINEL
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner._running_agents = {}
|
||||
runner._running_agents_ts = {}
|
||||
runner._pending_messages = {}
|
||||
runner._busy_ack_ts = {}
|
||||
runner._draining = False
|
||||
runner.adapters = {}
|
||||
runner.config = MagicMock()
|
||||
runner.session_store = None
|
||||
runner.hooks = MagicMock()
|
||||
runner.hooks.emit = AsyncMock()
|
||||
return runner, _AGENT_PENDING_SENTINEL
|
||||
|
||||
|
||||
def _make_adapter(platform_val="telegram"):
|
||||
"""Build a minimal adapter mock."""
|
||||
adapter = MagicMock()
|
||||
adapter._pending_messages = {}
|
||||
adapter._send_with_retry = AsyncMock()
|
||||
adapter.config = MagicMock()
|
||||
adapter.config.extra = {}
|
||||
adapter.platform = MagicMock(value=platform_val)
|
||||
return adapter
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBusySessionAck:
|
||||
"""User sends a message while agent is running — should get acknowledgment."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sends_ack_when_agent_running(self):
|
||||
"""First message during busy session should get a status ack."""
|
||||
runner, sentinel = _make_runner()
|
||||
adapter = _make_adapter()
|
||||
|
||||
event = _make_event(text="Are you working?")
|
||||
sk = build_session_key(event.source)
|
||||
|
||||
# Simulate running agent
|
||||
agent = MagicMock()
|
||||
agent.get_activity_summary.return_value = {
|
||||
"api_call_count": 21,
|
||||
"max_iterations": 60,
|
||||
"current_tool": "terminal",
|
||||
"last_activity_ts": time.time(),
|
||||
"last_activity_desc": "terminal",
|
||||
"seconds_since_activity": 1.0,
|
||||
}
|
||||
runner._running_agents[sk] = agent
|
||||
runner._running_agents_ts[sk] = time.time() - 600 # 10 min ago
|
||||
runner.adapters[event.source.platform] = adapter
|
||||
|
||||
result = await runner._handle_active_session_busy_message(event, sk)
|
||||
|
||||
assert result is True # handled
|
||||
# Verify ack was sent
|
||||
adapter._send_with_retry.assert_called_once()
|
||||
call_kwargs = adapter._send_with_retry.call_args
|
||||
content = call_kwargs.kwargs.get("content") or call_kwargs[1].get("content", "")
|
||||
if not content and call_kwargs.args:
|
||||
# positional args
|
||||
content = str(call_kwargs)
|
||||
assert "Interrupting" in content or "respond" in content
|
||||
assert "/stop" not in content # no need — we ARE interrupting
|
||||
|
||||
# Verify message was queued in adapter pending
|
||||
assert sk in adapter._pending_messages
|
||||
|
||||
# Verify agent interrupt was called
|
||||
agent.interrupt.assert_called_once_with("Are you working?")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_debounce_suppresses_rapid_acks(self):
|
||||
"""Second message within 30s should NOT send another ack."""
|
||||
runner, sentinel = _make_runner()
|
||||
adapter = _make_adapter()
|
||||
|
||||
event1 = _make_event(text="hello?")
|
||||
# Reuse the same source so platform mock matches
|
||||
event2 = MessageEvent(
|
||||
text="still there?",
|
||||
message_type=MessageType.TEXT,
|
||||
source=event1.source,
|
||||
message_id="msg2",
|
||||
)
|
||||
sk = build_session_key(event1.source)
|
||||
|
||||
agent = MagicMock()
|
||||
agent.get_activity_summary.return_value = {
|
||||
"api_call_count": 5,
|
||||
"max_iterations": 60,
|
||||
"current_tool": None,
|
||||
"last_activity_ts": time.time(),
|
||||
"last_activity_desc": "api_call",
|
||||
"seconds_since_activity": 0.5,
|
||||
}
|
||||
runner._running_agents[sk] = agent
|
||||
runner._running_agents_ts[sk] = time.time() - 60
|
||||
runner.adapters[event1.source.platform] = adapter
|
||||
|
||||
# First message — should get ack
|
||||
result1 = await runner._handle_active_session_busy_message(event1, sk)
|
||||
assert result1 is True
|
||||
assert adapter._send_with_retry.call_count == 1
|
||||
|
||||
# Second message within cooldown — should be queued but no ack
|
||||
result2 = await runner._handle_active_session_busy_message(event2, sk)
|
||||
assert result2 is True
|
||||
assert adapter._send_with_retry.call_count == 1 # still 1, no new ack
|
||||
|
||||
# But interrupt should still be called for both
|
||||
assert agent.interrupt.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ack_after_cooldown_expires(self):
|
||||
"""After 30s cooldown, a new message should send a fresh ack."""
|
||||
runner, sentinel = _make_runner()
|
||||
adapter = _make_adapter()
|
||||
|
||||
event = _make_event(text="hello?")
|
||||
sk = build_session_key(event.source)
|
||||
|
||||
agent = MagicMock()
|
||||
agent.get_activity_summary.return_value = {
|
||||
"api_call_count": 10,
|
||||
"max_iterations": 60,
|
||||
"current_tool": "web_search",
|
||||
"last_activity_ts": time.time(),
|
||||
"last_activity_desc": "tool",
|
||||
"seconds_since_activity": 0.5,
|
||||
}
|
||||
runner._running_agents[sk] = agent
|
||||
runner._running_agents_ts[sk] = time.time() - 120
|
||||
runner.adapters[event.source.platform] = adapter
|
||||
|
||||
# First ack
|
||||
await runner._handle_active_session_busy_message(event, sk)
|
||||
assert adapter._send_with_retry.call_count == 1
|
||||
|
||||
# Fake that cooldown expired
|
||||
runner._busy_ack_ts[sk] = time.time() - 31
|
||||
|
||||
# Second ack should go through
|
||||
await runner._handle_active_session_busy_message(event, sk)
|
||||
assert adapter._send_with_retry.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_includes_status_detail(self):
|
||||
"""Ack message should include iteration and tool info when available."""
|
||||
runner, sentinel = _make_runner()
|
||||
adapter = _make_adapter()
|
||||
|
||||
event = _make_event(text="yo")
|
||||
sk = build_session_key(event.source)
|
||||
|
||||
agent = MagicMock()
|
||||
agent.get_activity_summary.return_value = {
|
||||
"api_call_count": 21,
|
||||
"max_iterations": 60,
|
||||
"current_tool": "terminal",
|
||||
"last_activity_ts": time.time(),
|
||||
"last_activity_desc": "terminal",
|
||||
"seconds_since_activity": 0.5,
|
||||
}
|
||||
runner._running_agents[sk] = agent
|
||||
runner._running_agents_ts[sk] = time.time() - 600 # 10 min
|
||||
runner.adapters[event.source.platform] = adapter
|
||||
|
||||
await runner._handle_active_session_busy_message(event, sk)
|
||||
|
||||
call_kwargs = adapter._send_with_retry.call_args
|
||||
content = call_kwargs.kwargs.get("content", "")
|
||||
assert "21/60" in content # iteration
|
||||
assert "terminal" in content # current tool
|
||||
assert "10 min" in content # elapsed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_draining_still_works(self):
|
||||
"""Draining case should still produce the drain-specific message."""
|
||||
runner, sentinel = _make_runner()
|
||||
runner._draining = True
|
||||
adapter = _make_adapter()
|
||||
|
||||
event = _make_event(text="hello")
|
||||
sk = build_session_key(event.source)
|
||||
runner.adapters[event.source.platform] = adapter
|
||||
|
||||
# Mock the drain-specific methods
|
||||
runner._queue_during_drain_enabled = lambda: False
|
||||
runner._status_action_gerund = lambda: "restarting"
|
||||
|
||||
result = await runner._handle_active_session_busy_message(event, sk)
|
||||
assert result is True
|
||||
|
||||
call_kwargs = adapter._send_with_retry.call_args
|
||||
content = call_kwargs.kwargs.get("content", "")
|
||||
assert "restarting" in content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pending_sentinel_no_interrupt(self):
|
||||
"""When agent is PENDING_SENTINEL, don't call interrupt (it has no method)."""
|
||||
runner, sentinel = _make_runner()
|
||||
adapter = _make_adapter()
|
||||
|
||||
event = _make_event(text="hey")
|
||||
sk = build_session_key(event.source)
|
||||
|
||||
runner._running_agents[sk] = sentinel
|
||||
runner._running_agents_ts[sk] = time.time()
|
||||
runner.adapters[event.source.platform] = adapter
|
||||
|
||||
result = await runner._handle_active_session_busy_message(event, sk)
|
||||
assert result is True
|
||||
# Should still send ack
|
||||
adapter._send_with_retry.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_adapter_falls_through(self):
|
||||
"""If adapter is missing, return False so default path handles it."""
|
||||
runner, sentinel = _make_runner()
|
||||
|
||||
event = _make_event(text="hello")
|
||||
sk = build_session_key(event.source)
|
||||
|
||||
# No adapter registered
|
||||
runner._running_agents[sk] = MagicMock()
|
||||
|
||||
result = await runner._handle_active_session_busy_message(event, sk)
|
||||
assert result is False # not handled, let default path try
|
||||
@@ -117,6 +117,23 @@ async def test_registers_native_thread_slash_command(adapter):
|
||||
adapter._handle_thread_create_slash.assert_awaited_once_with(interaction, "Planning", "", 1440)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registers_native_restart_slash_command(adapter):
|
||||
adapter._run_simple_slash = AsyncMock()
|
||||
adapter._register_slash_commands()
|
||||
|
||||
assert "restart" in adapter._client.tree.commands
|
||||
|
||||
interaction = SimpleNamespace()
|
||||
await adapter._client.tree.commands["restart"](interaction)
|
||||
|
||||
adapter._run_simple_slash.assert_awaited_once_with(
|
||||
interaction,
|
||||
"/restart",
|
||||
"Restart requested~",
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# _handle_thread_create_slash — success, session dispatch, failure
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,291 @@
|
||||
"""Tests for duplicate reply suppression across the gateway stack.
|
||||
|
||||
Covers three fix paths:
|
||||
1. base.py: stale response suppressed when interrupt_event is set and a
|
||||
pending message exists (#8221 / #2483)
|
||||
2. run.py return path: already_sent propagated from stream consumer's
|
||||
already_sent flag without requiring response_previewed (#8375)
|
||||
3. run.py queued-message path: first response correctly detected as
|
||||
already-streamed when already_sent is True without response_previewed
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
ProcessingOutcome,
|
||||
SendResult,
|
||||
)
|
||||
from gateway.session import SessionSource, build_session_key
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class StubAdapter(BasePlatformAdapter):
|
||||
"""Minimal concrete adapter for testing."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=True, token="fake"), Platform.DISCORD)
|
||||
self.sent = []
|
||||
|
||||
async def connect(self):
|
||||
return True
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||
self.sent.append({"chat_id": chat_id, "content": content})
|
||||
return SendResult(success=True, message_id="msg1")
|
||||
|
||||
async def send_typing(self, chat_id, metadata=None):
|
||||
pass
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
||||
def _make_event(text="hello", chat_id="c1", user_id="u1"):
|
||||
return MessageEvent(
|
||||
text=text,
|
||||
source=SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id=chat_id,
|
||||
chat_type="dm",
|
||||
user_id=user_id,
|
||||
),
|
||||
message_id="m1",
|
||||
)
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Test 1: base.py — stale response suppressed on interrupt (#8221)
|
||||
# ===================================================================
|
||||
|
||||
class TestBaseInterruptSuppression:
|
||||
@pytest.mark.asyncio
|
||||
async def test_stale_response_suppressed_when_interrupted(self):
|
||||
"""When interrupt_event is set AND a pending message exists,
|
||||
base.py should suppress the stale response instead of sending it."""
|
||||
adapter = StubAdapter()
|
||||
|
||||
stale_response = "This is the stale answer to the first question."
|
||||
pending_response = "This is the answer to the second question."
|
||||
call_count = 0
|
||||
|
||||
async def fake_handler(event):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return stale_response
|
||||
return pending_response
|
||||
|
||||
adapter.set_message_handler(fake_handler)
|
||||
|
||||
event_a = _make_event(text="first question")
|
||||
session_key = build_session_key(event_a.source)
|
||||
|
||||
# Simulate: message A is being processed, message B arrives
|
||||
# The interrupt event is set and B is in pending_messages
|
||||
interrupt_event = asyncio.Event()
|
||||
interrupt_event.set()
|
||||
adapter._active_sessions[session_key] = interrupt_event
|
||||
|
||||
event_b = _make_event(text="second question")
|
||||
adapter._pending_messages[session_key] = event_b
|
||||
|
||||
await adapter._process_message_background(event_a, session_key)
|
||||
|
||||
# The stale response should NOT have been sent.
|
||||
stale_sends = [s for s in adapter.sent if s["content"] == stale_response]
|
||||
assert len(stale_sends) == 0, (
|
||||
f"Stale response was sent {len(stale_sends)} time(s) — should be suppressed"
|
||||
)
|
||||
# The pending message's response SHOULD have been sent.
|
||||
pending_sends = [s for s in adapter.sent if s["content"] == pending_response]
|
||||
assert len(pending_sends) == 1, "Pending message response should be sent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_not_suppressed_without_interrupt(self):
|
||||
"""Normal case: no interrupt, response should be sent."""
|
||||
adapter = StubAdapter()
|
||||
|
||||
async def fake_handler(event):
|
||||
return "Normal response"
|
||||
|
||||
adapter.set_message_handler(fake_handler)
|
||||
event = _make_event()
|
||||
session_key = build_session_key(event.source)
|
||||
|
||||
await adapter._process_message_background(event, session_key)
|
||||
|
||||
assert any(s["content"] == "Normal response" for s in adapter.sent)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_not_suppressed_with_interrupt_but_no_pending(self):
|
||||
"""Interrupt event set but no pending message (race already resolved) —
|
||||
response should still be sent."""
|
||||
adapter = StubAdapter()
|
||||
|
||||
async def fake_handler(event):
|
||||
return "Valid response"
|
||||
|
||||
adapter.set_message_handler(fake_handler)
|
||||
event = _make_event()
|
||||
session_key = build_session_key(event.source)
|
||||
|
||||
# Set interrupt but no pending message
|
||||
interrupt_event = asyncio.Event()
|
||||
interrupt_event.set()
|
||||
adapter._active_sessions[session_key] = interrupt_event
|
||||
|
||||
await adapter._process_message_background(event, session_key)
|
||||
|
||||
assert any(s["content"] == "Valid response" for s in adapter.sent)
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Test 2: run.py — already_sent without response_previewed (#8375)
|
||||
# ===================================================================
|
||||
|
||||
class TestAlreadySentWithoutResponsePreviewed:
|
||||
"""The already_sent flag on the response dict should be set when the
|
||||
stream consumer's already_sent is True, even if response_previewed is
|
||||
False. This prevents duplicate sends when streaming was interrupted
|
||||
by flood control."""
|
||||
|
||||
def _make_mock_stream_consumer(self, already_sent=False, final_response_sent=False):
|
||||
sc = SimpleNamespace(
|
||||
already_sent=already_sent,
|
||||
final_response_sent=final_response_sent,
|
||||
)
|
||||
return sc
|
||||
|
||||
def test_already_sent_set_without_response_previewed(self):
|
||||
"""Stream consumer already_sent=True should propagate to response
|
||||
dict even when response_previewed is False."""
|
||||
sc = self._make_mock_stream_consumer(already_sent=True, final_response_sent=False)
|
||||
response = {"final_response": "text", "response_previewed": False}
|
||||
|
||||
# Reproduce the logic from run.py return path (post-fix)
|
||||
if sc and isinstance(response, dict) and not response.get("failed"):
|
||||
if (
|
||||
getattr(sc, "final_response_sent", False)
|
||||
or getattr(sc, "already_sent", False)
|
||||
):
|
||||
response["already_sent"] = True
|
||||
|
||||
assert response.get("already_sent") is True
|
||||
|
||||
def test_already_sent_not_set_when_nothing_sent(self):
|
||||
"""When stream consumer hasn't sent anything, already_sent should
|
||||
not be set on the response."""
|
||||
sc = self._make_mock_stream_consumer(already_sent=False, final_response_sent=False)
|
||||
response = {"final_response": "text", "response_previewed": False}
|
||||
|
||||
if sc and isinstance(response, dict) and not response.get("failed"):
|
||||
if (
|
||||
getattr(sc, "final_response_sent", False)
|
||||
or getattr(sc, "already_sent", False)
|
||||
):
|
||||
response["already_sent"] = True
|
||||
|
||||
assert "already_sent" not in response
|
||||
|
||||
def test_already_sent_set_on_final_response_sent(self):
|
||||
"""final_response_sent=True should still work as before."""
|
||||
sc = self._make_mock_stream_consumer(already_sent=False, final_response_sent=True)
|
||||
response = {"final_response": "text"}
|
||||
|
||||
if sc and isinstance(response, dict) and not response.get("failed"):
|
||||
if (
|
||||
getattr(sc, "final_response_sent", False)
|
||||
or getattr(sc, "already_sent", False)
|
||||
):
|
||||
response["already_sent"] = True
|
||||
|
||||
assert response.get("already_sent") is True
|
||||
|
||||
def test_already_sent_not_set_on_failed_response(self):
|
||||
"""Failed responses should never be suppressed — user needs to see
|
||||
the error message even if streaming sent earlier partial output."""
|
||||
sc = self._make_mock_stream_consumer(already_sent=True, final_response_sent=False)
|
||||
response = {"final_response": "Error: something broke", "failed": True}
|
||||
|
||||
if sc and isinstance(response, dict) and not response.get("failed"):
|
||||
if (
|
||||
getattr(sc, "final_response_sent", False)
|
||||
or getattr(sc, "already_sent", False)
|
||||
):
|
||||
response["already_sent"] = True
|
||||
|
||||
assert "already_sent" not in response
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Test 3: run.py queued-message path — _already_streamed detection
|
||||
# ===================================================================
|
||||
|
||||
class TestQueuedMessageAlreadyStreamed:
|
||||
"""The queued-message path should detect that the first response was
|
||||
already streamed (already_sent=True) even without response_previewed."""
|
||||
|
||||
def _make_mock_sc(self, already_sent=False, final_response_sent=False):
|
||||
return SimpleNamespace(
|
||||
already_sent=already_sent,
|
||||
final_response_sent=final_response_sent,
|
||||
)
|
||||
|
||||
def test_queued_path_detects_already_streamed(self):
|
||||
"""already_sent=True on stream consumer means first response was
|
||||
streamed — skip re-sending before processing queued message."""
|
||||
_sc = self._make_mock_sc(already_sent=True)
|
||||
|
||||
# Reproduce the queued-message logic from run.py (post-fix)
|
||||
_already_streamed = bool(
|
||||
_sc
|
||||
and (
|
||||
getattr(_sc, "final_response_sent", False)
|
||||
or getattr(_sc, "already_sent", False)
|
||||
)
|
||||
)
|
||||
|
||||
assert _already_streamed is True
|
||||
|
||||
def test_queued_path_sends_when_not_streamed(self):
|
||||
"""Nothing was streamed — first response should be sent before
|
||||
processing the queued message."""
|
||||
_sc = self._make_mock_sc(already_sent=False)
|
||||
|
||||
_already_streamed = bool(
|
||||
_sc
|
||||
and (
|
||||
getattr(_sc, "final_response_sent", False)
|
||||
or getattr(_sc, "already_sent", False)
|
||||
)
|
||||
)
|
||||
|
||||
assert _already_streamed is False
|
||||
|
||||
def test_queued_path_with_no_stream_consumer(self):
|
||||
"""No stream consumer at all (streaming disabled) — not streamed."""
|
||||
_sc = None
|
||||
|
||||
_already_streamed = bool(
|
||||
_sc
|
||||
and (
|
||||
getattr(_sc, "final_response_sent", False)
|
||||
or getattr(_sc, "already_sent", False)
|
||||
)
|
||||
)
|
||||
|
||||
assert _already_streamed is False
|
||||
@@ -125,25 +125,6 @@ async def test_gateway_stop_service_restart_sets_named_exit_code():
|
||||
assert runner._exit_code == GATEWAY_SERVICE_RESTART_EXIT_CODE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gateway_stop_emits_shutdown_hook_after_drain(monkeypatch):
|
||||
runner, adapter = make_restart_runner()
|
||||
adapter.disconnect = AsyncMock()
|
||||
runner.hooks.emit = AsyncMock()
|
||||
|
||||
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
|
||||
await runner.stop(restart=True, service_restart=True)
|
||||
|
||||
runner.hooks.emit.assert_awaited_once_with(
|
||||
"gateway:shutdown",
|
||||
{
|
||||
"restart": True,
|
||||
"service_restart": True,
|
||||
"detached_restart": False,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_active_agents_throttles_status_updates():
|
||||
runner, _adapter = make_restart_runner()
|
||||
|
||||
@@ -9,7 +9,7 @@ import pytest
|
||||
from gateway.hooks import HookRegistry
|
||||
|
||||
|
||||
def _create_hook(hooks_dir, hook_name, events, handler_code, *, manifest_extra=""):
|
||||
def _create_hook(hooks_dir, hook_name, events, handler_code):
|
||||
"""Helper to create a hook directory with HOOK.yaml and handler.py."""
|
||||
hook_dir = hooks_dir / hook_name
|
||||
hook_dir.mkdir(parents=True)
|
||||
@@ -17,7 +17,6 @@ def _create_hook(hooks_dir, hook_name, events, handler_code, *, manifest_extra="
|
||||
f"name: {hook_name}\n"
|
||||
f"description: Test hook\n"
|
||||
f"events: {events}\n"
|
||||
f"{manifest_extra}"
|
||||
)
|
||||
(hook_dir / "handler.py").write_text(handler_code)
|
||||
return hook_dir
|
||||
@@ -113,24 +112,6 @@ class TestDiscoverAndLoad:
|
||||
|
||||
assert len(reg.loaded_hooks) == 2
|
||||
|
||||
def test_preserves_optional_startup_readiness_metadata(self, tmp_path):
|
||||
_create_hook(
|
||||
tmp_path,
|
||||
"ready-hook",
|
||||
'["gateway:startup"]',
|
||||
"def handle(e, c): pass\n",
|
||||
manifest_extra="startup_readiness:\n id: beam-runtime\n required: false\n",
|
||||
)
|
||||
|
||||
reg = HookRegistry()
|
||||
with patch("gateway.hooks.HOOKS_DIR", tmp_path), _patch_no_builtins(reg):
|
||||
reg.discover_and_load()
|
||||
|
||||
assert reg.loaded_hooks[0]["startup_readiness"] == {
|
||||
"id": "beam-runtime",
|
||||
"required": False,
|
||||
}
|
||||
|
||||
|
||||
class TestEmit:
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -230,6 +230,59 @@ async def test_notify_on_complete_preserves_user_identity(monkeypatch, tmp_path)
|
||||
assert event.source.user_name == "alice"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notify_on_complete_uses_session_store_origin_for_group_topic(monkeypatch, tmp_path):
|
||||
import tools.process_registry as pr_module
|
||||
from gateway.session import SessionSource
|
||||
|
||||
sessions = [
|
||||
SimpleNamespace(
|
||||
output_buffer="done\n", exited=True, exit_code=0, command="echo test"
|
||||
),
|
||||
]
|
||||
monkeypatch.setattr(pr_module, "process_registry", _FakeRegistry(sessions))
|
||||
|
||||
async def _instant_sleep(*_a, **_kw):
|
||||
pass
|
||||
monkeypatch.setattr(asyncio, "sleep", _instant_sleep)
|
||||
|
||||
runner = GatewayRunner(GatewayConfig())
|
||||
adapter = SimpleNamespace(send=AsyncMock(), handle_message=AsyncMock())
|
||||
runner.adapters[Platform.TELEGRAM] = adapter
|
||||
runner.session_store._entries["agent:main:telegram:group:-100:42"] = SimpleNamespace(
|
||||
origin=SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="-100",
|
||||
chat_type="group",
|
||||
thread_id="42",
|
||||
user_id="user-42",
|
||||
user_name="alice",
|
||||
)
|
||||
)
|
||||
|
||||
watcher = {
|
||||
"session_id": "proc_test_internal",
|
||||
"check_interval": 0,
|
||||
"session_key": "agent:main:telegram:group:-100:42",
|
||||
"platform": "telegram",
|
||||
"chat_id": "-100",
|
||||
"thread_id": "42",
|
||||
"notify_on_complete": True,
|
||||
}
|
||||
|
||||
await runner._run_process_watcher(watcher)
|
||||
|
||||
assert adapter.handle_message.await_count == 1
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.internal is True
|
||||
assert event.source.platform == Platform.TELEGRAM
|
||||
assert event.source.chat_id == "-100"
|
||||
assert event.source.chat_type == "group"
|
||||
assert event.source.thread_id == "42"
|
||||
assert event.source.user_id == "user-42"
|
||||
assert event.source.user_name == "alice"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_user_id_skips_pairing(monkeypatch, tmp_path):
|
||||
"""A non-internal event with user_id=None should be silently dropped."""
|
||||
|
||||
@@ -335,6 +335,29 @@ def _make_adapter():
|
||||
return adapter
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Typing indicator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMatrixTypingIndicator:
|
||||
def setup_method(self):
|
||||
self.adapter = _make_adapter()
|
||||
self.adapter._client = MagicMock()
|
||||
self.adapter._client.set_typing = AsyncMock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_typing_clears_matrix_typing_state(self):
|
||||
"""stop_typing() should send typing=false instead of waiting for timeout expiry."""
|
||||
from gateway.platforms.matrix import RoomID
|
||||
|
||||
await self.adapter.stop_typing("!room:example.org")
|
||||
|
||||
self.adapter._client.set_typing.assert_awaited_once_with(
|
||||
RoomID("!room:example.org"),
|
||||
timeout=0,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# mxc:// URL conversion
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -1831,4 +1854,3 @@ class TestMatrixPresence:
|
||||
assert result is False
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -193,7 +193,7 @@ async def test_shutdown_notification_says_restarting_when_restart_requested():
|
||||
|
||||
assert len(adapter.sent) == 1
|
||||
assert "restarting" in adapter.sent[0]
|
||||
assert "/retry" in adapter.sent[0]
|
||||
assert "resume" in adapter.sent[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -132,68 +132,6 @@ async def test_runner_records_connected_platform_state_on_success(monkeypatch, t
|
||||
assert state["platforms"]["discord"]["error_message"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_discovers_plugins_before_loading_hooks(monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.DISCORD: PlatformConfig(enabled=True, token="***")
|
||||
},
|
||||
sessions_dir=tmp_path / "sessions",
|
||||
)
|
||||
runner = GatewayRunner(config)
|
||||
order: list[str] = []
|
||||
|
||||
monkeypatch.setattr(runner, "_create_adapter", lambda platform, platform_config: _SuccessfulAdapter())
|
||||
monkeypatch.setattr("hermes_cli.plugins.discover_plugins", lambda: order.append("plugins"))
|
||||
monkeypatch.setattr(runner.hooks, "discover_and_load", lambda: order.append("hooks"))
|
||||
monkeypatch.setattr(runner.hooks, "emit", AsyncMock())
|
||||
|
||||
ok = await runner.start()
|
||||
|
||||
assert ok is True
|
||||
assert order == ["plugins", "hooks"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_initializes_startup_checks_before_gateway_startup_emit(monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.DISCORD: PlatformConfig(enabled=True, token="***")
|
||||
},
|
||||
sessions_dir=tmp_path / "sessions",
|
||||
)
|
||||
runner = GatewayRunner(config)
|
||||
|
||||
runner.hooks._loaded_hooks = [
|
||||
{
|
||||
"name": "beam-runtime",
|
||||
"events": ["gateway:startup"],
|
||||
"path": str(tmp_path / "hook"),
|
||||
"startup_readiness": {
|
||||
"id": "beam-runtime",
|
||||
"required": True,
|
||||
},
|
||||
}
|
||||
]
|
||||
monkeypatch.setattr(runner, "_create_adapter", lambda platform, platform_config: _SuccessfulAdapter())
|
||||
monkeypatch.setattr("hermes_cli.plugins.discover_plugins", lambda: None)
|
||||
monkeypatch.setattr(runner.hooks, "discover_and_load", lambda: None)
|
||||
|
||||
async def _assert_checks(event_type, context):
|
||||
state = read_runtime_status()
|
||||
assert event_type == "gateway:startup"
|
||||
assert state["startup_checks"]["beam-runtime"]["state"] == "pending"
|
||||
assert state["startup_checks"]["beam-runtime"]["required"] is True
|
||||
|
||||
monkeypatch.setattr(runner.hooks, "emit", _assert_checks)
|
||||
|
||||
ok = await runner.start()
|
||||
|
||||
assert ok is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_gateway_verbosity_imports_redacting_formatter(monkeypatch, tmp_path):
|
||||
"""Verbosity != None must not crash with NameError on RedactingFormatter (#8044)."""
|
||||
|
||||
@@ -132,72 +132,6 @@ class TestGatewayRuntimeStatus:
|
||||
assert payload["platforms"]["discord"]["error_code"] is None
|
||||
assert payload["platforms"]["discord"]["error_message"] is None
|
||||
|
||||
def test_reset_startup_checks_replaces_previous_run_entries(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
status.write_runtime_status(
|
||||
gateway_state="running",
|
||||
startup_checks={
|
||||
"old-check": {
|
||||
"state": "ready",
|
||||
"required": True,
|
||||
"source": "old-hook",
|
||||
"detail": None,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
status.reset_startup_checks([
|
||||
{
|
||||
"name": "new-hook",
|
||||
"startup_readiness": {
|
||||
"id": "new-check",
|
||||
"required": False,
|
||||
},
|
||||
}
|
||||
])
|
||||
|
||||
payload = status.read_runtime_status()
|
||||
assert set(payload["startup_checks"]) == {"new-check"}
|
||||
assert payload["startup_checks"]["new-check"]["state"] == "pending"
|
||||
assert payload["startup_checks"]["new-check"]["required"] is False
|
||||
assert payload["startup_checks"]["new-check"]["source"] == "new-hook"
|
||||
|
||||
def test_mark_startup_check_ready_persists_detail(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
status.reset_startup_checks([
|
||||
{
|
||||
"name": "beam",
|
||||
"startup_readiness": {
|
||||
"id": "beam-runtime",
|
||||
"required": True,
|
||||
},
|
||||
}
|
||||
])
|
||||
|
||||
status.mark_startup_check_ready("beam-runtime", detail="ready for RPC")
|
||||
|
||||
payload = status.read_runtime_status()
|
||||
assert payload["startup_checks"]["beam-runtime"]["state"] == "ready"
|
||||
assert payload["startup_checks"]["beam-runtime"]["detail"] == "ready for RPC"
|
||||
|
||||
def test_mark_startup_check_failed_creates_missing_entry(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
status.mark_startup_check_failed(
|
||||
"late-hook",
|
||||
detail="startup hook crashed",
|
||||
required=False,
|
||||
source="late-hook",
|
||||
)
|
||||
|
||||
payload = status.read_runtime_status()
|
||||
assert payload["startup_checks"]["late-hook"]["state"] == "failed"
|
||||
assert payload["startup_checks"]["late-hook"]["required"] is False
|
||||
assert payload["startup_checks"]["late-hook"]["source"] == "late-hook"
|
||||
assert payload["startup_checks"]["late-hook"]["detail"] == "startup hook crashed"
|
||||
|
||||
|
||||
class TestTerminatePid:
|
||||
def test_force_uses_taskkill_on_windows(self, monkeypatch):
|
||||
|
||||
@@ -0,0 +1,116 @@
|
||||
"""Tests for stuck-session loop detection (#7536).
|
||||
|
||||
When a session is active across 3+ consecutive gateway restarts (the agent
|
||||
gets stuck, gateway restarts, same session gets stuck again), the session
|
||||
is auto-suspended on startup so the user gets a clean slate.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.gateway.restart_test_helpers import make_restart_runner
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner_with_home(tmp_path, monkeypatch):
|
||||
"""Create a runner with a writable HERMES_HOME."""
|
||||
monkeypatch.setattr("gateway.run._hermes_home", tmp_path)
|
||||
runner, adapter = make_restart_runner()
|
||||
return runner, tmp_path
|
||||
|
||||
|
||||
class TestStuckLoopDetection:
|
||||
|
||||
def test_increment_creates_file(self, runner_with_home):
|
||||
runner, home = runner_with_home
|
||||
runner._increment_restart_failure_counts({"session:a", "session:b"})
|
||||
path = home / runner._STUCK_LOOP_FILE
|
||||
assert path.exists()
|
||||
counts = json.loads(path.read_text())
|
||||
assert counts["session:a"] == 1
|
||||
assert counts["session:b"] == 1
|
||||
|
||||
def test_increment_accumulates(self, runner_with_home):
|
||||
runner, home = runner_with_home
|
||||
runner._increment_restart_failure_counts({"session:a"})
|
||||
runner._increment_restart_failure_counts({"session:a"})
|
||||
runner._increment_restart_failure_counts({"session:a"})
|
||||
counts = json.loads((home / runner._STUCK_LOOP_FILE).read_text())
|
||||
assert counts["session:a"] == 3
|
||||
|
||||
def test_increment_drops_inactive_sessions(self, runner_with_home):
|
||||
runner, home = runner_with_home
|
||||
runner._increment_restart_failure_counts({"session:a", "session:b"})
|
||||
runner._increment_restart_failure_counts({"session:a"}) # b not active
|
||||
counts = json.loads((home / runner._STUCK_LOOP_FILE).read_text())
|
||||
assert "session:a" in counts
|
||||
assert "session:b" not in counts
|
||||
|
||||
def test_suspend_at_threshold(self, runner_with_home):
|
||||
runner, home = runner_with_home
|
||||
# Simulate 3 restarts with session:a active each time
|
||||
for _ in range(3):
|
||||
runner._increment_restart_failure_counts({"session:a"})
|
||||
|
||||
# Create a mock session entry
|
||||
mock_entry = MagicMock()
|
||||
mock_entry.suspended = False
|
||||
runner.session_store._entries = {"session:a": mock_entry}
|
||||
runner.session_store._save = MagicMock()
|
||||
|
||||
suspended = runner._suspend_stuck_loop_sessions()
|
||||
assert suspended == 1
|
||||
assert mock_entry.suspended is True
|
||||
|
||||
def test_no_suspend_below_threshold(self, runner_with_home):
|
||||
runner, home = runner_with_home
|
||||
runner._increment_restart_failure_counts({"session:a"})
|
||||
runner._increment_restart_failure_counts({"session:a"})
|
||||
# Only 2 restarts — below threshold of 3
|
||||
|
||||
mock_entry = MagicMock()
|
||||
mock_entry.suspended = False
|
||||
runner.session_store._entries = {"session:a": mock_entry}
|
||||
|
||||
suspended = runner._suspend_stuck_loop_sessions()
|
||||
assert suspended == 0
|
||||
assert mock_entry.suspended is False
|
||||
|
||||
def test_clear_on_success(self, runner_with_home):
|
||||
runner, home = runner_with_home
|
||||
runner._increment_restart_failure_counts({"session:a", "session:b"})
|
||||
runner._clear_restart_failure_count("session:a")
|
||||
|
||||
path = home / runner._STUCK_LOOP_FILE
|
||||
counts = json.loads(path.read_text())
|
||||
assert "session:a" not in counts
|
||||
assert "session:b" in counts
|
||||
|
||||
def test_clear_removes_file_when_empty(self, runner_with_home):
|
||||
runner, home = runner_with_home
|
||||
runner._increment_restart_failure_counts({"session:a"})
|
||||
runner._clear_restart_failure_count("session:a")
|
||||
assert not (home / runner._STUCK_LOOP_FILE).exists()
|
||||
|
||||
def test_suspend_clears_file(self, runner_with_home):
|
||||
runner, home = runner_with_home
|
||||
for _ in range(3):
|
||||
runner._increment_restart_failure_counts({"session:a"})
|
||||
|
||||
mock_entry = MagicMock()
|
||||
mock_entry.suspended = False
|
||||
runner.session_store._entries = {"session:a": mock_entry}
|
||||
runner.session_store._save = MagicMock()
|
||||
|
||||
runner._suspend_stuck_loop_sessions()
|
||||
assert not (home / runner._STUCK_LOOP_FILE).exists()
|
||||
|
||||
def test_no_file_no_crash(self, runner_with_home):
|
||||
runner, home = runner_with_home
|
||||
# No file exists — should return 0 and not crash
|
||||
assert runner._suspend_stuck_loop_sessions() == 0
|
||||
# Clear on nonexistent file — should not crash
|
||||
runner._clear_restart_failure_count("nonexistent")
|
||||
@@ -0,0 +1,275 @@
|
||||
"""Tests for the Command Installation check in hermes doctor."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
import hermes_cli.doctor as doctor_mod
|
||||
|
||||
|
||||
def _setup_doctor_env(monkeypatch, tmp_path, venv_name="venv"):
|
||||
"""Create a minimal HERMES_HOME + PROJECT_ROOT for doctor tests."""
|
||||
home = tmp_path / ".hermes"
|
||||
home.mkdir(parents=True, exist_ok=True)
|
||||
(home / "config.yaml").write_text("memory: {}\n", encoding="utf-8")
|
||||
|
||||
project = tmp_path / "project"
|
||||
project.mkdir(exist_ok=True)
|
||||
|
||||
# Create a fake venv entry point
|
||||
venv_bin_dir = project / venv_name / "bin"
|
||||
venv_bin_dir.mkdir(parents=True, exist_ok=True)
|
||||
hermes_bin = venv_bin_dir / "hermes"
|
||||
hermes_bin.write_text("#!/usr/bin/env python\n# entry point\n")
|
||||
hermes_bin.chmod(0o755)
|
||||
|
||||
monkeypatch.setattr(doctor_mod, "HERMES_HOME", home)
|
||||
monkeypatch.setattr(doctor_mod, "PROJECT_ROOT", project)
|
||||
monkeypatch.setattr(doctor_mod, "_DHH", str(home))
|
||||
|
||||
# Stub model_tools so doctor doesn't fail on import
|
||||
fake_model_tools = types.SimpleNamespace(
|
||||
check_tool_availability=lambda *a, **kw: ([], []),
|
||||
TOOLSET_REQUIREMENTS={},
|
||||
)
|
||||
monkeypatch.setitem(sys.modules, "model_tools", fake_model_tools)
|
||||
|
||||
# Stub auth checks
|
||||
try:
|
||||
from hermes_cli import auth as _auth_mod
|
||||
monkeypatch.setattr(_auth_mod, "get_nous_auth_status", lambda: {})
|
||||
monkeypatch.setattr(_auth_mod, "get_codex_auth_status", lambda: {})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Stub httpx.get to avoid network calls
|
||||
try:
|
||||
import httpx
|
||||
monkeypatch.setattr(httpx, "get", lambda *a, **kw: types.SimpleNamespace(status_code=200))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return home, project, hermes_bin
|
||||
|
||||
|
||||
def _run_doctor(fix=False):
|
||||
"""Run doctor and capture stdout."""
|
||||
import io
|
||||
import contextlib
|
||||
|
||||
buf = io.StringIO()
|
||||
with contextlib.redirect_stdout(buf):
|
||||
doctor_mod.run_doctor(Namespace(fix=fix))
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
class TestDoctorCommandInstallation:
|
||||
"""Tests for the ◆ Command Installation section."""
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Symlink check is Unix-only")
|
||||
def test_correct_symlink_shows_ok(self, monkeypatch, tmp_path):
|
||||
home, project, hermes_bin = _setup_doctor_env(monkeypatch, tmp_path)
|
||||
|
||||
# Create the command link dir with correct symlink
|
||||
cmd_link_dir = tmp_path / ".local" / "bin"
|
||||
cmd_link_dir.mkdir(parents=True)
|
||||
cmd_link = cmd_link_dir / "hermes"
|
||||
cmd_link.symlink_to(hermes_bin)
|
||||
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
out = _run_doctor(fix=False)
|
||||
assert "Command Installation" in out
|
||||
assert "Venv entry point exists" in out
|
||||
assert "correct target" in out
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Symlink check is Unix-only")
|
||||
def test_missing_symlink_shows_fail(self, monkeypatch, tmp_path):
|
||||
home, project, hermes_bin = _setup_doctor_env(monkeypatch, tmp_path)
|
||||
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
# Don't create the symlink — it should be missing
|
||||
|
||||
out = _run_doctor(fix=False)
|
||||
assert "Command Installation" in out
|
||||
assert "Venv entry point exists" in out
|
||||
assert "not found" in out
|
||||
assert "hermes doctor --fix" in out
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Symlink check is Unix-only")
|
||||
def test_fix_creates_missing_symlink(self, monkeypatch, tmp_path):
|
||||
home, project, hermes_bin = _setup_doctor_env(monkeypatch, tmp_path)
|
||||
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
out = _run_doctor(fix=True)
|
||||
assert "Command Installation" in out
|
||||
assert "Created symlink" in out
|
||||
|
||||
# Verify the symlink was actually created
|
||||
cmd_link = tmp_path / ".local" / "bin" / "hermes"
|
||||
assert cmd_link.is_symlink()
|
||||
assert cmd_link.resolve() == hermes_bin.resolve()
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Symlink check is Unix-only")
|
||||
def test_wrong_target_symlink_shows_warn(self, monkeypatch, tmp_path):
|
||||
home, project, hermes_bin = _setup_doctor_env(monkeypatch, tmp_path)
|
||||
|
||||
# Create a symlink pointing to the wrong target
|
||||
cmd_link_dir = tmp_path / ".local" / "bin"
|
||||
cmd_link_dir.mkdir(parents=True)
|
||||
cmd_link = cmd_link_dir / "hermes"
|
||||
wrong_target = tmp_path / "wrong_hermes"
|
||||
wrong_target.write_text("#!/usr/bin/env python\n")
|
||||
cmd_link.symlink_to(wrong_target)
|
||||
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
out = _run_doctor(fix=False)
|
||||
assert "Command Installation" in out
|
||||
assert "wrong target" in out
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Symlink check is Unix-only")
|
||||
def test_fix_repairs_wrong_symlink(self, monkeypatch, tmp_path):
|
||||
home, project, hermes_bin = _setup_doctor_env(monkeypatch, tmp_path)
|
||||
|
||||
# Create a symlink pointing to wrong target
|
||||
cmd_link_dir = tmp_path / ".local" / "bin"
|
||||
cmd_link_dir.mkdir(parents=True)
|
||||
cmd_link = cmd_link_dir / "hermes"
|
||||
wrong_target = tmp_path / "wrong_hermes"
|
||||
wrong_target.write_text("#!/usr/bin/env python\n")
|
||||
cmd_link.symlink_to(wrong_target)
|
||||
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
out = _run_doctor(fix=True)
|
||||
assert "Fixed symlink" in out
|
||||
|
||||
# Verify the symlink now points to the correct target
|
||||
assert cmd_link.is_symlink()
|
||||
assert cmd_link.resolve() == hermes_bin.resolve()
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Symlink check is Unix-only")
|
||||
def test_missing_venv_entry_point_shows_warn(self, monkeypatch, tmp_path):
|
||||
home = tmp_path / ".hermes"
|
||||
home.mkdir(parents=True, exist_ok=True)
|
||||
(home / "config.yaml").write_text("memory: {}\n", encoding="utf-8")
|
||||
|
||||
project = tmp_path / "project"
|
||||
project.mkdir(exist_ok=True)
|
||||
# Do NOT create any venv entry point
|
||||
|
||||
monkeypatch.setattr(doctor_mod, "HERMES_HOME", home)
|
||||
monkeypatch.setattr(doctor_mod, "PROJECT_ROOT", project)
|
||||
monkeypatch.setattr(doctor_mod, "_DHH", str(home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
fake_model_tools = types.SimpleNamespace(
|
||||
check_tool_availability=lambda *a, **kw: ([], []),
|
||||
TOOLSET_REQUIREMENTS={},
|
||||
)
|
||||
monkeypatch.setitem(sys.modules, "model_tools", fake_model_tools)
|
||||
try:
|
||||
from hermes_cli import auth as _auth_mod
|
||||
monkeypatch.setattr(_auth_mod, "get_nous_auth_status", lambda: {})
|
||||
monkeypatch.setattr(_auth_mod, "get_codex_auth_status", lambda: {})
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
import httpx
|
||||
monkeypatch.setattr(httpx, "get", lambda *a, **kw: types.SimpleNamespace(status_code=200))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
out = _run_doctor(fix=False)
|
||||
assert "Command Installation" in out
|
||||
assert "Venv entry point not found" in out
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Symlink check is Unix-only")
|
||||
def test_dot_venv_dir_is_found(self, monkeypatch, tmp_path):
|
||||
"""The check finds entry points in .venv/ as well as venv/."""
|
||||
home, project, _ = _setup_doctor_env(monkeypatch, tmp_path, venv_name=".venv")
|
||||
|
||||
# Create the command link with correct symlink
|
||||
hermes_bin = project / ".venv" / "bin" / "hermes"
|
||||
cmd_link_dir = tmp_path / ".local" / "bin"
|
||||
cmd_link_dir.mkdir(parents=True)
|
||||
cmd_link = cmd_link_dir / "hermes"
|
||||
cmd_link.symlink_to(hermes_bin)
|
||||
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
out = _run_doctor(fix=False)
|
||||
assert "Venv entry point exists" in out
|
||||
assert ".venv/bin/hermes" in out
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Symlink check is Unix-only")
|
||||
def test_non_symlink_regular_file_shows_ok(self, monkeypatch, tmp_path):
|
||||
"""If ~/.local/bin/hermes is a regular file (not symlink), accept it."""
|
||||
home, project, hermes_bin = _setup_doctor_env(monkeypatch, tmp_path)
|
||||
|
||||
cmd_link_dir = tmp_path / ".local" / "bin"
|
||||
cmd_link_dir.mkdir(parents=True)
|
||||
cmd_link = cmd_link_dir / "hermes"
|
||||
cmd_link.write_text("#!/bin/sh\nexec python -m hermes_cli.main \"$@\"\n")
|
||||
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
out = _run_doctor(fix=False)
|
||||
assert "non-symlink" in out
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Symlink check is Unix-only")
|
||||
def test_termux_uses_prefix_bin(self, monkeypatch, tmp_path):
|
||||
"""On Termux, the command link dir is $PREFIX/bin."""
|
||||
prefix_dir = tmp_path / "termux_prefix"
|
||||
prefix_bin = prefix_dir / "bin"
|
||||
prefix_bin.mkdir(parents=True)
|
||||
|
||||
home, project, hermes_bin = _setup_doctor_env(monkeypatch, tmp_path)
|
||||
|
||||
monkeypatch.setenv("TERMUX_VERSION", "0.118.3")
|
||||
monkeypatch.setenv("PREFIX", str(prefix_dir))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
out = _run_doctor(fix=False)
|
||||
assert "Command Installation" in out
|
||||
assert "$PREFIX/bin" in out
|
||||
|
||||
def test_windows_skips_check(self, monkeypatch, tmp_path):
|
||||
"""On Windows, the Command Installation section is skipped."""
|
||||
home = tmp_path / ".hermes"
|
||||
home.mkdir(parents=True, exist_ok=True)
|
||||
(home / "config.yaml").write_text("memory: {}\n", encoding="utf-8")
|
||||
|
||||
project = tmp_path / "project"
|
||||
project.mkdir(exist_ok=True)
|
||||
|
||||
monkeypatch.setattr(doctor_mod, "HERMES_HOME", home)
|
||||
monkeypatch.setattr(doctor_mod, "PROJECT_ROOT", project)
|
||||
monkeypatch.setattr(doctor_mod, "_DHH", str(home))
|
||||
monkeypatch.setattr(sys, "platform", "win32")
|
||||
|
||||
fake_model_tools = types.SimpleNamespace(
|
||||
check_tool_availability=lambda *a, **kw: ([], []),
|
||||
TOOLSET_REQUIREMENTS={},
|
||||
)
|
||||
monkeypatch.setitem(sys.modules, "model_tools", fake_model_tools)
|
||||
try:
|
||||
from hermes_cli import auth as _auth_mod
|
||||
monkeypatch.setattr(_auth_mod, "get_nous_auth_status", lambda: {})
|
||||
monkeypatch.setattr(_auth_mod, "get_codex_auth_status", lambda: {})
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
import httpx
|
||||
monkeypatch.setattr(httpx, "get", lambda *a, **kw: types.SimpleNamespace(status_code=200))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
out = _run_doctor(fix=False)
|
||||
assert "Command Installation" not in out
|
||||
@@ -6,21 +6,12 @@ from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import hermes_cli.gateway as gateway_cli
|
||||
import pytest
|
||||
from gateway.restart import (
|
||||
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT,
|
||||
GATEWAY_SERVICE_RESTART_EXIT_CODE,
|
||||
)
|
||||
|
||||
|
||||
_REAL_AWAIT_SERVICE_READY = gateway_cli._await_service_ready_or_exit
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _stub_service_readiness(monkeypatch):
|
||||
monkeypatch.setattr(gateway_cli, "_await_service_ready_or_exit", lambda **kwargs: None)
|
||||
|
||||
|
||||
class TestSystemdServiceRefresh:
|
||||
def test_systemd_install_repairs_outdated_unit_without_force(self, tmp_path, monkeypatch):
|
||||
unit_path = tmp_path / "hermes-gateway.service"
|
||||
@@ -91,30 +82,6 @@ class TestSystemdServiceRefresh:
|
||||
["systemctl", "--user", "reload-or-restart", gateway_cli.get_service_name()],
|
||||
]
|
||||
|
||||
def test_systemd_start_waits_for_readiness_before_reporting_success(self, monkeypatch):
|
||||
calls = []
|
||||
|
||||
monkeypatch.setattr(gateway_cli, "_select_systemd_scope", lambda system=False: False)
|
||||
monkeypatch.setattr(gateway_cli, "refresh_systemd_unit_if_needed", lambda system=False: calls.append(("refresh", system)))
|
||||
monkeypatch.setattr(
|
||||
gateway_cli,
|
||||
"_run_systemctl",
|
||||
lambda cmd, system=False, check=True, timeout=30, **kwargs: calls.append((tuple(cmd), system, timeout)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
gateway_cli,
|
||||
"_await_service_ready_or_exit",
|
||||
lambda **kwargs: calls.append(("ready", kwargs)),
|
||||
)
|
||||
|
||||
gateway_cli.systemd_start()
|
||||
|
||||
assert calls == [
|
||||
("refresh", False),
|
||||
(("start", gateway_cli.get_service_name()), False, 30),
|
||||
("ready", {"action": "start"}),
|
||||
]
|
||||
|
||||
|
||||
class TestGeneratedSystemdUnits:
|
||||
def test_user_unit_avoids_recursive_execstop_and_uses_extended_stop_timeout(self):
|
||||
@@ -301,32 +268,6 @@ class TestLaunchdServiceRecovery:
|
||||
["launchctl", "kickstart", target],
|
||||
]
|
||||
|
||||
def test_launchd_start_waits_for_readiness_before_reporting_success(self, tmp_path, monkeypatch):
|
||||
plist_path = tmp_path / "ai.hermes.gateway.plist"
|
||||
plist_path.write_text(gateway_cli.generate_launchd_plist(), encoding="utf-8")
|
||||
label = gateway_cli.get_launchd_label()
|
||||
calls = []
|
||||
|
||||
monkeypatch.setattr(gateway_cli, "get_launchd_plist_path", lambda: plist_path)
|
||||
monkeypatch.setattr(gateway_cli, "refresh_launchd_plist_if_needed", lambda: None)
|
||||
monkeypatch.setattr(
|
||||
gateway_cli.subprocess,
|
||||
"run",
|
||||
lambda cmd, check=False, **kwargs: calls.append(cmd) or SimpleNamespace(returncode=0, stdout="", stderr=""),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
gateway_cli,
|
||||
"_await_service_ready_or_exit",
|
||||
lambda **kwargs: calls.append(("ready", kwargs)),
|
||||
)
|
||||
|
||||
gateway_cli.launchd_start()
|
||||
|
||||
assert calls == [
|
||||
["launchctl", "kickstart", f"{gateway_cli._launchd_domain()}/{label}"],
|
||||
("ready", {"action": "start"}),
|
||||
]
|
||||
|
||||
def test_launchd_restart_drains_running_gateway_before_kickstart(self, monkeypatch):
|
||||
calls = []
|
||||
target = f"{gateway_cli._launchd_domain()}/{gateway_cli.get_launchd_label()}"
|
||||
@@ -374,7 +315,7 @@ class TestLaunchdServiceRecovery:
|
||||
gateway_cli.launchd_restart()
|
||||
|
||||
assert calls == [("self", 321)]
|
||||
assert "service restarted" in capsys.readouterr().out.lower()
|
||||
assert "restart requested" in capsys.readouterr().out.lower()
|
||||
|
||||
def test_launchd_stop_uses_bootout_not_kill(self, monkeypatch):
|
||||
"""launchd_stop must bootout the service so KeepAlive doesn't respawn it."""
|
||||
@@ -452,109 +393,6 @@ class TestLaunchdServiceRecovery:
|
||||
assert "not loaded" in output.lower()
|
||||
|
||||
|
||||
class TestGatewayServiceReadiness:
|
||||
def test_wait_for_service_readiness_accepts_running_gateway_without_checks(self, monkeypatch):
|
||||
monkeypatch.setattr("gateway.status.get_running_pid", lambda: 123)
|
||||
monkeypatch.setattr(
|
||||
"gateway.status.read_runtime_status",
|
||||
lambda: {"pid": 123, "gateway_state": "running", "startup_checks": {}},
|
||||
)
|
||||
|
||||
warnings = gateway_cli._wait_for_service_readiness(action="start", timeout=0.1, poll_interval=0.0)
|
||||
|
||||
assert warnings == []
|
||||
|
||||
def test_wait_for_service_readiness_ignores_stale_runtime_state_until_pid_matches(self, monkeypatch):
|
||||
runtime_states = iter(
|
||||
[
|
||||
{"pid": 999, "gateway_state": "running", "startup_checks": {}},
|
||||
{"pid": 123, "gateway_state": "running", "startup_checks": {}},
|
||||
]
|
||||
)
|
||||
|
||||
monkeypatch.setattr("gateway.status.get_running_pid", lambda: 123)
|
||||
monkeypatch.setattr("gateway.status.read_runtime_status", lambda: next(runtime_states))
|
||||
|
||||
warnings = gateway_cli._wait_for_service_readiness(action="start", timeout=0.1, poll_interval=0.0)
|
||||
|
||||
assert warnings == []
|
||||
|
||||
def test_wait_for_service_readiness_returns_optional_pending_warnings(self, monkeypatch):
|
||||
monkeypatch.setattr("gateway.status.get_running_pid", lambda: 123)
|
||||
monkeypatch.setattr(
|
||||
"gateway.status.read_runtime_status",
|
||||
lambda: {
|
||||
"pid": 123,
|
||||
"gateway_state": "running",
|
||||
"startup_checks": {
|
||||
"optional-check": {
|
||||
"state": "pending",
|
||||
"required": False,
|
||||
"source": "test-hook",
|
||||
"detail": "still warming",
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
warnings = gateway_cli._wait_for_service_readiness(action="start", timeout=0.1, poll_interval=0.0)
|
||||
|
||||
assert warnings == ["pending: optional-check (test-hook): still warming"]
|
||||
|
||||
def test_wait_for_service_readiness_fails_when_required_check_fails(self, monkeypatch):
|
||||
monkeypatch.setattr("gateway.status.get_running_pid", lambda: 123)
|
||||
monkeypatch.setattr(
|
||||
"gateway.status.read_runtime_status",
|
||||
lambda: {
|
||||
"pid": 123,
|
||||
"gateway_state": "running",
|
||||
"startup_checks": {
|
||||
"beam-runtime": {
|
||||
"state": "failed",
|
||||
"required": True,
|
||||
"source": "beam",
|
||||
"detail": "RPC boot failed",
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match=r"required startup checks failed: beam-runtime \(beam\): RPC boot failed"):
|
||||
gateway_cli._wait_for_service_readiness(action="start", timeout=0.1, poll_interval=0.0)
|
||||
|
||||
def test_wait_for_service_readiness_times_out_on_pending_required_check(self, monkeypatch):
|
||||
monkeypatch.setattr("gateway.status.get_running_pid", lambda: 123)
|
||||
monkeypatch.setattr(
|
||||
"gateway.status.read_runtime_status",
|
||||
lambda: {
|
||||
"pid": 123,
|
||||
"gateway_state": "running",
|
||||
"startup_checks": {
|
||||
"beam-runtime": {
|
||||
"state": "pending",
|
||||
"required": True,
|
||||
"source": "beam",
|
||||
"detail": "waiting for runtime",
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match=r"timed out waiting for required startup checks: beam-runtime \(beam\): waiting for runtime"):
|
||||
gateway_cli._wait_for_service_readiness(action="start", timeout=0.01, poll_interval=0.0)
|
||||
|
||||
def test_await_service_ready_or_exit_raises_system_exit_when_not_ready(self, monkeypatch):
|
||||
monkeypatch.setattr(gateway_cli, "_await_service_ready_or_exit", _REAL_AWAIT_SERVICE_READY)
|
||||
monkeypatch.setattr(
|
||||
gateway_cli,
|
||||
"_wait_for_service_readiness",
|
||||
lambda **kwargs: (_ for _ in ()).throw(RuntimeError("not ready")),
|
||||
)
|
||||
|
||||
with pytest.raises(SystemExit, match="1"):
|
||||
gateway_cli._await_service_ready_or_exit(action="start")
|
||||
|
||||
|
||||
class TestGatewayServiceDetection:
|
||||
def test_supports_systemd_services_requires_systemctl_binary(self, monkeypatch):
|
||||
monkeypatch.setattr(gateway_cli, "is_linux", lambda: True)
|
||||
@@ -614,7 +452,7 @@ class TestGatewayServiceDetection:
|
||||
|
||||
|
||||
class TestGatewaySystemServiceRouting:
|
||||
def test_systemd_restart_self_requests_graceful_restart_without_reload_or_restart(self, monkeypatch, capsys):
|
||||
def test_systemd_restart_self_requests_graceful_restart_and_waits(self, monkeypatch, capsys):
|
||||
calls = []
|
||||
|
||||
monkeypatch.setattr(gateway_cli, "_select_systemd_scope", lambda system=False: False)
|
||||
@@ -628,16 +466,37 @@ class TestGatewaySystemServiceRouting:
|
||||
"_request_gateway_self_restart",
|
||||
lambda pid: calls.append(("self", pid)) or True,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
gateway_cli.subprocess,
|
||||
"run",
|
||||
lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("systemctl should not run")),
|
||||
)
|
||||
|
||||
# Simulate: old process dies immediately, new process becomes active
|
||||
kill_call_count = [0]
|
||||
def fake_kill(pid, sig):
|
||||
kill_call_count[0] += 1
|
||||
if kill_call_count[0] >= 2: # first call checks, second = dead
|
||||
raise ProcessLookupError()
|
||||
monkeypatch.setattr(os, "kill", fake_kill)
|
||||
|
||||
# Simulate systemctl is-active returning "active" with a new PID
|
||||
new_pid = [None]
|
||||
def fake_subprocess_run(cmd, **kwargs):
|
||||
if "is-active" in cmd:
|
||||
result = SimpleNamespace(stdout="active\n", returncode=0)
|
||||
new_pid[0] = 999 # new PID
|
||||
return result
|
||||
raise AssertionError(f"Unexpected systemctl call: {cmd}")
|
||||
|
||||
monkeypatch.setattr(gateway_cli.subprocess, "run", fake_subprocess_run)
|
||||
# get_running_pid returns new PID after restart
|
||||
pid_calls = [0]
|
||||
def fake_get_pid():
|
||||
pid_calls[0] += 1
|
||||
return 999 if pid_calls[0] > 1 else 654
|
||||
monkeypatch.setattr("gateway.status.get_running_pid", fake_get_pid)
|
||||
|
||||
gateway_cli.systemd_restart()
|
||||
|
||||
assert calls == [("refresh", False), ("self", 654)]
|
||||
assert "service restarted" in capsys.readouterr().out.lower()
|
||||
assert ("self", 654) in calls
|
||||
out = capsys.readouterr().out.lower()
|
||||
assert "restarted" in out
|
||||
|
||||
def test_gateway_install_passes_system_flags(self, monkeypatch):
|
||||
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: True)
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
"""Tests for non-ASCII credential detection and sanitization.
|
||||
|
||||
Covers the fix for issue #6843 — API keys containing Unicode lookalike
|
||||
characters (e.g. ʋ U+028B instead of v) cause UnicodeEncodeError when
|
||||
httpx tries to encode the Authorization header as ASCII.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_cli.config import _check_non_ascii_credential
|
||||
|
||||
|
||||
class TestCheckNonAsciiCredential:
|
||||
"""Tests for _check_non_ascii_credential()."""
|
||||
|
||||
def test_ascii_key_unchanged(self):
|
||||
key = "sk-proj-" + "a" * 100
|
||||
result = _check_non_ascii_credential("TEST_API_KEY", key)
|
||||
assert result == key
|
||||
|
||||
def test_strips_unicode_v_lookalike(self, capsys):
|
||||
"""The exact scenario from issue #6843: ʋ instead of v."""
|
||||
key = "sk-proj-abc" + "ʋ" + "def" # \u028b
|
||||
result = _check_non_ascii_credential("OPENROUTER_API_KEY", key)
|
||||
assert result == "sk-proj-abcdef"
|
||||
assert "ʋ" not in result
|
||||
# Should print a warning
|
||||
captured = capsys.readouterr()
|
||||
assert "non-ASCII" in captured.err
|
||||
|
||||
def test_strips_multiple_non_ascii(self, capsys):
|
||||
key = "sk-proj-aʋbécd"
|
||||
result = _check_non_ascii_credential("OPENAI_API_KEY", key)
|
||||
assert result == "sk-proj-abcd"
|
||||
captured = capsys.readouterr()
|
||||
assert "U+028B" in captured.err # reports the char
|
||||
|
||||
def test_empty_key(self):
|
||||
result = _check_non_ascii_credential("TEST_KEY", "")
|
||||
assert result == ""
|
||||
|
||||
def test_all_ascii_no_warning(self, capsys):
|
||||
result = _check_non_ascii_credential("KEY", "all-ascii-value-123")
|
||||
assert result == "all-ascii-value-123"
|
||||
captured = capsys.readouterr()
|
||||
assert captured.err == ""
|
||||
|
||||
|
||||
class TestEnvLoaderSanitization:
|
||||
"""Tests for _sanitize_loaded_credentials in env_loader."""
|
||||
|
||||
def test_strips_non_ascii_from_api_key(self, monkeypatch):
|
||||
from hermes_cli.env_loader import _sanitize_loaded_credentials
|
||||
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "sk-proj-abcʋdef")
|
||||
_sanitize_loaded_credentials()
|
||||
assert os.environ["OPENROUTER_API_KEY"] == "sk-proj-abcdef"
|
||||
|
||||
def test_strips_non_ascii_from_token(self, monkeypatch):
|
||||
from hermes_cli.env_loader import _sanitize_loaded_credentials
|
||||
|
||||
monkeypatch.setenv("DISCORD_BOT_TOKEN", "tokénvalue")
|
||||
_sanitize_loaded_credentials()
|
||||
assert os.environ["DISCORD_BOT_TOKEN"] == "toknvalue"
|
||||
|
||||
def test_ignores_non_credential_vars(self, monkeypatch):
|
||||
from hermes_cli.env_loader import _sanitize_loaded_credentials
|
||||
|
||||
monkeypatch.setenv("MY_UNICODE_VAR", "héllo wörld")
|
||||
_sanitize_loaded_credentials()
|
||||
# Not a credential suffix — should be left alone
|
||||
assert os.environ["MY_UNICODE_VAR"] == "héllo wörld"
|
||||
|
||||
def test_ascii_credentials_untouched(self, monkeypatch):
|
||||
from hermes_cli.env_loader import _sanitize_loaded_credentials
|
||||
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-proj-allascii123")
|
||||
_sanitize_loaded_credentials()
|
||||
assert os.environ["OPENAI_API_KEY"] == "sk-proj-allascii123"
|
||||
@@ -0,0 +1,148 @@
|
||||
"""Tests for the defensive subparser routing workaround (bpo-9338).
|
||||
|
||||
The main() function in hermes_cli/main.py sets subparsers.required=True
|
||||
when argv contains a known subcommand name. This forces deterministic
|
||||
routing on Python versions where argparse fails to match subcommand tokens
|
||||
when the parent parser has nargs='?' optional arguments (--continue).
|
||||
|
||||
If the subcommand token is consumed as a flag value (e.g. `hermes -c model`
|
||||
to resume a session named 'model'), the required=True parse raises
|
||||
SystemExit and the code falls back to the default required=False behaviour.
|
||||
"""
|
||||
import argparse
|
||||
import io
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _build_parser():
|
||||
"""Build a minimal replica of the hermes top-level parser."""
|
||||
parser = argparse.ArgumentParser(prog="hermes")
|
||||
parser.add_argument("--version", "-V", action="store_true")
|
||||
parser.add_argument("--resume", "-r", metavar="SESSION", default=None)
|
||||
parser.add_argument(
|
||||
"--continue", "-c",
|
||||
dest="continue_last",
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=None,
|
||||
metavar="SESSION_NAME",
|
||||
)
|
||||
parser.add_argument("--worktree", "-w", action="store_true", default=False)
|
||||
parser.add_argument("--skills", "-s", action="append", default=None)
|
||||
parser.add_argument("--yolo", action="store_true", default=False)
|
||||
parser.add_argument("--pass-session-id", action="store_true", default=False)
|
||||
|
||||
subparsers = parser.add_subparsers(dest="command", help="Command to run")
|
||||
chat_p = subparsers.add_parser("chat")
|
||||
chat_p.add_argument("-q", "--query", default=None)
|
||||
subparsers.add_parser("model")
|
||||
subparsers.add_parser("gateway")
|
||||
subparsers.add_parser("setup")
|
||||
return parser, subparsers
|
||||
|
||||
|
||||
def _safe_parse(parser, subparsers, argv):
|
||||
"""Replica of the defensive parsing logic from main()."""
|
||||
known_cmds = set(subparsers.choices.keys()) if hasattr(subparsers, "choices") else set()
|
||||
has_cmd_token = any(t in known_cmds for t in argv if not t.startswith("-"))
|
||||
|
||||
if has_cmd_token:
|
||||
subparsers.required = True
|
||||
saved_stderr = sys.stderr
|
||||
try:
|
||||
sys.stderr = io.StringIO()
|
||||
args = parser.parse_args(argv)
|
||||
sys.stderr = saved_stderr
|
||||
return args
|
||||
except SystemExit:
|
||||
sys.stderr = saved_stderr
|
||||
subparsers.required = False
|
||||
return parser.parse_args(argv)
|
||||
else:
|
||||
subparsers.required = False
|
||||
return parser.parse_args(argv)
|
||||
|
||||
|
||||
class TestSubparserRoutingFallback:
|
||||
"""Verify the bpo-9338 defensive routing works for all key cases."""
|
||||
|
||||
def test_direct_subcommand(self):
|
||||
parser, sub = _build_parser()
|
||||
args = _safe_parse(parser, sub, ["model"])
|
||||
assert args.command == "model"
|
||||
|
||||
def test_subcommand_with_flags(self):
|
||||
parser, sub = _build_parser()
|
||||
args = _safe_parse(parser, sub, ["--yolo", "model"])
|
||||
assert args.command == "model"
|
||||
assert args.yolo is True
|
||||
|
||||
def test_bare_hermes_defaults_to_none(self):
|
||||
parser, sub = _build_parser()
|
||||
args = _safe_parse(parser, sub, [])
|
||||
assert args.command is None
|
||||
|
||||
def test_flags_only_defaults_to_none(self):
|
||||
parser, sub = _build_parser()
|
||||
args = _safe_parse(parser, sub, ["--yolo"])
|
||||
assert args.command is None
|
||||
assert args.yolo is True
|
||||
|
||||
def test_continue_flag_alone(self):
|
||||
parser, sub = _build_parser()
|
||||
args = _safe_parse(parser, sub, ["-c"])
|
||||
assert args.command is None
|
||||
assert args.continue_last is True
|
||||
|
||||
def test_continue_with_session_name(self):
|
||||
parser, sub = _build_parser()
|
||||
args = _safe_parse(parser, sub, ["-c", "myproject"])
|
||||
assert args.command is None
|
||||
assert args.continue_last == "myproject"
|
||||
|
||||
def test_continue_with_subcommand_name_as_session(self):
|
||||
"""Edge case: session named 'model' — should be treated as session name, not subcommand."""
|
||||
parser, sub = _build_parser()
|
||||
args = _safe_parse(parser, sub, ["-c", "model"])
|
||||
assert args.command is None
|
||||
assert args.continue_last == "model"
|
||||
|
||||
def test_continue_with_session_then_subcommand(self):
|
||||
parser, sub = _build_parser()
|
||||
args = _safe_parse(parser, sub, ["-c", "myproject", "model"])
|
||||
assert args.command == "model"
|
||||
assert args.continue_last == "myproject"
|
||||
|
||||
def test_chat_with_query(self):
|
||||
parser, sub = _build_parser()
|
||||
args = _safe_parse(parser, sub, ["chat", "-q", "hello"])
|
||||
assert args.command == "chat"
|
||||
assert args.query == "hello"
|
||||
|
||||
def test_resume_flag(self):
|
||||
parser, sub = _build_parser()
|
||||
args = _safe_parse(parser, sub, ["-r", "abc123"])
|
||||
assert args.command is None
|
||||
assert args.resume == "abc123"
|
||||
|
||||
def test_resume_with_subcommand(self):
|
||||
parser, sub = _build_parser()
|
||||
args = _safe_parse(parser, sub, ["-r", "abc123", "chat"])
|
||||
assert args.command == "chat"
|
||||
assert args.resume == "abc123"
|
||||
|
||||
def test_skills_flag_with_subcommand(self):
|
||||
parser, sub = _build_parser()
|
||||
args = _safe_parse(parser, sub, ["-s", "myskill", "chat"])
|
||||
assert args.command == "chat"
|
||||
assert args.skills == ["myskill"]
|
||||
|
||||
def test_all_flags_with_subcommand(self):
|
||||
parser, sub = _build_parser()
|
||||
args = _safe_parse(parser, sub, ["--yolo", "-w", "-s", "myskill", "model"])
|
||||
assert args.command == "model"
|
||||
assert args.yolo is True
|
||||
assert args.worktree is True
|
||||
assert args.skills == ["myskill"]
|
||||
@@ -8,6 +8,7 @@ from hermes_cli.tools_config import (
|
||||
_platform_toolset_summary,
|
||||
_save_platform_tools,
|
||||
_toolset_has_keys,
|
||||
CONFIGURABLE_TOOLSETS,
|
||||
TOOL_CATEGORIES,
|
||||
_visible_providers,
|
||||
tools_command,
|
||||
@@ -22,6 +23,15 @@ def test_get_platform_tools_uses_default_when_platform_not_configured():
|
||||
assert enabled
|
||||
|
||||
|
||||
def test_configurable_toolsets_include_messaging():
|
||||
assert any(ts_key == "messaging" for ts_key, _, _ in CONFIGURABLE_TOOLSETS)
|
||||
|
||||
def test_get_platform_tools_default_telegram_includes_messaging():
|
||||
enabled = _get_platform_tools({}, "telegram")
|
||||
|
||||
assert "messaging" in enabled
|
||||
|
||||
|
||||
def test_get_platform_tools_preserves_explicit_empty_selection():
|
||||
config = {"platform_toolsets": {"cli": []}}
|
||||
|
||||
|
||||
@@ -136,33 +136,29 @@ class TestGatewaySkipsPersistenceOnFailure:
|
||||
the gateway should NOT persist messages to the transcript."""
|
||||
|
||||
def test_agent_failed_early_detected(self):
|
||||
"""The agent_failed_early flag is True when failed=True and
|
||||
no final_response."""
|
||||
"""The agent_failed_early flag is True when failed=True,
|
||||
regardless of final_response."""
|
||||
agent_result = {
|
||||
"failed": True,
|
||||
"final_response": None,
|
||||
"messages": [],
|
||||
"error": "Non-retryable client error",
|
||||
}
|
||||
agent_failed_early = (
|
||||
agent_result.get("failed")
|
||||
and not agent_result.get("final_response")
|
||||
)
|
||||
agent_failed_early = bool(agent_result.get("failed"))
|
||||
assert agent_failed_early
|
||||
|
||||
def test_agent_with_response_not_failed_early(self):
|
||||
"""When the agent has a final_response, it's not a failed-early
|
||||
scenario even if failed=True."""
|
||||
def test_agent_failed_with_error_response_still_detected(self):
|
||||
"""When _run_agent_blocking converts an error to final_response,
|
||||
the failed flag should still trigger agent_failed_early. This
|
||||
was the core bug in #9893 — the old guard checked
|
||||
``not final_response`` which was always truthy after conversion."""
|
||||
agent_result = {
|
||||
"failed": True,
|
||||
"final_response": "Here is a partial response",
|
||||
"final_response": "⚠️ Request payload too large: max compression attempts reached.",
|
||||
"messages": [],
|
||||
}
|
||||
agent_failed_early = (
|
||||
agent_result.get("failed")
|
||||
and not agent_result.get("final_response")
|
||||
)
|
||||
assert not agent_failed_early
|
||||
agent_failed_early = bool(agent_result.get("failed"))
|
||||
assert agent_failed_early
|
||||
|
||||
def test_successful_agent_not_failed_early(self):
|
||||
"""A successful agent result should not trigger skip."""
|
||||
@@ -170,13 +166,41 @@ class TestGatewaySkipsPersistenceOnFailure:
|
||||
"final_response": "Hello!",
|
||||
"messages": [{"role": "assistant", "content": "Hello!"}],
|
||||
}
|
||||
agent_failed_early = (
|
||||
agent_result.get("failed")
|
||||
and not agent_result.get("final_response")
|
||||
)
|
||||
agent_failed_early = bool(agent_result.get("failed"))
|
||||
assert not agent_failed_early
|
||||
|
||||
|
||||
class TestCompressionExhaustedFlag:
|
||||
"""When compression is exhausted, the agent should set both
|
||||
failed=True and compression_exhausted=True so the gateway can
|
||||
auto-reset the session. (#9893)"""
|
||||
|
||||
def test_compression_exhausted_returns_carry_flag(self):
|
||||
"""Simulate the return dict from a compression-exhausted agent."""
|
||||
agent_result = {
|
||||
"messages": [],
|
||||
"completed": False,
|
||||
"api_calls": 3,
|
||||
"error": "Request payload too large: max compression attempts (3) reached.",
|
||||
"partial": True,
|
||||
"failed": True,
|
||||
"compression_exhausted": True,
|
||||
}
|
||||
assert agent_result.get("failed")
|
||||
assert agent_result.get("compression_exhausted")
|
||||
|
||||
def test_normal_failure_not_compression_exhausted(self):
|
||||
"""Non-compression failures should not have compression_exhausted."""
|
||||
agent_result = {
|
||||
"messages": [],
|
||||
"completed": False,
|
||||
"failed": True,
|
||||
"error": "Invalid API response after 3 retries",
|
||||
}
|
||||
assert agent_result.get("failed")
|
||||
assert not agent_result.get("compression_exhausted")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 3: Context-overflow error messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -103,7 +103,7 @@ class TestCleanupStaleAsyncClients:
|
||||
mock_client._client = MagicMock()
|
||||
mock_client._client.is_closed = False
|
||||
|
||||
key = ("test_stale", True, "", "", id(loop))
|
||||
key = ("test_stale", True, "", "", "", ())
|
||||
with _client_cache_lock:
|
||||
_client_cache[key] = (mock_client, "test-model", loop)
|
||||
|
||||
@@ -127,7 +127,7 @@ class TestCleanupStaleAsyncClients:
|
||||
loop = asyncio.new_event_loop() # NOT closed
|
||||
|
||||
mock_client = MagicMock()
|
||||
key = ("test_live", True, "", "", id(loop))
|
||||
key = ("test_live", True, "", "", "", ())
|
||||
with _client_cache_lock:
|
||||
_client_cache[key] = (mock_client, "test-model", loop)
|
||||
|
||||
@@ -149,7 +149,7 @@ class TestCleanupStaleAsyncClients:
|
||||
)
|
||||
|
||||
mock_client = MagicMock()
|
||||
key = ("test_sync", False, "", "", 0)
|
||||
key = ("test_sync", False, "", "", "", ())
|
||||
with _client_cache_lock:
|
||||
_client_cache[key] = (mock_client, "test-model", None)
|
||||
|
||||
@@ -160,3 +160,131 @@ class TestCleanupStaleAsyncClients:
|
||||
finally:
|
||||
with _client_cache_lock:
|
||||
_client_cache.pop(key, None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cache bounded growth (#10200)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestClientCacheBoundedGrowth:
|
||||
"""Verify the cache stays bounded when loops change (fix for #10200).
|
||||
|
||||
Previously, loop_id was part of the cache key, so every new event loop
|
||||
created a new entry for the same provider config. Now loop identity is
|
||||
validated at hit time and stale entries are replaced in-place.
|
||||
"""
|
||||
|
||||
def test_same_key_replaces_stale_loop_entry(self):
|
||||
"""When the loop changes, the old entry should be replaced, not duplicated."""
|
||||
from agent.auxiliary_client import (
|
||||
_client_cache,
|
||||
_client_cache_lock,
|
||||
_get_cached_client,
|
||||
)
|
||||
|
||||
key = ("test_replace", True, "", "", "", ())
|
||||
|
||||
# Simulate a stale entry from a closed loop
|
||||
old_loop = asyncio.new_event_loop()
|
||||
old_loop.close()
|
||||
old_client = MagicMock()
|
||||
old_client._client = MagicMock()
|
||||
old_client._client.is_closed = False
|
||||
|
||||
with _client_cache_lock:
|
||||
_client_cache[key] = (old_client, "old-model", old_loop)
|
||||
|
||||
try:
|
||||
# Now call _get_cached_client — should detect stale loop and evict
|
||||
with patch("agent.auxiliary_client.resolve_provider_client") as mock_resolve:
|
||||
mock_resolve.return_value = (MagicMock(), "new-model")
|
||||
client, model = _get_cached_client(
|
||||
"test_replace", async_mode=True,
|
||||
)
|
||||
# The old entry should have been replaced
|
||||
with _client_cache_lock:
|
||||
assert key in _client_cache, "Key should still exist (replaced)"
|
||||
entry = _client_cache[key]
|
||||
assert entry[1] == "new-model", "Should have the new model"
|
||||
finally:
|
||||
with _client_cache_lock:
|
||||
_client_cache.pop(key, None)
|
||||
|
||||
def test_different_loops_do_not_grow_cache(self):
|
||||
"""Multiple event loops for the same provider should NOT create multiple entries."""
|
||||
from agent.auxiliary_client import (
|
||||
_client_cache,
|
||||
_client_cache_lock,
|
||||
)
|
||||
|
||||
key = ("test_no_grow", True, "", "", "", ())
|
||||
|
||||
loops = []
|
||||
try:
|
||||
for i in range(5):
|
||||
loop = asyncio.new_event_loop()
|
||||
loops.append(loop)
|
||||
mock_client = MagicMock()
|
||||
mock_client._client = MagicMock()
|
||||
mock_client._client.is_closed = False
|
||||
|
||||
# Close previous loop entries (simulating worker thread recycling)
|
||||
if i > 0:
|
||||
loops[i - 1].close()
|
||||
|
||||
with _client_cache_lock:
|
||||
# Simulate what _get_cached_client does: replace on loop mismatch
|
||||
if key in _client_cache:
|
||||
old_entry = _client_cache[key]
|
||||
del _client_cache[key]
|
||||
_client_cache[key] = (mock_client, f"model-{i}", loop)
|
||||
|
||||
# Only one entry should exist for this key
|
||||
with _client_cache_lock:
|
||||
count = sum(1 for k in _client_cache if k == key)
|
||||
assert count == 1, f"Expected 1 entry, got {count}"
|
||||
finally:
|
||||
for loop in loops:
|
||||
if not loop.is_closed():
|
||||
loop.close()
|
||||
with _client_cache_lock:
|
||||
_client_cache.pop(key, None)
|
||||
|
||||
def test_max_cache_size_eviction(self):
|
||||
"""Cache should not exceed _CLIENT_CACHE_MAX_SIZE."""
|
||||
from agent.auxiliary_client import (
|
||||
_client_cache,
|
||||
_client_cache_lock,
|
||||
_CLIENT_CACHE_MAX_SIZE,
|
||||
)
|
||||
|
||||
# Save existing cache state
|
||||
with _client_cache_lock:
|
||||
saved = dict(_client_cache)
|
||||
_client_cache.clear()
|
||||
|
||||
try:
|
||||
# Fill to max + 5
|
||||
for i in range(_CLIENT_CACHE_MAX_SIZE + 5):
|
||||
mock_client = MagicMock()
|
||||
mock_client._client = MagicMock()
|
||||
mock_client._client.is_closed = False
|
||||
key = (f"evict_test_{i}", False, "", "", "", ())
|
||||
with _client_cache_lock:
|
||||
# Inline the eviction logic (same as _get_cached_client)
|
||||
while len(_client_cache) >= _CLIENT_CACHE_MAX_SIZE:
|
||||
evict_key = next(iter(_client_cache))
|
||||
del _client_cache[evict_key]
|
||||
_client_cache[key] = (mock_client, f"model-{i}", None)
|
||||
|
||||
with _client_cache_lock:
|
||||
assert len(_client_cache) <= _CLIENT_CACHE_MAX_SIZE, \
|
||||
f"Cache size {len(_client_cache)} exceeds max {_CLIENT_CACHE_MAX_SIZE}"
|
||||
# The earliest entries should have been evicted
|
||||
assert ("evict_test_0", False, "", "", "", ()) not in _client_cache
|
||||
# The latest entries should be present
|
||||
assert (f"evict_test_{_CLIENT_CACHE_MAX_SIZE + 4}", False, "", "", "", ()) in _client_cache
|
||||
finally:
|
||||
with _client_cache_lock:
|
||||
_client_cache.clear()
|
||||
_client_cache.update(saved)
|
||||
|
||||
@@ -28,7 +28,8 @@ class TestInterruptPropagationToChild(unittest.TestCase):
|
||||
agent = AIAgent.__new__(AIAgent)
|
||||
agent._interrupt_requested = False
|
||||
agent._interrupt_message = None
|
||||
agent._execution_thread_id = None # defaults to current thread in set_interrupt
|
||||
agent._execution_thread_id = None
|
||||
agent._interrupt_thread_signal_pending = False
|
||||
agent._active_children = []
|
||||
agent._active_children_lock = threading.Lock()
|
||||
agent.quiet_mode = True
|
||||
@@ -46,15 +47,17 @@ class TestInterruptPropagationToChild(unittest.TestCase):
|
||||
assert parent._interrupt_requested is True
|
||||
assert child._interrupt_requested is True
|
||||
assert child._interrupt_message == "new user message"
|
||||
assert is_interrupted() is True
|
||||
assert is_interrupted() is False
|
||||
assert parent._interrupt_thread_signal_pending is True
|
||||
|
||||
def test_child_clear_interrupt_at_start_clears_thread(self):
|
||||
"""child.clear_interrupt() at start of run_conversation clears the
|
||||
per-thread interrupt flag for the current thread.
|
||||
bound execution thread's interrupt flag.
|
||||
"""
|
||||
child = self._make_bare_agent()
|
||||
child._interrupt_requested = True
|
||||
child._interrupt_message = "msg"
|
||||
child._execution_thread_id = threading.current_thread().ident
|
||||
|
||||
# Interrupt for current thread is set
|
||||
set_interrupt(True)
|
||||
@@ -128,6 +131,36 @@ class TestInterruptPropagationToChild(unittest.TestCase):
|
||||
child_thread.join(timeout=1)
|
||||
set_interrupt(False)
|
||||
|
||||
def test_prestart_interrupt_binds_to_execution_thread(self):
|
||||
"""An interrupt that arrives before startup should bind to the agent thread."""
|
||||
agent = self._make_bare_agent()
|
||||
barrier = threading.Barrier(2)
|
||||
result = {}
|
||||
|
||||
agent.interrupt("stop before start")
|
||||
assert agent._interrupt_requested is True
|
||||
assert agent._interrupt_thread_signal_pending is True
|
||||
assert is_interrupted() is False
|
||||
|
||||
def run_thread():
|
||||
from tools.interrupt import set_interrupt as _set_interrupt_for_test
|
||||
|
||||
agent._execution_thread_id = threading.current_thread().ident
|
||||
_set_interrupt_for_test(False, agent._execution_thread_id)
|
||||
if agent._interrupt_requested:
|
||||
_set_interrupt_for_test(True, agent._execution_thread_id)
|
||||
agent._interrupt_thread_signal_pending = False
|
||||
barrier.wait(timeout=5)
|
||||
result["thread_interrupted"] = is_interrupted()
|
||||
|
||||
t = threading.Thread(target=run_thread)
|
||||
t.start()
|
||||
barrier.wait(timeout=5)
|
||||
t.join(timeout=2)
|
||||
|
||||
assert result["thread_interrupted"] is True
|
||||
assert agent._interrupt_thread_signal_pending is False
|
||||
|
||||
|
||||
class TestPerThreadInterruptIsolation(unittest.TestCase):
|
||||
"""Verify that interrupting one agent does NOT affect another agent's thread.
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
"""Tests that invalid context_length values in config produce visible warnings."""
|
||||
|
||||
from unittest.mock import patch, MagicMock, call
|
||||
|
||||
|
||||
def _build_agent(model_cfg, custom_providers=None, model="anthropic/claude-opus-4.6"):
|
||||
"""Build an AIAgent with the given model config."""
|
||||
cfg = {"model": model_cfg}
|
||||
if custom_providers is not None:
|
||||
cfg["custom_providers"] = custom_providers
|
||||
|
||||
with (
|
||||
patch("hermes_cli.config.load_config", return_value=cfg),
|
||||
patch("agent.model_metadata.get_model_context_length", return_value=128_000),
|
||||
patch("run_agent.get_tool_definitions", return_value=[]),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
from run_agent import AIAgent
|
||||
|
||||
agent = AIAgent(
|
||||
model=model,
|
||||
api_key="test-key-1234567890",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
return agent
|
||||
|
||||
|
||||
def test_valid_integer_context_length_no_warning():
|
||||
"""Plain integer context_length should work silently."""
|
||||
with patch("run_agent.logger") as mock_logger:
|
||||
agent = _build_agent({"default": "gpt5.4", "provider": "custom",
|
||||
"base_url": "http://localhost:4000/v1",
|
||||
"context_length": 256000})
|
||||
assert agent._config_context_length == 256000
|
||||
# No warning about invalid context_length
|
||||
for c in mock_logger.warning.call_args_list:
|
||||
assert "Invalid" not in str(c)
|
||||
|
||||
|
||||
def test_string_k_suffix_context_length_warns():
|
||||
"""context_length: '256K' should warn the user clearly."""
|
||||
with patch("run_agent.logger") as mock_logger:
|
||||
agent = _build_agent({"default": "gpt5.4", "provider": "custom",
|
||||
"base_url": "http://localhost:4000/v1",
|
||||
"context_length": "256K"})
|
||||
assert agent._config_context_length is None
|
||||
# Should have warned
|
||||
warning_calls = [c for c in mock_logger.warning.call_args_list
|
||||
if "Invalid" in str(c) and "256K" in str(c)]
|
||||
assert len(warning_calls) == 1
|
||||
assert "plain integer" in str(warning_calls[0])
|
||||
|
||||
|
||||
def test_string_numeric_context_length_works():
|
||||
"""context_length: '256000' (string) should parse fine via int()."""
|
||||
with patch("run_agent.logger") as mock_logger:
|
||||
agent = _build_agent({"default": "gpt5.4", "provider": "custom",
|
||||
"base_url": "http://localhost:4000/v1",
|
||||
"context_length": "256000"})
|
||||
assert agent._config_context_length == 256000
|
||||
for c in mock_logger.warning.call_args_list:
|
||||
assert "Invalid" not in str(c)
|
||||
|
||||
|
||||
def test_custom_providers_invalid_context_length_warns():
|
||||
"""Invalid context_length in custom_providers should warn."""
|
||||
custom_providers = [
|
||||
{
|
||||
"name": "LiteLLM",
|
||||
"base_url": "http://localhost:4000/v1",
|
||||
"models": {
|
||||
"gpt5.4": {"context_length": "256K"}
|
||||
},
|
||||
}
|
||||
]
|
||||
with patch("run_agent.logger") as mock_logger:
|
||||
agent = _build_agent(
|
||||
{"default": "gpt5.4", "provider": "custom",
|
||||
"base_url": "http://localhost:4000/v1"},
|
||||
custom_providers=custom_providers,
|
||||
model="gpt5.4",
|
||||
)
|
||||
warning_calls = [c for c in mock_logger.warning.call_args_list
|
||||
if "Invalid" in str(c) and "256K" in str(c)]
|
||||
assert len(warning_calls) == 1
|
||||
assert "custom_providers" in str(warning_calls[0])
|
||||
|
||||
|
||||
def test_custom_providers_valid_context_length():
|
||||
"""Valid integer in custom_providers should work silently."""
|
||||
custom_providers = [
|
||||
{
|
||||
"name": "LiteLLM",
|
||||
"base_url": "http://localhost:4000/v1",
|
||||
"models": {
|
||||
"gpt5.4": {"context_length": 256000}
|
||||
},
|
||||
}
|
||||
]
|
||||
with patch("run_agent.logger") as mock_logger:
|
||||
agent = _build_agent(
|
||||
{"default": "gpt5.4", "provider": "custom",
|
||||
"base_url": "http://localhost:4000/v1"},
|
||||
custom_providers=custom_providers,
|
||||
model="gpt5.4",
|
||||
)
|
||||
for c in mock_logger.warning.call_args_list:
|
||||
assert "Invalid" not in str(c)
|
||||
@@ -1249,13 +1249,17 @@ def test_chat_messages_to_responses_input_deduplicates_reasoning_ids(monkeypatch
|
||||
]
|
||||
items = agent._chat_messages_to_responses_input(messages)
|
||||
|
||||
reasoning_ids = [it["id"] for it in items if it.get("type") == "reasoning"]
|
||||
# rs_aaa should appear only once (first occurrence kept)
|
||||
assert reasoning_ids.count("rs_aaa") == 1
|
||||
# rs_bbb and rs_ccc should each appear once
|
||||
assert reasoning_ids.count("rs_bbb") == 1
|
||||
assert reasoning_ids.count("rs_ccc") == 1
|
||||
assert len(reasoning_ids) == 3
|
||||
reasoning_items = [it for it in items if it.get("type") == "reasoning"]
|
||||
# Dedup: rs_aaa appears in both turns but should only be emitted once.
|
||||
# 3 unique items total: enc_1 (from rs_aaa), enc_2 (rs_bbb), enc_3 (rs_ccc).
|
||||
assert len(reasoning_items) == 3
|
||||
encrypted = [it["encrypted_content"] for it in reasoning_items]
|
||||
assert encrypted.count("enc_1") == 1
|
||||
assert "enc_2" in encrypted
|
||||
assert "enc_3" in encrypted
|
||||
# IDs must be stripped — with store=False the API 404s on id lookups.
|
||||
for it in reasoning_items:
|
||||
assert "id" not in it
|
||||
|
||||
|
||||
def test_preflight_codex_input_deduplicates_reasoning_ids(monkeypatch):
|
||||
@@ -1272,7 +1276,11 @@ def test_preflight_codex_input_deduplicates_reasoning_ids(monkeypatch):
|
||||
normalized = agent._preflight_codex_input_items(raw_input)
|
||||
|
||||
reasoning_items = [it for it in normalized if it.get("type") == "reasoning"]
|
||||
reasoning_ids = [it["id"] for it in reasoning_items]
|
||||
assert reasoning_ids.count("rs_xyz") == 1
|
||||
assert reasoning_ids.count("rs_zzz") == 1
|
||||
# rs_xyz duplicate should be collapsed to one item; rs_zzz kept.
|
||||
assert len(reasoning_items) == 2
|
||||
encrypted = [it["encrypted_content"] for it in reasoning_items]
|
||||
assert encrypted.count("enc_a") == 1
|
||||
assert "enc_b" in encrypted
|
||||
# IDs must be stripped — with store=False the API 404s on id lookups.
|
||||
for it in reasoning_items:
|
||||
assert "id" not in it
|
||||
|
||||
@@ -142,6 +142,33 @@ class TestSurrogateVsAsciiSanitization:
|
||||
assert _sanitize_messages_surrogates(messages) is False
|
||||
|
||||
|
||||
class TestApiKeyNonAsciiSanitization:
|
||||
"""Tests for API key sanitization in the UnicodeEncodeError recovery.
|
||||
|
||||
Covers the root cause of issue #6843: a non-ASCII character (ʋ U+028B)
|
||||
in the API key causes httpx to fail when encoding the Authorization
|
||||
header as ASCII. The recovery block must strip non-ASCII from the key.
|
||||
"""
|
||||
|
||||
def test_strip_non_ascii_from_api_key(self):
|
||||
"""_strip_non_ascii removes ʋ from an API key string."""
|
||||
key = "sk-proj-abc" + "ʋ" + "def"
|
||||
assert _strip_non_ascii(key) == "sk-proj-abcdef"
|
||||
|
||||
def test_api_key_at_position_153(self):
|
||||
"""Reproduce the exact error: ʋ at position 153 in 'Bearer <key>'."""
|
||||
key = "sk-proj-" + "a" * 138 + "ʋ" + "bcd"
|
||||
auth_value = f"Bearer {key}"
|
||||
# This is what httpx does — and it fails:
|
||||
with pytest.raises(UnicodeEncodeError) as exc_info:
|
||||
auth_value.encode("ascii")
|
||||
assert exc_info.value.start == 153
|
||||
# After sanitization, it should work:
|
||||
sanitized_key = _strip_non_ascii(key)
|
||||
sanitized_auth = f"Bearer {sanitized_key}"
|
||||
sanitized_auth.encode("ascii") # should not raise
|
||||
|
||||
|
||||
class TestSanitizeToolsNonAscii:
|
||||
"""Tests for _sanitize_tools_non_ascii."""
|
||||
|
||||
@@ -203,3 +230,67 @@ class TestSanitizeStructureNonAscii:
|
||||
assert _sanitize_structure_non_ascii(payload) is True
|
||||
assert payload["default_headers"]["X-Title"] == "Hermes Agent"
|
||||
assert payload["default_headers"]["User-Agent"] == "Hermes/1.0 "
|
||||
|
||||
|
||||
class TestApiKeyClientSync:
|
||||
"""Verify that ASCII recovery updates the live OpenAI client's api_key.
|
||||
|
||||
The OpenAI SDK stores its own copy of api_key which auth_headers reads
|
||||
dynamically. If only self.api_key is updated but self.client.api_key
|
||||
is not, the next request still sends the corrupted key in the
|
||||
Authorization header.
|
||||
"""
|
||||
|
||||
def test_client_api_key_updated_on_sanitize(self):
|
||||
"""Simulate the recovery path and verify client.api_key is synced."""
|
||||
from unittest.mock import MagicMock
|
||||
from run_agent import AIAgent
|
||||
|
||||
agent = AIAgent.__new__(AIAgent)
|
||||
bad_key = "sk-proj-abc\u028bdef" # ʋ lookalike at position 11
|
||||
agent.api_key = bad_key
|
||||
agent._client_kwargs = {"api_key": bad_key}
|
||||
agent.quiet_mode = True
|
||||
|
||||
# Mock client with its own api_key attribute (like the real OpenAI client)
|
||||
mock_client = MagicMock()
|
||||
mock_client.api_key = bad_key
|
||||
agent.client = mock_client
|
||||
|
||||
# --- replicate the recovery logic from run_agent.py ---
|
||||
_raw_key = agent.api_key
|
||||
_clean_key = _strip_non_ascii(_raw_key)
|
||||
assert _clean_key != _raw_key, "test precondition: key should have non-ASCII"
|
||||
|
||||
agent.api_key = _clean_key
|
||||
agent._client_kwargs["api_key"] = _clean_key
|
||||
if getattr(agent, "client", None) is not None and hasattr(agent.client, "api_key"):
|
||||
agent.client.api_key = _clean_key
|
||||
|
||||
# All three locations should now hold the clean key
|
||||
assert agent.api_key == "sk-proj-abcdef"
|
||||
assert agent._client_kwargs["api_key"] == "sk-proj-abcdef"
|
||||
assert agent.client.api_key == "sk-proj-abcdef"
|
||||
# The bad char should be gone from all of them
|
||||
assert "\u028b" not in agent.api_key
|
||||
assert "\u028b" not in agent._client_kwargs["api_key"]
|
||||
assert "\u028b" not in agent.client.api_key
|
||||
|
||||
def test_client_none_does_not_crash(self):
|
||||
"""Recovery should not crash when client is None (pre-init)."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
agent = AIAgent.__new__(AIAgent)
|
||||
bad_key = "sk-proj-\u028b"
|
||||
agent.api_key = bad_key
|
||||
agent._client_kwargs = {"api_key": bad_key}
|
||||
agent.client = None
|
||||
|
||||
_clean_key = _strip_non_ascii(bad_key)
|
||||
agent.api_key = _clean_key
|
||||
agent._client_kwargs["api_key"] = _clean_key
|
||||
if getattr(agent, "client", None) is not None and hasattr(agent.client, "api_key"):
|
||||
agent.client.api_key = _clean_key
|
||||
|
||||
assert agent.api_key == "sk-proj-"
|
||||
assert agent.client is None # should not have been touched
|
||||
|
||||
@@ -116,6 +116,22 @@ class TestValidateToolset:
|
||||
def test_invalid(self):
|
||||
assert validate_toolset("nonexistent") is False
|
||||
|
||||
def test_mcp_alias_uses_live_registry(self, monkeypatch):
|
||||
reg = ToolRegistry()
|
||||
reg.register(
|
||||
name="mcp_dynserver_ping",
|
||||
toolset="mcp-dynserver",
|
||||
schema=_make_schema("mcp_dynserver_ping", "Ping"),
|
||||
handler=_dummy_handler,
|
||||
)
|
||||
reg.register_toolset_alias("dynserver", "mcp-dynserver")
|
||||
|
||||
monkeypatch.setattr("tools.registry.registry", reg)
|
||||
|
||||
assert validate_toolset("dynserver") is True
|
||||
assert validate_toolset("mcp-dynserver") is True
|
||||
assert "mcp_dynserver_ping" in resolve_toolset("dynserver")
|
||||
|
||||
|
||||
class TestGetToolsetInfo:
|
||||
def test_leaf(self):
|
||||
@@ -150,6 +166,23 @@ class TestCreateCustomToolset:
|
||||
del TOOLSETS["_test_custom"]
|
||||
|
||||
|
||||
class TestRegistryOwnedToolsets:
|
||||
def test_registry_membership_is_live(self, monkeypatch):
|
||||
reg = ToolRegistry()
|
||||
reg.register(
|
||||
name="test_live_toolset_tool",
|
||||
toolset="test-live-toolset",
|
||||
schema=_make_schema("test_live_toolset_tool", "Live"),
|
||||
handler=_dummy_handler,
|
||||
)
|
||||
|
||||
monkeypatch.setattr("tools.registry.registry", reg)
|
||||
|
||||
assert validate_toolset("test-live-toolset") is True
|
||||
assert get_toolset("test-live-toolset")["tools"] == ["test_live_toolset_tool"]
|
||||
assert resolve_toolset("test-live-toolset") == ["test_live_toolset_tool"]
|
||||
|
||||
|
||||
class TestToolsetConsistency:
|
||||
"""Verify structural integrity of the built-in TOOLSETS dict."""
|
||||
|
||||
|
||||
@@ -31,18 +31,25 @@ def _clear_browser_caches():
|
||||
|
||||
|
||||
class TestSanePath:
|
||||
"""Verify _SANE_PATH includes Homebrew directories."""
|
||||
"""Verify _SANE_PATH includes fallback directories used by browser_tool."""
|
||||
|
||||
def test_includes_termux_bin(self):
|
||||
assert "/data/data/com.termux/files/usr/bin" in _SANE_PATH.split(os.pathsep)
|
||||
|
||||
def test_includes_termux_sbin(self):
|
||||
assert "/data/data/com.termux/files/usr/sbin" in _SANE_PATH.split(os.pathsep)
|
||||
|
||||
def test_includes_homebrew_bin(self):
|
||||
assert "/opt/homebrew/bin" in _SANE_PATH
|
||||
assert "/opt/homebrew/bin" in _SANE_PATH.split(os.pathsep)
|
||||
|
||||
def test_includes_homebrew_sbin(self):
|
||||
assert "/opt/homebrew/sbin" in _SANE_PATH
|
||||
assert "/opt/homebrew/sbin" in _SANE_PATH.split(os.pathsep)
|
||||
|
||||
def test_includes_standard_dirs(self):
|
||||
assert "/usr/local/bin" in _SANE_PATH
|
||||
assert "/usr/bin" in _SANE_PATH
|
||||
assert "/bin" in _SANE_PATH
|
||||
path_parts = _SANE_PATH.split(os.pathsep)
|
||||
assert "/usr/local/bin" in path_parts
|
||||
assert "/usr/bin" in path_parts
|
||||
assert "/bin" in path_parts
|
||||
|
||||
|
||||
class TestDiscoverHomebrewNodeDirs:
|
||||
@@ -143,6 +150,44 @@ class TestFindAgentBrowser:
|
||||
result = _find_agent_browser()
|
||||
assert result == "npx agent-browser"
|
||||
|
||||
def test_finds_npx_in_termux_fallback_path(self):
|
||||
"""Should find npx when only Termux fallback dirs are available."""
|
||||
def mock_which(cmd, path=None):
|
||||
if cmd == "agent-browser":
|
||||
return None
|
||||
if cmd == "npx":
|
||||
if path and "/data/data/com.termux/files/usr/bin" in path:
|
||||
return "/data/data/com.termux/files/usr/bin/npx"
|
||||
return None
|
||||
return None
|
||||
|
||||
original_path_exists = Path.exists
|
||||
|
||||
def mock_path_exists(self):
|
||||
if "node_modules" in str(self) and "agent-browser" in str(self):
|
||||
return False
|
||||
return original_path_exists(self)
|
||||
|
||||
real_isdir = os.path.isdir
|
||||
|
||||
def selective_isdir(path):
|
||||
if path in (
|
||||
"/data/data/com.termux/files/usr/bin",
|
||||
"/data/data/com.termux/files/usr/sbin",
|
||||
):
|
||||
return True
|
||||
return real_isdir(path)
|
||||
|
||||
with patch("shutil.which", side_effect=mock_which), \
|
||||
patch("os.path.isdir", side_effect=selective_isdir), \
|
||||
patch.object(Path, "exists", mock_path_exists), \
|
||||
patch(
|
||||
"tools.browser_tool._discover_homebrew_node_dirs",
|
||||
return_value=[],
|
||||
):
|
||||
result = _find_agent_browser()
|
||||
assert result == "npx agent-browser"
|
||||
|
||||
def test_raises_when_not_found(self):
|
||||
"""Should raise FileNotFoundError when nothing works."""
|
||||
original_path_exists = Path.exists
|
||||
@@ -399,3 +444,51 @@ class TestRunBrowserCommandPathConstruction:
|
||||
result_path = captured_env.get("PATH", "")
|
||||
assert "/opt/homebrew/bin" in result_path
|
||||
assert "/opt/homebrew/sbin" in result_path
|
||||
|
||||
def test_subprocess_path_includes_termux_fallback_dirs(self, tmp_path):
|
||||
"""Termux fallback dirs should survive browser PATH rebuilding."""
|
||||
captured_env = {}
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.returncode = 0
|
||||
mock_proc.wait.return_value = 0
|
||||
|
||||
def capture_popen(cmd, **kwargs):
|
||||
captured_env.update(kwargs.get("env", {}))
|
||||
return mock_proc
|
||||
|
||||
fake_session = {
|
||||
"session_name": "test-session",
|
||||
"session_id": "test-id",
|
||||
"cdp_url": None,
|
||||
}
|
||||
|
||||
fake_json = json.dumps({"success": True})
|
||||
real_isdir = os.path.isdir
|
||||
|
||||
def selective_isdir(path):
|
||||
if path in (
|
||||
"/data/data/com.termux/files/usr/bin",
|
||||
"/data/data/com.termux/files/usr/sbin",
|
||||
):
|
||||
return True
|
||||
if path.startswith(str(tmp_path)):
|
||||
return True
|
||||
return real_isdir(path)
|
||||
|
||||
with patch("tools.browser_tool._find_agent_browser", return_value="/usr/local/bin/agent-browser"), \
|
||||
patch("tools.browser_tool._get_session_info", return_value=fake_session), \
|
||||
patch("tools.browser_tool._socket_safe_tmpdir", return_value=str(tmp_path)), \
|
||||
patch("tools.browser_tool._discover_homebrew_node_dirs", return_value=[]), \
|
||||
patch("os.path.isdir", side_effect=selective_isdir), \
|
||||
patch("subprocess.Popen", side_effect=capture_popen), \
|
||||
patch("os.open", return_value=99), \
|
||||
patch("os.close"), \
|
||||
patch("tools.interrupt.is_interrupted", return_value=False), \
|
||||
patch.dict(os.environ, {"PATH": "/usr/bin:/bin", "HOME": "/home/test"}, clear=True):
|
||||
with patch("builtins.open", mock_open(read_data=fake_json)):
|
||||
_run_browser_command("test-task", "navigate", ["https://example.com"])
|
||||
|
||||
result_path = captured_env.get("PATH", "")
|
||||
assert "/data/data/com.termux/files/usr/bin" in result_path
|
||||
assert "/data/data/com.termux/files/usr/sbin" in result_path
|
||||
|
||||
@@ -46,3 +46,59 @@ class TestFindDocker:
|
||||
with patch("tools.environments.docker.shutil.which", return_value=None):
|
||||
second = docker_mod.find_docker()
|
||||
assert first == second == "/usr/local/bin/docker"
|
||||
|
||||
def test_env_var_override_takes_precedence(self, tmp_path):
|
||||
"""HERMES_DOCKER_BINARY overrides PATH and known-location discovery."""
|
||||
fake_binary = tmp_path / "podman"
|
||||
fake_binary.write_text("#!/bin/sh\n")
|
||||
fake_binary.chmod(0o755)
|
||||
|
||||
with patch.dict(os.environ, {"HERMES_DOCKER_BINARY": str(fake_binary)}), \
|
||||
patch("tools.environments.docker.shutil.which", return_value="/usr/bin/docker"):
|
||||
result = docker_mod.find_docker()
|
||||
assert result == str(fake_binary)
|
||||
|
||||
def test_env_var_override_ignored_if_not_executable(self, tmp_path):
|
||||
"""Non-executable HERMES_DOCKER_BINARY falls through to normal discovery."""
|
||||
fake_binary = tmp_path / "podman"
|
||||
fake_binary.write_text("#!/bin/sh\n")
|
||||
fake_binary.chmod(0o644) # not executable
|
||||
|
||||
with patch.dict(os.environ, {"HERMES_DOCKER_BINARY": str(fake_binary)}), \
|
||||
patch("tools.environments.docker.shutil.which", return_value="/usr/bin/docker"):
|
||||
result = docker_mod.find_docker()
|
||||
assert result == "/usr/bin/docker"
|
||||
|
||||
def test_env_var_override_ignored_if_nonexistent(self):
|
||||
"""Non-existent HERMES_DOCKER_BINARY path falls through."""
|
||||
with patch.dict(os.environ, {"HERMES_DOCKER_BINARY": "/nonexistent/podman"}), \
|
||||
patch("tools.environments.docker.shutil.which", return_value="/usr/bin/docker"):
|
||||
result = docker_mod.find_docker()
|
||||
assert result == "/usr/bin/docker"
|
||||
|
||||
def test_podman_on_path_used_when_docker_missing(self):
|
||||
"""When docker is not on PATH, podman is tried next."""
|
||||
def which_side_effect(name):
|
||||
if name == "docker":
|
||||
return None
|
||||
if name == "podman":
|
||||
return "/usr/bin/podman"
|
||||
return None
|
||||
|
||||
with patch("tools.environments.docker.shutil.which", side_effect=which_side_effect), \
|
||||
patch("tools.environments.docker._DOCKER_SEARCH_PATHS", []):
|
||||
result = docker_mod.find_docker()
|
||||
assert result == "/usr/bin/podman"
|
||||
|
||||
def test_docker_preferred_over_podman(self):
|
||||
"""When both docker and podman are on PATH, docker wins."""
|
||||
def which_side_effect(name):
|
||||
if name == "docker":
|
||||
return "/usr/bin/docker"
|
||||
if name == "podman":
|
||||
return "/usr/bin/podman"
|
||||
return None
|
||||
|
||||
with patch("tools.environments.docker.shutil.which", side_effect=which_side_effect):
|
||||
result = docker_mod.find_docker()
|
||||
assert result == "/usr/bin/docker"
|
||||
|
||||
@@ -296,7 +296,7 @@ def test_managed_modal_execute_times_out_and_cancels(monkeypatch):
|
||||
modal_common = sys.modules["tools.environments.modal_utils"]
|
||||
|
||||
calls = []
|
||||
monotonic_values = iter([0.0, 12.5])
|
||||
monotonic_values = iter([0.0, 0.0, 0.0, 12.5, 12.5])
|
||||
|
||||
def fake_request(method, url, headers=None, json=None, timeout=None):
|
||||
calls.append((method, url, json, timeout))
|
||||
|
||||
@@ -21,34 +21,19 @@ class TestRegisterServerTools:
|
||||
def mock_registry(self):
|
||||
return ToolRegistry()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_toolsets(self):
|
||||
return {
|
||||
"hermes-cli": {"tools": ["terminal"], "description": "CLI", "includes": []},
|
||||
"hermes-telegram": {"tools": ["terminal"], "description": "TG", "includes": []},
|
||||
"custom-toolset": {"tools": [], "description": "Other", "includes": []},
|
||||
}
|
||||
|
||||
def test_injects_hermes_toolsets(self, mock_registry, mock_toolsets):
|
||||
"""Tools are injected into hermes-* toolsets but not custom ones."""
|
||||
def test_exposes_live_server_aliases(self, mock_registry):
|
||||
"""Registered MCP tools are reachable via live raw-server aliases."""
|
||||
server = MCPServerTask("my_srv")
|
||||
server._tools = [_make_mcp_tool("my_tool", "desc")]
|
||||
server.session = MagicMock()
|
||||
from toolsets import resolve_toolset, validate_toolset
|
||||
|
||||
with patch("tools.registry.registry", mock_registry), \
|
||||
patch("toolsets.create_custom_toolset"), \
|
||||
patch.dict("toolsets.TOOLSETS", mock_toolsets, clear=True):
|
||||
|
||||
with patch("tools.registry.registry", mock_registry):
|
||||
registered = _register_server_tools("my_srv", server, {})
|
||||
|
||||
assert "mcp_my_srv_my_tool" in registered
|
||||
assert "mcp_my_srv_my_tool" in mock_registry.get_all_tool_names()
|
||||
|
||||
# Injected into hermes-* toolsets
|
||||
assert "mcp_my_srv_my_tool" in mock_toolsets["hermes-cli"]["tools"]
|
||||
assert "mcp_my_srv_my_tool" in mock_toolsets["hermes-telegram"]["tools"]
|
||||
# NOT into non-hermes toolsets
|
||||
assert "mcp_my_srv_my_tool" not in mock_toolsets["custom-toolset"]["tools"]
|
||||
assert "mcp_my_srv_my_tool" in registered
|
||||
assert "mcp_my_srv_my_tool" in mock_registry.get_all_tool_names()
|
||||
assert validate_toolset("my_srv") is True
|
||||
assert "mcp_my_srv_my_tool" in resolve_toolset("my_srv")
|
||||
|
||||
|
||||
class TestRefreshTools:
|
||||
@@ -58,19 +43,13 @@ class TestRefreshTools:
|
||||
def mock_registry(self):
|
||||
return ToolRegistry()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_toolsets(self):
|
||||
return {
|
||||
"hermes-cli": {"tools": ["terminal"], "description": "CLI", "includes": []},
|
||||
"hermes-telegram": {"tools": ["terminal"], "description": "TG", "includes": []},
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nuke_and_repave(self, mock_registry, mock_toolsets):
|
||||
async def test_nuke_and_repave(self, mock_registry):
|
||||
"""Old tools are removed and new tools registered on refresh."""
|
||||
server = MCPServerTask("live_srv")
|
||||
server._refresh_lock = asyncio.Lock()
|
||||
server._config = {}
|
||||
from toolsets import resolve_toolset
|
||||
|
||||
# Seed initial state: one old tool registered
|
||||
mock_registry.register(
|
||||
@@ -79,7 +58,6 @@ class TestRefreshTools:
|
||||
description="", emoji="",
|
||||
)
|
||||
server._registered_tool_names = ["mcp_live_srv_old_tool"]
|
||||
mock_toolsets["hermes-cli"]["tools"].append("mcp_live_srv_old_tool")
|
||||
|
||||
# New tool list from server
|
||||
new_tool = _make_mcp_tool("new_tool", "new behavior")
|
||||
@@ -89,20 +67,13 @@ class TestRefreshTools:
|
||||
)
|
||||
)
|
||||
|
||||
with patch("tools.registry.registry", mock_registry), \
|
||||
patch("toolsets.create_custom_toolset"), \
|
||||
patch.dict("toolsets.TOOLSETS", mock_toolsets, clear=True):
|
||||
|
||||
with patch("tools.registry.registry", mock_registry):
|
||||
await server._refresh_tools()
|
||||
|
||||
# Old tool completely gone
|
||||
assert "mcp_live_srv_old_tool" not in mock_registry.get_all_tool_names()
|
||||
assert "mcp_live_srv_old_tool" not in mock_toolsets["hermes-cli"]["tools"]
|
||||
|
||||
# New tool registered
|
||||
assert "mcp_live_srv_new_tool" in mock_registry.get_all_tool_names()
|
||||
assert "mcp_live_srv_new_tool" in mock_toolsets["hermes-cli"]["tools"]
|
||||
assert server._registered_tool_names == ["mcp_live_srv_new_tool"]
|
||||
assert "mcp_live_srv_old_tool" not in mock_registry.get_all_tool_names()
|
||||
assert "mcp_live_srv_old_tool" not in resolve_toolset("live_srv")
|
||||
assert "mcp_live_srv_new_tool" in mock_registry.get_all_tool_names()
|
||||
assert "mcp_live_srv_new_tool" in resolve_toolset("live_srv")
|
||||
assert server._registered_tool_names == ["mcp_live_srv_new_tool"]
|
||||
|
||||
|
||||
class TestMessageHandler:
|
||||
@@ -165,6 +136,25 @@ class TestDeregister:
|
||||
# bar still in ts1, so check should remain
|
||||
assert "ts1" in reg._toolset_checks
|
||||
|
||||
def test_removes_toolset_alias_when_last_tool_is_removed(self):
|
||||
reg = ToolRegistry()
|
||||
reg.register(name="foo", toolset="mcp-srv", schema={}, handler=lambda x: x)
|
||||
reg.register_toolset_alias("srv", "mcp-srv")
|
||||
|
||||
reg.deregister("foo")
|
||||
|
||||
assert reg.get_toolset_alias_target("srv") is None
|
||||
|
||||
def test_preserves_toolset_alias_while_toolset_still_exists(self):
|
||||
reg = ToolRegistry()
|
||||
reg.register(name="foo", toolset="mcp-srv", schema={}, handler=lambda x: x)
|
||||
reg.register(name="bar", toolset="mcp-srv", schema={}, handler=lambda x: x)
|
||||
reg.register_toolset_alias("srv", "mcp-srv")
|
||||
|
||||
reg.deregister("foo")
|
||||
|
||||
assert reg.get_toolset_alias_target("srv") == "mcp-srv"
|
||||
|
||||
def test_noop_for_unknown_tool(self):
|
||||
reg = ToolRegistry()
|
||||
reg.deregister("nonexistent") # Should not raise
|
||||
|
||||
@@ -184,11 +184,7 @@ class TestToolHandler:
|
||||
def _patch_mcp_loop(self, coro_side_effect=None):
|
||||
"""Return a patch for _run_on_mcp_loop that runs the coroutine directly."""
|
||||
def fake_run(coro, timeout=30):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
return asyncio.run(coro)
|
||||
if coro_side_effect:
|
||||
return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=coro_side_effect)
|
||||
return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=fake_run)
|
||||
@@ -365,10 +361,13 @@ class TestDiscoverAndRegister:
|
||||
|
||||
_servers.pop("fs", None)
|
||||
|
||||
def test_toolset_created(self):
|
||||
"""A custom toolset is created for the MCP server."""
|
||||
def test_toolset_resolves_live_from_registry(self):
|
||||
"""MCP toolsets resolve through the live registry without TOOLSETS mutation."""
|
||||
from tools.registry import ToolRegistry
|
||||
from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask
|
||||
from toolsets import resolve_toolset, validate_toolset
|
||||
|
||||
mock_registry = ToolRegistry()
|
||||
mock_tools = [_make_mcp_tool("ping", "Ping")]
|
||||
mock_session = MagicMock()
|
||||
|
||||
@@ -378,16 +377,16 @@ class TestDiscoverAndRegister:
|
||||
server._tools = mock_tools
|
||||
return server
|
||||
|
||||
mock_create = MagicMock()
|
||||
with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
|
||||
patch("toolsets.create_custom_toolset", mock_create):
|
||||
patch("tools.registry.registry", mock_registry):
|
||||
asyncio.run(
|
||||
_discover_and_register_server("myserver", {"command": "test"})
|
||||
)
|
||||
|
||||
mock_create.assert_called_once()
|
||||
call_kwargs = mock_create.call_args
|
||||
assert call_kwargs[1]["name"] == "mcp-myserver" or call_kwargs[0][0] == "mcp-myserver"
|
||||
assert validate_toolset("myserver") is True
|
||||
assert validate_toolset("mcp-myserver") is True
|
||||
assert "mcp_myserver_ping" in resolve_toolset("myserver")
|
||||
assert "mcp_myserver_ping" in resolve_toolset("mcp-myserver")
|
||||
|
||||
_servers.pop("myserver", None)
|
||||
|
||||
@@ -550,12 +549,15 @@ class TestMCPServerTask:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestToolsetInjection:
|
||||
def test_mcp_tools_added_to_all_hermes_toolsets(self):
|
||||
"""Discovered MCP tools are dynamically injected into all hermes-* toolsets."""
|
||||
def test_mcp_tools_resolve_through_server_aliases(self):
|
||||
"""Discovered MCP tools resolve through raw server-name aliases."""
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
from tools.registry import ToolRegistry
|
||||
from toolsets import resolve_toolset, validate_toolset
|
||||
|
||||
mock_tools = [_make_mcp_tool("list_files", "List files")]
|
||||
mock_session = MagicMock()
|
||||
mock_registry = ToolRegistry()
|
||||
|
||||
fresh_servers = {}
|
||||
|
||||
@@ -565,43 +567,32 @@ class TestToolsetInjection:
|
||||
server._tools = mock_tools
|
||||
return server
|
||||
|
||||
fake_toolsets = {
|
||||
"hermes-cli": {"tools": ["terminal"], "description": "CLI", "includes": []},
|
||||
"hermes-telegram": {"tools": ["terminal"], "description": "TG", "includes": []},
|
||||
"hermes-gateway": {"tools": [], "description": "GW", "includes": []},
|
||||
"non-hermes": {"tools": [], "description": "other", "includes": []},
|
||||
}
|
||||
fake_config = {"fs": {"command": "npx", "args": []}}
|
||||
|
||||
with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
|
||||
patch("tools.mcp_tool._servers", fresh_servers), \
|
||||
patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \
|
||||
patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
|
||||
patch("toolsets.TOOLSETS", fake_toolsets):
|
||||
patch("tools.registry.registry", mock_registry):
|
||||
from tools.mcp_tool import discover_mcp_tools
|
||||
result = discover_mcp_tools()
|
||||
|
||||
assert "mcp_fs_list_files" in result
|
||||
# All hermes-* toolsets get injection
|
||||
assert "mcp_fs_list_files" in fake_toolsets["hermes-cli"]["tools"]
|
||||
assert "mcp_fs_list_files" in fake_toolsets["hermes-telegram"]["tools"]
|
||||
assert "mcp_fs_list_files" in fake_toolsets["hermes-gateway"]["tools"]
|
||||
# Non-hermes toolset should NOT get injection
|
||||
assert "mcp_fs_list_files" not in fake_toolsets["non-hermes"]["tools"]
|
||||
# Original tools preserved
|
||||
assert "terminal" in fake_toolsets["hermes-cli"]["tools"]
|
||||
# Server name becomes a standalone toolset
|
||||
assert "fs" in fake_toolsets
|
||||
assert "mcp_fs_list_files" in fake_toolsets["fs"]["tools"]
|
||||
assert fake_toolsets["fs"]["description"].startswith("MCP server '")
|
||||
assert "mcp_fs_list_files" in result
|
||||
assert validate_toolset("fs") is True
|
||||
assert validate_toolset("mcp-fs") is True
|
||||
assert "mcp_fs_list_files" in resolve_toolset("fs")
|
||||
assert "mcp_fs_list_files" in resolve_toolset("mcp-fs")
|
||||
|
||||
def test_server_toolset_skips_builtin_collision(self):
|
||||
"""MCP server named after a built-in toolset shouldn't overwrite it."""
|
||||
"""MCP raw aliases never overwrite a built-in toolset name."""
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
from tools.registry import ToolRegistry
|
||||
from toolsets import resolve_toolset, validate_toolset
|
||||
|
||||
mock_tools = [_make_mcp_tool("run", "Run command")]
|
||||
mock_session = MagicMock()
|
||||
fresh_servers = {}
|
||||
mock_registry = ToolRegistry()
|
||||
|
||||
async def fake_connect(name, config):
|
||||
server = MCPServerTask(name)
|
||||
@@ -620,12 +611,15 @@ class TestToolsetInjection:
|
||||
patch("tools.mcp_tool._servers", fresh_servers), \
|
||||
patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \
|
||||
patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
|
||||
patch("tools.registry.registry", mock_registry), \
|
||||
patch("toolsets.TOOLSETS", fake_toolsets):
|
||||
from tools.mcp_tool import discover_mcp_tools
|
||||
discover_mcp_tools()
|
||||
|
||||
# Built-in toolset preserved — description unchanged
|
||||
assert fake_toolsets["terminal"]["description"] == "Terminal tools"
|
||||
assert fake_toolsets["terminal"]["description"] == "Terminal tools"
|
||||
assert "mcp_terminal_run" not in resolve_toolset("terminal")
|
||||
assert validate_toolset("mcp-terminal") is True
|
||||
assert "mcp_terminal_run" in resolve_toolset("mcp-terminal")
|
||||
|
||||
def test_server_connection_failure_skipped(self):
|
||||
"""If one server fails to connect, others still proceed."""
|
||||
@@ -776,6 +770,42 @@ class TestShutdown:
|
||||
assert len(_servers) == 0
|
||||
mock_server.shutdown.assert_called_once()
|
||||
|
||||
def test_shutdown_deregisters_registered_tools(self):
|
||||
"""shutdown_mcp_servers removes MCP tools and their raw alias."""
|
||||
import tools.mcp_tool as mcp_mod
|
||||
from tools.mcp_tool import MCPServerTask, shutdown_mcp_servers, _servers
|
||||
from tools.registry import registry
|
||||
from toolsets import resolve_toolset, validate_toolset
|
||||
|
||||
_servers.clear()
|
||||
registry.register(
|
||||
name="mcp_test_ping",
|
||||
toolset="mcp-test",
|
||||
schema={
|
||||
"name": "mcp_test_ping",
|
||||
"description": "Ping",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
handler=lambda *_args, **_kwargs: "{}",
|
||||
)
|
||||
registry.register_toolset_alias("test", "mcp-test")
|
||||
|
||||
server = MCPServerTask("test")
|
||||
server._registered_tool_names = ["mcp_test_ping"]
|
||||
_servers["test"] = server
|
||||
|
||||
mcp_mod._ensure_mcp_loop()
|
||||
try:
|
||||
assert validate_toolset("test") is True
|
||||
assert "mcp_test_ping" in resolve_toolset("test")
|
||||
shutdown_mcp_servers()
|
||||
finally:
|
||||
mcp_mod._mcp_loop = None
|
||||
mcp_mod._mcp_thread = None
|
||||
|
||||
assert "mcp_test_ping" not in registry.get_all_tool_names()
|
||||
assert validate_toolset("test") is False
|
||||
|
||||
def test_shutdown_handles_errors(self):
|
||||
"""shutdown_mcp_servers handles errors during close gracefully."""
|
||||
import tools.mcp_tool as mcp_mod
|
||||
@@ -1179,7 +1209,11 @@ class TestConfigurableTimeouts:
|
||||
try:
|
||||
handler = _make_tool_handler("test_srv", "my_tool", 180)
|
||||
with patch("tools.mcp_tool._run_on_mcp_loop") as mock_run:
|
||||
mock_run.return_value = json.dumps({"result": "ok"})
|
||||
def fake_run(coro, timeout=30):
|
||||
coro.close()
|
||||
return json.dumps({"result": "ok"})
|
||||
|
||||
mock_run.side_effect = fake_run
|
||||
handler({})
|
||||
# Verify timeout=180 was passed
|
||||
call_kwargs = mock_run.call_args
|
||||
@@ -1279,11 +1313,7 @@ class TestUtilityHandlers:
|
||||
def _patch_mcp_loop(self):
|
||||
"""Return a patch for _run_on_mcp_loop that runs the coroutine directly."""
|
||||
def fake_run(coro, timeout=30):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
return asyncio.run(coro)
|
||||
return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=fake_run)
|
||||
|
||||
# -- list_resources --
|
||||
@@ -3038,14 +3068,23 @@ class TestSanitizeMcpNameComponent:
|
||||
assert "/" not in name
|
||||
assert "." not in name
|
||||
|
||||
def test_slash_in_sync_mcp_toolsets(self):
|
||||
"""_sync_mcp_toolsets uses sanitize consistently with _convert_mcp_schema."""
|
||||
from tools.mcp_tool import sanitize_mcp_name_component
|
||||
def test_slash_in_server_alias_resolution(self):
|
||||
"""Server names with slashes resolve through their live MCP alias."""
|
||||
from tools.registry import ToolRegistry
|
||||
from toolsets import resolve_toolset, validate_toolset
|
||||
|
||||
# Verify the prefix generation matches what _convert_mcp_schema produces
|
||||
server_name = "ai.exa/exa"
|
||||
safe_prefix = f"mcp_{sanitize_mcp_name_component(server_name)}_"
|
||||
assert safe_prefix == "mcp_ai_exa_exa_"
|
||||
reg = ToolRegistry()
|
||||
reg.register(
|
||||
name="mcp_ai_exa_exa_search",
|
||||
toolset="mcp-ai.exa/exa",
|
||||
schema={"name": "mcp_ai_exa_exa_search", "description": "Search", "parameters": {"type": "object", "properties": {}}},
|
||||
handler=lambda *_args, **_kwargs: "{}",
|
||||
)
|
||||
reg.register_toolset_alias("ai.exa/exa", "mcp-ai.exa/exa")
|
||||
|
||||
with patch("tools.registry.registry", reg):
|
||||
assert validate_toolset("ai.exa/exa") is True
|
||||
assert "mcp_ai_exa_exa_search" in resolve_toolset("ai.exa/exa")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
|
||||
import json
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from tools.registry import ToolRegistry
|
||||
from tools.registry import ToolRegistry, discover_builtin_tools
|
||||
|
||||
|
||||
def _dummy_handler(args, **kwargs):
|
||||
@@ -286,6 +288,74 @@ class TestCheckFnExceptionHandling:
|
||||
assert any(u["name"] == "crashes" for u in unavailable)
|
||||
|
||||
|
||||
class TestBuiltinDiscovery:
|
||||
def test_matches_previous_manual_builtin_tool_set(self):
|
||||
expected = {
|
||||
"tools.browser_tool",
|
||||
"tools.clarify_tool",
|
||||
"tools.code_execution_tool",
|
||||
"tools.cronjob_tools",
|
||||
"tools.delegate_tool",
|
||||
"tools.file_tools",
|
||||
"tools.homeassistant_tool",
|
||||
"tools.image_generation_tool",
|
||||
"tools.memory_tool",
|
||||
"tools.mixture_of_agents_tool",
|
||||
"tools.process_registry",
|
||||
"tools.rl_training_tool",
|
||||
"tools.send_message_tool",
|
||||
"tools.session_search_tool",
|
||||
"tools.skill_manager_tool",
|
||||
"tools.skills_tool",
|
||||
"tools.terminal_tool",
|
||||
"tools.todo_tool",
|
||||
"tools.tts_tool",
|
||||
"tools.vision_tools",
|
||||
"tools.web_tools",
|
||||
}
|
||||
|
||||
with patch("tools.registry.importlib.import_module"):
|
||||
imported = discover_builtin_tools(Path(__file__).resolve().parents[2] / "tools")
|
||||
|
||||
assert set(imported) == expected
|
||||
|
||||
def test_imports_only_self_registering_modules(self, tmp_path):
|
||||
tools_dir = tmp_path / "tools"
|
||||
tools_dir.mkdir()
|
||||
(tools_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
(tools_dir / "registry.py").write_text("", encoding="utf-8")
|
||||
(tools_dir / "alpha.py").write_text(
|
||||
"from tools.registry import registry\nregistry.register(name='alpha', toolset='x', schema={}, handler=lambda *_a, **_k: '{}')\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(tools_dir / "beta.py").write_text("VALUE = 1\n", encoding="utf-8")
|
||||
|
||||
with patch("tools.registry.importlib.import_module") as mock_import:
|
||||
imported = discover_builtin_tools(tools_dir)
|
||||
|
||||
assert imported == ["tools.alpha"]
|
||||
mock_import.assert_called_once_with("tools.alpha")
|
||||
|
||||
def test_skips_mcp_tool_even_if_it_registers(self, tmp_path):
|
||||
tools_dir = tmp_path / "tools"
|
||||
tools_dir.mkdir()
|
||||
(tools_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
(tools_dir / "mcp_tool.py").write_text(
|
||||
"from tools.registry import registry\nregistry.register(name='mcp_alpha', toolset='mcp-test', schema={}, handler=lambda *_a, **_k: '{}')\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(tools_dir / "alpha.py").write_text(
|
||||
"from tools.registry import registry\nregistry.register(name='alpha', toolset='x', schema={}, handler=lambda *_a, **_k: '{}')\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
with patch("tools.registry.importlib.import_module") as mock_import:
|
||||
imported = discover_builtin_tools(tools_dir)
|
||||
|
||||
assert imported == ["tools.alpha"]
|
||||
mock_import.assert_called_once_with("tools.alpha")
|
||||
|
||||
|
||||
class TestEmojiMetadata:
|
||||
"""Verify per-tool emoji registration and lookup."""
|
||||
|
||||
|
||||
@@ -752,6 +752,38 @@ class TestParseTargetRefDiscord:
|
||||
assert is_explicit is True
|
||||
|
||||
|
||||
class TestParseTargetRefMatrix:
|
||||
"""_parse_target_ref correctly handles Matrix room IDs and user MXIDs."""
|
||||
|
||||
def test_matrix_room_id_is_explicit(self):
|
||||
"""Matrix room IDs (!) are recognized as explicit targets."""
|
||||
chat_id, thread_id, is_explicit = _parse_target_ref("matrix", "!HLOQwxYGgFPMPJUSNR:matrix.org")
|
||||
assert chat_id == "!HLOQwxYGgFPMPJUSNR:matrix.org"
|
||||
assert thread_id is None
|
||||
assert is_explicit is True
|
||||
|
||||
def test_matrix_user_mxid_is_explicit(self):
|
||||
"""Matrix user MXIDs (@) are recognized as explicit targets."""
|
||||
chat_id, thread_id, is_explicit = _parse_target_ref("matrix", "@hermes:matrix.org")
|
||||
assert chat_id == "@hermes:matrix.org"
|
||||
assert thread_id is None
|
||||
assert is_explicit is True
|
||||
|
||||
def test_matrix_alias_is_not_explicit(self):
|
||||
"""Matrix room aliases (#) are NOT explicit — they need resolution."""
|
||||
chat_id, thread_id, is_explicit = _parse_target_ref("matrix", "#general:matrix.org")
|
||||
assert chat_id is None
|
||||
assert is_explicit is False
|
||||
|
||||
def test_matrix_prefix_only_matches_matrix_platform(self):
|
||||
"""! and @ prefixes are only treated as explicit for the matrix platform."""
|
||||
chat_id, _, is_explicit = _parse_target_ref("telegram", "!something")
|
||||
assert is_explicit is False
|
||||
|
||||
chat_id, _, is_explicit = _parse_target_ref("discord", "@someone")
|
||||
assert is_explicit is False
|
||||
|
||||
|
||||
class TestSendDiscordThreadId:
|
||||
"""_send_discord uses thread_id when provided."""
|
||||
|
||||
@@ -854,3 +886,225 @@ class TestSendToPlatformDiscordThread:
|
||||
send_mock.assert_awaited_once()
|
||||
_, call_kwargs = send_mock.await_args
|
||||
assert call_kwargs["thread_id"] is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Discord media attachment support
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSendDiscordMedia:
|
||||
"""_send_discord uploads media files via multipart/form-data."""
|
||||
|
||||
@staticmethod
|
||||
def _build_mock(response_status, response_data=None, response_text="error body"):
|
||||
"""Build a properly-structured aiohttp mock chain."""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status = response_status
|
||||
mock_resp.json = AsyncMock(return_value=response_data or {"id": "msg123"})
|
||||
mock_resp.text = AsyncMock(return_value=response_text)
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_session.post = MagicMock(return_value=mock_resp)
|
||||
|
||||
return mock_session, mock_resp
|
||||
|
||||
def test_text_and_media_sends_both(self, tmp_path):
|
||||
"""Text message is sent first, then each media file as multipart."""
|
||||
img = tmp_path / "photo.png"
|
||||
img.write_bytes(b"\x89PNG fake image data")
|
||||
|
||||
mock_session, _ = self._build_mock(200, {"id": "msg999"})
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
result = asyncio.run(
|
||||
_send_discord("tok", "111", "hello", media_files=[(str(img), False)])
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["message_id"] == "msg999"
|
||||
# Two POSTs: one text JSON, one multipart upload
|
||||
assert mock_session.post.call_count == 2
|
||||
|
||||
def test_media_only_skips_text_post(self, tmp_path):
|
||||
"""When message is empty and media is present, text POST is skipped."""
|
||||
img = tmp_path / "photo.png"
|
||||
img.write_bytes(b"\x89PNG fake image data")
|
||||
|
||||
mock_session, _ = self._build_mock(200, {"id": "media_only"})
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
result = asyncio.run(
|
||||
_send_discord("tok", "222", " ", media_files=[(str(img), False)])
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
# Only one POST: the media upload (text was whitespace-only)
|
||||
assert mock_session.post.call_count == 1
|
||||
|
||||
def test_missing_media_file_collected_as_warning(self):
|
||||
"""Non-existent media paths produce warnings but don't fail."""
|
||||
mock_session, _ = self._build_mock(200, {"id": "txt_ok"})
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
result = asyncio.run(
|
||||
_send_discord("tok", "333", "hello", media_files=[("/nonexistent/file.png", False)])
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert "warnings" in result
|
||||
assert any("not found" in w for w in result["warnings"])
|
||||
# Only the text POST was made, media was skipped
|
||||
assert mock_session.post.call_count == 1
|
||||
|
||||
def test_media_upload_failure_collected_as_warning(self, tmp_path):
|
||||
"""Failed media upload becomes a warning, text still succeeds."""
|
||||
img = tmp_path / "photo.png"
|
||||
img.write_bytes(b"\x89PNG fake image data")
|
||||
|
||||
# First call (text) succeeds, second call (media) returns 413
|
||||
text_resp = MagicMock()
|
||||
text_resp.status = 200
|
||||
text_resp.json = AsyncMock(return_value={"id": "txt_ok"})
|
||||
text_resp.__aenter__ = AsyncMock(return_value=text_resp)
|
||||
text_resp.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
media_resp = MagicMock()
|
||||
media_resp.status = 413
|
||||
media_resp.text = AsyncMock(return_value="Request Entity Too Large")
|
||||
media_resp.__aenter__ = AsyncMock(return_value=media_resp)
|
||||
media_resp.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_session.post = MagicMock(side_effect=[text_resp, media_resp])
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
result = asyncio.run(
|
||||
_send_discord("tok", "444", "hello", media_files=[(str(img), False)])
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["message_id"] == "txt_ok"
|
||||
assert "warnings" in result
|
||||
assert any("413" in w for w in result["warnings"])
|
||||
|
||||
def test_no_text_no_media_returns_error(self):
|
||||
"""Empty text with no media returns error dict."""
|
||||
mock_session, _ = self._build_mock(200)
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
result = asyncio.run(
|
||||
_send_discord("tok", "555", "", media_files=[])
|
||||
)
|
||||
|
||||
# Text is empty but media_files is empty, so text POST fires
|
||||
# (the "skip text if media present" condition isn't met)
|
||||
assert result["success"] is True
|
||||
|
||||
def test_multiple_media_files_uploaded_separately(self, tmp_path):
|
||||
"""Each media file gets its own multipart POST."""
|
||||
img1 = tmp_path / "a.png"
|
||||
img1.write_bytes(b"img1")
|
||||
img2 = tmp_path / "b.jpg"
|
||||
img2.write_bytes(b"img2")
|
||||
|
||||
mock_session, _ = self._build_mock(200, {"id": "last"})
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
result = asyncio.run(
|
||||
_send_discord("tok", "666", "hi", media_files=[
|
||||
(str(img1), False), (str(img2), False)
|
||||
])
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
# 1 text POST + 2 media POSTs = 3
|
||||
assert mock_session.post.call_count == 3
|
||||
|
||||
|
||||
class TestSendToPlatformDiscordMedia:
|
||||
"""_send_to_platform routes Discord media correctly."""
|
||||
|
||||
def test_media_files_passed_on_last_chunk_only(self):
|
||||
"""Discord media_files are only passed on the final chunk."""
|
||||
call_log = []
|
||||
|
||||
async def mock_send_discord(token, chat_id, message, thread_id=None, media_files=None):
|
||||
call_log.append({"message": message, "media_files": media_files or []})
|
||||
return {"success": True, "platform": "discord", "chat_id": chat_id, "message_id": "1"}
|
||||
|
||||
# A message long enough to get chunked (Discord limit is 2000)
|
||||
long_msg = "A" * 1900 + " " + "B" * 1900
|
||||
|
||||
with patch("tools.send_message_tool._send_discord", side_effect=mock_send_discord):
|
||||
result = asyncio.run(
|
||||
_send_to_platform(
|
||||
Platform.DISCORD,
|
||||
SimpleNamespace(enabled=True, token="tok", extra={}),
|
||||
"999",
|
||||
long_msg,
|
||||
media_files=[("/fake/img.png", False)],
|
||||
)
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert len(call_log) == 2 # Message was chunked
|
||||
assert call_log[0]["media_files"] == [] # First chunk: no media
|
||||
assert call_log[1]["media_files"] == [("/fake/img.png", False)] # Last chunk: media attached
|
||||
|
||||
def test_single_chunk_gets_media(self):
|
||||
"""Short message (single chunk) gets media_files directly."""
|
||||
send_mock = AsyncMock(return_value={"success": True, "message_id": "1"})
|
||||
|
||||
with patch("tools.send_message_tool._send_discord", send_mock):
|
||||
result = asyncio.run(
|
||||
_send_to_platform(
|
||||
Platform.DISCORD,
|
||||
SimpleNamespace(enabled=True, token="tok", extra={}),
|
||||
"888",
|
||||
"short message",
|
||||
media_files=[("/fake/img.png", False)],
|
||||
)
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
send_mock.assert_awaited_once()
|
||||
call_kwargs = send_mock.await_args.kwargs
|
||||
assert call_kwargs["media_files"] == [("/fake/img.png", False)]
|
||||
|
||||
|
||||
class TestSendMatrixUrlEncoding:
|
||||
"""_send_matrix URL-encodes Matrix room IDs in the API path."""
|
||||
|
||||
def test_room_id_is_percent_encoded_in_url(self):
|
||||
"""Matrix room IDs with ! and : are percent-encoded in the PUT URL."""
|
||||
import aiohttp
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.json = AsyncMock(return_value={"event_id": "$evt123"})
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.put = MagicMock(return_value=mock_resp)
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
from tools.send_message_tool import _send_matrix
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
_send_matrix(
|
||||
"test_token",
|
||||
{"homeserver": "https://matrix.example.org"},
|
||||
"!HLOQwxYGgFPMPJUSNR:matrix.org",
|
||||
"hello",
|
||||
)
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
# Verify the URL was called with percent-encoded room ID
|
||||
put_url = mock_session.put.call_args[0][0]
|
||||
assert "%21HLOQwxYGgFPMPJUSNR%3Amatrix.org" in put_url
|
||||
assert "!HLOQwxYGgFPMPJUSNR:matrix.org" not in put_url
|
||||
|
||||
@@ -92,6 +92,25 @@ class TestCheckWatchPatterns:
|
||||
assert "disk full" in evt["output"]
|
||||
assert evt["session_id"] == "proc_test_watch"
|
||||
|
||||
def test_match_carries_session_key_and_watcher_routing_metadata(self, registry):
|
||||
session = _make_session(watch_patterns=["ERROR"])
|
||||
session.session_key = "agent:main:telegram:group:-100:42"
|
||||
session.watcher_platform = "telegram"
|
||||
session.watcher_chat_id = "-100"
|
||||
session.watcher_user_id = "u123"
|
||||
session.watcher_user_name = "alice"
|
||||
session.watcher_thread_id = "42"
|
||||
|
||||
registry._check_watch_patterns(session, "ERROR: disk full\n")
|
||||
evt = registry.completion_queue.get_nowait()
|
||||
|
||||
assert evt["session_key"] == "agent:main:telegram:group:-100:42"
|
||||
assert evt["platform"] == "telegram"
|
||||
assert evt["chat_id"] == "-100"
|
||||
assert evt["user_id"] == "u123"
|
||||
assert evt["user_name"] == "alice"
|
||||
assert evt["thread_id"] == "42"
|
||||
|
||||
def test_multiple_patterns(self, registry):
|
||||
"""First matching pattern is reported."""
|
||||
session = _make_session(watch_patterns=["WARN", "ERROR"])
|
||||
|
||||
+46
-40
@@ -94,11 +94,21 @@ except ImportError:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Standard PATH entries for environments with minimal PATH (e.g. systemd services).
|
||||
# Includes macOS Homebrew paths (/opt/homebrew/* for Apple Silicon).
|
||||
_SANE_PATH = (
|
||||
"/opt/homebrew/bin:/opt/homebrew/sbin:"
|
||||
"/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
|
||||
# Includes Android/Termux and macOS Homebrew locations needed for agent-browser,
|
||||
# npx, node, and Android's glibc runner (grun).
|
||||
_SANE_PATH_DIRS = (
|
||||
"/data/data/com.termux/files/usr/bin",
|
||||
"/data/data/com.termux/files/usr/sbin",
|
||||
"/opt/homebrew/bin",
|
||||
"/opt/homebrew/sbin",
|
||||
"/usr/local/sbin",
|
||||
"/usr/local/bin",
|
||||
"/usr/sbin",
|
||||
"/usr/bin",
|
||||
"/sbin",
|
||||
"/bin",
|
||||
)
|
||||
_SANE_PATH = os.pathsep.join(_SANE_PATH_DIRS)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
@@ -123,6 +133,28 @@ def _discover_homebrew_node_dirs() -> tuple[str, ...]:
|
||||
pass
|
||||
return tuple(dirs)
|
||||
|
||||
|
||||
def _browser_candidate_path_dirs() -> list[str]:
|
||||
"""Return ordered browser CLI PATH candidates shared by discovery and execution."""
|
||||
hermes_home = get_hermes_home()
|
||||
hermes_node_bin = str(hermes_home / "node" / "bin")
|
||||
return [hermes_node_bin, *list(_discover_homebrew_node_dirs()), *_SANE_PATH_DIRS]
|
||||
|
||||
|
||||
def _merge_browser_path(existing_path: str = "") -> str:
|
||||
"""Prepend browser-specific PATH fallbacks without reordering existing entries."""
|
||||
path_parts = [p for p in (existing_path or "").split(os.pathsep) if p]
|
||||
existing_parts = set(path_parts)
|
||||
prefix_parts: list[str] = []
|
||||
|
||||
for part in _browser_candidate_path_dirs():
|
||||
if not part or part in existing_parts or part in prefix_parts:
|
||||
continue
|
||||
if os.path.isdir(part):
|
||||
prefix_parts.append(part)
|
||||
|
||||
return os.pathsep.join(prefix_parts + path_parts)
|
||||
|
||||
# Throttle screenshot cleanup to avoid repeated full directory scans.
|
||||
_last_screenshot_cleanup_by_dir: dict[str, float] = {}
|
||||
|
||||
@@ -895,21 +927,10 @@ def _find_agent_browser() -> str:
|
||||
_agent_browser_resolved = True
|
||||
return which_result
|
||||
|
||||
# Build an extended search PATH including Homebrew and Hermes-managed dirs.
|
||||
# This covers macOS where the process PATH may not include Homebrew paths.
|
||||
extra_dirs: list[str] = []
|
||||
for d in ["/opt/homebrew/bin", "/usr/local/bin"]:
|
||||
if os.path.isdir(d):
|
||||
extra_dirs.append(d)
|
||||
extra_dirs.extend(_discover_homebrew_node_dirs())
|
||||
|
||||
hermes_home = get_hermes_home()
|
||||
hermes_node_bin = str(hermes_home / "node" / "bin")
|
||||
if os.path.isdir(hermes_node_bin):
|
||||
extra_dirs.append(hermes_node_bin)
|
||||
|
||||
if extra_dirs:
|
||||
extended_path = os.pathsep.join(extra_dirs)
|
||||
# Build an extended search PATH including Hermes-managed Node, macOS
|
||||
# versioned Homebrew installs, and fallback system dirs like Termux.
|
||||
extended_path = _merge_browser_path("")
|
||||
if extended_path:
|
||||
which_result = shutil.which("agent-browser", path=extended_path)
|
||||
if which_result:
|
||||
_cached_agent_browser = which_result
|
||||
@@ -924,10 +945,10 @@ def _find_agent_browser() -> str:
|
||||
_agent_browser_resolved = True
|
||||
return _cached_agent_browser
|
||||
|
||||
# Check common npx locations (also search extended dirs)
|
||||
# Check common npx locations (also search the extended fallback PATH)
|
||||
npx_path = shutil.which("npx")
|
||||
if not npx_path and extra_dirs:
|
||||
npx_path = shutil.which("npx", path=os.pathsep.join(extra_dirs))
|
||||
if not npx_path and extended_path:
|
||||
npx_path = shutil.which("npx", path=extended_path)
|
||||
if npx_path:
|
||||
_cached_agent_browser = "npx agent-browser"
|
||||
_agent_browser_resolved = True
|
||||
@@ -1046,24 +1067,9 @@ def _run_browser_command(
|
||||
|
||||
browser_env = {**os.environ}
|
||||
|
||||
# Ensure PATH includes Hermes-managed Node first, Homebrew versioned
|
||||
# node dirs (for macOS ``brew install node@24``), then standard system dirs.
|
||||
hermes_home = get_hermes_home()
|
||||
hermes_node_bin = str(hermes_home / "node" / "bin")
|
||||
|
||||
existing_path = browser_env.get("PATH", "")
|
||||
path_parts = [p for p in existing_path.split(":") if p]
|
||||
candidate_dirs = (
|
||||
[hermes_node_bin]
|
||||
+ list(_discover_homebrew_node_dirs())
|
||||
+ [p for p in _SANE_PATH.split(":") if p]
|
||||
)
|
||||
|
||||
for part in reversed(candidate_dirs):
|
||||
if os.path.isdir(part) and part not in path_parts:
|
||||
path_parts.insert(0, part)
|
||||
|
||||
browser_env["PATH"] = ":".join(path_parts)
|
||||
# Ensure subprocesses inherit the same browser-specific PATH fallbacks
|
||||
# used during CLI discovery.
|
||||
browser_env["PATH"] = _merge_browser_path(browser_env.get("PATH", ""))
|
||||
browser_env["AGENT_BROWSER_SOCKET_DIR"] = task_socket_dir
|
||||
|
||||
# Use temp files for stdout/stderr instead of pipes.
|
||||
|
||||
@@ -13,6 +13,8 @@ import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from hermes_constants import display_hermes_home
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Import from cron module (will be available when properly installed)
|
||||
@@ -391,6 +393,8 @@ Use action='create' to schedule a new job from a prompt or one or more skills.
|
||||
Use action='list' to inspect jobs.
|
||||
Use action='update', 'pause', 'resume', 'remove', or 'run' to manage an existing job.
|
||||
|
||||
To stop a job the user no longer wants: first action='list' to find the job_id, then action='remove' with that job_id. Never guess job IDs — always list first.
|
||||
|
||||
Jobs run in a fresh session with no current-chat context, so prompts must be self-contained.
|
||||
If skills are provided on create, the future cron run loads those skills in order, then follows the prompt as the task instruction.
|
||||
On update, passing skills=[] clears attached skills.
|
||||
@@ -453,7 +457,7 @@ Important safety rule: cron-run sessions should not recursively schedule more cr
|
||||
},
|
||||
"script": {
|
||||
"type": "string",
|
||||
"description": "Optional path to a Python script that runs before each cron job execution. Its stdout is injected into the prompt as context. Use for data collection and change detection. Relative paths resolve under ~/.hermes/scripts/. On update, pass empty string to clear."
|
||||
"description": f"Optional path to a Python script that runs before each cron job execution. Its stdout is injected into the prompt as context. Use for data collection and change detection. Relative paths resolve under {display_hermes_home()}/scripts/. On update, pass empty string to clear."
|
||||
},
|
||||
},
|
||||
"required": ["action"]
|
||||
|
||||
@@ -99,23 +99,41 @@ def _load_hermes_env_vars() -> dict[str, str]:
|
||||
|
||||
|
||||
def find_docker() -> Optional[str]:
|
||||
"""Locate the docker CLI binary.
|
||||
"""Locate the docker (or podman) CLI binary.
|
||||
|
||||
Checks ``shutil.which`` first (respects PATH), then probes well-known
|
||||
install locations on macOS where Docker Desktop may not be in PATH
|
||||
(e.g. when running as a gateway service via launchd).
|
||||
Resolution order:
|
||||
1. ``HERMES_DOCKER_BINARY`` env var — explicit override (e.g. ``/usr/bin/podman``)
|
||||
2. ``docker`` on PATH via ``shutil.which``
|
||||
3. ``podman`` on PATH via ``shutil.which``
|
||||
4. Well-known macOS Docker Desktop install locations
|
||||
|
||||
Returns the absolute path, or ``None`` if docker cannot be found.
|
||||
Returns the absolute path, or ``None`` if neither runtime can be found.
|
||||
"""
|
||||
global _docker_executable
|
||||
if _docker_executable is not None:
|
||||
return _docker_executable
|
||||
|
||||
# 1. Explicit override via env var (e.g. for Podman on immutable distros)
|
||||
override = os.getenv("HERMES_DOCKER_BINARY")
|
||||
if override and os.path.isfile(override) and os.access(override, os.X_OK):
|
||||
_docker_executable = override
|
||||
logger.info("Using HERMES_DOCKER_BINARY override: %s", override)
|
||||
return override
|
||||
|
||||
# 2. docker on PATH
|
||||
found = shutil.which("docker")
|
||||
if found:
|
||||
_docker_executable = found
|
||||
return found
|
||||
|
||||
# 3. podman on PATH (drop-in compatible for our use case)
|
||||
found = shutil.which("podman")
|
||||
if found:
|
||||
_docker_executable = found
|
||||
logger.info("Using podman as container runtime: %s", found)
|
||||
return found
|
||||
|
||||
# 4. Well-known macOS Docker Desktop locations
|
||||
for path in _DOCKER_SEARCH_PATHS:
|
||||
if os.path.isfile(path) and os.access(path, os.X_OK):
|
||||
_docker_executable = path
|
||||
|
||||
@@ -105,6 +105,10 @@ class BaseModalExecutionEnvironment(BaseEnvironment):
|
||||
if self._client_timeout_grace_seconds is not None:
|
||||
deadline = time.monotonic() + prepared.timeout + self._client_timeout_grace_seconds
|
||||
|
||||
_last_activity_touch = time.monotonic()
|
||||
_modal_exec_start = time.monotonic()
|
||||
_ACTIVITY_INTERVAL = 10.0 # match _wait_for_process cadence
|
||||
|
||||
while True:
|
||||
if is_interrupted():
|
||||
try:
|
||||
@@ -128,6 +132,22 @@ class BaseModalExecutionEnvironment(BaseEnvironment):
|
||||
pass
|
||||
return self._timeout_result_for_modal(prepared.timeout)
|
||||
|
||||
# Periodic activity touch so the gateway knows we're alive
|
||||
_now = time.monotonic()
|
||||
if _now - _last_activity_touch >= _ACTIVITY_INTERVAL:
|
||||
_last_activity_touch = _now
|
||||
try:
|
||||
from tools.environments.base import _get_activity_callback
|
||||
_cb = _get_activity_callback()
|
||||
except Exception:
|
||||
_cb = None
|
||||
if _cb:
|
||||
try:
|
||||
_elapsed = int(_now - _modal_exec_start)
|
||||
_cb(f"modal command running ({_elapsed}s elapsed)")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time.sleep(self._poll_interval_seconds)
|
||||
|
||||
def _before_execute(self) -> None:
|
||||
|
||||
+14
-80
@@ -846,8 +846,7 @@ class MCPServerTask:
|
||||
After the initial ``await`` (list_tools), all mutations are synchronous
|
||||
— atomic from the event loop's perspective.
|
||||
"""
|
||||
from tools.registry import registry, tool_error
|
||||
from toolsets import TOOLSETS
|
||||
from tools.registry import registry
|
||||
|
||||
async with self._refresh_lock:
|
||||
# Capture old tool names for change diff
|
||||
@@ -857,16 +856,11 @@ class MCPServerTask:
|
||||
tools_result = await self.session.list_tools()
|
||||
new_mcp_tools = tools_result.tools if hasattr(tools_result, "tools") else []
|
||||
|
||||
# 2. Remove old tools from hermes-* umbrella toolsets
|
||||
for ts_name, ts in TOOLSETS.items():
|
||||
if ts_name.startswith("hermes-"):
|
||||
ts["tools"] = [t for t in ts["tools"] if t not in self._registered_tool_names]
|
||||
|
||||
# 3. Deregister old tools from the central registry
|
||||
# 2. Deregister old tools from the central registry
|
||||
for prefixed_name in self._registered_tool_names:
|
||||
registry.deregister(prefixed_name)
|
||||
|
||||
# 4. Re-register with fresh tool list
|
||||
# 3. Re-register with fresh tool list
|
||||
self._tools = new_mcp_tools
|
||||
self._registered_tool_names = _register_server_tools(
|
||||
self.name, self, self._config
|
||||
@@ -1144,6 +1138,8 @@ class MCPServerTask:
|
||||
|
||||
async def shutdown(self):
|
||||
"""Signal the Task to exit and wait for clean resource teardown."""
|
||||
from tools.registry import registry
|
||||
|
||||
self._shutdown_event.set()
|
||||
if self._task and not self._task.done():
|
||||
try:
|
||||
@@ -1158,6 +1154,9 @@ class MCPServerTask:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
for tool_name in list(getattr(self, "_registered_tool_names", [])):
|
||||
registry.deregister(tool_name)
|
||||
self._registered_tool_names = []
|
||||
self.session = None
|
||||
|
||||
|
||||
@@ -1671,57 +1670,6 @@ def _convert_mcp_schema(server_name: str, mcp_tool) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def _sync_mcp_toolsets(server_names: Optional[List[str]] = None) -> None:
|
||||
"""Expose each MCP server as a standalone toolset and inject into hermes-* sets.
|
||||
|
||||
Creates a real toolset entry in TOOLSETS for each server name (e.g.
|
||||
TOOLSETS["github"] = {"tools": ["mcp_github_list_files", ...]}). This
|
||||
makes raw server names resolvable in platform_toolsets overrides.
|
||||
|
||||
Also injects all MCP tools into hermes-* umbrella toolsets for the
|
||||
default behavior.
|
||||
|
||||
Skips server names that collide with built-in toolsets.
|
||||
"""
|
||||
from toolsets import TOOLSETS
|
||||
|
||||
if server_names is None:
|
||||
server_names = list(_load_mcp_config().keys())
|
||||
|
||||
existing = _existing_tool_names()
|
||||
all_mcp_tools: List[str] = []
|
||||
|
||||
for server_name in server_names:
|
||||
safe_prefix = f"mcp_{sanitize_mcp_name_component(server_name)}_"
|
||||
server_tools = sorted(
|
||||
t for t in existing if t.startswith(safe_prefix)
|
||||
)
|
||||
all_mcp_tools.extend(server_tools)
|
||||
|
||||
# Don't overwrite a built-in toolset that happens to share the name.
|
||||
existing_ts = TOOLSETS.get(server_name)
|
||||
if existing_ts and not str(existing_ts.get("description", "")).startswith("MCP server '"):
|
||||
logger.warning(
|
||||
"Skipping MCP toolset alias '%s' — a built-in toolset already uses that name",
|
||||
server_name,
|
||||
)
|
||||
continue
|
||||
|
||||
TOOLSETS[server_name] = {
|
||||
"description": f"MCP server '{server_name}' tools",
|
||||
"tools": server_tools,
|
||||
"includes": [],
|
||||
}
|
||||
|
||||
# Also inject into hermes-* umbrella toolsets for default behavior.
|
||||
for ts_name, ts in TOOLSETS.items():
|
||||
if not ts_name.startswith("hermes-"):
|
||||
continue
|
||||
for tool_name in all_mcp_tools:
|
||||
if tool_name not in ts["tools"]:
|
||||
ts["tools"].append(tool_name)
|
||||
|
||||
|
||||
def _build_utility_schemas(server_name: str) -> List[dict]:
|
||||
"""Build schemas for the MCP utility tools (resources & prompts).
|
||||
|
||||
@@ -1874,16 +1822,16 @@ def _existing_tool_names() -> List[str]:
|
||||
def _register_server_tools(name: str, server: MCPServerTask, config: dict) -> List[str]:
|
||||
"""Register tools from an already-connected server into the registry.
|
||||
|
||||
Handles include/exclude filtering, utility tools, toolset creation,
|
||||
and hermes-* umbrella toolset injection.
|
||||
Handles include/exclude filtering and utility tools. Toolset resolution
|
||||
for ``mcp-{server}`` and raw server-name aliases is derived from the live
|
||||
registry, rather than mutating ``toolsets.TOOLSETS`` at runtime.
|
||||
|
||||
Used by both initial discovery and dynamic refresh (list_changed).
|
||||
|
||||
Returns:
|
||||
List of registered prefixed tool names.
|
||||
"""
|
||||
from tools.registry import registry, tool_error
|
||||
from toolsets import create_custom_toolset, TOOLSETS
|
||||
from tools.registry import registry
|
||||
|
||||
registered_names: List[str] = []
|
||||
toolset_name = f"mcp-{name}"
|
||||
@@ -1973,19 +1921,8 @@ def _register_server_tools(name: str, server: MCPServerTask, config: dict) -> Li
|
||||
)
|
||||
registered_names.append(util_name)
|
||||
|
||||
# Create a custom toolset so these tools are discoverable
|
||||
if registered_names:
|
||||
create_custom_toolset(
|
||||
name=toolset_name,
|
||||
description=f"MCP tools from {name} server",
|
||||
tools=registered_names,
|
||||
)
|
||||
# Inject into hermes-* umbrella toolsets for default behavior
|
||||
for ts_name, ts in TOOLSETS.items():
|
||||
if ts_name.startswith("hermes-"):
|
||||
for tool_name in registered_names:
|
||||
if tool_name not in ts["tools"]:
|
||||
ts["tools"].append(tool_name)
|
||||
registry.register_toolset_alias(name, toolset_name)
|
||||
|
||||
return registered_names
|
||||
|
||||
@@ -2049,7 +1986,6 @@ def register_mcp_servers(servers: Dict[str, dict]) -> List[str]:
|
||||
}
|
||||
|
||||
if not new_servers:
|
||||
_sync_mcp_toolsets(list(servers.keys()))
|
||||
return _existing_tool_names()
|
||||
|
||||
# Start the background event loop for MCP connections
|
||||
@@ -2080,8 +2016,6 @@ def register_mcp_servers(servers: Dict[str, dict]) -> List[str]:
|
||||
# The outer timeout is generous: 120s total for parallel discovery.
|
||||
_run_on_mcp_loop(_discover_all(), timeout=120)
|
||||
|
||||
_sync_mcp_toolsets(list(servers.keys()))
|
||||
|
||||
# Log a summary so ACP callers get visibility into what was registered.
|
||||
with _lock:
|
||||
connected = [n for n in new_servers if n in _servers]
|
||||
@@ -2102,7 +2036,7 @@ def register_mcp_servers(servers: Dict[str, dict]) -> List[str]:
|
||||
def discover_mcp_tools() -> List[str]:
|
||||
"""Entry point: load config, connect to MCP servers, register tools.
|
||||
|
||||
Called from ``model_tools._discover_tools()``. Safe to call even when
|
||||
Called from ``model_tools`` after ``discover_builtin_tools()``. Safe to call even when
|
||||
the ``mcp`` package is not installed (returns empty list).
|
||||
|
||||
Idempotent for already-connected servers. If some servers failed on a
|
||||
|
||||
@@ -191,9 +191,15 @@ class ProcessRegistry:
|
||||
session._watch_disabled = True
|
||||
self.completion_queue.put({
|
||||
"session_id": session.id,
|
||||
"session_key": session.session_key,
|
||||
"command": session.command,
|
||||
"type": "watch_disabled",
|
||||
"suppressed": session._watch_suppressed,
|
||||
"platform": session.watcher_platform,
|
||||
"chat_id": session.watcher_chat_id,
|
||||
"user_id": session.watcher_user_id,
|
||||
"user_name": session.watcher_user_name,
|
||||
"thread_id": session.watcher_thread_id,
|
||||
"message": (
|
||||
f"Watch patterns disabled for process {session.id} — "
|
||||
f"too many matches ({session._watch_suppressed} suppressed). "
|
||||
@@ -219,11 +225,17 @@ class ProcessRegistry:
|
||||
|
||||
self.completion_queue.put({
|
||||
"session_id": session.id,
|
||||
"session_key": session.session_key,
|
||||
"command": session.command,
|
||||
"type": "watch_match",
|
||||
"pattern": matched_pattern,
|
||||
"output": output,
|
||||
"suppressed": suppressed,
|
||||
"platform": session.watcher_platform,
|
||||
"chat_id": session.watcher_chat_id,
|
||||
"user_id": session.watcher_user_id,
|
||||
"user_name": session.watcher_user_name,
|
||||
"thread_id": session.watcher_thread_id,
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
|
||||
+83
-3
@@ -14,14 +14,65 @@ Import chain (circular-import safe):
|
||||
run_agent.py, cli.py, batch_runner.py, etc.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Set
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _is_registry_register_call(node: ast.AST) -> bool:
|
||||
"""Return True when *node* is a ``registry.register(...)`` call expression."""
|
||||
if not isinstance(node, ast.Expr) or not isinstance(node.value, ast.Call):
|
||||
return False
|
||||
func = node.value.func
|
||||
return (
|
||||
isinstance(func, ast.Attribute)
|
||||
and func.attr == "register"
|
||||
and isinstance(func.value, ast.Name)
|
||||
and func.value.id == "registry"
|
||||
)
|
||||
|
||||
|
||||
def _module_registers_tools(module_path: Path) -> bool:
|
||||
"""Return True when the module contains a top-level ``registry.register(...)`` call.
|
||||
|
||||
Only inspects module-body statements so that helper modules which happen
|
||||
to call ``registry.register()`` inside a function are not picked up.
|
||||
"""
|
||||
try:
|
||||
source = module_path.read_text(encoding="utf-8")
|
||||
tree = ast.parse(source, filename=str(module_path))
|
||||
except (OSError, SyntaxError):
|
||||
return False
|
||||
|
||||
return any(_is_registry_register_call(stmt) for stmt in tree.body)
|
||||
|
||||
|
||||
def discover_builtin_tools(tools_dir: Optional[Path] = None) -> List[str]:
|
||||
"""Import built-in self-registering tool modules and return their module names."""
|
||||
tools_path = Path(tools_dir) if tools_dir is not None else Path(__file__).resolve().parent
|
||||
module_names = [
|
||||
f"tools.{path.stem}"
|
||||
for path in sorted(tools_path.glob("*.py"))
|
||||
if path.name not in {"__init__.py", "registry.py", "mcp_tool.py"}
|
||||
and _module_registers_tools(path)
|
||||
]
|
||||
|
||||
imported: List[str] = []
|
||||
for mod_name in module_names:
|
||||
try:
|
||||
importlib.import_module(mod_name)
|
||||
imported.append(mod_name)
|
||||
except Exception as e:
|
||||
logger.warning("Could not import tool module %s: %s", mod_name, e)
|
||||
return imported
|
||||
|
||||
|
||||
class ToolEntry:
|
||||
"""Metadata for a single registered tool."""
|
||||
|
||||
@@ -52,6 +103,7 @@ class ToolRegistry:
|
||||
def __init__(self):
|
||||
self._tools: Dict[str, ToolEntry] = {}
|
||||
self._toolset_checks: Dict[str, Callable] = {}
|
||||
self._toolset_aliases: Dict[str, str] = {}
|
||||
# MCP dynamic refresh can mutate the registry while other threads are
|
||||
# reading tool metadata, so keep mutations serialized and readers on
|
||||
# stable snapshots.
|
||||
@@ -96,6 +148,27 @@ class ToolRegistry:
|
||||
if entry.toolset == toolset
|
||||
)
|
||||
|
||||
def register_toolset_alias(self, alias: str, toolset: str) -> None:
|
||||
"""Register an explicit alias for a canonical toolset name."""
|
||||
with self._lock:
|
||||
existing = self._toolset_aliases.get(alias)
|
||||
if existing and existing != toolset:
|
||||
logger.warning(
|
||||
"Toolset alias collision: '%s' (%s) overwritten by %s",
|
||||
alias, existing, toolset,
|
||||
)
|
||||
self._toolset_aliases[alias] = toolset
|
||||
|
||||
def get_registered_toolset_aliases(self) -> Dict[str, str]:
|
||||
"""Return a snapshot of ``{alias: canonical_toolset}`` mappings."""
|
||||
with self._lock:
|
||||
return dict(self._toolset_aliases)
|
||||
|
||||
def get_toolset_alias_target(self, alias: str) -> Optional[str]:
|
||||
"""Return the canonical toolset name for an alias, or None."""
|
||||
with self._lock:
|
||||
return self._toolset_aliases.get(alias)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Registration
|
||||
# ------------------------------------------------------------------
|
||||
@@ -164,11 +237,18 @@ class ToolRegistry:
|
||||
entry = self._tools.pop(name, None)
|
||||
if entry is None:
|
||||
return
|
||||
# Drop the toolset check if this was the last tool in that toolset
|
||||
if entry.toolset in self._toolset_checks and not any(
|
||||
# Drop the toolset check and aliases if this was the last tool in
|
||||
# that toolset.
|
||||
toolset_still_exists = any(
|
||||
e.toolset == entry.toolset for e in self._tools.values()
|
||||
):
|
||||
)
|
||||
if not toolset_still_exists:
|
||||
self._toolset_checks.pop(entry.toolset, None)
|
||||
self._toolset_aliases = {
|
||||
alias: target
|
||||
for alias, target in self._toolset_aliases.items()
|
||||
if target != entry.toolset
|
||||
}
|
||||
logger.debug("Deregistered tool: %s", name)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
+81
-16
@@ -68,7 +68,7 @@ SEND_MESSAGE_SCHEMA = {
|
||||
},
|
||||
"target": {
|
||||
"type": "string",
|
||||
"description": "Delivery target. Format: 'platform' (uses home channel), 'platform:#channel-name', 'platform:chat_id', or 'platform:chat_id:thread_id' for Telegram topics and Discord threads. Examples: 'telegram', 'telegram:-1001234567890:17585', 'discord:999888777:555444333', 'discord:#bot-home', 'slack:#engineering', 'signal:+155****4567'"
|
||||
"description": "Delivery target. Format: 'platform' (uses home channel), 'platform:#channel-name', 'platform:chat_id', or 'platform:chat_id:thread_id' for Telegram topics and Discord threads. Examples: 'telegram', 'telegram:-1001234567890:17585', 'discord:999888777:555444333', 'discord:#bot-home', 'slack:#engineering', 'signal:+155****4567', 'matrix:!roomid:server.org', 'matrix:@user:server.org'"
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
@@ -248,6 +248,9 @@ def _parse_target_ref(platform_name: str, target_ref: str):
|
||||
return match.group(1), None, True
|
||||
if target_ref.lstrip("-").isdigit():
|
||||
return target_ref, None, True
|
||||
# Matrix room IDs (start with !) and user IDs (start with @) are explicit
|
||||
if platform_name == "matrix" and (target_ref.startswith("!") or target_ref.startswith("@")):
|
||||
return target_ref, None, True
|
||||
return None, None, False
|
||||
|
||||
|
||||
@@ -384,11 +387,28 @@ async def _send_to_platform(platform, pconfig, chat_id, message, thread_id=None,
|
||||
if platform == Platform.WEIXIN:
|
||||
return await _send_weixin(pconfig, chat_id, message, media_files=media_files)
|
||||
|
||||
# --- Non-Telegram platforms ---
|
||||
# --- Discord: special handling for media attachments ---
|
||||
if platform == Platform.DISCORD:
|
||||
last_result = None
|
||||
for i, chunk in enumerate(chunks):
|
||||
is_last = (i == len(chunks) - 1)
|
||||
result = await _send_discord(
|
||||
pconfig.token,
|
||||
chat_id,
|
||||
chunk,
|
||||
media_files=media_files if is_last else [],
|
||||
thread_id=thread_id,
|
||||
)
|
||||
if isinstance(result, dict) and result.get("error"):
|
||||
return result
|
||||
last_result = result
|
||||
return last_result
|
||||
|
||||
# --- Non-Telegram/Discord platforms ---
|
||||
if media_files and not message.strip():
|
||||
return {
|
||||
"error": (
|
||||
f"send_message MEDIA delivery is currently only supported for telegram; "
|
||||
f"send_message MEDIA delivery is currently only supported for telegram, discord, and weixin; "
|
||||
f"target {platform.value} had only media attachments"
|
||||
)
|
||||
}
|
||||
@@ -396,14 +416,12 @@ async def _send_to_platform(platform, pconfig, chat_id, message, thread_id=None,
|
||||
if media_files:
|
||||
warning = (
|
||||
f"MEDIA attachments were omitted for {platform.value}; "
|
||||
"native send_message media delivery is currently only supported for telegram"
|
||||
"native send_message media delivery is currently only supported for telegram, discord, and weixin"
|
||||
)
|
||||
|
||||
last_result = None
|
||||
for chunk in chunks:
|
||||
if platform == Platform.DISCORD:
|
||||
result = await _send_discord(pconfig.token, chat_id, chunk, thread_id=thread_id)
|
||||
elif platform == Platform.SLACK:
|
||||
if platform == Platform.SLACK:
|
||||
result = await _send_slack(pconfig.token, chat_id, chunk)
|
||||
elif platform == Platform.WHATSAPP:
|
||||
result = await _send_whatsapp(pconfig.extra, chat_id, chunk)
|
||||
@@ -568,13 +586,16 @@ async def _send_telegram(token, chat_id, message, media_files=None, thread_id=No
|
||||
return _error(f"Telegram send failed: {e}")
|
||||
|
||||
|
||||
async def _send_discord(token, chat_id, message, thread_id=None):
|
||||
async def _send_discord(token, chat_id, message, thread_id=None, media_files=None):
|
||||
"""Send a single message via Discord REST API (no websocket client needed).
|
||||
|
||||
Chunking is handled by _send_to_platform() before this is called.
|
||||
|
||||
When thread_id is provided, the message is sent directly to that thread
|
||||
via the /channels/{thread_id}/messages endpoint.
|
||||
|
||||
Media files are uploaded one-by-one via multipart/form-data after the
|
||||
text message is sent (same pattern as Telegram).
|
||||
"""
|
||||
try:
|
||||
import aiohttp
|
||||
@@ -589,14 +610,56 @@ async def _send_discord(token, chat_id, message, thread_id=None):
|
||||
url = f"https://discord.com/api/v10/channels/{thread_id}/messages"
|
||||
else:
|
||||
url = f"https://discord.com/api/v10/channels/{chat_id}/messages"
|
||||
headers = {"Authorization": f"Bot {token}", "Content-Type": "application/json"}
|
||||
auth_headers = {"Authorization": f"Bot {token}"}
|
||||
media_files = media_files or []
|
||||
last_data = None
|
||||
warnings = []
|
||||
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30), **_sess_kw) as session:
|
||||
async with session.post(url, headers=headers, json={"content": message}, **_req_kw) as resp:
|
||||
if resp.status not in (200, 201):
|
||||
body = await resp.text()
|
||||
return _error(f"Discord API error ({resp.status}): {body}")
|
||||
data = await resp.json()
|
||||
return {"success": True, "platform": "discord", "chat_id": chat_id, "message_id": data.get("id")}
|
||||
# Send text message (skip if empty and media is present)
|
||||
if message.strip() or not media_files:
|
||||
headers = {**auth_headers, "Content-Type": "application/json"}
|
||||
async with session.post(url, headers=headers, json={"content": message}, **_req_kw) as resp:
|
||||
if resp.status not in (200, 201):
|
||||
body = await resp.text()
|
||||
return _error(f"Discord API error ({resp.status}): {body}")
|
||||
last_data = await resp.json()
|
||||
|
||||
# Send each media file as a separate multipart upload
|
||||
for media_path, _is_voice in media_files:
|
||||
if not os.path.exists(media_path):
|
||||
warning = f"Media file not found, skipping: {media_path}"
|
||||
logger.warning(warning)
|
||||
warnings.append(warning)
|
||||
continue
|
||||
try:
|
||||
form = aiohttp.FormData()
|
||||
filename = os.path.basename(media_path)
|
||||
with open(media_path, "rb") as f:
|
||||
form.add_field("files[0]", f, filename=filename)
|
||||
async with session.post(url, headers=auth_headers, data=form, **_req_kw) as resp:
|
||||
if resp.status not in (200, 201):
|
||||
body = await resp.text()
|
||||
warning = _sanitize_error_text(f"Failed to send media {media_path}: Discord API error ({resp.status}): {body}")
|
||||
logger.error(warning)
|
||||
warnings.append(warning)
|
||||
continue
|
||||
last_data = await resp.json()
|
||||
except Exception as e:
|
||||
warning = _sanitize_error_text(f"Failed to send media {media_path}: {e}")
|
||||
logger.error(warning)
|
||||
warnings.append(warning)
|
||||
|
||||
if last_data is None:
|
||||
error = "No deliverable text or media remained after processing"
|
||||
if warnings:
|
||||
return {"error": error, "warnings": warnings}
|
||||
return {"error": error}
|
||||
|
||||
result = {"success": True, "platform": "discord", "chat_id": chat_id, "message_id": last_data.get("id")}
|
||||
if warnings:
|
||||
result["warnings"] = warnings
|
||||
return result
|
||||
except Exception as e:
|
||||
return _error(f"Discord send failed: {e}")
|
||||
|
||||
@@ -816,7 +879,9 @@ async def _send_matrix(token, extra, chat_id, message):
|
||||
if not homeserver or not token:
|
||||
return {"error": "Matrix not configured (MATRIX_HOMESERVER, MATRIX_ACCESS_TOKEN required)"}
|
||||
txn_id = f"hermes_{int(time.time() * 1000)}_{os.urandom(4).hex()}"
|
||||
url = f"{homeserver}/_matrix/client/v3/rooms/{chat_id}/send/m.room.message/{txn_id}"
|
||||
from urllib.parse import quote
|
||||
encoded_room = quote(chat_id, safe="")
|
||||
url = f"{homeserver}/_matrix/client/v3/rooms/{encoded_room}/send/m.room.message/{txn_id}"
|
||||
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
|
||||
|
||||
# Build message payload with optional HTML formatted_body.
|
||||
|
||||
@@ -39,7 +39,7 @@ import re
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from hermes_constants import get_hermes_home
|
||||
from hermes_constants import get_hermes_home, display_hermes_home
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -655,7 +655,7 @@ SKILL_MANAGE_SCHEMA = {
|
||||
"description": (
|
||||
"Manage skills (create, update, delete). Skills are your procedural "
|
||||
"memory — reusable approaches for recurring task types. "
|
||||
"New skills go to ~/.hermes/skills/; existing skills can be modified wherever they live.\n\n"
|
||||
f"New skills go to {display_hermes_home()}/skills/; existing skills can be modified wherever they live.\n\n"
|
||||
"Actions: create (full SKILL.md + optional category), "
|
||||
"patch (old_string/new_string — preferred for fixes), "
|
||||
"edit (full SKILL.md rewrite — major overhauls only), "
|
||||
|
||||
@@ -69,7 +69,7 @@ Usage:
|
||||
import json
|
||||
import logging
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
from hermes_constants import get_hermes_home, display_hermes_home
|
||||
import os
|
||||
import re
|
||||
from enum import Enum
|
||||
@@ -408,7 +408,7 @@ def _gateway_setup_hint() -> str:
|
||||
|
||||
return GATEWAY_SECRET_CAPTURE_UNSUPPORTED_MESSAGE
|
||||
except Exception:
|
||||
return "Secure secret entry is not available. Load this skill in the local CLI to be prompted, or add the key to ~/.hermes/.env manually."
|
||||
return f"Secure secret entry is not available. Load this skill in the local CLI to be prompted, or add the key to {display_hermes_home()}/.env manually."
|
||||
|
||||
|
||||
def _build_setup_note(
|
||||
@@ -666,7 +666,7 @@ def skills_list(category: str = None, task_id: str = None) -> str:
|
||||
"success": True,
|
||||
"skills": [],
|
||||
"categories": [],
|
||||
"message": "No skills found. Skills directory created at ~/.hermes/skills/",
|
||||
"message": f"No skills found. Skills directory created at {display_hermes_home()}/skills/",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
+19
-13
@@ -1384,14 +1384,10 @@ def terminal_tool(
|
||||
if pty_disabled_reason:
|
||||
result_data["pty_note"] = pty_disabled_reason
|
||||
|
||||
# Mark for agent notification on completion
|
||||
if notify_on_complete and background:
|
||||
proc_session.notify_on_complete = True
|
||||
result_data["notify_on_complete"] = True
|
||||
|
||||
# In gateway mode, auto-register a fast watcher so the
|
||||
# gateway can detect completion and trigger a new agent
|
||||
# turn. CLI mode uses the completion_queue directly.
|
||||
# Populate routing metadata on the session so that
|
||||
# watch-pattern and completion notifications can be
|
||||
# routed back to the correct chat/thread.
|
||||
if background and (notify_on_complete or watch_patterns):
|
||||
from gateway.session_context import get_session_env as _gse
|
||||
_gw_platform = _gse("HERMES_SESSION_PLATFORM", "")
|
||||
if _gw_platform:
|
||||
@@ -1404,16 +1400,26 @@ def terminal_tool(
|
||||
proc_session.watcher_user_id = _gw_user_id
|
||||
proc_session.watcher_user_name = _gw_user_name
|
||||
proc_session.watcher_thread_id = _gw_thread_id
|
||||
|
||||
# Mark for agent notification on completion
|
||||
if notify_on_complete and background:
|
||||
proc_session.notify_on_complete = True
|
||||
result_data["notify_on_complete"] = True
|
||||
|
||||
# In gateway mode, auto-register a fast watcher so the
|
||||
# gateway can detect completion and trigger a new agent
|
||||
# turn. CLI mode uses the completion_queue directly.
|
||||
if proc_session.watcher_platform:
|
||||
proc_session.watcher_interval = 5
|
||||
process_registry.pending_watchers.append({
|
||||
"session_id": proc_session.id,
|
||||
"check_interval": 5,
|
||||
"session_key": session_key,
|
||||
"platform": _gw_platform,
|
||||
"chat_id": _gw_chat_id,
|
||||
"user_id": _gw_user_id,
|
||||
"user_name": _gw_user_name,
|
||||
"thread_id": _gw_thread_id,
|
||||
"platform": proc_session.watcher_platform,
|
||||
"chat_id": proc_session.watcher_chat_id,
|
||||
"user_id": proc_session.watcher_user_id,
|
||||
"user_name": proc_session.watcher_user_name,
|
||||
"thread_id": proc_session.watcher_thread_id,
|
||||
"notify_on_complete": True,
|
||||
})
|
||||
|
||||
|
||||
+3
-1
@@ -40,6 +40,8 @@ from pathlib import Path
|
||||
from typing import Callable, Dict, Any, Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from hermes_constants import display_hermes_home
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from tools.managed_tool_gateway import resolve_managed_tool_gateway
|
||||
from tools.tool_backend_helpers import managed_nous_tools_enabled, resolve_openai_audio_api_key
|
||||
@@ -1050,7 +1052,7 @@ TTS_SCHEMA = {
|
||||
},
|
||||
"output_path": {
|
||||
"type": "string",
|
||||
"description": "Optional custom file path to save the audio. Defaults to ~/.hermes/audio_cache/<timestamp>.mp3"
|
||||
"description": f"Optional custom file path to save the audio. Defaults to {display_hermes_home()}/audio_cache/<timestamp>.mp3"
|
||||
}
|
||||
},
|
||||
"required": ["text"]
|
||||
|
||||
+69
-28
@@ -409,8 +409,39 @@ def get_toolset(name: str) -> Optional[Dict[str, Any]]:
|
||||
Dict: Toolset definition with description, tools, and includes
|
||||
None: If toolset not found
|
||||
"""
|
||||
# Return toolset definition
|
||||
return TOOLSETS.get(name)
|
||||
toolset = TOOLSETS.get(name)
|
||||
if toolset:
|
||||
return toolset
|
||||
|
||||
try:
|
||||
from tools.registry import registry
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
registry_toolset = name
|
||||
description = f"Plugin toolset: {name}"
|
||||
alias_target = registry.get_toolset_alias_target(name)
|
||||
|
||||
if name not in _get_plugin_toolset_names():
|
||||
registry_toolset = alias_target
|
||||
if not registry_toolset:
|
||||
return None
|
||||
description = f"MCP server '{name}' tools"
|
||||
else:
|
||||
reverse_aliases = {
|
||||
canonical: alias
|
||||
for alias, canonical in _get_registry_toolset_aliases().items()
|
||||
if alias not in TOOLSETS
|
||||
}
|
||||
alias = reverse_aliases.get(name)
|
||||
if alias:
|
||||
description = f"MCP server '{alias}' tools"
|
||||
|
||||
return {
|
||||
"description": description,
|
||||
"tools": registry.get_tool_names_for_toolset(registry_toolset),
|
||||
"includes": [],
|
||||
}
|
||||
|
||||
|
||||
def resolve_toolset(name: str, visited: Set[str] = None) -> List[str]:
|
||||
@@ -438,7 +469,7 @@ def resolve_toolset(name: str, visited: Set[str] = None) -> List[str]:
|
||||
# Use a fresh visited set per branch to avoid cross-branch contamination
|
||||
resolved = resolve_toolset(toolset_name, visited.copy())
|
||||
all_tools.update(resolved)
|
||||
return list(all_tools)
|
||||
return sorted(all_tools)
|
||||
|
||||
# Check for cycles / already-resolved (diamond deps).
|
||||
# Silently return [] — either this is a diamond (not a bug, tools already
|
||||
@@ -449,15 +480,8 @@ def resolve_toolset(name: str, visited: Set[str] = None) -> List[str]:
|
||||
visited.add(name)
|
||||
|
||||
# Get toolset definition
|
||||
toolset = TOOLSETS.get(name)
|
||||
toolset = get_toolset(name)
|
||||
if not toolset:
|
||||
# Fall back to tool registry for plugin-provided toolsets
|
||||
if name in _get_plugin_toolset_names():
|
||||
try:
|
||||
from tools.registry import registry
|
||||
return registry.get_tool_names_for_toolset(name)
|
||||
except Exception:
|
||||
pass
|
||||
return []
|
||||
|
||||
# Collect direct tools
|
||||
@@ -470,7 +494,7 @@ def resolve_toolset(name: str, visited: Set[str] = None) -> List[str]:
|
||||
included_tools = resolve_toolset(included_name, visited)
|
||||
tools.update(included_tools)
|
||||
|
||||
return list(tools)
|
||||
return sorted(tools)
|
||||
|
||||
|
||||
def resolve_multiple_toolsets(toolset_names: List[str]) -> List[str]:
|
||||
@@ -489,7 +513,7 @@ def resolve_multiple_toolsets(toolset_names: List[str]) -> List[str]:
|
||||
tools = resolve_toolset(name)
|
||||
all_tools.update(tools)
|
||||
|
||||
return list(all_tools)
|
||||
return sorted(all_tools)
|
||||
|
||||
|
||||
def _get_plugin_toolset_names() -> Set[str]:
|
||||
@@ -509,6 +533,15 @@ def _get_plugin_toolset_names() -> Set[str]:
|
||||
return set()
|
||||
|
||||
|
||||
def _get_registry_toolset_aliases() -> Dict[str, str]:
|
||||
"""Return explicit toolset aliases registered in the live registry."""
|
||||
try:
|
||||
from tools.registry import registry
|
||||
return registry.get_registered_toolset_aliases()
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def get_all_toolsets() -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Get all available toolsets with their definitions.
|
||||
@@ -518,19 +551,19 @@ def get_all_toolsets() -> Dict[str, Dict[str, Any]]:
|
||||
Returns:
|
||||
Dict: All toolset definitions
|
||||
"""
|
||||
result = TOOLSETS.copy()
|
||||
# Add plugin-provided toolsets (synthetic entries)
|
||||
result = dict(TOOLSETS)
|
||||
aliases = _get_registry_toolset_aliases()
|
||||
for ts_name in _get_plugin_toolset_names():
|
||||
if ts_name not in result:
|
||||
try:
|
||||
from tools.registry import registry
|
||||
tools = registry.get_tool_names_for_toolset(ts_name)
|
||||
result[ts_name] = {
|
||||
"description": f"Plugin toolset: {ts_name}",
|
||||
"tools": tools,
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
display_name = ts_name
|
||||
for alias, canonical in aliases.items():
|
||||
if canonical == ts_name and alias not in TOOLSETS:
|
||||
display_name = alias
|
||||
break
|
||||
if display_name in result:
|
||||
continue
|
||||
toolset = get_toolset(display_name)
|
||||
if toolset:
|
||||
result[display_name] = toolset
|
||||
return result
|
||||
|
||||
|
||||
@@ -544,7 +577,14 @@ def get_toolset_names() -> List[str]:
|
||||
List[str]: List of toolset names
|
||||
"""
|
||||
names = set(TOOLSETS.keys())
|
||||
names |= _get_plugin_toolset_names()
|
||||
aliases = _get_registry_toolset_aliases()
|
||||
for ts_name in _get_plugin_toolset_names():
|
||||
for alias, canonical in aliases.items():
|
||||
if canonical == ts_name and alias not in TOOLSETS:
|
||||
names.add(alias)
|
||||
break
|
||||
else:
|
||||
names.add(ts_name)
|
||||
return sorted(names)
|
||||
|
||||
|
||||
@@ -565,8 +605,9 @@ def validate_toolset(name: str) -> bool:
|
||||
return True
|
||||
if name in TOOLSETS:
|
||||
return True
|
||||
# Check tool registry for plugin-provided toolsets
|
||||
return name in _get_plugin_toolset_names()
|
||||
if name in _get_plugin_toolset_names():
|
||||
return True
|
||||
return name in _get_registry_toolset_aliases()
|
||||
|
||||
|
||||
def create_custom_toolset(
|
||||
|
||||
@@ -14,11 +14,12 @@ Make it a **Tool** when it requires end-to-end integration with API keys, custom
|
||||
|
||||
## Overview
|
||||
|
||||
Adding a tool touches **3 files**:
|
||||
Adding a tool touches **2 files**:
|
||||
|
||||
1. **`tools/your_tool.py`** — handler, schema, check function, `registry.register()` call
|
||||
2. **`toolsets.py`** — add tool name to `_HERMES_CORE_TOOLS` (or a specific toolset)
|
||||
3. **`model_tools.py`** — add `"tools.your_tool"` to the `_discover_tools()` list
|
||||
|
||||
Any `tools/*.py` file with a top-level `registry.register()` call is auto-discovered at startup — no manual import list required.
|
||||
|
||||
## Step 1: Create the Tool File
|
||||
|
||||
@@ -124,19 +125,9 @@ _HERMES_CORE_TOOLS = [
|
||||
},
|
||||
```
|
||||
|
||||
## Step 3: Add Discovery Import
|
||||
## ~~Step 3: Add Discovery Import~~ (No longer needed)
|
||||
|
||||
In `model_tools.py`, add the module to the `_discover_tools()` list:
|
||||
|
||||
```python
|
||||
def _discover_tools():
|
||||
_modules = [
|
||||
...
|
||||
"tools.weather_tool", # <-- add here
|
||||
]
|
||||
```
|
||||
|
||||
This import triggers the `registry.register()` call at the bottom of your tool file.
|
||||
Tool modules with a top-level `registry.register()` call are auto-discovered by `discover_builtin_tools()` in `tools/registry.py`. No manual import list to maintain — just create your file in `tools/` and it's picked up at startup.
|
||||
|
||||
## Async Handlers
|
||||
|
||||
|
||||
@@ -275,4 +275,4 @@ model_tools.py (imports tools/registry + triggers tool discovery)
|
||||
run_agent.py, cli.py, batch_runner.py, environments/
|
||||
```
|
||||
|
||||
This chain means tool registration happens at import time, before any agent instance is created. Adding a new tool requires an import in `model_tools.py`'s `_discover_tools()` list.
|
||||
This chain means tool registration happens at import time, before any agent instance is created. Any `tools/*.py` file with a top-level `registry.register()` call is auto-discovered — no manual import list needed.
|
||||
|
||||
@@ -42,37 +42,23 @@ registry.register(
|
||||
|
||||
Each call creates a `ToolEntry` stored in the singleton `ToolRegistry._tools` dict keyed by tool name. If a name collision occurs across toolsets, a warning is logged and the later registration wins.
|
||||
|
||||
### Discovery: `_discover_tools()`
|
||||
### Discovery: `discover_builtin_tools()`
|
||||
|
||||
When `model_tools.py` is imported, it calls `_discover_tools()` which imports every tool module in order:
|
||||
When `model_tools.py` is imported, it calls `discover_builtin_tools()` from `tools/registry.py`. This function scans every `tools/*.py` file using AST parsing to find modules that contain top-level `registry.register()` calls, then imports them:
|
||||
|
||||
```python
|
||||
_modules = [
|
||||
"tools.web_tools",
|
||||
"tools.terminal_tool",
|
||||
"tools.file_tools",
|
||||
"tools.vision_tools",
|
||||
"tools.mixture_of_agents_tool",
|
||||
"tools.image_generation_tool",
|
||||
"tools.skills_tool",
|
||||
"tools.skill_manager_tool",
|
||||
"tools.browser_tool",
|
||||
"tools.cronjob_tools",
|
||||
"tools.rl_training_tool",
|
||||
"tools.tts_tool",
|
||||
"tools.todo_tool",
|
||||
"tools.memory_tool",
|
||||
"tools.session_search_tool",
|
||||
"tools.clarify_tool",
|
||||
"tools.code_execution_tool",
|
||||
"tools.delegate_tool",
|
||||
"tools.process_registry",
|
||||
"tools.send_message_tool",
|
||||
# "tools.honcho_tools", # Removed — Honcho is now a memory provider plugin
|
||||
"tools.homeassistant_tool",
|
||||
]
|
||||
# tools/registry.py (simplified)
|
||||
def discover_builtin_tools(tools_dir=None):
|
||||
tools_path = Path(tools_dir) if tools_dir else Path(__file__).parent
|
||||
for path in sorted(tools_path.glob("*.py")):
|
||||
if path.name in {"__init__.py", "registry.py", "mcp_tool.py"}:
|
||||
continue
|
||||
if _module_registers_tools(path): # AST check for top-level registry.register()
|
||||
importlib.import_module(f"tools.{path.stem}")
|
||||
```
|
||||
|
||||
This auto-discovery means new tool files are picked up automatically — no manual list to maintain. The AST check only matches top-level `registry.register()` calls (not calls inside functions), so helper modules in `tools/` are not imported.
|
||||
|
||||
Each import triggers the module's `registry.register()` calls. Errors in optional tools (e.g., missing `fal_client` for image generation) are caught and logged — they don't prevent other tools from loading.
|
||||
|
||||
After core tool discovery, MCP tools and plugin tools are also discovered:
|
||||
|
||||
@@ -152,12 +152,15 @@ hermes setup
|
||||
|
||||
### Install optional Node dependencies manually
|
||||
|
||||
The tested Termux path skips Node/browser bootstrap on purpose. If you want to experiment later:
|
||||
The tested Termux path skips Node/browser bootstrap on purpose. If you want to experiment with browser tooling later:
|
||||
|
||||
```bash
|
||||
pkg install nodejs-lts
|
||||
npm install
|
||||
```
|
||||
|
||||
The browser tool automatically includes Termux directories (`/data/data/com.termux/files/usr/bin`) in its PATH search, so `agent-browser` and `npx` are discovered without any extra PATH configuration.
|
||||
|
||||
Treat browser / WhatsApp tooling on Android as experimental until documented otherwise.
|
||||
|
||||
---
|
||||
|
||||
@@ -49,6 +49,17 @@ The OpenAI Codex provider authenticates via device code (open a URL, enter a cod
|
||||
Even when using Nous Portal, Codex, or a custom endpoint, some tools (vision, web summarization, MoA) use a separate "auxiliary" model — by default Gemini Flash via OpenRouter. An `OPENROUTER_API_KEY` enables these tools automatically. You can also configure which model and provider these tools use — see [Auxiliary Models](/docs/user-guide/configuration#auxiliary-models).
|
||||
:::
|
||||
|
||||
### Two Commands for Model Management
|
||||
|
||||
Hermes has **two** model commands that serve different purposes:
|
||||
|
||||
| Command | Where to run | What it does |
|
||||
|---------|-------------|--------------|
|
||||
| **`hermes model`** | Your terminal (outside any session) | Full setup wizard — add providers, run OAuth, enter API keys, configure endpoints |
|
||||
| **`/model`** | Inside a Hermes chat session | Quick switch between **already-configured** providers and models |
|
||||
|
||||
If you're trying to switch to a provider you haven't set up yet (e.g. you only have OpenRouter configured and want to use Anthropic), you need `hermes model`, not `/model`. Exit your session first (`Ctrl+C` or `/quit`), run `hermes model`, complete the provider setup, then start a new session.
|
||||
|
||||
### Anthropic (Native)
|
||||
|
||||
Use Claude models directly through the Anthropic API — no OpenRouter proxy needed. Supports three auth methods:
|
||||
@@ -252,7 +263,15 @@ Both approaches persist to `config.yaml`, which is the source of truth for model
|
||||
|
||||
### Switching Models with `/model`
|
||||
|
||||
Once a custom endpoint is configured, you can switch models mid-session:
|
||||
:::warning hermes model vs /model
|
||||
**`hermes model`** (run from your terminal, outside any chat session) is the **full provider setup wizard**. Use it to add new providers, run OAuth flows, enter API keys, and configure custom endpoints.
|
||||
|
||||
**`/model`** (typed inside an active Hermes chat session) can only **switch between providers and models you've already set up**. It cannot add new providers, run OAuth, or prompt for API keys. If you've only configured one provider (e.g. OpenRouter), `/model` will only show models for that provider.
|
||||
|
||||
**To add a new provider:** Exit your session (`Ctrl+C` or `/quit`), run `hermes model`, set up the new provider, then start a new session.
|
||||
:::
|
||||
|
||||
Once you have at least one custom endpoint configured, you can switch models mid-session:
|
||||
|
||||
```
|
||||
/model custom:qwen-2.5 # Switch to a model on your custom endpoint
|
||||
|
||||
@@ -109,22 +109,31 @@ hermes chat --worktree -q "Review this repo and open a PR"
|
||||
|
||||
## `hermes model`
|
||||
|
||||
Interactive provider + model selector.
|
||||
Interactive provider + model selector. **This is the command for adding new providers, setting up API keys, and running OAuth flows.** Run it from your terminal — not from inside an active Hermes chat session.
|
||||
|
||||
```bash
|
||||
hermes model
|
||||
```
|
||||
|
||||
Use this when you want to:
|
||||
- switch default providers
|
||||
- log into OAuth-backed providers during model selection
|
||||
- **add a new provider** (OpenRouter, Anthropic, Copilot, DeepSeek, custom, etc.)
|
||||
- log into OAuth-backed providers (Anthropic, Copilot, Codex, Nous Portal)
|
||||
- enter or update API keys
|
||||
- pick from provider-specific model lists
|
||||
- configure a custom/self-hosted endpoint
|
||||
- save the new default into config
|
||||
|
||||
:::warning hermes model vs /model — know the difference
|
||||
**`hermes model`** (run from your terminal, outside any Hermes session) is the **full provider setup wizard**. It can add new providers, run OAuth flows, prompt for API keys, and configure endpoints.
|
||||
|
||||
**`/model`** (typed inside an active Hermes chat session) can only **switch between providers and models you've already set up**. It cannot add new providers, run OAuth, or prompt for API keys.
|
||||
|
||||
**If you need to add a new provider:** Exit your Hermes session first (`Ctrl+C` or `/quit`), then run `hermes model` from your terminal prompt.
|
||||
:::
|
||||
|
||||
### `/model` slash command (mid-session)
|
||||
|
||||
Switch models without leaving a session:
|
||||
Switch between already-configured models without leaving a session:
|
||||
|
||||
```
|
||||
/model # Show current model and available options
|
||||
@@ -136,6 +145,16 @@ Switch models without leaving a session:
|
||||
/model openrouter:anthropic/claude-sonnet-4 # Switch back to cloud
|
||||
```
|
||||
|
||||
By default, `/model` changes apply **to the current session only**. Add `--global` to persist the change to `config.yaml`:
|
||||
|
||||
```
|
||||
/model claude-sonnet-4 --global # Switch and save as new default
|
||||
```
|
||||
|
||||
:::info What if I only see OpenRouter models?
|
||||
If you've only configured OpenRouter, `/model` will only show OpenRouter models. To add another provider (Anthropic, DeepSeek, Copilot, etc.), exit your session and run `hermes model` from the terminal.
|
||||
:::
|
||||
|
||||
Provider and base URL changes are persisted to `config.yaml` automatically. When switching away from a custom endpoint, the stale base URL is cleared to prevent it leaking into other providers.
|
||||
|
||||
## `hermes gateway`
|
||||
|
||||
@@ -187,6 +187,32 @@ curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scri
|
||||
|
||||
### Provider & Model Issues
|
||||
|
||||
#### `/model` only shows one provider / can't switch providers
|
||||
|
||||
**Cause:** `/model` (inside a chat session) can only switch between providers you've **already configured**. If you've only set up OpenRouter, that's all `/model` will show.
|
||||
|
||||
**Solution:** Exit your session and use `hermes model` from your terminal to add new providers:
|
||||
|
||||
```bash
|
||||
# Exit the Hermes chat session first (Ctrl+C or /quit)
|
||||
|
||||
# Run the full provider setup wizard
|
||||
hermes model
|
||||
|
||||
# This lets you: add providers, run OAuth, enter API keys, configure endpoints
|
||||
```
|
||||
|
||||
After adding a new provider via `hermes model`, start a new chat session — `/model` will now show all your configured providers.
|
||||
|
||||
:::tip Quick reference
|
||||
| Want to... | Use |
|
||||
|-----------|-----|
|
||||
| Add a new provider | `hermes model` (from terminal) |
|
||||
| Enter/change API keys | `hermes model` (from terminal) |
|
||||
| Switch model mid-session | `/model <name>` (inside session) |
|
||||
| Switch to different configured provider | `/model provider:model` (inside session) |
|
||||
:::
|
||||
|
||||
#### API key not working
|
||||
|
||||
**Cause:** Key is missing, expired, incorrectly set, or for the wrong provider.
|
||||
|
||||
@@ -46,7 +46,7 @@ Type `/` in the CLI to open the autocomplete menu. Built-in commands are case-in
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/config` | Show current configuration |
|
||||
| `/model [model-name]` | Show or change the current model. Supports: `/model claude-sonnet-4`, `/model provider:model` (switch providers), `/model custom:model` (custom endpoint), `/model custom:name:model` (named custom provider), `/model custom` (auto-detect from endpoint). Use `--global` to persist the change to config.yaml. |
|
||||
| `/model [model-name]` | Show or change the current model. Supports: `/model claude-sonnet-4`, `/model provider:model` (switch providers), `/model custom:model` (custom endpoint), `/model custom:name:model` (named custom provider), `/model custom` (auto-detect from endpoint). Use `--global` to persist the change to config.yaml. **Note:** `/model` can only switch between already-configured providers. To add a new provider, exit the session and run `hermes model` from your terminal. |
|
||||
| `/provider` | Show available providers and current provider |
|
||||
| `/personality` | Set a predefined personality |
|
||||
| `/verbose` | Cycle tool progress display: off → new → all → verbose. Can be [enabled for messaging](#notes) via config. |
|
||||
@@ -124,7 +124,7 @@ The messaging gateway supports the following built-in commands inside Telegram,
|
||||
| `/reset` | Reset conversation history. |
|
||||
| `/status` | Show session info. |
|
||||
| `/stop` | Kill all running background processes and interrupt the running agent. |
|
||||
| `/model [provider:model]` | Show or change the model. Supports provider switches (`/model zai:glm-5`), custom endpoints (`/model custom:model`), named custom providers (`/model custom:local:qwen`), and auto-detect (`/model custom`). Use `--global` to persist the change to config.yaml. |
|
||||
| `/model [provider:model]` | Show or change the model. Supports provider switches (`/model zai:glm-5`), custom endpoints (`/model custom:model`), named custom providers (`/model custom:local:qwen`), and auto-detect (`/model custom`). Use `--global` to persist the change to config.yaml. **Note:** `/model` can only switch between already-configured providers. To add a new provider or set up API keys, use `hermes model` from your terminal (outside the chat session). |
|
||||
| `/provider` | Show provider availability and auth status. |
|
||||
| `/personality [name]` | Set a personality overlay for the session. |
|
||||
| `/fast [normal\|fast\|status]` | Toggle fast mode — OpenAI Priority Processing / Anthropic Fast Mode. |
|
||||
|
||||
@@ -83,9 +83,11 @@ Standard OpenAI Chat Completions format. Stateless — the full conversation is
|
||||
}
|
||||
```
|
||||
|
||||
**Streaming** (`"stream": true`): Returns Server-Sent Events (SSE) with token-by-token response chunks. When streaming is enabled in config, tokens are emitted live as the LLM generates them. When disabled, the full response is sent as a single SSE chunk.
|
||||
**Streaming** (`"stream": true`): Returns Server-Sent Events (SSE) with token-by-token response chunks. For **Chat Completions**, the stream uses standard `chat.completion.chunk` events plus Hermes' custom `hermes.tool.progress` event for tool-start UX. For **Responses**, the stream uses OpenAI Responses event types such as `response.created`, `response.output_text.delta`, `response.output_item.added`, `response.output_item.done`, and `response.completed`.
|
||||
|
||||
**Tool progress in streams**: When the agent calls tools during a streaming request, brief progress indicators are injected into the content stream as the tools start executing (e.g. `` `💻 pwd` ``, `` `🔍 Python docs` ``). These appear as inline markdown before the agent's response text, giving frontends like Open WebUI real-time visibility into tool execution.
|
||||
**Tool progress in streams**:
|
||||
- **Chat Completions**: Hermes emits `event: hermes.tool.progress` for tool-start visibility without polluting persisted assistant text.
|
||||
- **Responses**: Hermes emits spec-native `function_call` and `function_call_output` output items during the SSE stream, so clients can render structured tool UI in real time.
|
||||
|
||||
### POST /v1/responses
|
||||
|
||||
@@ -128,7 +130,7 @@ Chain responses to maintain full context (including tool calls) across turns:
|
||||
}
|
||||
```
|
||||
|
||||
The server reconstructs the full conversation from the stored response chain — all previous tool calls and results are preserved.
|
||||
The server reconstructs the full conversation from the stored response chain — all previous tool calls and results are preserved. Chained requests also share the same session, so multi-turn conversations appear as a single entry in the dashboard and session history.
|
||||
|
||||
#### Named conversations
|
||||
|
||||
|
||||
@@ -134,10 +134,10 @@ To use the Responses API mode:
|
||||
3. Change **API Type** from "Chat Completions" to **"Responses (Experimental)"**
|
||||
4. Save
|
||||
|
||||
With the Responses API, Open WebUI sends requests in the Responses format (`input` array + `instructions`), and Hermes Agent can preserve full tool call history across turns via `previous_response_id`.
|
||||
With the Responses API, Open WebUI sends requests in the Responses format (`input` array + `instructions`), and Hermes Agent can preserve full tool call history across turns via `previous_response_id`. When `stream: true`, Hermes also streams spec-native `function_call` and `function_call_output` items, which enables custom structured tool-call UI in clients that render Responses events.
|
||||
|
||||
:::note
|
||||
Open WebUI currently manages conversation history client-side even in Responses mode — it sends the full message history in each request rather than using `previous_response_id`. The Responses API mode is mainly useful for future compatibility as frontends evolve.
|
||||
Open WebUI currently manages conversation history client-side even in Responses mode — it sends the full message history in each request rather than using `previous_response_id`. The main advantage of Responses mode today is the structured event stream: text deltas, `function_call`, and `function_call_output` items arrive as OpenAI Responses SSE events instead of Chat Completions chunks.
|
||||
:::
|
||||
|
||||
## How It Works
|
||||
|
||||
@@ -0,0 +1,191 @@
|
||||
---
|
||||
sidebar_position: 2
|
||||
sidebar_label: "Google Workspace"
|
||||
title: "Google Workspace — Gmail, Calendar, Drive, Sheets & Docs"
|
||||
description: "Send email, manage calendar events, search Drive, read/write Sheets, and access Docs — all through OAuth2-authenticated Google APIs"
|
||||
---
|
||||
|
||||
# Google Workspace Skill
|
||||
|
||||
Gmail, Calendar, Drive, Contacts, Sheets, and Docs integration for Hermes. Uses OAuth2 with automatic token refresh. Prefers the [Google Workspace CLI (`gws`)](https://github.com/nicholasgasior/gws) when available for broader coverage, and falls back to Google's Python client libraries otherwise.
|
||||
|
||||
**Skill path:** `skills/productivity/google-workspace/`
|
||||
|
||||
## Setup
|
||||
|
||||
The setup is fully agent-driven — ask Hermes to set up Google Workspace and it walks you through each step. The flow:
|
||||
|
||||
1. **Create a Google Cloud project** and enable the required APIs (Gmail, Calendar, Drive, Sheets, Docs, People)
|
||||
2. **Create OAuth 2.0 credentials** (Desktop app type) and download the client secret JSON
|
||||
3. **Authorize** — Hermes generates an auth URL, you approve in the browser, paste back the redirect URL
|
||||
4. **Done** — token auto-refreshes from that point on
|
||||
|
||||
:::tip Email-only users
|
||||
If you only need email (no Calendar/Drive/Sheets), use the **himalaya** skill instead — it works with a Gmail App Password and takes 2 minutes. No Google Cloud project needed.
|
||||
:::
|
||||
|
||||
## Gmail
|
||||
|
||||
### Searching
|
||||
|
||||
```bash
|
||||
$GAPI gmail search "is:unread" --max 10
|
||||
$GAPI gmail search "from:boss@company.com newer_than:1d"
|
||||
$GAPI gmail search "has:attachment filename:pdf newer_than:7d"
|
||||
```
|
||||
|
||||
Returns JSON with `id`, `from`, `subject`, `date`, `snippet`, and `labels` for each message.
|
||||
|
||||
### Reading
|
||||
|
||||
```bash
|
||||
$GAPI gmail get MESSAGE_ID
|
||||
```
|
||||
|
||||
Returns the full message body as text (prefers plain text, falls back to HTML).
|
||||
|
||||
### Sending
|
||||
|
||||
```bash
|
||||
# Basic send
|
||||
$GAPI gmail send --to user@example.com --subject "Hello" --body "Message text"
|
||||
|
||||
# HTML email
|
||||
$GAPI gmail send --to user@example.com --subject "Report" \
|
||||
--body "<h1>Q4 Results</h1><p>Details here</p>" --html
|
||||
|
||||
# Custom From header (display name + email)
|
||||
$GAPI gmail send --to user@example.com --subject "Hello" \
|
||||
--from '"Research Agent" <user@example.com>' --body "Message text"
|
||||
|
||||
# With CC
|
||||
$GAPI gmail send --to user@example.com --cc "team@example.com" \
|
||||
--subject "Update" --body "FYI"
|
||||
```
|
||||
|
||||
### Custom From Header
|
||||
|
||||
The `--from` flag lets you customize the sender display name on outgoing emails. This is useful when multiple agents share the same Gmail account but you want recipients to see different names:
|
||||
|
||||
```bash
|
||||
# Agent 1
|
||||
$GAPI gmail send --to client@co.com --subject "Research Summary" \
|
||||
--from '"Research Agent" <shared@company.com>' --body "..."
|
||||
|
||||
# Agent 2
|
||||
$GAPI gmail send --to client@co.com --subject "Code Review" \
|
||||
--from '"Code Assistant" <shared@company.com>' --body "..."
|
||||
```
|
||||
|
||||
**How it works:** The `--from` value is set as the RFC 5322 `From` header on the MIME message. Gmail allows customizing the display name on your own authenticated email address without any additional configuration. Recipients see the custom display name (e.g. "Research Agent") while the email address stays the same.
|
||||
|
||||
**Important:** If you use a *different email address* in `--from` (not the authenticated account), Gmail requires that address to be configured as a [Send As alias](https://support.google.com/mail/answer/22370) in Gmail Settings → Accounts → Send mail as.
|
||||
|
||||
The `--from` flag works on both `send` and `reply`:
|
||||
|
||||
```bash
|
||||
$GAPI gmail reply MESSAGE_ID \
|
||||
--from '"Support Bot" <shared@company.com>' --body "We're on it"
|
||||
```
|
||||
|
||||
### Replying
|
||||
|
||||
```bash
|
||||
$GAPI gmail reply MESSAGE_ID --body "Thanks, that works for me."
|
||||
```
|
||||
|
||||
Automatically threads the reply (sets `In-Reply-To` and `References` headers) and uses the original message's thread ID.
|
||||
|
||||
### Labels
|
||||
|
||||
```bash
|
||||
# List all labels
|
||||
$GAPI gmail labels
|
||||
|
||||
# Add/remove labels
|
||||
$GAPI gmail modify MESSAGE_ID --add-labels LABEL_ID
|
||||
$GAPI gmail modify MESSAGE_ID --remove-labels UNREAD
|
||||
```
|
||||
|
||||
## Calendar
|
||||
|
||||
```bash
|
||||
# List events (defaults to next 7 days)
|
||||
$GAPI calendar list
|
||||
$GAPI calendar list --start 2026-03-01T00:00:00Z --end 2026-03-07T23:59:59Z
|
||||
|
||||
# Create event (timezone required)
|
||||
$GAPI calendar create --summary "Team Standup" \
|
||||
--start 2026-03-01T10:00:00-07:00 --end 2026-03-01T10:30:00-07:00
|
||||
|
||||
# With location and attendees
|
||||
$GAPI calendar create --summary "Lunch" \
|
||||
--start 2026-03-01T12:00:00Z --end 2026-03-01T13:00:00Z \
|
||||
--location "Cafe" --attendees "alice@co.com,bob@co.com"
|
||||
|
||||
# Delete event
|
||||
$GAPI calendar delete EVENT_ID
|
||||
```
|
||||
|
||||
:::warning
|
||||
Calendar times **must** include a timezone offset (e.g. `-07:00`) or use UTC (`Z`). Bare datetimes like `2026-03-01T10:00:00` are ambiguous and will be treated as UTC.
|
||||
:::
|
||||
|
||||
## Drive
|
||||
|
||||
```bash
|
||||
$GAPI drive search "quarterly report" --max 10
|
||||
$GAPI drive search "mimeType='application/pdf'" --raw-query --max 5
|
||||
```
|
||||
|
||||
## Sheets
|
||||
|
||||
```bash
|
||||
# Read a range
|
||||
$GAPI sheets get SHEET_ID "Sheet1!A1:D10"
|
||||
|
||||
# Write to a range
|
||||
$GAPI sheets update SHEET_ID "Sheet1!A1:B2" --values '[["Name","Score"],["Alice","95"]]'
|
||||
|
||||
# Append rows
|
||||
$GAPI sheets append SHEET_ID "Sheet1!A:C" --values '[["new","row","data"]]'
|
||||
```
|
||||
|
||||
## Docs
|
||||
|
||||
```bash
|
||||
$GAPI docs get DOC_ID
|
||||
```
|
||||
|
||||
Returns the document title and full text content.
|
||||
|
||||
## Contacts
|
||||
|
||||
```bash
|
||||
$GAPI contacts list --max 20
|
||||
```
|
||||
|
||||
## Output Format
|
||||
|
||||
All commands return JSON. Key fields per service:
|
||||
|
||||
| Command | Fields |
|
||||
|---------|--------|
|
||||
| `gmail search` | `id`, `threadId`, `from`, `to`, `subject`, `date`, `snippet`, `labels` |
|
||||
| `gmail get` | `id`, `threadId`, `from`, `to`, `subject`, `date`, `labels`, `body` |
|
||||
| `gmail send/reply` | `status`, `id`, `threadId` |
|
||||
| `calendar list` | `id`, `summary`, `start`, `end`, `location`, `description`, `htmlLink` |
|
||||
| `calendar create` | `status`, `id`, `summary`, `htmlLink` |
|
||||
| `drive search` | `id`, `name`, `mimeType`, `modifiedTime`, `webViewLink` |
|
||||
| `contacts list` | `name`, `emails`, `phones` |
|
||||
| `sheets get` | 2D array of cell values |
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
| Problem | Fix |
|
||||
|---------|-----|
|
||||
| `NOT_AUTHENTICATED` | Run setup (ask Hermes to set up Google Workspace) |
|
||||
| `REFRESH_FAILED` | Token revoked — re-run authorization steps |
|
||||
| `HttpError 403: Insufficient Permission` | Missing scope — revoke and re-authorize with the right services |
|
||||
| `HttpError 403: Access Not Configured` | API not enabled in Google Cloud Console |
|
||||
| `ModuleNotFoundError` | Run setup script with `--install-deps` |
|
||||
@@ -92,6 +92,7 @@ const sidebars: SidebarsConfig = {
|
||||
label: 'Skills',
|
||||
items: [
|
||||
'user-guide/skills/godmode',
|
||||
'user-guide/skills/google-workspace',
|
||||
],
|
||||
},
|
||||
],
|
||||
|
||||
@@ -8,20 +8,24 @@
|
||||
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&family=JetBrains+Mono:wght@400;500&display=swap');
|
||||
|
||||
:root {
|
||||
/* Gold/Amber palette from landing page */
|
||||
--ifm-color-primary: #FFD700;
|
||||
--ifm-color-primary-dark: #E6C200;
|
||||
--ifm-color-primary-darker: #D9B700;
|
||||
--ifm-color-primary-darkest: #B39600;
|
||||
--ifm-color-primary-light: #FFDD33;
|
||||
--ifm-color-primary-lighter: #FFE14D;
|
||||
--ifm-color-primary-lightest: #FFEB80;
|
||||
/* Dark amber palette for light mode — readable on white (WCAG AA compliant)
|
||||
Current gold #FFD700 has only 1.4:1 contrast on white; these tones pass 4.5:1+ */
|
||||
--ifm-color-primary: #8B6508;
|
||||
--ifm-color-primary-dark: #7A5800;
|
||||
--ifm-color-primary-darker: #6E4F00;
|
||||
--ifm-color-primary-darkest: #5A4100;
|
||||
--ifm-color-primary-light: #9E7410;
|
||||
--ifm-color-primary-lighter: #B38319;
|
||||
--ifm-color-primary-lightest: #C89222;
|
||||
|
||||
--ifm-font-family-base: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
|
||||
--ifm-font-family-monospace: 'JetBrains Mono', 'Fira Code', 'Cascadia Code', monospace;
|
||||
|
||||
--ifm-code-font-size: 90%;
|
||||
--ifm-heading-font-weight: 600;
|
||||
|
||||
--ifm-link-color: #7A5800;
|
||||
--ifm-link-hover-color: #5A4100;
|
||||
}
|
||||
|
||||
/* Dark mode — the PRIMARY mode, matches landing page */
|
||||
@@ -91,6 +95,13 @@
|
||||
padding-left: calc(var(--ifm-menu-link-padding-horizontal) - 3px);
|
||||
}
|
||||
|
||||
/* Light mode sidebar active */
|
||||
[data-theme='light'] .menu__link--active:not(.menu__link--sublist) {
|
||||
background-color: rgba(139, 101, 8, 0.08);
|
||||
border-left: 3px solid #8B6508;
|
||||
padding-left: calc(var(--ifm-menu-link-padding-horizontal) - 3px);
|
||||
}
|
||||
|
||||
/* Code blocks */
|
||||
[data-theme='dark'] .prism-code {
|
||||
background-color: #0a0a12 !important;
|
||||
@@ -167,6 +178,16 @@ pre.prism-code.language-ascii code {
|
||||
border-color: rgba(255, 215, 0, 0.06);
|
||||
}
|
||||
|
||||
/* Light mode table styling */
|
||||
[data-theme='light'] table th {
|
||||
background-color: rgba(139, 101, 8, 0.06);
|
||||
border-color: rgba(139, 101, 8, 0.15);
|
||||
}
|
||||
|
||||
[data-theme='light'] table td {
|
||||
border-color: rgba(139, 101, 8, 0.10);
|
||||
}
|
||||
|
||||
/* Footer */
|
||||
.footer {
|
||||
border-top: 1px solid rgba(255, 215, 0, 0.08);
|
||||
@@ -177,11 +198,16 @@ pre.prism-code.language-ascii code {
|
||||
transition: color 0.2s;
|
||||
}
|
||||
|
||||
.footer a:hover {
|
||||
[data-theme='dark'] .footer a:hover {
|
||||
color: #FFD700;
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
[data-theme='light'] .footer a:hover {
|
||||
color: #7A5800;
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
/* Scrollbar */
|
||||
[data-theme='dark'] ::-webkit-scrollbar {
|
||||
width: 8px;
|
||||
|
||||
Reference in New Issue
Block a user