Compare commits
75 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ab6abc2c13 | |||
| aafe86d81a | |||
| 1aa7027be1 | |||
| f961937097 | |||
| 7a427d7b03 | |||
| 66a1942524 | |||
| 1173adbe86 | |||
| a5beb6d8f0 | |||
| 0e3b7b6a39 | |||
| 5e705bc31b | |||
| 55ce601502 | |||
| 8f6ecd5c64 | |||
| a51a767407 | |||
| 2ea4dd30c6 | |||
| 80e578d3e3 | |||
| c52353cf8a | |||
| d76ebf0ec3 | |||
| 4be5070427 | |||
| e140c02d51 | |||
| 88643a1ba9 | |||
| b7b585656b | |||
| 4494c0b033 | |||
| aa6416399e | |||
| b313751acf | |||
| b1d05dfe8b | |||
| f8899af113 | |||
| cf29cba084 | |||
| ec9b868aea | |||
| 3ec6c71e43 | |||
| 4ad0083118 | |||
| 1055d4356a | |||
| 5822711ae6 | |||
| b19f5133c3 | |||
| 471ea81a7d | |||
| b1832faaae | |||
| 3a9a1bbb84 | |||
| d8081790f3 | |||
| 493bf8db7e | |||
| d9eba2a44f | |||
| fc061c2fee | |||
| aaa96713d4 | |||
| 02954c1a10 | |||
| 4355f30422 | |||
| 2f07df3177 | |||
| 672e9752a0 | |||
| df0f684c34 | |||
| 21afa134f0 | |||
| 6bcec1ac25 | |||
| fe331ed9bd | |||
| 746abf5e28 | |||
| 4d2c93a04f | |||
| 3959e3cadb | |||
| ec5fdb8b92 | |||
| c030ac1d85 | |||
| d223f7388d | |||
| 816d1344ee | |||
| 4c0c7f4c6e | |||
| 04b6ecadc4 | |||
| e84d952dc0 | |||
| 388130a122 | |||
| bb59057d5d | |||
| 36a4481152 | |||
| efa753678c | |||
| 7f3a567259 | |||
| defbe0f9e9 | |||
| 18862145e4 | |||
| 35558dadf4 | |||
| ae8059ca24 | |||
| 116984feb7 | |||
| 219af75704 | |||
| d76fa7fc37 | |||
| 7b6d14e62a | |||
| 67d707e851 | |||
| e648863d52 | |||
| a7cc1cf309 |
@@ -5,7 +5,7 @@ Instructions for AI coding assistants and developers working on the hermes-agent
|
||||
## Development Environment
|
||||
|
||||
```bash
|
||||
source .venv/bin/activate # ALWAYS activate before running Python
|
||||
source venv/bin/activate # ALWAYS activate before running Python
|
||||
```
|
||||
|
||||
## Project Structure
|
||||
@@ -23,6 +23,7 @@ hermes-agent/
|
||||
│ ├── prompt_caching.py # Anthropic prompt caching
|
||||
│ ├── auxiliary_client.py # Auxiliary LLM client (vision, summarization)
|
||||
│ ├── model_metadata.py # Model context lengths, token estimation
|
||||
│ ├── models_dev.py # models.dev registry integration (provider-aware context)
|
||||
│ ├── display.py # KawaiiSpinner, tool preview formatting
|
||||
│ ├── skill_commands.py # Skill slash commands (shared CLI/gateway)
|
||||
│ └── trajectory.py # Trajectory saving helpers
|
||||
@@ -366,6 +367,9 @@ Leaks as literal `?[K` text under `prompt_toolkit`'s `patch_stdout`. Use space-p
|
||||
### `_last_resolved_tool_names` is a process-global in `model_tools.py`
|
||||
`_run_single_child()` in `delegate_tool.py` saves and restores this global around subagent execution. If you add new code that reads this global, be aware it may be temporarily stale during child agent runs.
|
||||
|
||||
### DO NOT hardcode cross-tool references in schema descriptions
|
||||
Tool schema descriptions must not mention tools from other toolsets by name (e.g., `browser_navigate` saying "prefer web_search"). Those tools may be unavailable (missing API keys, disabled toolset), causing the model to hallucinate calls to non-existent tools. If a cross-reference is needed, add it dynamically in `get_tool_definitions()` in `model_tools.py` — see the `browser_navigate` / `execute_code` post-processing blocks for the pattern.
|
||||
|
||||
### Tests must not write to `~/.hermes/`
|
||||
The `_isolate_hermes_home` autouse fixture in `tests/conftest.py` redirects `HERMES_HOME` to a temp dir. Never hardcode `~/.hermes/` paths in tests.
|
||||
|
||||
@@ -374,7 +378,7 @@ The `_isolate_hermes_home` autouse fixture in `tests/conftest.py` redirects `HER
|
||||
## Testing
|
||||
|
||||
```bash
|
||||
source .venv/bin/activate
|
||||
source venv/bin/activate
|
||||
python -m pytest tests/ -q # Full suite (~3000 tests, ~3 min)
|
||||
python -m pytest tests/test_model_tools.py -q # Toolset resolution
|
||||
python -m pytest tests/test_cli_init.py -q # CLI config loading
|
||||
|
||||
@@ -146,8 +146,8 @@ git clone https://github.com/NousResearch/hermes-agent.git
|
||||
cd hermes-agent
|
||||
git submodule update --init mini-swe-agent # required terminal backend
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
uv venv .venv --python 3.11
|
||||
source .venv/bin/activate
|
||||
uv venv venv --python 3.11
|
||||
source venv/bin/activate
|
||||
uv pip install -e ".[all,dev]"
|
||||
uv pip install -e "./mini-swe-agent"
|
||||
python -m pytest tests/ -q
|
||||
|
||||
@@ -304,6 +304,8 @@ class HermesACPAgent(acp.Agent):
|
||||
|
||||
if result.get("messages"):
|
||||
state.history = result["messages"]
|
||||
# Persist updated history so sessions survive process restarts.
|
||||
self.session_manager.save_session(session_id)
|
||||
|
||||
final_response = result.get("final_response", "")
|
||||
if final_response and conn:
|
||||
@@ -400,6 +402,7 @@ class HermesACPAgent(acp.Agent):
|
||||
cwd=state.cwd,
|
||||
model=new_model,
|
||||
)
|
||||
self.session_manager.save_session(state.session_id)
|
||||
provider_label = target_provider or getattr(state.agent, "provider", "auto")
|
||||
logger.info("Session %s: model switched to %s", state.session_id, new_model)
|
||||
return f"Model switched to: {new_model}\nProvider: {provider_label}"
|
||||
@@ -444,6 +447,7 @@ class HermesACPAgent(acp.Agent):
|
||||
|
||||
def _cmd_reset(self, args: str, state: SessionState) -> str:
|
||||
state.history.clear()
|
||||
self.session_manager.save_session(state.session_id)
|
||||
return "Conversation history cleared."
|
||||
|
||||
def _cmd_compact(self, args: str, state: SessionState) -> str:
|
||||
@@ -453,6 +457,7 @@ class HermesACPAgent(acp.Agent):
|
||||
agent = state.agent
|
||||
if hasattr(agent, "compress_context"):
|
||||
agent.compress_context(state.history)
|
||||
self.session_manager.save_session(state.session_id)
|
||||
return f"Context compressed. Messages: {len(state.history)}"
|
||||
return "Context compression not available for this agent."
|
||||
except Exception as e:
|
||||
@@ -475,5 +480,6 @@ class HermesACPAgent(acp.Agent):
|
||||
cwd=state.cwd,
|
||||
model=model_id,
|
||||
)
|
||||
self.session_manager.save_session(session_id)
|
||||
logger.info("Session %s: model switched to %s", session_id, model_id)
|
||||
return None
|
||||
|
||||
+260
-34
@@ -1,7 +1,15 @@
|
||||
"""ACP session manager — maps ACP sessions to Hermes AIAgent instances."""
|
||||
"""ACP session manager — maps ACP sessions to Hermes AIAgent instances.
|
||||
|
||||
Sessions are persisted to the shared SessionDB (``~/.hermes/state.db``) so they
|
||||
survive process restarts and appear in ``session_search``. When the editor
|
||||
reconnects after idle/restart, the ``load_session`` / ``resume_session`` calls
|
||||
find the persisted session in the database and restore the full conversation
|
||||
history.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
@@ -46,18 +54,26 @@ class SessionState:
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""Thread-safe manager for ACP sessions backed by Hermes AIAgent instances."""
|
||||
"""Thread-safe manager for ACP sessions backed by Hermes AIAgent instances.
|
||||
|
||||
def __init__(self, agent_factory=None):
|
||||
Sessions are held in-memory for fast access **and** persisted to the
|
||||
shared SessionDB so they survive process restarts and are searchable
|
||||
via ``session_search``.
|
||||
"""
|
||||
|
||||
def __init__(self, agent_factory=None, db=None):
|
||||
"""
|
||||
Args:
|
||||
agent_factory: Optional callable that creates an AIAgent-like object.
|
||||
Used by tests. When omitted, a real AIAgent is created
|
||||
using the current Hermes runtime provider configuration.
|
||||
db: Optional SessionDB instance. When omitted, the default
|
||||
SessionDB (``~/.hermes/state.db``) is lazily created.
|
||||
"""
|
||||
self._sessions: Dict[str, SessionState] = {}
|
||||
self._lock = Lock()
|
||||
self._agent_factory = agent_factory
|
||||
self._db_instance = db # None → lazy-init on first use
|
||||
|
||||
# ---- public API ---------------------------------------------------------
|
||||
|
||||
@@ -77,54 +93,67 @@ class SessionManager:
|
||||
with self._lock:
|
||||
self._sessions[session_id] = state
|
||||
_register_task_cwd(session_id, cwd)
|
||||
self._persist(state)
|
||||
logger.info("Created ACP session %s (cwd=%s)", session_id, cwd)
|
||||
return state
|
||||
|
||||
def get_session(self, session_id: str) -> Optional[SessionState]:
|
||||
"""Return the session for *session_id*, or ``None``."""
|
||||
"""Return the session for *session_id*, or ``None``.
|
||||
|
||||
If the session is not in memory but exists in the database (e.g. after
|
||||
a process restart), it is transparently restored.
|
||||
"""
|
||||
with self._lock:
|
||||
return self._sessions.get(session_id)
|
||||
state = self._sessions.get(session_id)
|
||||
if state is not None:
|
||||
return state
|
||||
# Attempt to restore from database.
|
||||
return self._restore(session_id)
|
||||
|
||||
def remove_session(self, session_id: str) -> bool:
|
||||
"""Remove a session. Returns True if it existed."""
|
||||
"""Remove a session from memory and database. Returns True if it existed."""
|
||||
with self._lock:
|
||||
existed = self._sessions.pop(session_id, None) is not None
|
||||
if existed:
|
||||
db_existed = self._delete_persisted(session_id)
|
||||
if existed or db_existed:
|
||||
_clear_task_cwd(session_id)
|
||||
return existed
|
||||
return existed or db_existed
|
||||
|
||||
def fork_session(self, session_id: str, cwd: str = ".") -> Optional[SessionState]:
|
||||
"""Deep-copy a session's history into a new session."""
|
||||
import threading
|
||||
|
||||
with self._lock:
|
||||
original = self._sessions.get(session_id)
|
||||
if original is None:
|
||||
return None
|
||||
original = self.get_session(session_id) # checks DB too
|
||||
if original is None:
|
||||
return None
|
||||
|
||||
new_id = str(uuid.uuid4())
|
||||
agent = self._make_agent(
|
||||
session_id=new_id,
|
||||
cwd=cwd,
|
||||
model=original.model or None,
|
||||
)
|
||||
state = SessionState(
|
||||
session_id=new_id,
|
||||
agent=agent,
|
||||
cwd=cwd,
|
||||
model=getattr(agent, "model", original.model) or original.model,
|
||||
history=copy.deepcopy(original.history),
|
||||
cancel_event=threading.Event(),
|
||||
)
|
||||
new_id = str(uuid.uuid4())
|
||||
agent = self._make_agent(
|
||||
session_id=new_id,
|
||||
cwd=cwd,
|
||||
model=original.model or None,
|
||||
)
|
||||
state = SessionState(
|
||||
session_id=new_id,
|
||||
agent=agent,
|
||||
cwd=cwd,
|
||||
model=getattr(agent, "model", original.model) or original.model,
|
||||
history=copy.deepcopy(original.history),
|
||||
cancel_event=threading.Event(),
|
||||
)
|
||||
with self._lock:
|
||||
self._sessions[new_id] = state
|
||||
_register_task_cwd(new_id, cwd)
|
||||
self._persist(state)
|
||||
logger.info("Forked ACP session %s -> %s", session_id, new_id)
|
||||
return state
|
||||
|
||||
def list_sessions(self) -> List[Dict[str, Any]]:
|
||||
"""Return lightweight info dicts for all sessions."""
|
||||
"""Return lightweight info dicts for all sessions (memory + database)."""
|
||||
# Collect in-memory sessions first.
|
||||
with self._lock:
|
||||
return [
|
||||
seen_ids = set(self._sessions.keys())
|
||||
results = [
|
||||
{
|
||||
"session_id": s.session_id,
|
||||
"cwd": s.cwd,
|
||||
@@ -134,23 +163,220 @@ class SessionManager:
|
||||
for s in self._sessions.values()
|
||||
]
|
||||
|
||||
# Merge any persisted sessions not currently in memory.
|
||||
db = self._get_db()
|
||||
if db is not None:
|
||||
try:
|
||||
rows = db.search_sessions(source="acp", limit=1000)
|
||||
for row in rows:
|
||||
sid = row["id"]
|
||||
if sid in seen_ids:
|
||||
continue
|
||||
# Extract cwd from model_config JSON.
|
||||
cwd = "."
|
||||
mc = row.get("model_config")
|
||||
if mc:
|
||||
try:
|
||||
cwd = json.loads(mc).get("cwd", ".")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
results.append({
|
||||
"session_id": sid,
|
||||
"cwd": cwd,
|
||||
"model": row.get("model") or "",
|
||||
"history_len": row.get("message_count") or 0,
|
||||
})
|
||||
except Exception:
|
||||
logger.debug("Failed to list ACP sessions from DB", exc_info=True)
|
||||
|
||||
return results
|
||||
|
||||
def update_cwd(self, session_id: str, cwd: str) -> Optional[SessionState]:
|
||||
"""Update the working directory for a session and its tool overrides."""
|
||||
with self._lock:
|
||||
state = self._sessions.get(session_id)
|
||||
if state is None:
|
||||
return None
|
||||
state.cwd = cwd
|
||||
state = self.get_session(session_id) # checks DB too
|
||||
if state is None:
|
||||
return None
|
||||
state.cwd = cwd
|
||||
_register_task_cwd(session_id, cwd)
|
||||
self._persist(state)
|
||||
return state
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Remove all sessions and clear task-specific cwd overrides."""
|
||||
"""Remove all sessions (memory and database) and clear task-specific cwd overrides."""
|
||||
with self._lock:
|
||||
session_ids = list(self._sessions.keys())
|
||||
self._sessions.clear()
|
||||
for session_id in session_ids:
|
||||
_clear_task_cwd(session_id)
|
||||
self._delete_persisted(session_id)
|
||||
# Also remove any DB-only ACP sessions not currently in memory.
|
||||
db = self._get_db()
|
||||
if db is not None:
|
||||
try:
|
||||
rows = db.search_sessions(source="acp", limit=10000)
|
||||
for row in rows:
|
||||
sid = row["id"]
|
||||
_clear_task_cwd(sid)
|
||||
db.delete_session(sid)
|
||||
except Exception:
|
||||
logger.debug("Failed to cleanup ACP sessions from DB", exc_info=True)
|
||||
|
||||
def save_session(self, session_id: str) -> None:
|
||||
"""Persist the current state of a session to the database.
|
||||
|
||||
Called by the server after prompt completion, slash commands that
|
||||
mutate history, and model switches.
|
||||
"""
|
||||
with self._lock:
|
||||
state = self._sessions.get(session_id)
|
||||
if state is not None:
|
||||
self._persist(state)
|
||||
|
||||
# ---- persistence via SessionDB ------------------------------------------
|
||||
|
||||
def _get_db(self):
|
||||
"""Lazily initialise and return the SessionDB instance.
|
||||
|
||||
Returns ``None`` if the DB is unavailable (e.g. import error in a
|
||||
minimal test environment).
|
||||
|
||||
Note: we resolve ``HERMES_HOME`` dynamically rather than relying on
|
||||
the module-level ``DEFAULT_DB_PATH`` constant, because that constant
|
||||
is evaluated at import time and won't reflect env-var changes made
|
||||
later (e.g. by the test fixture ``_isolate_hermes_home``).
|
||||
"""
|
||||
if self._db_instance is not None:
|
||||
return self._db_instance
|
||||
try:
|
||||
import os
|
||||
from pathlib import Path
|
||||
from hermes_state import SessionDB
|
||||
hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
self._db_instance = SessionDB(db_path=hermes_home / "state.db")
|
||||
return self._db_instance
|
||||
except Exception:
|
||||
logger.debug("SessionDB unavailable for ACP persistence", exc_info=True)
|
||||
return None
|
||||
|
||||
def _persist(self, state: SessionState) -> None:
|
||||
"""Write session state to the database.
|
||||
|
||||
Creates the session record if it doesn't exist, then replaces all
|
||||
stored messages with the current in-memory history.
|
||||
"""
|
||||
db = self._get_db()
|
||||
if db is None:
|
||||
return
|
||||
|
||||
# Ensure model is a plain string (not a MagicMock or other proxy).
|
||||
model_str = str(state.model) if state.model else None
|
||||
cwd_json = json.dumps({"cwd": state.cwd})
|
||||
|
||||
try:
|
||||
# Ensure the session record exists.
|
||||
existing = db.get_session(state.session_id)
|
||||
if existing is None:
|
||||
db.create_session(
|
||||
session_id=state.session_id,
|
||||
source="acp",
|
||||
model=model_str,
|
||||
model_config={"cwd": state.cwd},
|
||||
)
|
||||
else:
|
||||
# Update model_config (contains cwd) if changed.
|
||||
try:
|
||||
with db._lock:
|
||||
db._conn.execute(
|
||||
"UPDATE sessions SET model_config = ?, model = COALESCE(?, model) WHERE id = ?",
|
||||
(cwd_json, model_str, state.session_id),
|
||||
)
|
||||
db._conn.commit()
|
||||
except Exception:
|
||||
logger.debug("Failed to update ACP session metadata", exc_info=True)
|
||||
|
||||
# Replace stored messages with current history.
|
||||
db.clear_messages(state.session_id)
|
||||
for msg in state.history:
|
||||
db.append_message(
|
||||
session_id=state.session_id,
|
||||
role=msg.get("role", "user"),
|
||||
content=msg.get("content"),
|
||||
tool_name=msg.get("tool_name") or msg.get("name"),
|
||||
tool_calls=msg.get("tool_calls"),
|
||||
tool_call_id=msg.get("tool_call_id"),
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to persist ACP session %s", state.session_id, exc_info=True)
|
||||
|
||||
def _restore(self, session_id: str) -> Optional[SessionState]:
|
||||
"""Load a session from the database into memory, recreating the AIAgent."""
|
||||
import threading
|
||||
|
||||
db = self._get_db()
|
||||
if db is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
row = db.get_session(session_id)
|
||||
except Exception:
|
||||
logger.debug("Failed to query DB for ACP session %s", session_id, exc_info=True)
|
||||
return None
|
||||
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
# Only restore ACP sessions.
|
||||
if row.get("source") != "acp":
|
||||
return None
|
||||
|
||||
# Extract cwd from model_config.
|
||||
cwd = "."
|
||||
mc = row.get("model_config")
|
||||
if mc:
|
||||
try:
|
||||
cwd = json.loads(mc).get("cwd", ".")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
model = row.get("model") or None
|
||||
|
||||
# Load conversation history.
|
||||
try:
|
||||
history = db.get_messages_as_conversation(session_id)
|
||||
except Exception:
|
||||
logger.warning("Failed to load messages for ACP session %s", session_id, exc_info=True)
|
||||
history = []
|
||||
|
||||
try:
|
||||
agent = self._make_agent(session_id=session_id, cwd=cwd, model=model)
|
||||
except Exception:
|
||||
logger.warning("Failed to recreate agent for ACP session %s", session_id, exc_info=True)
|
||||
return None
|
||||
|
||||
state = SessionState(
|
||||
session_id=session_id,
|
||||
agent=agent,
|
||||
cwd=cwd,
|
||||
model=model or getattr(agent, "model", "") or "",
|
||||
history=history,
|
||||
cancel_event=threading.Event(),
|
||||
)
|
||||
with self._lock:
|
||||
self._sessions[session_id] = state
|
||||
_register_task_cwd(session_id, cwd)
|
||||
logger.info("Restored ACP session %s from DB (%d messages)", session_id, len(history))
|
||||
return state
|
||||
|
||||
def _delete_persisted(self, session_id: str) -> bool:
|
||||
"""Delete a session from the database. Returns True if it existed."""
|
||||
db = self._get_db()
|
||||
if db is None:
|
||||
return False
|
||||
try:
|
||||
return db.delete_session(session_id)
|
||||
except Exception:
|
||||
logger.debug("Failed to delete ACP session %s from DB", session_id, exc_info=True)
|
||||
return False
|
||||
|
||||
# ---- internal -----------------------------------------------------------
|
||||
|
||||
|
||||
@@ -935,6 +935,26 @@ def convert_messages_to_anthropic(
|
||||
if not m["content"]:
|
||||
m["content"] = [{"type": "text", "text": "(tool call removed)"}]
|
||||
|
||||
# Strip orphaned tool_result blocks (no matching tool_use precedes them).
|
||||
# This is the mirror of the above: context compression or session truncation
|
||||
# can remove an assistant message containing a tool_use while leaving the
|
||||
# subsequent tool_result intact. Anthropic rejects these with a 400.
|
||||
tool_use_ids = set()
|
||||
for m in result:
|
||||
if m["role"] == "assistant" and isinstance(m["content"], list):
|
||||
for block in m["content"]:
|
||||
if block.get("type") == "tool_use":
|
||||
tool_use_ids.add(block.get("id"))
|
||||
for m in result:
|
||||
if m["role"] == "user" and isinstance(m["content"], list):
|
||||
m["content"] = [
|
||||
b
|
||||
for b in m["content"]
|
||||
if b.get("type") != "tool_result" or b.get("tool_use_id") in tool_use_ids
|
||||
]
|
||||
if not m["content"]:
|
||||
m["content"] = [{"type": "text", "text": "(tool result removed)"}]
|
||||
|
||||
# Enforce strict role alternation (Anthropic rejects consecutive same-role messages)
|
||||
fixed = []
|
||||
for m in result:
|
||||
|
||||
@@ -654,10 +654,23 @@ def _try_anthropic() -> Tuple[Optional[Any], Optional[str]]:
|
||||
if not token:
|
||||
return None, None
|
||||
|
||||
# Allow base URL override from config.yaml model.base_url
|
||||
base_url = _ANTHROPIC_DEFAULT_BASE_URL
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
cfg = load_config()
|
||||
model_cfg = cfg.get("model")
|
||||
if isinstance(model_cfg, dict):
|
||||
cfg_base_url = (model_cfg.get("base_url") or "").strip().rstrip("/")
|
||||
if cfg_base_url:
|
||||
base_url = cfg_base_url
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
model = _API_KEY_PROVIDER_AUX_MODELS.get("anthropic", "claude-haiku-4-5-20251001")
|
||||
logger.debug("Auxiliary client: Anthropic native (%s)", model)
|
||||
real_client = build_anthropic_client(token, _ANTHROPIC_DEFAULT_BASE_URL)
|
||||
return AnthropicAuxiliaryClient(real_client, model, token, _ANTHROPIC_DEFAULT_BASE_URL), model
|
||||
logger.debug("Auxiliary client: Anthropic native (%s) at %s", model, base_url)
|
||||
real_client = build_anthropic_client(token, base_url)
|
||||
return AnthropicAuxiliaryClient(real_client, model, token, base_url), model
|
||||
|
||||
|
||||
def _resolve_forced_provider(forced: str) -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
@@ -1178,8 +1191,18 @@ def _get_cached_client(
|
||||
cache_key = (provider, async_mode, base_url or "", api_key or "")
|
||||
with _client_cache_lock:
|
||||
if cache_key in _client_cache:
|
||||
cached_client, cached_default = _client_cache[cache_key]
|
||||
return cached_client, model or cached_default
|
||||
cached_client, cached_default, cached_loop = _client_cache[cache_key]
|
||||
if async_mode:
|
||||
# Async clients are bound to the event loop that created them.
|
||||
# A cached async client whose loop has been closed will raise
|
||||
# "Event loop is closed" when httpx tries to clean up its
|
||||
# transport. Discard the stale client and create a fresh one.
|
||||
if cached_loop is not None and cached_loop.is_closed():
|
||||
del _client_cache[cache_key]
|
||||
else:
|
||||
return cached_client, model or cached_default
|
||||
else:
|
||||
return cached_client, model or cached_default
|
||||
# Build outside the lock
|
||||
client, default_model = resolve_provider_client(
|
||||
provider,
|
||||
@@ -1189,11 +1212,20 @@ def _get_cached_client(
|
||||
explicit_api_key=api_key,
|
||||
)
|
||||
if client is not None:
|
||||
# For async clients, remember which loop they were created on so we
|
||||
# can detect stale entries later.
|
||||
bound_loop = None
|
||||
if async_mode:
|
||||
try:
|
||||
import asyncio as _aio
|
||||
bound_loop = _aio.get_event_loop()
|
||||
except RuntimeError:
|
||||
pass
|
||||
with _client_cache_lock:
|
||||
if cache_key not in _client_cache:
|
||||
_client_cache[cache_key] = (client, default_model)
|
||||
_client_cache[cache_key] = (client, default_model, bound_loop)
|
||||
else:
|
||||
client, default_model = _client_cache[cache_key]
|
||||
client, default_model, _ = _client_cache[cache_key]
|
||||
return client, model or default_model
|
||||
|
||||
|
||||
|
||||
@@ -46,17 +46,24 @@ class ContextCompressor:
|
||||
summary_model_override: str = None,
|
||||
base_url: str = "",
|
||||
api_key: str = "",
|
||||
config_context_length: int | None = None,
|
||||
provider: str = "",
|
||||
):
|
||||
self.model = model
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
self.provider = provider
|
||||
self.threshold_percent = threshold_percent
|
||||
self.protect_first_n = protect_first_n
|
||||
self.protect_last_n = protect_last_n
|
||||
self.summary_target_tokens = summary_target_tokens
|
||||
self.quiet_mode = quiet_mode
|
||||
|
||||
self.context_length = get_model_context_length(model, base_url=base_url, api_key=api_key)
|
||||
self.context_length = get_model_context_length(
|
||||
model, base_url=base_url, api_key=api_key,
|
||||
config_context_length=config_context_length,
|
||||
provider=provider,
|
||||
)
|
||||
self.threshold_tokens = int(self.context_length * threshold_percent)
|
||||
self.compression_count = 0
|
||||
self._context_probed = False # True after a step-down from context error
|
||||
|
||||
@@ -356,7 +356,7 @@ class CopilotACPClient:
|
||||
text_parts=text_parts,
|
||||
reasoning_parts=reasoning_parts,
|
||||
)
|
||||
return "".join(text_parts).strip(), "".join(reasoning_parts).strip()
|
||||
return "".join(text_parts), "".join(reasoning_parts)
|
||||
finally:
|
||||
self.close()
|
||||
|
||||
@@ -380,7 +380,7 @@ class CopilotACPClient:
|
||||
content = update.get("content") or {}
|
||||
chunk_text = ""
|
||||
if isinstance(content, dict):
|
||||
chunk_text = str(content.get("text") or "").strip()
|
||||
chunk_text = str(content.get("text") or "")
|
||||
if kind == "agent_message_chunk" and chunk_text and text_parts is not None:
|
||||
text_parts.append(chunk_text)
|
||||
elif kind == "agent_thought_chunk" and chunk_text and reasoning_parts is not None:
|
||||
|
||||
@@ -612,3 +612,95 @@ def write_tty(text: str) -> None:
|
||||
except OSError:
|
||||
sys.stdout.write(text)
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Context pressure display (CLI user-facing warnings)
|
||||
# =========================================================================
|
||||
|
||||
# ANSI color codes for context pressure tiers
|
||||
_CYAN = "\033[36m"
|
||||
_YELLOW = "\033[33m"
|
||||
_BOLD = "\033[1m"
|
||||
_DIM_ANSI = "\033[2m"
|
||||
|
||||
# Bar characters
|
||||
_BAR_FILLED = "▰"
|
||||
_BAR_EMPTY = "▱"
|
||||
_BAR_WIDTH = 20
|
||||
|
||||
|
||||
def format_context_pressure(
|
||||
compaction_progress: float,
|
||||
threshold_tokens: int,
|
||||
threshold_percent: float,
|
||||
compression_enabled: bool = True,
|
||||
) -> str:
|
||||
"""Build a formatted context pressure line for CLI display.
|
||||
|
||||
The bar and percentage show progress toward the compaction threshold,
|
||||
NOT the raw context window. 100% = compaction fires.
|
||||
|
||||
Uses ANSI colors:
|
||||
- cyan at ~60% to compaction = informational
|
||||
- bold yellow at ~85% to compaction = warning
|
||||
|
||||
Args:
|
||||
compaction_progress: How close to compaction (0.0–1.0, 1.0 = fires).
|
||||
threshold_tokens: Compaction threshold in tokens.
|
||||
threshold_percent: Compaction threshold as a fraction of context window.
|
||||
compression_enabled: Whether auto-compression is active.
|
||||
"""
|
||||
pct_int = int(compaction_progress * 100)
|
||||
filled = min(int(compaction_progress * _BAR_WIDTH), _BAR_WIDTH)
|
||||
bar = _BAR_FILLED * filled + _BAR_EMPTY * (_BAR_WIDTH - filled)
|
||||
|
||||
threshold_k = f"{threshold_tokens // 1000}k" if threshold_tokens >= 1000 else str(threshold_tokens)
|
||||
threshold_pct_int = int(threshold_percent * 100)
|
||||
|
||||
# Tier styling
|
||||
if compaction_progress >= 0.85:
|
||||
color = f"{_BOLD}{_YELLOW}"
|
||||
icon = "⚠"
|
||||
if compression_enabled:
|
||||
hint = "compaction imminent"
|
||||
else:
|
||||
hint = "no auto-compaction"
|
||||
else:
|
||||
color = _CYAN
|
||||
icon = "◐"
|
||||
hint = "approaching compaction"
|
||||
|
||||
return (
|
||||
f" {color}{icon} context {bar} {pct_int}% to compaction{_ANSI_RESET}"
|
||||
f" {_DIM_ANSI}{threshold_k} threshold ({threshold_pct_int}%) · {hint}{_ANSI_RESET}"
|
||||
)
|
||||
|
||||
|
||||
def format_context_pressure_gateway(
|
||||
compaction_progress: float,
|
||||
threshold_percent: float,
|
||||
compression_enabled: bool = True,
|
||||
) -> str:
|
||||
"""Build a plain-text context pressure notification for messaging platforms.
|
||||
|
||||
No ANSI — just Unicode and plain text suitable for Telegram/Discord/etc.
|
||||
The percentage shows progress toward the compaction threshold.
|
||||
"""
|
||||
pct_int = int(compaction_progress * 100)
|
||||
filled = min(int(compaction_progress * _BAR_WIDTH), _BAR_WIDTH)
|
||||
bar = _BAR_FILLED * filled + _BAR_EMPTY * (_BAR_WIDTH - filled)
|
||||
|
||||
threshold_pct_int = int(threshold_percent * 100)
|
||||
|
||||
if compaction_progress >= 0.85:
|
||||
icon = "⚠️"
|
||||
if compression_enabled:
|
||||
hint = f"Context compaction is imminent (threshold: {threshold_pct_int}% of window)."
|
||||
else:
|
||||
hint = "Auto-compaction is disabled — context may be truncated."
|
||||
else:
|
||||
icon = "ℹ️"
|
||||
hint = f"Compaction threshold is at {threshold_pct_int}% of context window."
|
||||
|
||||
return f"{icon} Context: {bar} {pct_int}% to compaction\n{hint}"
|
||||
|
||||
+15
-12
@@ -181,22 +181,25 @@ class InsightsEngine:
|
||||
"billing_base_url, billing_mode, estimated_cost_usd, "
|
||||
"actual_cost_usd, cost_status, cost_source")
|
||||
|
||||
# Pre-computed query strings — f-string evaluated once at class definition,
|
||||
# not at runtime, so no user-controlled value can alter the query structure.
|
||||
_GET_SESSIONS_WITH_SOURCE = (
|
||||
f"SELECT {_SESSION_COLS} FROM sessions"
|
||||
" WHERE started_at >= ? AND source = ?"
|
||||
" ORDER BY started_at DESC"
|
||||
)
|
||||
_GET_SESSIONS_ALL = (
|
||||
f"SELECT {_SESSION_COLS} FROM sessions"
|
||||
" WHERE started_at >= ?"
|
||||
" ORDER BY started_at DESC"
|
||||
)
|
||||
|
||||
def _get_sessions(self, cutoff: float, source: str = None) -> List[Dict]:
|
||||
"""Fetch sessions within the time window."""
|
||||
if source:
|
||||
cursor = self._conn.execute(
|
||||
f"""SELECT {self._SESSION_COLS} FROM sessions
|
||||
WHERE started_at >= ? AND source = ?
|
||||
ORDER BY started_at DESC""",
|
||||
(cutoff, source),
|
||||
)
|
||||
cursor = self._conn.execute(self._GET_SESSIONS_WITH_SOURCE, (cutoff, source))
|
||||
else:
|
||||
cursor = self._conn.execute(
|
||||
f"""SELECT {self._SESSION_COLS} FROM sessions
|
||||
WHERE started_at >= ?
|
||||
ORDER BY started_at DESC""",
|
||||
(cutoff,),
|
||||
)
|
||||
cursor = self._conn.execute(self._GET_SESSIONS_ALL, (cutoff,))
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def _get_tool_usage(self, cutoff: float, source: str = None) -> List[Dict]:
|
||||
|
||||
+467
-106
@@ -19,6 +19,46 @@ from hermes_constants import OPENROUTER_MODELS_URL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Provider names that can appear as a "provider:" prefix before a model ID.
|
||||
# Only these are stripped — Ollama-style "model:tag" colons (e.g. "qwen3.5:27b")
|
||||
# are preserved so the full model name reaches cache lookups and server queries.
|
||||
_PROVIDER_PREFIXES: frozenset[str] = frozenset({
|
||||
"openrouter", "nous", "openai-codex", "copilot", "copilot-acp",
|
||||
"zai", "kimi-coding", "minimax", "minimax-cn", "anthropic", "deepseek",
|
||||
"opencode-zen", "opencode-go", "ai-gateway", "kilocode", "alibaba",
|
||||
"custom", "local",
|
||||
# Common aliases
|
||||
"glm", "z-ai", "z.ai", "zhipu", "github", "github-copilot",
|
||||
"github-models", "kimi", "moonshot", "claude", "deep-seek",
|
||||
"opencode", "zen", "go", "vercel", "kilo", "dashscope", "aliyun", "qwen",
|
||||
})
|
||||
|
||||
|
||||
_OLLAMA_TAG_PATTERN = re.compile(
|
||||
r"^(\d+\.?\d*b|latest|stable|q\d|fp?\d|instruct|chat|coder|vision|text)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def _strip_provider_prefix(model: str) -> str:
|
||||
"""Strip a recognised provider prefix from a model string.
|
||||
|
||||
``"local:my-model"`` → ``"my-model"``
|
||||
``"qwen3.5:27b"`` → ``"qwen3.5:27b"`` (unchanged — not a provider prefix)
|
||||
``"qwen:0.5b"`` → ``"qwen:0.5b"`` (unchanged — Ollama model:tag)
|
||||
``"deepseek:latest"``→ ``"deepseek:latest"``(unchanged — Ollama model:tag)
|
||||
"""
|
||||
if ":" not in model or model.startswith("http"):
|
||||
return model
|
||||
prefix, suffix = model.split(":", 1)
|
||||
prefix_lower = prefix.strip().lower()
|
||||
if prefix_lower in _PROVIDER_PREFIXES:
|
||||
# Don't strip if suffix looks like an Ollama tag (e.g. "7b", "latest", "q4_0")
|
||||
if _OLLAMA_TAG_PATTERN.match(suffix.strip()):
|
||||
return model
|
||||
return suffix
|
||||
return model
|
||||
|
||||
_model_metadata_cache: Dict[str, Dict[str, Any]] = {}
|
||||
_model_metadata_cache_time: float = 0
|
||||
_MODEL_CACHE_TTL = 3600
|
||||
@@ -27,104 +67,52 @@ _endpoint_model_metadata_cache_time: Dict[str, float] = {}
|
||||
_ENDPOINT_MODEL_CACHE_TTL = 300
|
||||
|
||||
# Descending tiers for context length probing when the model is unknown.
|
||||
# We start high and step down on context-length errors until one works.
|
||||
# We start at 128K (a safe default for most modern models) and step down
|
||||
# on context-length errors until one works.
|
||||
CONTEXT_PROBE_TIERS = [
|
||||
2_000_000,
|
||||
1_000_000,
|
||||
512_000,
|
||||
200_000,
|
||||
128_000,
|
||||
64_000,
|
||||
32_000,
|
||||
16_000,
|
||||
8_000,
|
||||
]
|
||||
|
||||
# Default context length when no detection method succeeds.
|
||||
DEFAULT_FALLBACK_CONTEXT = CONTEXT_PROBE_TIERS[0]
|
||||
|
||||
# Thin fallback defaults — only broad model family patterns.
|
||||
# These fire only when provider is unknown AND models.dev/OpenRouter/Anthropic
|
||||
# all miss. Replaced the previous 80+ entry dict.
|
||||
# For provider-specific context lengths, models.dev is the primary source.
|
||||
DEFAULT_CONTEXT_LENGTHS = {
|
||||
"anthropic/claude-opus-4": 200000,
|
||||
"anthropic/claude-opus-4.5": 200000,
|
||||
"anthropic/claude-opus-4.6": 200000,
|
||||
"anthropic/claude-sonnet-4": 200000,
|
||||
"anthropic/claude-sonnet-4-20250514": 200000,
|
||||
"anthropic/claude-sonnet-4.5": 200000,
|
||||
"anthropic/claude-sonnet-4.6": 200000,
|
||||
"anthropic/claude-haiku-4.5": 200000,
|
||||
# Bare Anthropic model IDs (for native API provider)
|
||||
"claude-opus-4-6": 200000,
|
||||
"claude-sonnet-4-6": 200000,
|
||||
"claude-opus-4-5-20251101": 200000,
|
||||
"claude-sonnet-4-5-20250929": 200000,
|
||||
"claude-opus-4-1-20250805": 200000,
|
||||
"claude-opus-4-20250514": 200000,
|
||||
"claude-sonnet-4-20250514": 200000,
|
||||
"claude-haiku-4-5-20251001": 200000,
|
||||
"openai/gpt-5": 128000,
|
||||
"openai/gpt-4.1": 1047576,
|
||||
"openai/gpt-4.1-mini": 1047576,
|
||||
"openai/gpt-4o": 128000,
|
||||
"openai/gpt-4-turbo": 128000,
|
||||
"openai/gpt-4o-mini": 128000,
|
||||
"google/gemini-3-pro-preview": 1048576,
|
||||
"google/gemini-3-flash": 1048576,
|
||||
"google/gemini-2.5-flash": 1048576,
|
||||
"google/gemini-2.0-flash": 1048576,
|
||||
"google/gemini-2.5-pro": 1048576,
|
||||
"deepseek/deepseek-v3.2": 65536,
|
||||
"meta-llama/llama-3.3-70b-instruct": 131072,
|
||||
"deepseek/deepseek-chat-v3": 65536,
|
||||
"qwen/qwen-2.5-72b-instruct": 32768,
|
||||
"glm-4.7": 202752,
|
||||
"glm-5": 202752,
|
||||
"glm-4.5": 131072,
|
||||
"glm-4.5-flash": 131072,
|
||||
"kimi-for-coding": 262144,
|
||||
"kimi-k2.5": 262144,
|
||||
"kimi-k2-thinking": 262144,
|
||||
"kimi-k2-thinking-turbo": 262144,
|
||||
"kimi-k2-turbo-preview": 262144,
|
||||
"kimi-k2-0905-preview": 131072,
|
||||
"MiniMax-M2.7": 204800,
|
||||
"MiniMax-M2.7-highspeed": 204800,
|
||||
"MiniMax-M2.5": 204800,
|
||||
"MiniMax-M2.5-highspeed": 204800,
|
||||
"MiniMax-M2.1": 204800,
|
||||
# OpenCode Zen models
|
||||
"gpt-5.4-pro": 128000,
|
||||
"gpt-5.4": 128000,
|
||||
"gpt-5.3-codex": 128000,
|
||||
"gpt-5.3-codex-spark": 128000,
|
||||
"gpt-5.2": 128000,
|
||||
"gpt-5.2-codex": 128000,
|
||||
"gpt-5.1": 128000,
|
||||
"gpt-5.1-codex": 128000,
|
||||
"gpt-5.1-codex-max": 128000,
|
||||
"gpt-5.1-codex-mini": 128000,
|
||||
# Anthropic Claude 4.6 (1M context) — bare IDs only to avoid
|
||||
# fuzzy-match collisions (e.g. "anthropic/claude-sonnet-4" is a
|
||||
# substring of "anthropic/claude-sonnet-4.6").
|
||||
# OpenRouter-prefixed models resolve via OpenRouter live API or models.dev.
|
||||
"claude-opus-4-6": 1000000,
|
||||
"claude-sonnet-4-6": 1000000,
|
||||
"claude-opus-4.6": 1000000,
|
||||
"claude-sonnet-4.6": 1000000,
|
||||
# Catch-all for older Claude models (must sort after specific entries)
|
||||
"claude": 200000,
|
||||
# OpenAI
|
||||
"gpt-4.1": 1047576,
|
||||
"gpt-5": 128000,
|
||||
"gpt-5-codex": 128000,
|
||||
"gpt-5-nano": 128000,
|
||||
# Bare model IDs without provider prefix (avoid duplicates with entries above)
|
||||
"claude-opus-4-5": 200000,
|
||||
"claude-opus-4-1": 200000,
|
||||
"claude-sonnet-4-5": 200000,
|
||||
"claude-sonnet-4": 200000,
|
||||
"claude-haiku-4-5": 200000,
|
||||
"claude-3-5-haiku": 200000,
|
||||
"gemini-3.1-pro": 1048576,
|
||||
"gemini-3-pro": 1048576,
|
||||
"gemini-3-flash": 1048576,
|
||||
"minimax-m2.5": 204800,
|
||||
"minimax-m2.5-free": 204800,
|
||||
"minimax-m2.1": 204800,
|
||||
"glm-4.6": 202752,
|
||||
"kimi-k2": 262144,
|
||||
"qwen3-coder": 32768,
|
||||
"big-pickle": 128000,
|
||||
# Alibaba Cloud / DashScope Qwen models
|
||||
"qwen3.5-plus": 131072,
|
||||
"qwen3-max": 131072,
|
||||
"qwen3-coder-plus": 131072,
|
||||
"qwen3-coder-next": 131072,
|
||||
"qwen-plus-latest": 131072,
|
||||
"qwen3.5-flash": 131072,
|
||||
"qwen-vl-max": 32768,
|
||||
"gpt-4": 128000,
|
||||
# Google
|
||||
"gemini": 1048576,
|
||||
# DeepSeek
|
||||
"deepseek": 128000,
|
||||
# Meta
|
||||
"llama": 131072,
|
||||
# Qwen
|
||||
"qwen": 131072,
|
||||
# MiniMax
|
||||
"minimax": 204800,
|
||||
# GLM
|
||||
"glm": 202752,
|
||||
# Kimi
|
||||
"kimi": 262144,
|
||||
}
|
||||
|
||||
_CONTEXT_LENGTH_KEYS = (
|
||||
@@ -136,6 +124,8 @@ _CONTEXT_LENGTH_KEYS = (
|
||||
"max_input_tokens",
|
||||
"max_sequence_length",
|
||||
"max_seq_len",
|
||||
"n_ctx_train",
|
||||
"n_ctx",
|
||||
)
|
||||
|
||||
_MAX_COMPLETION_KEYS = (
|
||||
@@ -144,6 +134,9 @@ _MAX_COMPLETION_KEYS = (
|
||||
"max_tokens",
|
||||
)
|
||||
|
||||
# Local server hostnames / address patterns
|
||||
_LOCAL_HOSTS = ("localhost", "127.0.0.1", "::1", "0.0.0.0")
|
||||
|
||||
|
||||
def _normalize_base_url(base_url: str) -> str:
|
||||
return (base_url or "").strip().rstrip("/")
|
||||
@@ -176,6 +169,99 @@ def _is_known_provider_base_url(base_url: str) -> bool:
|
||||
return any(known_host in host for known_host in known_hosts)
|
||||
|
||||
|
||||
def is_local_endpoint(base_url: str) -> bool:
|
||||
"""Return True if base_url points to a local machine (localhost / RFC-1918 / WSL)."""
|
||||
normalized = _normalize_base_url(base_url)
|
||||
if not normalized:
|
||||
return False
|
||||
url = normalized if "://" in normalized else f"http://{normalized}"
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
host = parsed.hostname or ""
|
||||
except Exception:
|
||||
return False
|
||||
if host in _LOCAL_HOSTS:
|
||||
return True
|
||||
# RFC-1918 private ranges and link-local
|
||||
import ipaddress
|
||||
try:
|
||||
addr = ipaddress.ip_address(host)
|
||||
return addr.is_private or addr.is_loopback or addr.is_link_local
|
||||
except ValueError:
|
||||
pass
|
||||
# Bare IP that looks like a private range (e.g. 172.26.x.x for WSL)
|
||||
parts = host.split(".")
|
||||
if len(parts) == 4:
|
||||
try:
|
||||
first, second = int(parts[0]), int(parts[1])
|
||||
if first == 10:
|
||||
return True
|
||||
if first == 172 and 16 <= second <= 31:
|
||||
return True
|
||||
if first == 192 and second == 168:
|
||||
return True
|
||||
except ValueError:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def detect_local_server_type(base_url: str) -> Optional[str]:
|
||||
"""Detect which local server is running at base_url by probing known endpoints.
|
||||
|
||||
Returns one of: "ollama", "lm-studio", "vllm", "llamacpp", or None.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
normalized = _normalize_base_url(base_url)
|
||||
server_url = normalized
|
||||
if server_url.endswith("/v1"):
|
||||
server_url = server_url[:-3]
|
||||
|
||||
try:
|
||||
with httpx.Client(timeout=2.0) as client:
|
||||
# LM Studio exposes /api/v1/models — check first (most specific)
|
||||
try:
|
||||
r = client.get(f"{server_url}/api/v1/models")
|
||||
if r.status_code == 200:
|
||||
return "lm-studio"
|
||||
except Exception:
|
||||
pass
|
||||
# Ollama exposes /api/tags and responds with {"models": [...]}
|
||||
# LM Studio returns {"error": "Unexpected endpoint"} with status 200
|
||||
# on this path, so we must verify the response contains "models".
|
||||
try:
|
||||
r = client.get(f"{server_url}/api/tags")
|
||||
if r.status_code == 200:
|
||||
try:
|
||||
data = r.json()
|
||||
if "models" in data:
|
||||
return "ollama"
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
# llama.cpp exposes /props
|
||||
try:
|
||||
r = client.get(f"{server_url}/props")
|
||||
if r.status_code == 200 and "default_generation_settings" in r.text:
|
||||
return "llamacpp"
|
||||
except Exception:
|
||||
pass
|
||||
# vLLM: /version
|
||||
try:
|
||||
r = client.get(f"{server_url}/version")
|
||||
if r.status_code == 200:
|
||||
data = r.json()
|
||||
if "version" in data:
|
||||
return "vllm"
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _iter_nested_dicts(value: Any):
|
||||
if isinstance(value, dict):
|
||||
yield value
|
||||
@@ -342,6 +428,25 @@ def fetch_endpoint_model_metadata(
|
||||
entry["pricing"] = pricing
|
||||
_add_model_aliases(cache, model_id, entry)
|
||||
|
||||
# If this is a llama.cpp server, query /props for actual allocated context
|
||||
is_llamacpp = any(
|
||||
m.get("owned_by") == "llamacpp"
|
||||
for m in payload.get("data", []) if isinstance(m, dict)
|
||||
)
|
||||
if is_llamacpp:
|
||||
try:
|
||||
props_url = candidate.rstrip("/").replace("/v1", "") + "/props"
|
||||
props_resp = requests.get(props_url, headers=headers, timeout=5)
|
||||
if props_resp.ok:
|
||||
props = props_resp.json()
|
||||
gen_settings = props.get("default_generation_settings", {})
|
||||
n_ctx = gen_settings.get("n_ctx")
|
||||
model_alias = props.get("model_alias", "")
|
||||
if n_ctx and model_alias and model_alias in cache:
|
||||
cache[model_alias]["context_length"] = n_ctx
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
_endpoint_model_metadata_cache[normalized] = cache
|
||||
_endpoint_model_metadata_cache_time[normalized] = time.time()
|
||||
return cache
|
||||
@@ -362,7 +467,7 @@ def _get_context_cache_path() -> Path:
|
||||
|
||||
|
||||
def _load_context_cache() -> Dict[str, int]:
|
||||
"""Load the model+provider → context_length cache from disk."""
|
||||
"""Load the model+provider -> context_length cache from disk."""
|
||||
path = _get_context_cache_path()
|
||||
if not path.exists():
|
||||
return {}
|
||||
@@ -391,7 +496,7 @@ def save_context_length(model: str, base_url: str, length: int) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, "w") as f:
|
||||
yaml.dump({"context_lengths": cache}, f, default_flow_style=False)
|
||||
logger.info("Cached context length %s → %s tokens", key, f"{length:,}")
|
||||
logger.info("Cached context length %s -> %s tokens", key, f"{length:,}")
|
||||
except Exception as e:
|
||||
logger.debug("Failed to save context length cache: %s", e)
|
||||
|
||||
@@ -439,16 +544,219 @@ def parse_context_limit_from_error(error_msg: str) -> Optional[int]:
|
||||
return None
|
||||
|
||||
|
||||
def get_model_context_length(model: str, base_url: str = "", api_key: str = "") -> int:
|
||||
def _model_id_matches(candidate_id: str, lookup_model: str) -> bool:
|
||||
"""Return True if *candidate_id* (from server) matches *lookup_model* (configured).
|
||||
|
||||
Supports two forms:
|
||||
- Exact match: "nvidia-nemotron-super-49b-v1" == "nvidia-nemotron-super-49b-v1"
|
||||
- Slug match: "nvidia/nvidia-nemotron-super-49b-v1" matches "nvidia-nemotron-super-49b-v1"
|
||||
(the part after the last "/" equals lookup_model)
|
||||
|
||||
This covers LM Studio's native API which stores models as "publisher/slug"
|
||||
while users typically configure only the slug after the "local:" prefix.
|
||||
"""
|
||||
if candidate_id == lookup_model:
|
||||
return True
|
||||
# Slug match: basename of candidate equals the lookup name
|
||||
if "/" in candidate_id and candidate_id.rsplit("/", 1)[1] == lookup_model:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _query_local_context_length(model: str, base_url: str) -> Optional[int]:
|
||||
"""Query a local server for the model's context length."""
|
||||
import httpx
|
||||
|
||||
# Strip recognised provider prefix (e.g., "local:model-name" → "model-name").
|
||||
# Ollama "model:tag" colons (e.g. "qwen3.5:27b") are intentionally preserved.
|
||||
model = _strip_provider_prefix(model)
|
||||
|
||||
# Strip /v1 suffix to get the server root
|
||||
server_url = base_url.rstrip("/")
|
||||
if server_url.endswith("/v1"):
|
||||
server_url = server_url[:-3]
|
||||
|
||||
try:
|
||||
server_type = detect_local_server_type(base_url)
|
||||
except Exception:
|
||||
server_type = None
|
||||
|
||||
try:
|
||||
with httpx.Client(timeout=3.0) as client:
|
||||
# Ollama: /api/show returns model details with context info
|
||||
if server_type == "ollama":
|
||||
resp = client.post(f"{server_url}/api/show", json={"name": model})
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
# Check model_info for context length
|
||||
model_info = data.get("model_info", {})
|
||||
for key, value in model_info.items():
|
||||
if "context_length" in key and isinstance(value, (int, float)):
|
||||
return int(value)
|
||||
# Check parameters string for num_ctx
|
||||
params = data.get("parameters", "")
|
||||
if "num_ctx" in params:
|
||||
for line in params.split("\n"):
|
||||
if "num_ctx" in line:
|
||||
parts = line.strip().split()
|
||||
if len(parts) >= 2:
|
||||
try:
|
||||
return int(parts[-1])
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# LM Studio native API: /api/v1/models returns max_context_length.
|
||||
# This is more reliable than the OpenAI-compat /v1/models which
|
||||
# doesn't include context window information for LM Studio servers.
|
||||
# Use _model_id_matches for fuzzy matching: LM Studio stores models as
|
||||
# "publisher/slug" but users configure only "slug" after "local:" prefix.
|
||||
if server_type == "lm-studio":
|
||||
resp = client.get(f"{server_url}/api/v1/models")
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
for m in data.get("models", []):
|
||||
if _model_id_matches(m.get("key", ""), model) or _model_id_matches(m.get("id", ""), model):
|
||||
# Prefer loaded instance context (actual runtime value)
|
||||
for inst in m.get("loaded_instances", []):
|
||||
cfg = inst.get("config", {})
|
||||
ctx = cfg.get("context_length")
|
||||
if ctx and isinstance(ctx, (int, float)):
|
||||
return int(ctx)
|
||||
# Fall back to max_context_length (theoretical model max)
|
||||
ctx = m.get("max_context_length") or m.get("context_length")
|
||||
if ctx and isinstance(ctx, (int, float)):
|
||||
return int(ctx)
|
||||
|
||||
# LM Studio / vLLM / llama.cpp: try /v1/models/{model}
|
||||
resp = client.get(f"{server_url}/v1/models/{model}")
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
# vLLM returns max_model_len
|
||||
ctx = data.get("max_model_len") or data.get("context_length") or data.get("max_tokens")
|
||||
if ctx and isinstance(ctx, (int, float)):
|
||||
return int(ctx)
|
||||
|
||||
# Try /v1/models and find the model in the list.
|
||||
# Use _model_id_matches to handle "publisher/slug" vs bare "slug".
|
||||
resp = client.get(f"{server_url}/v1/models")
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
models_list = data.get("data", [])
|
||||
for m in models_list:
|
||||
if _model_id_matches(m.get("id", ""), model):
|
||||
ctx = m.get("max_model_len") or m.get("context_length") or m.get("max_tokens")
|
||||
if ctx and isinstance(ctx, (int, float)):
|
||||
return int(ctx)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_model_version(model: str) -> str:
|
||||
"""Normalize version separators for matching.
|
||||
|
||||
Nous uses dashes: claude-opus-4-6, claude-sonnet-4-5
|
||||
OpenRouter uses dots: claude-opus-4.6, claude-sonnet-4.5
|
||||
Normalize both to dashes for comparison.
|
||||
"""
|
||||
return model.replace(".", "-")
|
||||
|
||||
|
||||
def _query_anthropic_context_length(model: str, base_url: str, api_key: str) -> Optional[int]:
|
||||
"""Query Anthropic's /v1/models endpoint for context length.
|
||||
|
||||
Only works with regular ANTHROPIC_API_KEY (sk-ant-api*).
|
||||
OAuth tokens (sk-ant-oat*) from Claude Code return 401.
|
||||
"""
|
||||
if not api_key or api_key.startswith("sk-ant-oat"):
|
||||
return None # OAuth tokens can't access /v1/models
|
||||
try:
|
||||
base = base_url.rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
base = base[:-3]
|
||||
url = f"{base}/v1/models?limit=1000"
|
||||
headers = {
|
||||
"x-api-key": api_key,
|
||||
"anthropic-version": "2023-06-01",
|
||||
}
|
||||
resp = requests.get(url, headers=headers, timeout=10)
|
||||
if resp.status_code != 200:
|
||||
return None
|
||||
data = resp.json()
|
||||
for m in data.get("data", []):
|
||||
if m.get("id") == model:
|
||||
ctx = m.get("max_input_tokens")
|
||||
if isinstance(ctx, int) and ctx > 0:
|
||||
return ctx
|
||||
except Exception as e:
|
||||
logger.debug("Anthropic /v1/models query failed: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_nous_context_length(model: str) -> Optional[int]:
|
||||
"""Resolve Nous Portal model context length via OpenRouter metadata.
|
||||
|
||||
Nous model IDs are bare (e.g. 'claude-opus-4-6') while OpenRouter uses
|
||||
prefixed IDs (e.g. 'anthropic/claude-opus-4.6'). Try suffix matching
|
||||
with version normalization (dot↔dash).
|
||||
"""
|
||||
metadata = fetch_model_metadata() # OpenRouter cache
|
||||
# Exact match first
|
||||
if model in metadata:
|
||||
return metadata[model].get("context_length")
|
||||
|
||||
normalized = _normalize_model_version(model).lower()
|
||||
|
||||
for or_id, entry in metadata.items():
|
||||
bare = or_id.split("/", 1)[1] if "/" in or_id else or_id
|
||||
if bare.lower() == model.lower() or _normalize_model_version(bare).lower() == normalized:
|
||||
return entry.get("context_length")
|
||||
|
||||
# Partial prefix match for cases like gemini-3-flash → gemini-3-flash-preview
|
||||
# Require match to be at a word boundary (followed by -, :, or end of string)
|
||||
model_lower = model.lower()
|
||||
for or_id, entry in metadata.items():
|
||||
bare = or_id.split("/", 1)[1] if "/" in or_id else or_id
|
||||
for candidate, query in [(bare.lower(), model_lower), (_normalize_model_version(bare).lower(), normalized)]:
|
||||
if candidate.startswith(query) and (
|
||||
len(candidate) == len(query) or candidate[len(query)] in "-:."
|
||||
):
|
||||
return entry.get("context_length")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_model_context_length(
|
||||
model: str,
|
||||
base_url: str = "",
|
||||
api_key: str = "",
|
||||
config_context_length: int | None = None,
|
||||
provider: str = "",
|
||||
) -> int:
|
||||
"""Get the context length for a model.
|
||||
|
||||
Resolution order:
|
||||
0. Explicit config override (model.context_length or custom_providers per-model)
|
||||
1. Persistent cache (previously discovered via probing)
|
||||
2. Active endpoint metadata (/models for explicit custom endpoints)
|
||||
3. OpenRouter API metadata
|
||||
4. Hardcoded DEFAULT_CONTEXT_LENGTHS (fuzzy match for hosted routes only)
|
||||
5. First probe tier (2M) — will be narrowed on first context error
|
||||
3. Local server query (for local endpoints)
|
||||
4. Anthropic /v1/models API (API-key users only, not OAuth)
|
||||
5. OpenRouter live API metadata
|
||||
6. Nous suffix-match via OpenRouter cache
|
||||
7. models.dev registry lookup (provider-aware)
|
||||
8. Thin hardcoded defaults (broad family patterns)
|
||||
9. Default fallback (128K)
|
||||
"""
|
||||
# 0. Explicit config override — user knows best
|
||||
if config_context_length is not None and isinstance(config_context_length, int) and config_context_length > 0:
|
||||
return config_context_length
|
||||
|
||||
# Normalise provider-prefixed model names (e.g. "local:model-name" →
|
||||
# "model-name") so cache lookups and server queries use the bare ID that
|
||||
# local servers actually know about. Ollama "model:tag" colons are preserved.
|
||||
model = _strip_provider_prefix(model)
|
||||
|
||||
# 1. Check persistent cache (model+provider)
|
||||
if base_url:
|
||||
cached = get_cached_context_length(model, base_url)
|
||||
@@ -458,29 +766,82 @@ def get_model_context_length(model: str, base_url: str = "", api_key: str = "")
|
||||
# 2. Active endpoint metadata for explicit custom routes
|
||||
if _is_custom_endpoint(base_url):
|
||||
endpoint_metadata = fetch_endpoint_model_metadata(base_url, api_key=api_key)
|
||||
if model in endpoint_metadata:
|
||||
context_length = endpoint_metadata[model].get("context_length")
|
||||
matched = endpoint_metadata.get(model)
|
||||
if not matched:
|
||||
# Single-model servers: if only one model is loaded, use it
|
||||
if len(endpoint_metadata) == 1:
|
||||
matched = next(iter(endpoint_metadata.values()))
|
||||
else:
|
||||
# Fuzzy match: substring in either direction
|
||||
for key, entry in endpoint_metadata.items():
|
||||
if model in key or key in model:
|
||||
matched = entry
|
||||
break
|
||||
if matched:
|
||||
context_length = matched.get("context_length")
|
||||
if isinstance(context_length, int):
|
||||
return context_length
|
||||
if not _is_known_provider_base_url(base_url):
|
||||
# Explicit third-party endpoints should not borrow fuzzy global
|
||||
# defaults from unrelated providers with similarly named models.
|
||||
return CONTEXT_PROBE_TIERS[0]
|
||||
# 3. Try querying local server directly
|
||||
if is_local_endpoint(base_url):
|
||||
local_ctx = _query_local_context_length(model, base_url)
|
||||
if local_ctx and local_ctx > 0:
|
||||
save_context_length(model, base_url, local_ctx)
|
||||
return local_ctx
|
||||
logger.info(
|
||||
"Could not detect context length for model %r at %s — "
|
||||
"defaulting to %s tokens (probe-down). Set model.context_length "
|
||||
"in config.yaml to override.",
|
||||
model, base_url, f"{DEFAULT_FALLBACK_CONTEXT:,}",
|
||||
)
|
||||
return DEFAULT_FALLBACK_CONTEXT
|
||||
|
||||
# 3. OpenRouter API metadata
|
||||
# 4. Anthropic /v1/models API (only for regular API keys, not OAuth)
|
||||
if provider == "anthropic" or (
|
||||
base_url and "api.anthropic.com" in base_url
|
||||
):
|
||||
ctx = _query_anthropic_context_length(model, base_url or "https://api.anthropic.com", api_key)
|
||||
if ctx:
|
||||
return ctx
|
||||
|
||||
# 5. Provider-aware lookups (before generic OpenRouter cache)
|
||||
# These are provider-specific and take priority over the generic OR cache,
|
||||
# since the same model can have different context limits per provider
|
||||
# (e.g. claude-opus-4.6 is 1M on Anthropic but 128K on GitHub Copilot).
|
||||
if provider == "nous":
|
||||
ctx = _resolve_nous_context_length(model)
|
||||
if ctx:
|
||||
return ctx
|
||||
if provider:
|
||||
from agent.models_dev import lookup_models_dev_context
|
||||
ctx = lookup_models_dev_context(provider, model)
|
||||
if ctx:
|
||||
return ctx
|
||||
|
||||
# 6. OpenRouter live API metadata (provider-unaware fallback)
|
||||
metadata = fetch_model_metadata()
|
||||
if model in metadata:
|
||||
return metadata[model].get("context_length", 128000)
|
||||
|
||||
# 4. Hardcoded defaults (fuzzy match — longest key first for specificity)
|
||||
# 8. Hardcoded defaults (fuzzy match — longest key first for specificity)
|
||||
# Only check `default_model in model` (is the key a substring of the input).
|
||||
# The reverse (`model in default_model`) causes shorter names like
|
||||
# "claude-sonnet-4" to incorrectly match "claude-sonnet-4-6" and return 1M.
|
||||
for default_model, length in sorted(
|
||||
DEFAULT_CONTEXT_LENGTHS.items(), key=lambda x: len(x[0]), reverse=True
|
||||
):
|
||||
if default_model in model or model in default_model:
|
||||
if default_model in model:
|
||||
return length
|
||||
|
||||
# 5. Unknown model — start at highest probe tier
|
||||
return CONTEXT_PROBE_TIERS[0]
|
||||
# 9. Query local server as last resort
|
||||
if base_url and is_local_endpoint(base_url):
|
||||
local_ctx = _query_local_context_length(model, base_url)
|
||||
if local_ctx and local_ctx > 0:
|
||||
save_context_length(model, base_url, local_ctx)
|
||||
return local_ctx
|
||||
|
||||
# 10. Default fallback — 128K
|
||||
return DEFAULT_FALLBACK_CONTEXT
|
||||
|
||||
|
||||
def estimate_tokens_rough(text: str) -> int:
|
||||
|
||||
@@ -0,0 +1,171 @@
|
||||
"""Models.dev registry integration for provider-aware context length detection.
|
||||
|
||||
Fetches model metadata from https://models.dev/api.json — a community-maintained
|
||||
database of 3800+ models across 100+ providers, including per-provider context
|
||||
windows, pricing, and capabilities.
|
||||
|
||||
Data is cached in memory (1hr TTL) and on disk (~/.hermes/models_dev_cache.json)
|
||||
to avoid cold-start network latency.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MODELS_DEV_URL = "https://models.dev/api.json"
|
||||
_MODELS_DEV_CACHE_TTL = 3600 # 1 hour in-memory
|
||||
|
||||
# In-memory cache
|
||||
_models_dev_cache: Dict[str, Any] = {}
|
||||
_models_dev_cache_time: float = 0
|
||||
|
||||
# Provider ID mapping: Hermes provider names → models.dev provider IDs
|
||||
PROVIDER_TO_MODELS_DEV: Dict[str, str] = {
|
||||
"openrouter": "openrouter",
|
||||
"anthropic": "anthropic",
|
||||
"zai": "zai",
|
||||
"kimi-coding": "kimi-for-coding",
|
||||
"minimax": "minimax",
|
||||
"minimax-cn": "minimax-cn",
|
||||
"deepseek": "deepseek",
|
||||
"alibaba": "alibaba",
|
||||
"copilot": "github-copilot",
|
||||
"ai-gateway": "vercel",
|
||||
"opencode-zen": "opencode",
|
||||
"opencode-go": "opencode-go",
|
||||
"kilocode": "kilo",
|
||||
}
|
||||
|
||||
|
||||
def _get_cache_path() -> Path:
|
||||
"""Return path to disk cache file."""
|
||||
env_val = os.environ.get("HERMES_HOME", "")
|
||||
hermes_home = Path(env_val) if env_val else Path.home() / ".hermes"
|
||||
return hermes_home / "models_dev_cache.json"
|
||||
|
||||
|
||||
def _load_disk_cache() -> Dict[str, Any]:
|
||||
"""Load models.dev data from disk cache."""
|
||||
try:
|
||||
cache_path = _get_cache_path()
|
||||
if cache_path.exists():
|
||||
with open(cache_path, encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to load models.dev disk cache: %s", e)
|
||||
return {}
|
||||
|
||||
|
||||
def _save_disk_cache(data: Dict[str, Any]) -> None:
|
||||
"""Save models.dev data to disk cache."""
|
||||
try:
|
||||
cache_path = _get_cache_path()
|
||||
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(cache_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, separators=(",", ":"))
|
||||
except Exception as e:
|
||||
logger.debug("Failed to save models.dev disk cache: %s", e)
|
||||
|
||||
|
||||
def fetch_models_dev(force_refresh: bool = False) -> Dict[str, Any]:
|
||||
"""Fetch models.dev registry. In-memory cache (1hr) + disk fallback.
|
||||
|
||||
Returns the full registry dict keyed by provider ID, or empty dict on failure.
|
||||
"""
|
||||
global _models_dev_cache, _models_dev_cache_time
|
||||
|
||||
# Check in-memory cache
|
||||
if (
|
||||
not force_refresh
|
||||
and _models_dev_cache
|
||||
and (time.time() - _models_dev_cache_time) < _MODELS_DEV_CACHE_TTL
|
||||
):
|
||||
return _models_dev_cache
|
||||
|
||||
# Try network fetch
|
||||
try:
|
||||
response = requests.get(MODELS_DEV_URL, timeout=15)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
if isinstance(data, dict) and len(data) > 0:
|
||||
_models_dev_cache = data
|
||||
_models_dev_cache_time = time.time()
|
||||
_save_disk_cache(data)
|
||||
logger.debug(
|
||||
"Fetched models.dev registry: %d providers, %d total models",
|
||||
len(data),
|
||||
sum(len(p.get("models", {})) for p in data.values() if isinstance(p, dict)),
|
||||
)
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.debug("Failed to fetch models.dev: %s", e)
|
||||
|
||||
# Fall back to disk cache — use a short TTL (5 min) so we retry
|
||||
# the network fetch soon instead of serving stale data for a full hour.
|
||||
if not _models_dev_cache:
|
||||
_models_dev_cache = _load_disk_cache()
|
||||
if _models_dev_cache:
|
||||
_models_dev_cache_time = time.time() - _MODELS_DEV_CACHE_TTL + 300
|
||||
logger.debug("Loaded models.dev from disk cache (%d providers)", len(_models_dev_cache))
|
||||
|
||||
return _models_dev_cache
|
||||
|
||||
|
||||
def lookup_models_dev_context(provider: str, model: str) -> Optional[int]:
|
||||
"""Look up context_length for a provider+model combo in models.dev.
|
||||
|
||||
Returns the context window in tokens, or None if not found.
|
||||
Handles case-insensitive matching and filters out context=0 entries.
|
||||
"""
|
||||
mdev_provider_id = PROVIDER_TO_MODELS_DEV.get(provider)
|
||||
if not mdev_provider_id:
|
||||
return None
|
||||
|
||||
data = fetch_models_dev()
|
||||
provider_data = data.get(mdev_provider_id)
|
||||
if not isinstance(provider_data, dict):
|
||||
return None
|
||||
|
||||
models = provider_data.get("models", {})
|
||||
if not isinstance(models, dict):
|
||||
return None
|
||||
|
||||
# Exact match
|
||||
entry = models.get(model)
|
||||
if entry:
|
||||
ctx = _extract_context(entry)
|
||||
if ctx:
|
||||
return ctx
|
||||
|
||||
# Case-insensitive match
|
||||
model_lower = model.lower()
|
||||
for mid, mdata in models.items():
|
||||
if mid.lower() == model_lower:
|
||||
ctx = _extract_context(mdata)
|
||||
if ctx:
|
||||
return ctx
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _extract_context(entry: Dict[str, Any]) -> Optional[int]:
|
||||
"""Extract context_length from a models.dev model entry.
|
||||
|
||||
Returns None for invalid/zero values (some audio/image models have context=0).
|
||||
"""
|
||||
if not isinstance(entry, dict):
|
||||
return None
|
||||
limit = entry.get("limit")
|
||||
if not isinstance(limit, dict):
|
||||
return None
|
||||
ctx = limit.get("context")
|
||||
if isinstance(ctx, (int, float)) and ctx > 0:
|
||||
return int(ctx)
|
||||
return None
|
||||
@@ -206,11 +206,11 @@ PLATFORM_HINTS = {
|
||||
"contextually appropriate."
|
||||
),
|
||||
"cron": (
|
||||
"You are running as a scheduled cron job. Your final response is automatically "
|
||||
"delivered to the job's configured destination, so do not use send_message to "
|
||||
"send to that same target again. If you want the user to receive something in "
|
||||
"the scheduled destination, put it directly in your final response. Use "
|
||||
"send_message only for additional or different targets."
|
||||
"You are running as a scheduled cron job. There is no user present — you "
|
||||
"cannot ask questions, request clarification, or wait for follow-up. Execute "
|
||||
"the task fully and autonomously, making reasonable decisions where needed. "
|
||||
"Your final response is automatically delivered to the job's configured "
|
||||
"destination — put the primary content directly in your response."
|
||||
),
|
||||
"cli": (
|
||||
"You are a CLI AI Agent. Try not to use markdown but simple text "
|
||||
|
||||
@@ -973,6 +973,8 @@ def save_config_value(key_path: str, value: any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# HermesCLI Class
|
||||
# ============================================================================
|
||||
@@ -1046,6 +1048,14 @@ class HermesCLI:
|
||||
_config_model = _model_config.get("default", "") if isinstance(_model_config, dict) else (_model_config or "")
|
||||
_FALLBACK_MODEL = "anthropic/claude-opus-4.6"
|
||||
self.model = model or _config_model or _FALLBACK_MODEL
|
||||
# Auto-detect model from local server if still on fallback
|
||||
if self.model == _FALLBACK_MODEL:
|
||||
_base_url = _model_config.get("base_url", "") if isinstance(_model_config, dict) else ""
|
||||
if "localhost" in _base_url or "127.0.0.1" in _base_url:
|
||||
from hermes_cli.runtime_provider import _auto_detect_local_model
|
||||
_detected = _auto_detect_local_model(_base_url)
|
||||
if _detected:
|
||||
self.model = _detected
|
||||
# Track whether model was explicitly chosen by the user or fell back
|
||||
# to the global default. Provider-specific normalisation may override
|
||||
# the default silently but should warn when overriding an explicit choice.
|
||||
@@ -1251,6 +1261,8 @@ class HermesCLI:
|
||||
def _get_status_bar_snapshot(self) -> Dict[str, Any]:
|
||||
model_name = self.model or "unknown"
|
||||
model_short = model_name.split("/")[-1] if "/" in model_name else model_name
|
||||
if model_short.endswith(".gguf"):
|
||||
model_short = model_short[:-5]
|
||||
if len(model_short) > 26:
|
||||
model_short = f"{model_short[:23]}..."
|
||||
|
||||
@@ -1512,9 +1524,11 @@ class HermesCLI:
|
||||
# Track whether we're inside a reasoning/thinking block.
|
||||
# These tags are model-generated (system prompt tells the model
|
||||
# to use them) and get stripped from final_response. We must
|
||||
# suppress them during streaming too.
|
||||
_OPEN_TAGS = ("<REASONING_SCRATCHPAD>", "<think>", "<reasoning>", "<THINKING>")
|
||||
_CLOSE_TAGS = ("</REASONING_SCRATCHPAD>", "</think>", "</reasoning>", "</THINKING>")
|
||||
# suppress them during streaming too — unless show_reasoning is
|
||||
# enabled, in which case we route the inner content to the
|
||||
# reasoning display box instead of discarding it.
|
||||
_OPEN_TAGS = ("<REASONING_SCRATCHPAD>", "<think>", "<reasoning>", "<THINKING>", "<thinking>")
|
||||
_CLOSE_TAGS = ("</REASONING_SCRATCHPAD>", "</think>", "</reasoning>", "</THINKING>", "</thinking>")
|
||||
|
||||
# Append to a pre-filter buffer first
|
||||
self._stream_prefilt = getattr(self, "_stream_prefilt", "") + text
|
||||
@@ -1554,6 +1568,12 @@ class HermesCLI:
|
||||
idx = self._stream_prefilt.find(tag)
|
||||
if idx != -1:
|
||||
self._in_reasoning_block = False
|
||||
# When show_reasoning is on, route inner content to
|
||||
# the reasoning display box instead of discarding.
|
||||
if self.show_reasoning:
|
||||
inner = self._stream_prefilt[:idx]
|
||||
if inner:
|
||||
self._stream_reasoning_delta(inner)
|
||||
after = self._stream_prefilt[idx + len(tag):]
|
||||
self._stream_prefilt = ""
|
||||
# Process remaining text after close tag through full
|
||||
@@ -1561,10 +1581,15 @@ class HermesCLI:
|
||||
if after:
|
||||
self._stream_delta(after)
|
||||
return
|
||||
# Still inside reasoning block — keep only the tail that could
|
||||
# be a partial close tag prefix (save memory on long blocks).
|
||||
# When show_reasoning is on, stream reasoning content live
|
||||
# instead of silently accumulating. Keep only the tail that
|
||||
# could be a partial close tag prefix.
|
||||
max_tag_len = max(len(t) for t in _CLOSE_TAGS)
|
||||
if len(self._stream_prefilt) > max_tag_len:
|
||||
if self.show_reasoning:
|
||||
# Route the safe prefix to reasoning display
|
||||
safe_reasoning = self._stream_prefilt[:-max_tag_len]
|
||||
self._stream_reasoning_delta(safe_reasoning)
|
||||
self._stream_prefilt = self._stream_prefilt[-max_tag_len:]
|
||||
return
|
||||
|
||||
@@ -2721,6 +2746,7 @@ class HermesCLI:
|
||||
if self.agent:
|
||||
self.agent.session_id = self.session_id
|
||||
self.agent.session_start = self.session_start
|
||||
self.agent.reset_session_state()
|
||||
if hasattr(self.agent, "_last_flushed_db_idx"):
|
||||
self.agent._last_flushed_db_idx = 0
|
||||
if hasattr(self.agent, "_todo_store"):
|
||||
@@ -2880,6 +2906,14 @@ class HermesCLI:
|
||||
for mid, desc in curated:
|
||||
current_marker = " ← current" if (is_active and mid == self.model) else ""
|
||||
print(f" {mid}{current_marker}")
|
||||
elif p["id"] == "custom":
|
||||
from hermes_cli.models import _get_custom_base_url
|
||||
custom_url = _get_custom_base_url() or os.getenv("OPENAI_BASE_URL", "")
|
||||
if custom_url:
|
||||
print(f" endpoint: {custom_url}")
|
||||
if is_active:
|
||||
print(f" model: {self.model} ← current")
|
||||
print(f" (use /model custom:<model-name>)")
|
||||
else:
|
||||
print(f" (use /model {p['id']}:<model-name>)")
|
||||
print()
|
||||
@@ -3483,8 +3517,17 @@ class HermesCLI:
|
||||
# Parse provider:model syntax (e.g. "openrouter:anthropic/claude-sonnet-4.5")
|
||||
current_provider = self.provider or self.requested_provider or "openrouter"
|
||||
target_provider, new_model = parse_model_input(raw_input, current_provider)
|
||||
# Auto-detect provider when no explicit provider:model syntax was used
|
||||
if target_provider == current_provider:
|
||||
# Auto-detect provider when no explicit provider:model syntax was used.
|
||||
# Skip auto-detection for custom providers — the model name might
|
||||
# coincidentally match a known provider's catalog, but the user
|
||||
# intends to use it on their custom endpoint. Require explicit
|
||||
# provider:model syntax (e.g. /model openai-codex:gpt-5.2-codex)
|
||||
# to switch away from a custom endpoint.
|
||||
_base = self.base_url or ""
|
||||
is_custom = current_provider == "custom" or (
|
||||
"localhost" in _base or "127.0.0.1" in _base
|
||||
)
|
||||
if target_provider == current_provider and not is_custom:
|
||||
from hermes_cli.models import detect_provider_for_model
|
||||
detected = detect_provider_for_model(new_model, current_provider)
|
||||
if detected:
|
||||
@@ -3552,6 +3595,13 @@ class HermesCLI:
|
||||
if message:
|
||||
print(f" Reason: {message}")
|
||||
print(" Note: Model will revert on restart. Use a verified model to save to config.")
|
||||
|
||||
# Helpful hint when staying on a custom endpoint
|
||||
if is_custom and not provider_changed:
|
||||
endpoint = self.base_url or "custom endpoint"
|
||||
print(f" Endpoint: {endpoint}")
|
||||
print(f" Tip: To switch providers, use /model provider:model")
|
||||
print(f" e.g. /model openai-codex:gpt-5.2-codex")
|
||||
else:
|
||||
self._show_model_and_providers()
|
||||
elif canonical == "provider":
|
||||
@@ -3628,6 +3678,18 @@ class HermesCLI:
|
||||
self._handle_stop_command()
|
||||
elif canonical == "background":
|
||||
self._handle_background_command(cmd_original)
|
||||
elif canonical == "queue":
|
||||
if not self._agent_running:
|
||||
_cprint(" /queue only works while Hermes is busy. Just type your message normally.")
|
||||
else:
|
||||
# Extract prompt after "/queue " or "/q "
|
||||
parts = cmd_original.split(None, 1)
|
||||
payload = parts[1].strip() if len(parts) > 1 else ""
|
||||
if not payload:
|
||||
_cprint(" Usage: /queue <prompt>")
|
||||
else:
|
||||
self._pending_input.put(payload)
|
||||
_cprint(f" Queued for the next turn: {payload[:80]}{'...' if len(payload) > 80 else ''}")
|
||||
elif canonical == "skin":
|
||||
self._handle_skin_command(cmd_original)
|
||||
elif canonical == "voice":
|
||||
@@ -3916,7 +3978,7 @@ class HermesCLI:
|
||||
parts = cmd.strip().split(None, 1)
|
||||
sub = parts[1].lower().strip() if len(parts) > 1 else "status"
|
||||
|
||||
_DEFAULT_CDP = "ws://localhost:9222"
|
||||
_DEFAULT_CDP = "http://localhost:9222"
|
||||
current = os.environ.get("BROWSER_CDP_URL", "").strip()
|
||||
|
||||
if sub.startswith("connect"):
|
||||
@@ -5877,7 +5939,12 @@ class HermesCLI:
|
||||
|
||||
@kb.add('tab', eager=True)
|
||||
def handle_tab(event):
|
||||
"""Tab: accept completion and re-trigger if we just completed a provider.
|
||||
"""Tab: accept completion, auto-suggestion, or start completions.
|
||||
|
||||
Priority:
|
||||
1. Completion menu open → accept selected completion
|
||||
2. Ghost text suggestion available → accept auto-suggestion
|
||||
3. Otherwise → start completion menu
|
||||
|
||||
After accepting a provider like 'anthropic:', the completion menu
|
||||
closes and complete_while_typing doesn't fire (no keystroke).
|
||||
@@ -5886,6 +5953,7 @@ class HermesCLI:
|
||||
"""
|
||||
buf = event.current_buffer
|
||||
if buf.complete_state:
|
||||
# Completion menu is open — accept the selection
|
||||
completion = buf.complete_state.current_completion
|
||||
if completion is None:
|
||||
# Menu open but nothing selected — select first then grab it
|
||||
@@ -5899,8 +5967,11 @@ class HermesCLI:
|
||||
text = buf.document.text_before_cursor
|
||||
if text.startswith("/model ") and text.endswith(":"):
|
||||
buf.start_completion()
|
||||
elif buf.suggestion and buf.suggestion.text:
|
||||
# No completion menu, but there's a ghost text auto-suggestion — accept it
|
||||
buf.insert_text(buf.suggestion.text)
|
||||
else:
|
||||
# No menu open — start completions from scratch
|
||||
# No menu and no suggestion — start completions from scratch
|
||||
buf.start_completion()
|
||||
|
||||
# --- Clarify tool: arrow-key navigation for multiple-choice questions ---
|
||||
|
||||
+18
-2
@@ -136,6 +136,10 @@ def _deliver_result(job: dict, content: str) -> None:
|
||||
"slack": Platform.SLACK,
|
||||
"whatsapp": Platform.WHATSAPP,
|
||||
"signal": Platform.SIGNAL,
|
||||
"matrix": Platform.MATRIX,
|
||||
"mattermost": Platform.MATTERMOST,
|
||||
"homeassistant": Platform.HOMEASSISTANT,
|
||||
"dingtalk": Platform.DINGTALK,
|
||||
"email": Platform.EMAIL,
|
||||
"sms": Platform.SMS,
|
||||
}
|
||||
@@ -207,11 +211,14 @@ def _build_job_prompt(job: dict) -> str:
|
||||
from tools.skills_tool import skill_view
|
||||
|
||||
parts = []
|
||||
skipped: list[str] = []
|
||||
for skill_name in skill_names:
|
||||
loaded = json.loads(skill_view(skill_name))
|
||||
if not loaded.get("success"):
|
||||
error = loaded.get("error") or f"Failed to load skill '{skill_name}'"
|
||||
raise RuntimeError(error)
|
||||
logger.warning("Cron job '%s': skill not found, skipping — %s", job.get("name", job.get("id")), error)
|
||||
skipped.append(skill_name)
|
||||
continue
|
||||
|
||||
content = str(loaded.get("content") or "").strip()
|
||||
if parts:
|
||||
@@ -224,6 +231,15 @@ def _build_job_prompt(job: dict) -> str:
|
||||
]
|
||||
)
|
||||
|
||||
if skipped:
|
||||
notice = (
|
||||
f"[SYSTEM: The following skill(s) were listed for this job but could not be found "
|
||||
f"and were skipped: {', '.join(skipped)}. "
|
||||
f"Start your response with a brief notice so the user is aware, e.g.: "
|
||||
f"'⚠️ Skill(s) not found and skipped: {', '.join(skipped)}']"
|
||||
)
|
||||
parts.insert(0, notice)
|
||||
|
||||
if prompt:
|
||||
parts.extend(["", f"The user has provided the following instruction alongside the skill invocation: {prompt}"])
|
||||
return "\n".join(parts)
|
||||
@@ -379,7 +395,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
providers_ignored=pr.get("ignore"),
|
||||
providers_order=pr.get("order"),
|
||||
provider_sort=pr.get("sort"),
|
||||
disabled_toolsets=["cronjob"],
|
||||
disabled_toolsets=["cronjob", "messaging", "clarify"],
|
||||
quiet_mode=True,
|
||||
platform="cron",
|
||||
session_id=f"cron_{job_id}_{_hermes_now().strftime('%Y%m%d_%H%M%S')}",
|
||||
|
||||
@@ -56,6 +56,7 @@ class Platform(Enum):
|
||||
SMS = "sms"
|
||||
DINGTALK = "dingtalk"
|
||||
API_SERVER = "api_server"
|
||||
WEBHOOK = "webhook"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -254,6 +255,9 @@ class GatewayConfig:
|
||||
# API Server uses enabled flag only (no token needed)
|
||||
elif platform == Platform.API_SERVER:
|
||||
connected.append(platform)
|
||||
# Webhook uses enabled flag only (secrets are per-route)
|
||||
elif platform == Platform.WEBHOOK:
|
||||
connected.append(platform)
|
||||
return connected
|
||||
|
||||
def get_home_channel(self, platform: Platform) -> Optional[HomeChannel]:
|
||||
@@ -734,6 +738,22 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
if api_server_host:
|
||||
config.platforms[Platform.API_SERVER].extra["host"] = api_server_host
|
||||
|
||||
# Webhook platform
|
||||
webhook_enabled = os.getenv("WEBHOOK_ENABLED", "").lower() in ("true", "1", "yes")
|
||||
webhook_port = os.getenv("WEBHOOK_PORT")
|
||||
webhook_secret = os.getenv("WEBHOOK_SECRET", "")
|
||||
if webhook_enabled:
|
||||
if Platform.WEBHOOK not in config.platforms:
|
||||
config.platforms[Platform.WEBHOOK] = PlatformConfig()
|
||||
config.platforms[Platform.WEBHOOK].enabled = True
|
||||
if webhook_port:
|
||||
try:
|
||||
config.platforms[Platform.WEBHOOK].extra["port"] = int(webhook_port)
|
||||
except ValueError:
|
||||
pass
|
||||
if webhook_secret:
|
||||
config.platforms[Platform.WEBHOOK].extra["secret"] = webhook_secret
|
||||
|
||||
# Session settings
|
||||
idle_minutes = os.getenv("SESSION_IDLE_MINUTES")
|
||||
if idle_minutes:
|
||||
|
||||
@@ -179,6 +179,11 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
# Normalize account for self-message filtering
|
||||
self._account_normalized = self.account.strip()
|
||||
|
||||
# Track recently sent message timestamps to prevent echo-back loops
|
||||
# in Note to Self / self-chat mode (mirrors WhatsApp recentlySentIds)
|
||||
self._recent_sent_timestamps: set = set()
|
||||
self._max_recent_timestamps = 50
|
||||
|
||||
logger.info("Signal adapter initialized: url=%s account=%s groups=%s",
|
||||
self.http_url, _redact_phone(self.account),
|
||||
"enabled" if self.group_allow_from else "disabled")
|
||||
@@ -353,10 +358,26 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
# Unwrap nested envelope if present
|
||||
envelope_data = envelope.get("envelope", envelope)
|
||||
|
||||
# Filter syncMessage envelopes (sent transcripts, read receipts, etc.)
|
||||
# signal-cli may set syncMessage to null vs omitting it, so check key existence
|
||||
# Handle syncMessage: extract "Note to Self" messages (sent to own account)
|
||||
# while still filtering other sync events (read receipts, typing, etc.)
|
||||
is_note_to_self = False
|
||||
if "syncMessage" in envelope_data:
|
||||
return
|
||||
sync_msg = envelope_data.get("syncMessage")
|
||||
if sync_msg and isinstance(sync_msg, dict):
|
||||
sent_msg = sync_msg.get("sentMessage")
|
||||
if sent_msg and isinstance(sent_msg, dict):
|
||||
dest = sent_msg.get("destinationNumber") or sent_msg.get("destination")
|
||||
sent_ts = sent_msg.get("timestamp")
|
||||
if dest == self._account_normalized:
|
||||
# Check if this is an echo of our own outbound reply
|
||||
if sent_ts and sent_ts in self._recent_sent_timestamps:
|
||||
self._recent_sent_timestamps.discard(sent_ts)
|
||||
return
|
||||
# Genuine user Note to Self — promote to dataMessage
|
||||
is_note_to_self = True
|
||||
envelope_data = {**envelope_data, "dataMessage": sent_msg}
|
||||
if not is_note_to_self:
|
||||
return
|
||||
|
||||
# Extract sender info
|
||||
sender = (
|
||||
@@ -371,8 +392,8 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
logger.debug("Signal: ignoring envelope with no sender")
|
||||
return
|
||||
|
||||
# Self-message filtering — prevent reply loops
|
||||
if self._account_normalized and sender == self._account_normalized:
|
||||
# Self-message filtering — prevent reply loops (but allow Note to Self)
|
||||
if self._account_normalized and sender == self._account_normalized and not is_note_to_self:
|
||||
return
|
||||
|
||||
# Filter stories
|
||||
@@ -577,9 +598,18 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
result = await self._rpc("send", params)
|
||||
|
||||
if result is not None:
|
||||
self._track_sent_timestamp(result)
|
||||
return SendResult(success=True)
|
||||
return SendResult(success=False, error="RPC send failed")
|
||||
|
||||
def _track_sent_timestamp(self, rpc_result) -> None:
|
||||
"""Record outbound message timestamp for echo-back filtering."""
|
||||
ts = rpc_result.get("timestamp") if isinstance(rpc_result, dict) else None
|
||||
if ts:
|
||||
self._recent_sent_timestamps.add(ts)
|
||||
if len(self._recent_sent_timestamps) > self._max_recent_timestamps:
|
||||
self._recent_sent_timestamps.pop()
|
||||
|
||||
async def send_typing(self, chat_id: str, metadata=None) -> None:
|
||||
"""Send a typing indicator."""
|
||||
params: Dict[str, Any] = {
|
||||
@@ -635,6 +665,7 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
|
||||
result = await self._rpc("send", params)
|
||||
if result is not None:
|
||||
self._track_sent_timestamp(result)
|
||||
return SendResult(success=True)
|
||||
return SendResult(success=False, error="RPC send with attachment failed")
|
||||
|
||||
@@ -665,6 +696,7 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
|
||||
result = await self._rpc("send", params)
|
||||
if result is not None:
|
||||
self._track_sent_timestamp(result)
|
||||
return SendResult(success=True)
|
||||
return SendResult(success=False, error="RPC send document failed")
|
||||
|
||||
|
||||
@@ -0,0 +1,557 @@
|
||||
"""Generic webhook platform adapter.
|
||||
|
||||
Runs an aiohttp HTTP server that receives webhook POSTs from external
|
||||
services (GitHub, GitLab, JIRA, Stripe, etc.), validates HMAC signatures,
|
||||
transforms payloads into agent prompts, and routes responses back to the
|
||||
source or to another configured platform.
|
||||
|
||||
Configuration lives in config.yaml under platforms.webhook.extra.routes.
|
||||
Each route defines:
|
||||
- events: which event types to accept (header-based filtering)
|
||||
- secret: HMAC secret for signature validation (REQUIRED)
|
||||
- prompt: template string formatted with the webhook payload
|
||||
- skills: optional list of skills to load for the agent
|
||||
- deliver: where to send the response (github_comment, telegram, etc.)
|
||||
- deliver_extra: additional delivery config (repo, pr_number, chat_id)
|
||||
|
||||
Security:
|
||||
- HMAC secret is required per route (validated at startup)
|
||||
- Rate limiting per route (fixed-window, configurable)
|
||||
- Idempotency cache prevents duplicate agent runs on webhook retries
|
||||
- Body size limits checked before reading payload
|
||||
- Set secret to "INSECURE_NO_AUTH" to skip validation (testing only)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import subprocess
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
try:
|
||||
from aiohttp import web
|
||||
|
||||
AIOHTTP_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIOHTTP_AVAILABLE = False
|
||||
web = None # type: ignore[assignment]
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_HOST = "0.0.0.0"
|
||||
DEFAULT_PORT = 8644
|
||||
_INSECURE_NO_AUTH = "INSECURE_NO_AUTH"
|
||||
|
||||
|
||||
def check_webhook_requirements() -> bool:
|
||||
"""Check if webhook adapter dependencies are available."""
|
||||
return AIOHTTP_AVAILABLE
|
||||
|
||||
|
||||
class WebhookAdapter(BasePlatformAdapter):
|
||||
"""Generic webhook receiver that triggers agent runs from HTTP POSTs."""
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.WEBHOOK)
|
||||
self._host: str = config.extra.get("host", DEFAULT_HOST)
|
||||
self._port: int = int(config.extra.get("port", DEFAULT_PORT))
|
||||
self._global_secret: str = config.extra.get("secret", "")
|
||||
self._routes: Dict[str, dict] = config.extra.get("routes", {})
|
||||
self._runner = None
|
||||
|
||||
# Delivery info keyed by session chat_id — consumed by send()
|
||||
self._delivery_info: Dict[str, dict] = {}
|
||||
|
||||
# Reference to gateway runner for cross-platform delivery (set externally)
|
||||
self.gateway_runner = None
|
||||
|
||||
# Idempotency: TTL cache of recently processed delivery IDs.
|
||||
# Prevents duplicate agent runs when webhook providers retry.
|
||||
self._seen_deliveries: Dict[str, float] = {}
|
||||
self._idempotency_ttl: int = 3600 # 1 hour
|
||||
|
||||
# Rate limiting: per-route timestamps in a fixed window.
|
||||
self._rate_counts: Dict[str, List[float]] = {}
|
||||
self._rate_limit: int = int(config.extra.get("rate_limit", 30)) # per minute
|
||||
|
||||
# Body size limit (auth-before-body pattern)
|
||||
self._max_body_bytes: int = int(
|
||||
config.extra.get("max_body_bytes", 1_048_576)
|
||||
) # 1MB
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def connect(self) -> bool:
|
||||
# Validate routes at startup — secret is required per route
|
||||
for name, route in self._routes.items():
|
||||
secret = route.get("secret", self._global_secret)
|
||||
if not secret:
|
||||
raise ValueError(
|
||||
f"[webhook] Route '{name}' has no HMAC secret. "
|
||||
f"Set 'secret' on the route or globally. "
|
||||
f"For testing without auth, set secret to '{_INSECURE_NO_AUTH}'."
|
||||
)
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get("/health", self._handle_health)
|
||||
app.router.add_post("/webhooks/{route_name}", self._handle_webhook)
|
||||
|
||||
self._runner = web.AppRunner(app)
|
||||
await self._runner.setup()
|
||||
site = web.TCPSite(self._runner, self._host, self._port)
|
||||
await site.start()
|
||||
self._mark_connected()
|
||||
|
||||
route_names = ", ".join(self._routes.keys()) or "(none configured)"
|
||||
logger.info(
|
||||
"[webhook] Listening on %s:%d — routes: %s",
|
||||
self._host,
|
||||
self._port,
|
||||
route_names,
|
||||
)
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
if self._runner:
|
||||
await self._runner.cleanup()
|
||||
self._runner = None
|
||||
self._mark_disconnected()
|
||||
logger.info("[webhook] Disconnected")
|
||||
|
||||
async def send(
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
"""Deliver the agent's response to the configured destination.
|
||||
|
||||
chat_id is ``webhook:{route}:{delivery_id}`` — we pop the delivery
|
||||
info stored during webhook receipt so it doesn't leak memory.
|
||||
"""
|
||||
delivery = self._delivery_info.pop(chat_id, {})
|
||||
deliver_type = delivery.get("deliver", "log")
|
||||
|
||||
if deliver_type == "log":
|
||||
logger.info("[webhook] Response for %s: %s", chat_id, content[:200])
|
||||
return SendResult(success=True)
|
||||
|
||||
if deliver_type == "github_comment":
|
||||
return await self._deliver_github_comment(content, delivery)
|
||||
|
||||
# Cross-platform delivery (telegram, discord, etc.)
|
||||
if self.gateway_runner and deliver_type in (
|
||||
"telegram",
|
||||
"discord",
|
||||
"slack",
|
||||
"signal",
|
||||
"sms",
|
||||
):
|
||||
return await self._deliver_cross_platform(
|
||||
deliver_type, content, delivery
|
||||
)
|
||||
|
||||
logger.warning("[webhook] Unknown deliver type: %s", deliver_type)
|
||||
return SendResult(
|
||||
success=False, error=f"Unknown deliver type: {deliver_type}"
|
||||
)
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
return {"name": chat_id, "type": "webhook"}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# HTTP handlers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _handle_health(self, request: "web.Request") -> "web.Response":
|
||||
"""GET /health — simple health check."""
|
||||
return web.json_response({"status": "ok", "platform": "webhook"})
|
||||
|
||||
async def _handle_webhook(self, request: "web.Request") -> "web.Response":
|
||||
"""POST /webhooks/{route_name} — receive and process a webhook event."""
|
||||
route_name = request.match_info.get("route_name", "")
|
||||
route_config = self._routes.get(route_name)
|
||||
|
||||
if not route_config:
|
||||
return web.json_response(
|
||||
{"error": f"Unknown route: {route_name}"}, status=404
|
||||
)
|
||||
|
||||
# ── Auth-before-body ─────────────────────────────────────
|
||||
# Check Content-Length before reading the full payload.
|
||||
content_length = request.content_length or 0
|
||||
if content_length > self._max_body_bytes:
|
||||
return web.json_response(
|
||||
{"error": "Payload too large"}, status=413
|
||||
)
|
||||
|
||||
# ── Rate limiting ────────────────────────────────────────
|
||||
now = time.time()
|
||||
window = self._rate_counts.setdefault(route_name, [])
|
||||
window[:] = [t for t in window if now - t < 60]
|
||||
if len(window) >= self._rate_limit:
|
||||
return web.json_response(
|
||||
{"error": "Rate limit exceeded"}, status=429
|
||||
)
|
||||
window.append(now)
|
||||
|
||||
# Read body
|
||||
try:
|
||||
raw_body = await request.read()
|
||||
except Exception as e:
|
||||
logger.error("[webhook] Failed to read body: %s", e)
|
||||
return web.json_response({"error": "Bad request"}, status=400)
|
||||
|
||||
# Validate HMAC signature (skip for INSECURE_NO_AUTH testing mode)
|
||||
secret = route_config.get("secret", self._global_secret)
|
||||
if secret and secret != _INSECURE_NO_AUTH:
|
||||
if not self._validate_signature(request, raw_body, secret):
|
||||
logger.warning(
|
||||
"[webhook] Invalid signature for route %s", route_name
|
||||
)
|
||||
return web.json_response(
|
||||
{"error": "Invalid signature"}, status=401
|
||||
)
|
||||
|
||||
# Parse payload
|
||||
try:
|
||||
payload = json.loads(raw_body)
|
||||
except json.JSONDecodeError:
|
||||
# Try form-encoded as fallback
|
||||
try:
|
||||
import urllib.parse
|
||||
|
||||
payload = dict(
|
||||
urllib.parse.parse_qsl(raw_body.decode("utf-8"))
|
||||
)
|
||||
except Exception:
|
||||
return web.json_response(
|
||||
{"error": "Cannot parse body"}, status=400
|
||||
)
|
||||
|
||||
# Check event type filter
|
||||
event_type = (
|
||||
request.headers.get("X-GitHub-Event", "")
|
||||
or request.headers.get("X-GitLab-Event", "")
|
||||
or payload.get("event_type", "")
|
||||
or "unknown"
|
||||
)
|
||||
allowed_events = route_config.get("events", [])
|
||||
if allowed_events and event_type not in allowed_events:
|
||||
logger.debug(
|
||||
"[webhook] Ignoring event %s for route %s (allowed: %s)",
|
||||
event_type,
|
||||
route_name,
|
||||
allowed_events,
|
||||
)
|
||||
return web.json_response(
|
||||
{"status": "ignored", "event": event_type}
|
||||
)
|
||||
|
||||
# Format prompt from template
|
||||
prompt_template = route_config.get("prompt", "")
|
||||
prompt = self._render_prompt(
|
||||
prompt_template, payload, event_type, route_name
|
||||
)
|
||||
|
||||
# Inject skill content if configured.
|
||||
# We call build_skill_invocation_message() directly rather than
|
||||
# using /skill-name slash commands — the gateway's command parser
|
||||
# would intercept those and break the flow.
|
||||
skills = route_config.get("skills", [])
|
||||
if skills:
|
||||
try:
|
||||
from agent.skill_commands import (
|
||||
build_skill_invocation_message,
|
||||
get_skill_commands,
|
||||
)
|
||||
|
||||
skill_cmds = get_skill_commands()
|
||||
for skill_name in skills:
|
||||
cmd_key = f"/{skill_name}"
|
||||
if cmd_key in skill_cmds:
|
||||
skill_content = build_skill_invocation_message(
|
||||
cmd_key, user_instruction=prompt
|
||||
)
|
||||
if skill_content:
|
||||
prompt = skill_content
|
||||
break # Load the first matching skill
|
||||
else:
|
||||
logger.warning(
|
||||
"[webhook] Skill '%s' not found", skill_name
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("[webhook] Skill loading failed: %s", e)
|
||||
|
||||
# Build a unique delivery ID
|
||||
delivery_id = request.headers.get(
|
||||
"X-GitHub-Delivery",
|
||||
request.headers.get("X-Request-ID", str(int(time.time() * 1000))),
|
||||
)
|
||||
|
||||
# ── Idempotency ─────────────────────────────────────────
|
||||
# Skip duplicate deliveries (webhook retries).
|
||||
now = time.time()
|
||||
# Prune expired entries
|
||||
self._seen_deliveries = {
|
||||
k: v
|
||||
for k, v in self._seen_deliveries.items()
|
||||
if now - v < self._idempotency_ttl
|
||||
}
|
||||
if delivery_id in self._seen_deliveries:
|
||||
logger.info(
|
||||
"[webhook] Skipping duplicate delivery %s", delivery_id
|
||||
)
|
||||
return web.json_response(
|
||||
{"status": "duplicate", "delivery_id": delivery_id},
|
||||
status=200,
|
||||
)
|
||||
self._seen_deliveries[delivery_id] = now
|
||||
|
||||
# Use delivery_id in session key so concurrent webhooks on the
|
||||
# same route get independent agent runs (not queued/interrupted).
|
||||
session_chat_id = f"webhook:{route_name}:{delivery_id}"
|
||||
|
||||
# Store delivery info for send() — consumed (popped) on delivery
|
||||
deliver_config = {
|
||||
"deliver": route_config.get("deliver", "log"),
|
||||
"deliver_extra": self._render_delivery_extra(
|
||||
route_config.get("deliver_extra", {}), payload
|
||||
),
|
||||
"payload": payload,
|
||||
}
|
||||
self._delivery_info[session_chat_id] = deliver_config
|
||||
|
||||
# Build source and event
|
||||
source = self.build_source(
|
||||
chat_id=session_chat_id,
|
||||
chat_name=f"webhook/{route_name}",
|
||||
chat_type="webhook",
|
||||
user_id=f"webhook:{route_name}",
|
||||
user_name=route_name,
|
||||
)
|
||||
event = MessageEvent(
|
||||
text=prompt,
|
||||
message_type=MessageType.TEXT,
|
||||
source=source,
|
||||
raw_message=payload,
|
||||
message_id=delivery_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"[webhook] %s event=%s route=%s prompt_len=%d delivery=%s",
|
||||
request.method,
|
||||
event_type,
|
||||
route_name,
|
||||
len(prompt),
|
||||
delivery_id,
|
||||
)
|
||||
|
||||
# Non-blocking — return 202 Accepted immediately
|
||||
asyncio.create_task(self.handle_message(event))
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"status": "accepted",
|
||||
"route": route_name,
|
||||
"event": event_type,
|
||||
"delivery_id": delivery_id,
|
||||
},
|
||||
status=202,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Signature validation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _validate_signature(
|
||||
self, request: "web.Request", body: bytes, secret: str
|
||||
) -> bool:
|
||||
"""Validate webhook signature (GitHub, GitLab, generic HMAC-SHA256)."""
|
||||
# GitHub: X-Hub-Signature-256 = sha256=<hex>
|
||||
gh_sig = request.headers.get("X-Hub-Signature-256", "")
|
||||
if gh_sig:
|
||||
expected = "sha256=" + hmac.new(
|
||||
secret.encode(), body, hashlib.sha256
|
||||
).hexdigest()
|
||||
return hmac.compare_digest(gh_sig, expected)
|
||||
|
||||
# GitLab: X-Gitlab-Token = <plain secret>
|
||||
gl_token = request.headers.get("X-Gitlab-Token", "")
|
||||
if gl_token:
|
||||
return hmac.compare_digest(gl_token, secret)
|
||||
|
||||
# Generic: X-Webhook-Signature = <hex HMAC-SHA256>
|
||||
generic_sig = request.headers.get("X-Webhook-Signature", "")
|
||||
if generic_sig:
|
||||
expected = hmac.new(
|
||||
secret.encode(), body, hashlib.sha256
|
||||
).hexdigest()
|
||||
return hmac.compare_digest(generic_sig, expected)
|
||||
|
||||
# No recognised signature header but secret is configured → reject
|
||||
logger.debug(
|
||||
"[webhook] Secret configured but no signature header found"
|
||||
)
|
||||
return False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Prompt rendering
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _render_prompt(
|
||||
self,
|
||||
template: str,
|
||||
payload: dict,
|
||||
event_type: str,
|
||||
route_name: str,
|
||||
) -> str:
|
||||
"""Render a prompt template with the webhook payload.
|
||||
|
||||
Supports dot-notation access into nested dicts:
|
||||
``{pull_request.title}`` → ``payload["pull_request"]["title"]``
|
||||
"""
|
||||
if not template:
|
||||
truncated = json.dumps(payload, indent=2)[:4000]
|
||||
return (
|
||||
f"Webhook event '{event_type}' on route "
|
||||
f"'{route_name}':\n\n```json\n{truncated}\n```"
|
||||
)
|
||||
|
||||
def _resolve(match: re.Match) -> str:
|
||||
key = match.group(1)
|
||||
value: Any = payload
|
||||
for part in key.split("."):
|
||||
if isinstance(value, dict):
|
||||
value = value.get(part, f"{{{key}}}")
|
||||
else:
|
||||
return f"{{{key}}}"
|
||||
if isinstance(value, (dict, list)):
|
||||
return json.dumps(value, indent=2)[:2000]
|
||||
return str(value)
|
||||
|
||||
return re.sub(r"\{([a-zA-Z0-9_.]+)\}", _resolve, template)
|
||||
|
||||
def _render_delivery_extra(
|
||||
self, extra: dict, payload: dict
|
||||
) -> dict:
|
||||
"""Render delivery_extra template values with payload data."""
|
||||
rendered: Dict[str, Any] = {}
|
||||
for key, value in extra.items():
|
||||
if isinstance(value, str):
|
||||
rendered[key] = self._render_prompt(value, payload, "", "")
|
||||
else:
|
||||
rendered[key] = value
|
||||
return rendered
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Response delivery
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _deliver_github_comment(
|
||||
self, content: str, delivery: dict
|
||||
) -> SendResult:
|
||||
"""Post agent response as a GitHub PR/issue comment via ``gh`` CLI."""
|
||||
extra = delivery.get("deliver_extra", {})
|
||||
repo = extra.get("repo", "")
|
||||
pr_number = extra.get("pr_number", "")
|
||||
|
||||
if not repo or not pr_number:
|
||||
logger.error(
|
||||
"[webhook] github_comment delivery missing repo or pr_number"
|
||||
)
|
||||
return SendResult(
|
||||
success=False, error="Missing repo or pr_number"
|
||||
)
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
"gh",
|
||||
"pr",
|
||||
"comment",
|
||||
str(pr_number),
|
||||
"--repo",
|
||||
repo,
|
||||
"--body",
|
||||
content,
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
logger.info(
|
||||
"[webhook] Posted comment on %s#%s", repo, pr_number
|
||||
)
|
||||
return SendResult(success=True)
|
||||
else:
|
||||
logger.error(
|
||||
"[webhook] gh pr comment failed: %s", result.stderr
|
||||
)
|
||||
return SendResult(success=False, error=result.stderr)
|
||||
except FileNotFoundError:
|
||||
logger.error(
|
||||
"[webhook] 'gh' CLI not found — install GitHub CLI for "
|
||||
"github_comment delivery"
|
||||
)
|
||||
return SendResult(
|
||||
success=False, error="gh CLI not installed"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("[webhook] github_comment delivery error: %s", e)
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def _deliver_cross_platform(
|
||||
self, platform_name: str, content: str, delivery: dict
|
||||
) -> SendResult:
|
||||
"""Route response to another platform (telegram, discord, etc.)."""
|
||||
if not self.gateway_runner:
|
||||
return SendResult(
|
||||
success=False,
|
||||
error="No gateway runner for cross-platform delivery",
|
||||
)
|
||||
|
||||
try:
|
||||
target_platform = Platform(platform_name)
|
||||
except ValueError:
|
||||
return SendResult(
|
||||
success=False, error=f"Unknown platform: {platform_name}"
|
||||
)
|
||||
|
||||
adapter = self.gateway_runner.adapters.get(target_platform)
|
||||
if not adapter:
|
||||
return SendResult(
|
||||
success=False,
|
||||
error=f"Platform {platform_name} not connected",
|
||||
)
|
||||
|
||||
# Use home channel if no specific chat_id in deliver_extra
|
||||
extra = delivery.get("deliver_extra", {})
|
||||
chat_id = extra.get("chat_id", "")
|
||||
if not chat_id:
|
||||
home = self.gateway_runner.config.get_home_channel(target_platform)
|
||||
if home:
|
||||
chat_id = home.chat_id
|
||||
else:
|
||||
return SendResult(
|
||||
success=False,
|
||||
error=f"No chat_id or home channel for {platform_name}",
|
||||
)
|
||||
|
||||
return await adapter.send(chat_id, content)
|
||||
@@ -182,9 +182,31 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
# Ensure session directory exists
|
||||
self._session_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Check if bridge is already running and connected
|
||||
import aiohttp
|
||||
import asyncio
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"http://127.0.0.1:{self._bridge_port}/health",
|
||||
timeout=aiohttp.ClientTimeout(total=2)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
bridge_status = data.get("status", "unknown")
|
||||
if bridge_status == "connected":
|
||||
print(f"[{self.name}] Using existing bridge (status: {bridge_status})")
|
||||
self._running = True
|
||||
self._bridge_process = None # Not managed by us
|
||||
asyncio.create_task(self._poll_messages())
|
||||
return True
|
||||
else:
|
||||
print(f"[{self.name}] Bridge found but not connected (status: {bridge_status}), restarting")
|
||||
except Exception:
|
||||
pass # Bridge not running, start a new one
|
||||
|
||||
# Kill any orphaned bridge from a previous gateway run
|
||||
_kill_port_process(self._bridge_port)
|
||||
import asyncio
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Start the bridge process in its own process group.
|
||||
@@ -232,7 +254,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"http://localhost:{self._bridge_port}/health",
|
||||
f"http://127.0.0.1:{self._bridge_port}/health",
|
||||
timeout=aiohttp.ClientTimeout(total=2)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
@@ -264,7 +286,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"http://localhost:{self._bridge_port}/health",
|
||||
f"http://127.0.0.1:{self._bridge_port}/health",
|
||||
timeout=aiohttp.ClientTimeout(total=2)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
@@ -326,9 +348,9 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
self._bridge_process.kill()
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Error stopping bridge: {e}")
|
||||
|
||||
# Also kill any orphaned bridge processes on our port
|
||||
_kill_port_process(self._bridge_port)
|
||||
else:
|
||||
# Bridge was not started by us, don't kill it
|
||||
print(f"[{self.name}] Disconnecting (external bridge left running)")
|
||||
|
||||
self._running = False
|
||||
self._bridge_process = None
|
||||
@@ -358,7 +380,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
payload["replyTo"] = reply_to
|
||||
|
||||
async with session.post(
|
||||
f"http://localhost:{self._bridge_port}/send",
|
||||
f"http://127.0.0.1:{self._bridge_port}/send",
|
||||
json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=30)
|
||||
) as resp:
|
||||
@@ -394,7 +416,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
import aiohttp
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"http://localhost:{self._bridge_port}/edit",
|
||||
f"http://127.0.0.1:{self._bridge_port}/edit",
|
||||
json={
|
||||
"chatId": chat_id,
|
||||
"messageId": message_id,
|
||||
@@ -439,7 +461,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"http://localhost:{self._bridge_port}/send-media",
|
||||
f"http://127.0.0.1:{self._bridge_port}/send-media",
|
||||
json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=120),
|
||||
) as resp:
|
||||
@@ -515,7 +537,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
await session.post(
|
||||
f"http://localhost:{self._bridge_port}/typing",
|
||||
f"http://127.0.0.1:{self._bridge_port}/typing",
|
||||
json={"chatId": chat_id},
|
||||
timeout=aiohttp.ClientTimeout(total=5)
|
||||
)
|
||||
@@ -532,7 +554,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"http://localhost:{self._bridge_port}/chat/{chat_id}",
|
||||
f"http://127.0.0.1:{self._bridge_port}/chat/{chat_id}",
|
||||
timeout=aiohttp.ClientTimeout(total=10)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
@@ -559,7 +581,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"http://localhost:{self._bridge_port}/messages",
|
||||
f"http://127.0.0.1:{self._bridge_port}/messages",
|
||||
timeout=aiohttp.ClientTimeout(total=30)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
@@ -621,6 +643,11 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
print(f"[{self.name}] Failed to cache image: {e}", flush=True)
|
||||
cached_urls.append(url)
|
||||
media_types.append("image/jpeg")
|
||||
elif msg_type == MessageType.PHOTO and os.path.isabs(url):
|
||||
# Local file path — bridge already downloaded the image
|
||||
cached_urls.append(url)
|
||||
media_types.append("image/jpeg")
|
||||
print(f"[{self.name}] Using bridge-cached image: {url}", flush=True)
|
||||
elif msg_type == MessageType.VOICE and url.startswith(("http://", "https://")):
|
||||
try:
|
||||
cached_path = await cache_audio_from_url(url, ext=".ogg")
|
||||
|
||||
+225
-34
@@ -222,6 +222,12 @@ from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageTyp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Sentinel placed into _running_agents immediately when a session starts
|
||||
# processing, *before* any await. Prevents a second message for the same
|
||||
# session from bypassing the "already running" guard during the async gap
|
||||
# between the guard check and actual agent creation.
|
||||
_AGENT_PENDING_SENTINEL = object()
|
||||
|
||||
|
||||
def _resolve_runtime_agent_kwargs() -> dict:
|
||||
"""Resolve provider credentials for gateway-created AIAgent instances."""
|
||||
@@ -1050,6 +1056,8 @@ class GatewayRunner:
|
||||
self._running = False
|
||||
|
||||
for session_key, agent in list(self._running_agents.items()):
|
||||
if agent is _AGENT_PENDING_SENTINEL:
|
||||
continue
|
||||
try:
|
||||
agent.interrupt("Gateway shutting down")
|
||||
logger.debug("Interrupted running agent for session %s during shutdown", session_key[:20])
|
||||
@@ -1183,6 +1191,15 @@ class GatewayRunner:
|
||||
return None
|
||||
return APIServerAdapter(config)
|
||||
|
||||
elif platform == Platform.WEBHOOK:
|
||||
from gateway.platforms.webhook import WebhookAdapter, check_webhook_requirements
|
||||
if not check_webhook_requirements():
|
||||
logger.warning("Webhook: aiohttp not installed")
|
||||
return None
|
||||
adapter = WebhookAdapter(config)
|
||||
adapter.gateway_runner = self # For cross-platform delivery
|
||||
return adapter
|
||||
|
||||
return None
|
||||
|
||||
def _is_user_authorized(self, source: SessionSource) -> bool:
|
||||
@@ -1199,7 +1216,9 @@ class GatewayRunner:
|
||||
# Home Assistant events are system-generated (state changes), not
|
||||
# user-initiated messages. The HASS_TOKEN already authenticates the
|
||||
# connection, so HA events are always authorized.
|
||||
if source.platform == Platform.HOMEASSISTANT:
|
||||
# Webhook events are authenticated via HMAC signature validation in
|
||||
# the adapter itself — no user allowlist applies.
|
||||
if source.platform in (Platform.HOMEASSISTANT, Platform.WEBHOOK):
|
||||
return True
|
||||
|
||||
user_id = source.user_id
|
||||
@@ -1325,6 +1344,48 @@ class GatewayRunner:
|
||||
if event.get_command() == "status":
|
||||
return await self._handle_status_command(event)
|
||||
|
||||
# /reset and /new must bypass the running-agent guard so they
|
||||
# actually dispatch as commands instead of being queued as user
|
||||
# text (which would be fed back to the agent with the same
|
||||
# broken history — #2170). Interrupt the agent first, then
|
||||
# clear the adapter's pending queue so the stale "/reset" text
|
||||
# doesn't get re-processed as a user message after the
|
||||
# interrupt completes.
|
||||
from hermes_cli.commands import resolve_command as _resolve_cmd_inner
|
||||
_evt_cmd = event.get_command()
|
||||
_cmd_def_inner = _resolve_cmd_inner(_evt_cmd) if _evt_cmd else None
|
||||
if _cmd_def_inner and _cmd_def_inner.name == "new":
|
||||
running_agent = self._running_agents.get(_quick_key)
|
||||
if running_agent and running_agent is not _AGENT_PENDING_SENTINEL:
|
||||
running_agent.interrupt("Session reset requested")
|
||||
# Clear any pending messages so the old text doesn't replay
|
||||
adapter = self.adapters.get(source.platform)
|
||||
if adapter and hasattr(adapter, 'get_pending_message'):
|
||||
adapter.get_pending_message(_quick_key) # consume and discard
|
||||
self._pending_messages.pop(_quick_key, None)
|
||||
# Clean up the running agent entry so the reset handler
|
||||
# doesn't think an agent is still active.
|
||||
if _quick_key in self._running_agents:
|
||||
del self._running_agents[_quick_key]
|
||||
return await self._handle_reset_command(event)
|
||||
|
||||
# /queue <prompt> — queue without interrupting
|
||||
if event.get_command() in ("queue", "q"):
|
||||
queued_text = event.get_command_args().strip()
|
||||
if not queued_text:
|
||||
return "Usage: /queue <prompt>"
|
||||
adapter = self.adapters.get(source.platform)
|
||||
if adapter:
|
||||
from gateway.platforms.base import MessageEvent as _ME, MessageType as _MT
|
||||
queued_event = _ME(
|
||||
text=queued_text,
|
||||
message_type=_MT.TEXT,
|
||||
source=event.source,
|
||||
message_id=event.message_id,
|
||||
)
|
||||
adapter._pending_messages[_quick_key] = queued_event
|
||||
return "Queued for the next turn."
|
||||
|
||||
if event.message_type == MessageType.PHOTO:
|
||||
logger.debug("PRIORITY photo follow-up for session %s — queueing without interrupt", _quick_key[:20])
|
||||
adapter = self.adapters.get(source.platform)
|
||||
@@ -1346,7 +1407,18 @@ class GatewayRunner:
|
||||
adapter._pending_messages[_quick_key] = event
|
||||
return None
|
||||
|
||||
running_agent = self._running_agents[_quick_key]
|
||||
running_agent = self._running_agents.get(_quick_key)
|
||||
if running_agent is _AGENT_PENDING_SENTINEL:
|
||||
# Agent is being set up but not ready yet.
|
||||
if event.get_command() == "stop":
|
||||
# Nothing to interrupt — agent hasn't started yet.
|
||||
return "⏳ The agent is still starting up — nothing to stop yet."
|
||||
# Queue the message so it will be picked up after the
|
||||
# agent starts.
|
||||
adapter = self.adapters.get(source.platform)
|
||||
if adapter:
|
||||
adapter._pending_messages[_quick_key] = event
|
||||
return None
|
||||
logger.debug("PRIORITY interrupt for session %s", _quick_key[:20])
|
||||
running_agent.interrupt(event.text)
|
||||
if _quick_key in self._pending_messages:
|
||||
@@ -1354,7 +1426,7 @@ class GatewayRunner:
|
||||
else:
|
||||
self._pending_messages[_quick_key] = event.text
|
||||
return None
|
||||
|
||||
|
||||
# Check for commands
|
||||
command = event.get_command()
|
||||
|
||||
@@ -1441,6 +1513,12 @@ class GatewayRunner:
|
||||
if canonical == "reload-mcp":
|
||||
return await self._handle_reload_mcp_command(event)
|
||||
|
||||
if canonical == "approve":
|
||||
return await self._handle_approve_command(event)
|
||||
|
||||
if canonical == "deny":
|
||||
return await self._handle_deny_command(event)
|
||||
|
||||
if canonical == "update":
|
||||
return await self._handle_update_command(event)
|
||||
|
||||
@@ -1518,33 +1596,32 @@ class GatewayRunner:
|
||||
except Exception as e:
|
||||
logger.debug("Skill command check failed (non-fatal): %s", e)
|
||||
|
||||
# Check for pending exec approval responses
|
||||
session_key_preview = self._session_key_for_source(source)
|
||||
if session_key_preview in self._pending_approvals:
|
||||
user_text = event.text.strip().lower()
|
||||
if user_text in ("yes", "y", "approve", "ok", "go", "do it"):
|
||||
approval = self._pending_approvals.pop(session_key_preview)
|
||||
cmd = approval["command"]
|
||||
pattern_keys = approval.get("pattern_keys", [])
|
||||
if not pattern_keys:
|
||||
pk = approval.get("pattern_key", "")
|
||||
pattern_keys = [pk] if pk else []
|
||||
logger.info("User approved dangerous command: %s...", cmd[:60])
|
||||
from tools.terminal_tool import terminal_tool
|
||||
from tools.approval import approve_session
|
||||
for pk in pattern_keys:
|
||||
approve_session(session_key_preview, pk)
|
||||
result = terminal_tool(command=cmd, force=True)
|
||||
return f"✅ Command approved and executed.\n\n```\n{result[:3500]}\n```"
|
||||
elif user_text in ("no", "n", "deny", "cancel", "nope"):
|
||||
self._pending_approvals.pop(session_key_preview)
|
||||
return "❌ Command denied."
|
||||
elif user_text in ("full", "show", "view", "show full", "view full"):
|
||||
# Show full command without consuming the approval
|
||||
cmd = self._pending_approvals[session_key_preview]["command"]
|
||||
return f"Full command:\n\n```\n{cmd}\n```\n\nReply yes/no to approve or deny."
|
||||
# If it's not clearly an approval/denial, fall through to normal processing
|
||||
|
||||
# Pending exec approvals are handled by /approve and /deny commands above.
|
||||
# No bare text matching — "yes" in normal conversation must not trigger
|
||||
# execution of a dangerous command.
|
||||
|
||||
# ── Claim this session before any await ───────────────────────
|
||||
# Between here and _run_agent registering the real AIAgent, there
|
||||
# are numerous await points (hooks, vision enrichment, STT,
|
||||
# session hygiene compression). Without this sentinel a second
|
||||
# message arriving during any of those yields would pass the
|
||||
# "already running" guard and spin up a duplicate agent for the
|
||||
# same session — corrupting the transcript.
|
||||
self._running_agents[_quick_key] = _AGENT_PENDING_SENTINEL
|
||||
|
||||
try:
|
||||
return await self._handle_message_with_agent(event, source, _quick_key)
|
||||
finally:
|
||||
# If _run_agent replaced the sentinel with a real agent and
|
||||
# then cleaned it up, this is a no-op. If we exited early
|
||||
# (exception, command fallthrough, etc.) the sentinel must
|
||||
# not linger or the session would be permanently locked out.
|
||||
if self._running_agents.get(_quick_key) is _AGENT_PENDING_SENTINEL:
|
||||
del self._running_agents[_quick_key]
|
||||
|
||||
async def _handle_message_with_agent(self, event, source, _quick_key: str):
|
||||
"""Inner handler that runs under the _running_agents sentinel guard."""
|
||||
|
||||
# Get or create session
|
||||
session_entry = self.session_store.get_or_create_session(source)
|
||||
session_key = session_entry.session_key
|
||||
@@ -2059,9 +2136,22 @@ class GatewayRunner:
|
||||
# Check if the agent encountered a dangerous command needing approval
|
||||
try:
|
||||
from tools.approval import pop_pending
|
||||
import time as _time
|
||||
pending = pop_pending(session_key)
|
||||
if pending:
|
||||
pending["timestamp"] = _time.time()
|
||||
self._pending_approvals[session_key] = pending
|
||||
# Append structured instructions so the user knows how to respond
|
||||
cmd_preview = pending.get("command", "")
|
||||
if len(cmd_preview) > 200:
|
||||
cmd_preview = cmd_preview[:200] + "..."
|
||||
approval_hint = (
|
||||
f"\n\n⚠️ **Dangerous command requires approval:**\n"
|
||||
f"```\n{cmd_preview}\n```\n"
|
||||
f"Reply `/approve` to execute, `/approve session` to approve this pattern "
|
||||
f"for the session, or `/deny` to cancel."
|
||||
)
|
||||
response = (response or "") + approval_hint
|
||||
except Exception as e:
|
||||
logger.debug("Failed to check pending approvals: %s", e)
|
||||
|
||||
@@ -2295,8 +2385,10 @@ class GatewayRunner:
|
||||
session_entry = self.session_store.get_or_create_session(source)
|
||||
session_key = session_entry.session_key
|
||||
|
||||
if session_key in self._running_agents:
|
||||
agent = self._running_agents[session_key]
|
||||
agent = self._running_agents.get(session_key)
|
||||
if agent is _AGENT_PENDING_SENTINEL:
|
||||
return "⏳ The agent is still starting up — nothing to stop yet."
|
||||
if agent:
|
||||
agent.interrupt()
|
||||
return "⚡ Stopping the current task... The agent will finish its current step and respond."
|
||||
else:
|
||||
@@ -2384,8 +2476,14 @@ class GatewayRunner:
|
||||
lines = [
|
||||
f"🤖 **Current model:** `{current}`",
|
||||
f"**Provider:** {provider_label}",
|
||||
"",
|
||||
]
|
||||
# Show custom endpoint URL when using a custom provider
|
||||
if current_provider == "custom":
|
||||
from hermes_cli.models import _get_custom_base_url
|
||||
custom_url = _get_custom_base_url() or os.getenv("OPENAI_BASE_URL", "")
|
||||
if custom_url:
|
||||
lines.append(f"**Endpoint:** `{custom_url}`")
|
||||
lines.append("")
|
||||
curated = curated_models_for_provider(current_provider)
|
||||
if curated:
|
||||
lines.append(f"**Available models ({provider_label}):**")
|
||||
@@ -2395,7 +2493,7 @@ class GatewayRunner:
|
||||
lines.append(f"• `{mid}`{label}{marker}")
|
||||
lines.append("")
|
||||
lines.append("To change: `/model model-name`")
|
||||
lines.append("Switch provider: `/model provider:model-name`")
|
||||
lines.append("Switch provider: `/model provider-name` or `/model provider:model-name`")
|
||||
return "\n".join(lines)
|
||||
|
||||
# Parse provider:model syntax
|
||||
@@ -3696,6 +3794,78 @@ class GatewayRunner:
|
||||
logger.warning("MCP reload failed: %s", e)
|
||||
return f"❌ MCP reload failed: {e}"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# /approve & /deny — explicit dangerous-command approval
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
_APPROVAL_TIMEOUT_SECONDS = 300 # 5 minutes
|
||||
|
||||
async def _handle_approve_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /approve command — execute a pending dangerous command.
|
||||
|
||||
Usage:
|
||||
/approve — approve and execute the pending command
|
||||
/approve session — approve and remember for this session
|
||||
/approve always — approve this pattern permanently
|
||||
"""
|
||||
source = event.source
|
||||
session_key = self._session_key_for_source(source)
|
||||
|
||||
if session_key not in self._pending_approvals:
|
||||
return "No pending command to approve."
|
||||
|
||||
import time as _time
|
||||
approval = self._pending_approvals[session_key]
|
||||
|
||||
# Check for timeout
|
||||
ts = approval.get("timestamp", 0)
|
||||
if _time.time() - ts > self._APPROVAL_TIMEOUT_SECONDS:
|
||||
self._pending_approvals.pop(session_key, None)
|
||||
return "⚠️ Approval expired (timed out after 5 minutes). Ask the agent to try again."
|
||||
|
||||
self._pending_approvals.pop(session_key)
|
||||
cmd = approval["command"]
|
||||
pattern_keys = approval.get("pattern_keys", [])
|
||||
if not pattern_keys:
|
||||
pk = approval.get("pattern_key", "")
|
||||
pattern_keys = [pk] if pk else []
|
||||
|
||||
# Determine approval scope from args
|
||||
args = event.get_command_args().strip().lower()
|
||||
from tools.approval import approve_session, approve_permanent
|
||||
|
||||
if args in ("always", "permanent", "permanently"):
|
||||
for pk in pattern_keys:
|
||||
approve_permanent(pk)
|
||||
scope_msg = " (pattern approved permanently)"
|
||||
elif args in ("session", "ses"):
|
||||
for pk in pattern_keys:
|
||||
approve_session(session_key, pk)
|
||||
scope_msg = " (pattern approved for this session)"
|
||||
else:
|
||||
# One-time approval — just approve for session so the immediate
|
||||
# replay works, but don't advertise it as session-wide
|
||||
for pk in pattern_keys:
|
||||
approve_session(session_key, pk)
|
||||
scope_msg = ""
|
||||
|
||||
logger.info("User approved dangerous command via /approve: %s...%s", cmd[:60], scope_msg)
|
||||
from tools.terminal_tool import terminal_tool
|
||||
result = terminal_tool(command=cmd, force=True)
|
||||
return f"✅ Command approved and executed{scope_msg}.\n\n```\n{result[:3500]}\n```"
|
||||
|
||||
async def _handle_deny_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /deny command — reject a pending dangerous command."""
|
||||
source = event.source
|
||||
session_key = self._session_key_for_source(source)
|
||||
|
||||
if session_key not in self._pending_approvals:
|
||||
return "No pending command to deny."
|
||||
|
||||
self._pending_approvals.pop(session_key)
|
||||
logger.info("User denied dangerous command via /deny")
|
||||
return "❌ Command denied."
|
||||
|
||||
async def _handle_update_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /update command — update Hermes Agent to the latest version.
|
||||
|
||||
@@ -4411,6 +4581,26 @@ class GatewayRunner:
|
||||
except Exception as _e:
|
||||
logger.debug("agent:step hook error: %s", _e)
|
||||
|
||||
# Bridge sync status_callback → async adapter.send for context pressure
|
||||
_status_adapter = self.adapters.get(source.platform)
|
||||
_status_chat_id = source.chat_id
|
||||
_status_thread_metadata = {"thread_id": source.thread_id} if source.thread_id else None
|
||||
|
||||
def _status_callback_sync(event_type: str, message: str) -> None:
|
||||
if not _status_adapter:
|
||||
return
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
_status_adapter.send(
|
||||
_status_chat_id,
|
||||
message,
|
||||
metadata=_status_thread_metadata,
|
||||
),
|
||||
_loop_for_step,
|
||||
)
|
||||
except Exception as _e:
|
||||
logger.debug("status_callback error (%s): %s", event_type, _e)
|
||||
|
||||
def run_sync():
|
||||
# Pass session_key to process registry via env var so background
|
||||
# processes can be mapped back to this gateway session
|
||||
@@ -4503,6 +4693,7 @@ class GatewayRunner:
|
||||
tool_progress_callback=progress_callback if tool_progress_enabled else None,
|
||||
step_callback=_step_callback_sync if _hooks_ref.loaded_hooks else None,
|
||||
stream_delta_callback=_stream_delta_cb,
|
||||
status_callback=_status_callback_sync,
|
||||
platform=platform_key,
|
||||
honcho_session_key=session_key,
|
||||
honcho_manager=honcho_manager,
|
||||
|
||||
+2
-2
@@ -145,7 +145,7 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
|
||||
id="minimax",
|
||||
name="MiniMax",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://api.minimax.io/v1",
|
||||
inference_base_url="https://api.minimax.io/anthropic",
|
||||
api_key_env_vars=("MINIMAX_API_KEY",),
|
||||
base_url_env_var="MINIMAX_BASE_URL",
|
||||
),
|
||||
@@ -168,7 +168,7 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
|
||||
id="minimax-cn",
|
||||
name="MiniMax (China)",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://api.minimaxi.com/v1",
|
||||
inference_base_url="https://api.minimaxi.com/anthropic",
|
||||
api_key_env_vars=("MINIMAX_CN_API_KEY",),
|
||||
base_url_env_var="MINIMAX_CN_BASE_URL",
|
||||
),
|
||||
|
||||
@@ -289,6 +289,8 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
|
||||
_hero = HERMES_CADUCEUS
|
||||
left_lines = ["", _hero, ""]
|
||||
model_short = model.split("/")[-1] if "/" in model else model
|
||||
if model_short.endswith(".gguf"):
|
||||
model_short = model_short[:-5]
|
||||
if len(model_short) > 28:
|
||||
model_short = model_short[:25] + "..."
|
||||
ctx_str = f" [dim {dim}]·[/] [dim {dim}]{_format_context_length(context_length)} context[/]" if context_length else ""
|
||||
|
||||
@@ -61,8 +61,14 @@ COMMAND_REGISTRY: list[CommandDef] = [
|
||||
CommandDef("rollback", "List or restore filesystem checkpoints", "Session",
|
||||
args_hint="[number]"),
|
||||
CommandDef("stop", "Kill all running background processes", "Session"),
|
||||
CommandDef("approve", "Approve a pending dangerous command", "Session",
|
||||
gateway_only=True, args_hint="[session|always]"),
|
||||
CommandDef("deny", "Deny a pending dangerous command", "Session",
|
||||
gateway_only=True),
|
||||
CommandDef("background", "Run a prompt in the background", "Session",
|
||||
aliases=("bg",), args_hint="<prompt>"),
|
||||
CommandDef("queue", "Queue a prompt for the next turn (doesn't interrupt)", "Session",
|
||||
aliases=("q",), args_hint="<prompt>"),
|
||||
CommandDef("status", "Show session info", "Session",
|
||||
gateway_only=True),
|
||||
CommandDef("sethome", "Set this chat as the home channel", "Session",
|
||||
|
||||
@@ -670,6 +670,11 @@ OPTIONAL_ENV_VARS = {
|
||||
"password": True,
|
||||
"category": "tool",
|
||||
},
|
||||
"HONCHO_BASE_URL": {
|
||||
"description": "Base URL for self-hosted Honcho instances (no API key needed)",
|
||||
"prompt": "Honcho base URL (e.g. http://localhost:8000)",
|
||||
"category": "tool",
|
||||
},
|
||||
|
||||
# ── Messaging platforms ──
|
||||
"TELEGRAM_BOT_TOKEN": {
|
||||
@@ -807,6 +812,27 @@ OPTIONAL_ENV_VARS = {
|
||||
"category": "messaging",
|
||||
"advanced": True,
|
||||
},
|
||||
"WEBHOOK_ENABLED": {
|
||||
"description": "Enable the webhook platform adapter for receiving events from GitHub, GitLab, etc.",
|
||||
"prompt": "Enable webhooks (true/false)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "messaging",
|
||||
},
|
||||
"WEBHOOK_PORT": {
|
||||
"description": "Port for the webhook HTTP server (default: 8644).",
|
||||
"prompt": "Webhook port",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "messaging",
|
||||
},
|
||||
"WEBHOOK_SECRET": {
|
||||
"description": "Global HMAC secret for webhook signature validation (overridable per route in config.yaml).",
|
||||
"prompt": "Webhook secret",
|
||||
"url": None,
|
||||
"password": True,
|
||||
"category": "messaging",
|
||||
},
|
||||
|
||||
# ── Agent settings ──
|
||||
"MESSAGING_CWD": {
|
||||
|
||||
+35
-12
@@ -1137,10 +1137,21 @@ def _model_flow_custom(config):
|
||||
base_url = input(f"API base URL [{current_url or 'e.g. https://api.example.com/v1'}]: ").strip()
|
||||
api_key = input(f"API key [{current_key[:8] + '...' if current_key else 'optional'}]: ").strip()
|
||||
model_name = input("Model name (e.g. gpt-4, llama-3-70b): ").strip()
|
||||
context_length_str = input("Context length in tokens [leave blank for auto-detect]: ").strip()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print("\nCancelled.")
|
||||
return
|
||||
|
||||
context_length = None
|
||||
if context_length_str:
|
||||
try:
|
||||
context_length = int(context_length_str.replace(",", "").replace("k", "000").replace("K", "000"))
|
||||
if context_length <= 0:
|
||||
context_length = None
|
||||
except ValueError:
|
||||
print(f"Invalid context length: {context_length_str} — will auto-detect.")
|
||||
context_length = None
|
||||
|
||||
if not base_url and not current_url:
|
||||
print("No URL provided. Cancelled.")
|
||||
return
|
||||
@@ -1203,14 +1214,14 @@ def _model_flow_custom(config):
|
||||
print("Endpoint saved. Use `/model` in chat or `hermes model` to set a model.")
|
||||
|
||||
# Auto-save to custom_providers so it appears in the menu next time
|
||||
_save_custom_provider(effective_url, effective_key, model_name or "")
|
||||
_save_custom_provider(effective_url, effective_key, model_name or "", context_length=context_length)
|
||||
|
||||
|
||||
def _save_custom_provider(base_url, api_key="", model=""):
|
||||
def _save_custom_provider(base_url, api_key="", model="", context_length=None):
|
||||
"""Save a custom endpoint to custom_providers in config.yaml.
|
||||
|
||||
Deduplicates by base_url — if the URL already exists, updates the
|
||||
model name but doesn't add a duplicate entry.
|
||||
model name and context_length but doesn't add a duplicate entry.
|
||||
Auto-generates a display name from the URL hostname.
|
||||
"""
|
||||
from hermes_cli.config import load_config, save_config
|
||||
@@ -1220,14 +1231,24 @@ def _save_custom_provider(base_url, api_key="", model=""):
|
||||
if not isinstance(providers, list):
|
||||
providers = []
|
||||
|
||||
# Check if this URL is already saved — update model if so
|
||||
# Check if this URL is already saved — update model/context_length if so
|
||||
for entry in providers:
|
||||
if isinstance(entry, dict) and entry.get("base_url", "").rstrip("/") == base_url.rstrip("/"):
|
||||
changed = False
|
||||
if model and entry.get("model") != model:
|
||||
entry["model"] = model
|
||||
changed = True
|
||||
if model and context_length:
|
||||
models_cfg = entry.get("models", {})
|
||||
if not isinstance(models_cfg, dict):
|
||||
models_cfg = {}
|
||||
models_cfg[model] = {"context_length": context_length}
|
||||
entry["models"] = models_cfg
|
||||
changed = True
|
||||
if changed:
|
||||
cfg["custom_providers"] = providers
|
||||
save_config(cfg)
|
||||
return # already saved, updated model if needed
|
||||
return # already saved, updated if needed
|
||||
|
||||
# Auto-generate a name from the URL
|
||||
import re
|
||||
@@ -1249,6 +1270,8 @@ def _save_custom_provider(base_url, api_key="", model=""):
|
||||
entry["api_key"] = api_key
|
||||
if model:
|
||||
entry["model"] = model
|
||||
if model and context_length:
|
||||
entry["models"] = {model: {"context_length": context_length}}
|
||||
|
||||
providers.append(entry)
|
||||
cfg["custom_providers"] = providers
|
||||
@@ -3721,20 +3744,20 @@ For more help on a command:
|
||||
return
|
||||
has_titles = any(s.get("title") for s in sessions)
|
||||
if has_titles:
|
||||
print(f"{'Title':<22} {'Preview':<40} {'Last Active':<13} {'ID'}")
|
||||
print("─" * 100)
|
||||
print(f"{'Title':<32} {'Preview':<40} {'Last Active':<13} {'ID'}")
|
||||
print("─" * 110)
|
||||
else:
|
||||
print(f"{'Preview':<50} {'Last Active':<13} {'Src':<6} {'ID'}")
|
||||
print("─" * 90)
|
||||
print("─" * 95)
|
||||
for s in sessions:
|
||||
last_active = _relative_time(s.get("last_active"))
|
||||
preview = s.get("preview", "")[:38] if has_titles else s.get("preview", "")[:48]
|
||||
if has_titles:
|
||||
title = (s.get("title") or "—")[:20]
|
||||
sid = s["id"][:20]
|
||||
print(f"{title:<22} {preview:<40} {last_active:<13} {sid}")
|
||||
title = (s.get("title") or "—")[:30]
|
||||
sid = s["id"]
|
||||
print(f"{title:<32} {preview:<40} {last_active:<13} {sid}")
|
||||
else:
|
||||
sid = s["id"][:20]
|
||||
sid = s["id"]
|
||||
print(f"{preview:<50} {last_active:<13} {s['source']:<6} {sid}")
|
||||
|
||||
elif action == "export":
|
||||
|
||||
@@ -389,6 +389,7 @@ def detect_provider_for_model(
|
||||
Returns ``None`` when no confident match is found.
|
||||
|
||||
Priority:
|
||||
0. Bare provider name → switch to that provider's default model
|
||||
1. Direct provider with credentials (highest)
|
||||
2. Direct provider without credentials → remap to OpenRouter slug
|
||||
3. OpenRouter catalog match
|
||||
@@ -399,6 +400,21 @@ def detect_provider_for_model(
|
||||
|
||||
name_lower = name.lower()
|
||||
|
||||
# --- Step 0: bare provider name typed as model ---
|
||||
# If someone types `/model nous` or `/model anthropic`, treat it as a
|
||||
# provider switch and pick the first model from that provider's catalog.
|
||||
# Skip "custom" and "openrouter" — custom has no model catalog, and
|
||||
# openrouter requires an explicit model name to be useful.
|
||||
resolved_provider = _PROVIDER_ALIASES.get(name_lower, name_lower)
|
||||
if resolved_provider not in {"custom", "openrouter"}:
|
||||
default_models = _PROVIDER_MODELS.get(resolved_provider, [])
|
||||
if (
|
||||
resolved_provider in _PROVIDER_LABELS
|
||||
and default_models
|
||||
and resolved_provider != normalize_provider(current_provider)
|
||||
):
|
||||
return (resolved_provider, default_models[0])
|
||||
|
||||
# Aggregators list other providers' models — never auto-switch TO them
|
||||
_AGGREGATORS = {"nous", "openrouter"}
|
||||
|
||||
|
||||
@@ -24,11 +24,53 @@ def _normalize_custom_provider_name(value: str) -> str:
|
||||
return value.strip().lower().replace(" ", "-")
|
||||
|
||||
|
||||
def _detect_api_mode_for_url(base_url: str) -> Optional[str]:
|
||||
"""Auto-detect api_mode from the resolved base URL.
|
||||
|
||||
Direct api.openai.com endpoints need the Responses API for GPT-5.x
|
||||
tool calls with reasoning (chat/completions returns 400).
|
||||
"""
|
||||
normalized = (base_url or "").strip().lower().rstrip("/")
|
||||
if "api.openai.com" in normalized and "openrouter" not in normalized:
|
||||
return "codex_responses"
|
||||
return None
|
||||
|
||||
|
||||
def _auto_detect_local_model(base_url: str) -> str:
|
||||
"""Query a local server for its model name when only one model is loaded."""
|
||||
if not base_url:
|
||||
return ""
|
||||
try:
|
||||
import requests
|
||||
url = base_url.rstrip("/")
|
||||
if not url.endswith("/v1"):
|
||||
url += "/v1"
|
||||
resp = requests.get(url + "/models", timeout=5)
|
||||
if resp.ok:
|
||||
models = resp.json().get("data", [])
|
||||
if len(models) == 1:
|
||||
model_id = models[0].get("id", "")
|
||||
if model_id:
|
||||
return model_id
|
||||
except Exception:
|
||||
pass
|
||||
return ""
|
||||
|
||||
|
||||
def _get_model_config() -> Dict[str, Any]:
|
||||
config = load_config()
|
||||
model_cfg = config.get("model")
|
||||
if isinstance(model_cfg, dict):
|
||||
return dict(model_cfg)
|
||||
cfg = dict(model_cfg)
|
||||
default = cfg.get("default", "").strip()
|
||||
base_url = cfg.get("base_url", "").strip()
|
||||
is_local = "localhost" in base_url or "127.0.0.1" in base_url
|
||||
is_fallback = not default or default == "anthropic/claude-opus-4.6"
|
||||
if is_local and is_fallback and base_url:
|
||||
detected = _auto_detect_local_model(base_url)
|
||||
if detected:
|
||||
cfg["default"] = detected
|
||||
return cfg
|
||||
if isinstance(model_cfg, str) and model_cfg.strip():
|
||||
return {"default": model_cfg.strip()}
|
||||
return {}
|
||||
@@ -155,7 +197,9 @@ def _resolve_named_custom_runtime(
|
||||
|
||||
return {
|
||||
"provider": "openrouter",
|
||||
"api_mode": custom_provider.get("api_mode", "chat_completions"),
|
||||
"api_mode": custom_provider.get("api_mode")
|
||||
or _detect_api_mode_for_url(base_url)
|
||||
or "chat_completions",
|
||||
"base_url": base_url,
|
||||
"api_key": api_key,
|
||||
"source": f"custom_provider:{custom_provider.get('name', requested_provider)}",
|
||||
@@ -233,7 +277,9 @@ def _resolve_openrouter_runtime(
|
||||
|
||||
return {
|
||||
"provider": "openrouter",
|
||||
"api_mode": _parse_api_mode(model_cfg.get("api_mode")) or "chat_completions",
|
||||
"api_mode": _parse_api_mode(model_cfg.get("api_mode"))
|
||||
or _detect_api_mode_for_url(base_url)
|
||||
or "chat_completions",
|
||||
"base_url": base_url,
|
||||
"api_key": api_key,
|
||||
"source": source,
|
||||
@@ -313,10 +359,14 @@ def resolve_runtime_provider(
|
||||
"No Anthropic credentials found. Set ANTHROPIC_TOKEN or ANTHROPIC_API_KEY, "
|
||||
"run 'claude setup-token', or authenticate with 'claude /login'."
|
||||
)
|
||||
# Allow base URL override from config.yaml model.base_url
|
||||
model_cfg = _get_model_config()
|
||||
cfg_base_url = (model_cfg.get("base_url") or "").strip().rstrip("/")
|
||||
base_url = cfg_base_url or "https://api.anthropic.com"
|
||||
return {
|
||||
"provider": "anthropic",
|
||||
"api_mode": "anthropic_messages",
|
||||
"base_url": "https://api.anthropic.com",
|
||||
"base_url": base_url,
|
||||
"api_key": token,
|
||||
"source": "env",
|
||||
"requested_provider": requested_provider,
|
||||
@@ -340,13 +390,29 @@ def resolve_runtime_provider(
|
||||
if pconfig and pconfig.auth_type == "api_key":
|
||||
creds = resolve_api_key_provider_credentials(provider)
|
||||
model_cfg = _get_model_config()
|
||||
base_url = creds.get("base_url", "").rstrip("/")
|
||||
api_mode = "chat_completions"
|
||||
if provider == "copilot":
|
||||
api_mode = _copilot_runtime_api_mode(model_cfg, creds.get("api_key", ""))
|
||||
else:
|
||||
# Check explicit api_mode from model config first
|
||||
configured_mode = _parse_api_mode(model_cfg.get("api_mode"))
|
||||
if configured_mode:
|
||||
api_mode = configured_mode
|
||||
# Auto-detect Anthropic-compatible endpoints by URL convention
|
||||
# (e.g. https://api.minimax.io/anthropic, https://dashscope.../anthropic)
|
||||
elif base_url.rstrip("/").endswith("/anthropic"):
|
||||
api_mode = "anthropic_messages"
|
||||
# MiniMax providers always use Anthropic Messages API.
|
||||
# Auto-correct stale /v1 URLs (from old .env or config) to /anthropic.
|
||||
elif provider in ("minimax", "minimax-cn"):
|
||||
api_mode = "anthropic_messages"
|
||||
if base_url.rstrip("/").endswith("/v1"):
|
||||
base_url = base_url.rstrip("/")[:-3] + "/anthropic"
|
||||
return {
|
||||
"provider": provider,
|
||||
"api_mode": api_mode,
|
||||
"base_url": creds.get("base_url", "").rstrip("/"),
|
||||
"base_url": base_url,
|
||||
"api_key": creds.get("api_key", ""),
|
||||
"source": creds.get("source", "env"),
|
||||
"requested_provider": requested_provider,
|
||||
|
||||
+66
-86
@@ -1045,93 +1045,17 @@ def setup_model_provider(config: dict):
|
||||
print()
|
||||
print_header("Custom OpenAI-Compatible Endpoint")
|
||||
print_info("Works with any API that follows OpenAI's chat completions spec")
|
||||
print()
|
||||
|
||||
current_url = get_env_value("OPENAI_BASE_URL") or ""
|
||||
current_key = get_env_value("OPENAI_API_KEY")
|
||||
_raw_model = config.get("model", "")
|
||||
current_model = (
|
||||
_raw_model.get("default", "")
|
||||
if isinstance(_raw_model, dict)
|
||||
else (_raw_model or "")
|
||||
)
|
||||
|
||||
if current_url:
|
||||
print_info(f" Current URL: {current_url}")
|
||||
if current_key:
|
||||
print_info(f" Current key: {current_key[:8]}... (configured)")
|
||||
|
||||
base_url = prompt(
|
||||
" API base URL (e.g., https://api.example.com/v1)", current_url
|
||||
).strip()
|
||||
api_key = prompt(" API key", password=True)
|
||||
model_name = prompt(" Model name (e.g., gpt-4, claude-3-opus)", current_model)
|
||||
|
||||
if base_url:
|
||||
from hermes_cli.models import probe_api_models
|
||||
|
||||
probe = probe_api_models(api_key, base_url)
|
||||
if probe.get("used_fallback") and probe.get("resolved_base_url"):
|
||||
print_warning(
|
||||
f"Endpoint verification worked at {probe['resolved_base_url']}/models, "
|
||||
f"not the exact URL you entered. Saving the working base URL instead."
|
||||
)
|
||||
base_url = probe["resolved_base_url"]
|
||||
elif probe.get("models") is not None:
|
||||
print_success(
|
||||
f"Verified endpoint via {probe.get('probed_url')} "
|
||||
f"({len(probe.get('models') or [])} model(s) visible)"
|
||||
)
|
||||
else:
|
||||
print_warning(
|
||||
f"Could not verify this endpoint via {probe.get('probed_url')}. "
|
||||
f"Hermes will still save it."
|
||||
)
|
||||
if probe.get("suggested_base_url"):
|
||||
print_info(
|
||||
f" If this server expects /v1, try base URL: {probe['suggested_base_url']}"
|
||||
)
|
||||
|
||||
save_env_value("OPENAI_BASE_URL", base_url)
|
||||
if api_key:
|
||||
save_env_value("OPENAI_API_KEY", api_key)
|
||||
if model_name:
|
||||
_set_default_model(config, model_name)
|
||||
|
||||
try:
|
||||
from hermes_cli.auth import deactivate_provider
|
||||
|
||||
deactivate_provider()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Save provider and base_url to config.yaml so the gateway and CLI
|
||||
# both resolve the correct provider without relying on env-var heuristics.
|
||||
if base_url:
|
||||
import yaml
|
||||
|
||||
config_path = (
|
||||
Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
|
||||
/ "config.yaml"
|
||||
)
|
||||
try:
|
||||
disk_cfg = {}
|
||||
if config_path.exists():
|
||||
disk_cfg = yaml.safe_load(config_path.read_text()) or {}
|
||||
model_section = disk_cfg.get("model", {})
|
||||
if isinstance(model_section, str):
|
||||
model_section = {"default": model_section}
|
||||
model_section["provider"] = "custom"
|
||||
model_section["base_url"] = base_url.rstrip("/")
|
||||
if model_name:
|
||||
model_section["default"] = model_name
|
||||
disk_cfg["model"] = model_section
|
||||
config_path.write_text(yaml.safe_dump(disk_cfg, sort_keys=False))
|
||||
except Exception as e:
|
||||
logger.debug("Could not save provider to config.yaml: %s", e)
|
||||
|
||||
_set_model_provider(config, "custom", base_url)
|
||||
|
||||
print_success("Custom endpoint configured")
|
||||
# Reuse the shared custom endpoint flow from `hermes model`.
|
||||
# This handles: URL/key/model/context-length prompts, endpoint probing,
|
||||
# env saving, config.yaml updates, and custom_providers persistence.
|
||||
from hermes_cli.main import _model_flow_custom
|
||||
_model_flow_custom(config)
|
||||
# _model_flow_custom handles model selection, config, env vars,
|
||||
# and custom_providers. Keep selected_provider = "custom" so
|
||||
# the model selection step below is skipped (line 1631 check)
|
||||
# but vision and TTS setup still run.
|
||||
|
||||
elif provider_idx == 4: # Z.AI / GLM
|
||||
selected_provider = "zai"
|
||||
@@ -2851,6 +2775,61 @@ def setup_gateway(config: dict):
|
||||
print_info("Run 'hermes whatsapp' to choose your mode (separate bot number")
|
||||
print_info("or personal self-chat) and pair via QR code.")
|
||||
|
||||
# ── Webhooks ──
|
||||
existing_webhook = get_env_value("WEBHOOK_ENABLED")
|
||||
if existing_webhook:
|
||||
print_info("Webhooks: already configured")
|
||||
if prompt_yes_no("Reconfigure webhooks?", False):
|
||||
existing_webhook = None
|
||||
|
||||
if not existing_webhook and prompt_yes_no("Set up webhooks? (GitHub, GitLab, etc.)", False):
|
||||
print()
|
||||
print_warning(
|
||||
"⚠ Webhook and SMS platforms require exposing gateway ports to the"
|
||||
)
|
||||
print_warning(
|
||||
" internet. For security, run the gateway in a sandboxed environment"
|
||||
)
|
||||
print_warning(
|
||||
" (Docker, VM, etc.) to limit blast radius from prompt injection."
|
||||
)
|
||||
print()
|
||||
print_info(
|
||||
" Full guide: https://hermes-agent.nousresearch.com/docs/user-guide/messaging/webhooks/"
|
||||
)
|
||||
print()
|
||||
|
||||
port = prompt("Webhook port (default 8644)")
|
||||
if port:
|
||||
try:
|
||||
save_env_value("WEBHOOK_PORT", str(int(port)))
|
||||
print_success(f"Webhook port set to {port}")
|
||||
except ValueError:
|
||||
print_warning("Invalid port number, using default 8644")
|
||||
|
||||
secret = prompt("Global HMAC secret (shared across all routes)", password=True)
|
||||
if secret:
|
||||
save_env_value("WEBHOOK_SECRET", secret)
|
||||
print_success("Webhook secret saved")
|
||||
else:
|
||||
print_warning("No secret set — you must configure per-route secrets in config.yaml")
|
||||
|
||||
save_env_value("WEBHOOK_ENABLED", "true")
|
||||
print()
|
||||
print_success("Webhooks enabled! Next steps:")
|
||||
print_info(" 1. Define webhook routes in ~/.hermes/config.yaml")
|
||||
print_info(" 2. Point your service (GitHub, GitLab, etc.) at:")
|
||||
print_info(" http://your-server:8644/webhooks/<route-name>")
|
||||
print()
|
||||
print_info(
|
||||
" Route configuration guide:"
|
||||
)
|
||||
print_info(
|
||||
" https://hermes-agent.nousresearch.com/docs/user-guide/messaging/webhooks/#configuring-routes"
|
||||
)
|
||||
print()
|
||||
print_info(" Open config in your editor: hermes config edit")
|
||||
|
||||
# ── Gateway Service Setup ──
|
||||
any_messaging = (
|
||||
get_env_value("TELEGRAM_BOT_TOKEN")
|
||||
@@ -2860,6 +2839,7 @@ def setup_gateway(config: dict):
|
||||
or get_env_value("MATRIX_ACCESS_TOKEN")
|
||||
or get_env_value("MATRIX_PASSWORD")
|
||||
or get_env_value("WHATSAPP_ENABLED")
|
||||
or get_env_value("WEBHOOK_ENABLED")
|
||||
)
|
||||
if any_messaging:
|
||||
print()
|
||||
|
||||
+5
-1
@@ -181,7 +181,11 @@ class SessionDB:
|
||||
]
|
||||
for name, column_type in new_columns:
|
||||
try:
|
||||
cursor.execute(f"ALTER TABLE sessions ADD COLUMN {name} {column_type}")
|
||||
# name and column_type come from the hardcoded tuple above,
|
||||
# not user input. Double-quote identifier escaping is applied
|
||||
# as defense-in-depth; SQLite DDL cannot be parameterized.
|
||||
safe_name = name.replace('"', '""')
|
||||
cursor.execute(f'ALTER TABLE sessions ADD COLUMN "{safe_name}" {column_type}')
|
||||
except sqlite3.OperationalError:
|
||||
pass
|
||||
cursor.execute("UPDATE schema_version SET version = 5")
|
||||
|
||||
@@ -117,11 +117,13 @@ class HonchoClientConfig:
|
||||
def from_env(cls, workspace_id: str = "hermes") -> HonchoClientConfig:
|
||||
"""Create config from environment variables (fallback)."""
|
||||
api_key = os.environ.get("HONCHO_API_KEY")
|
||||
base_url = os.environ.get("HONCHO_BASE_URL", "").strip() or None
|
||||
return cls(
|
||||
workspace_id=workspace_id,
|
||||
api_key=api_key,
|
||||
environment=os.environ.get("HONCHO_ENVIRONMENT", "production"),
|
||||
enabled=bool(api_key),
|
||||
base_url=base_url,
|
||||
enabled=bool(api_key or base_url),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -171,8 +173,14 @@ class HonchoClientConfig:
|
||||
or raw.get("environment", "production")
|
||||
)
|
||||
|
||||
# Auto-enable when API key is present (unless explicitly disabled)
|
||||
# Host-level enabled wins, then root-level, then auto-enable if key exists.
|
||||
base_url = (
|
||||
raw.get("baseUrl")
|
||||
or os.environ.get("HONCHO_BASE_URL", "").strip()
|
||||
or None
|
||||
)
|
||||
|
||||
# Auto-enable when API key or base_url is present (unless explicitly disabled)
|
||||
# Host-level enabled wins, then root-level, then auto-enable if key/url exists.
|
||||
host_enabled = host_block.get("enabled")
|
||||
root_enabled = raw.get("enabled")
|
||||
if host_enabled is not None:
|
||||
@@ -180,8 +188,8 @@ class HonchoClientConfig:
|
||||
elif root_enabled is not None:
|
||||
enabled = root_enabled
|
||||
else:
|
||||
# Not explicitly set anywhere -> auto-enable if API key exists
|
||||
enabled = bool(api_key)
|
||||
# Not explicitly set anywhere -> auto-enable if API key or base_url exists
|
||||
enabled = bool(api_key or base_url)
|
||||
|
||||
# write_frequency: accept int or string
|
||||
raw_wf = (
|
||||
@@ -214,6 +222,7 @@ class HonchoClientConfig:
|
||||
workspace_id=workspace,
|
||||
api_key=api_key,
|
||||
environment=environment,
|
||||
base_url=base_url,
|
||||
peer_name=host_block.get("peerName") or raw.get("peerName"),
|
||||
ai_peer=ai_peer,
|
||||
linked_hosts=linked_hosts,
|
||||
@@ -348,11 +357,12 @@ def get_honcho_client(config: HonchoClientConfig | None = None) -> Honcho:
|
||||
if config is None:
|
||||
config = HonchoClientConfig.from_global_config()
|
||||
|
||||
if not config.api_key:
|
||||
if not config.api_key and not config.base_url:
|
||||
raise ValueError(
|
||||
"Honcho API key not found. "
|
||||
"Get your API key at https://app.honcho.dev, "
|
||||
"then run 'hermes honcho setup' or set HONCHO_API_KEY."
|
||||
"then run 'hermes honcho setup' or set HONCHO_API_KEY. "
|
||||
"For local instances, set HONCHO_BASE_URL instead."
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
+96
-5
@@ -24,6 +24,7 @@ import json
|
||||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
import threading
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
|
||||
from tools.registry import registry
|
||||
@@ -36,6 +37,48 @@ logger = logging.getLogger(__name__)
|
||||
# Async Bridging (single source of truth -- used by registry.dispatch too)
|
||||
# =============================================================================
|
||||
|
||||
_tool_loop = None # persistent loop for the main (CLI) thread
|
||||
_tool_loop_lock = threading.Lock()
|
||||
_worker_thread_local = threading.local() # per-worker-thread persistent loops
|
||||
|
||||
|
||||
def _get_tool_loop():
|
||||
"""Return a long-lived event loop for running async tool handlers.
|
||||
|
||||
Using a persistent loop (instead of asyncio.run() which creates and
|
||||
*closes* a fresh loop every time) prevents "Event loop is closed"
|
||||
errors that occur when cached httpx/AsyncOpenAI clients attempt to
|
||||
close their transport on a dead loop during garbage collection.
|
||||
"""
|
||||
global _tool_loop
|
||||
with _tool_loop_lock:
|
||||
if _tool_loop is None or _tool_loop.is_closed():
|
||||
_tool_loop = asyncio.new_event_loop()
|
||||
return _tool_loop
|
||||
|
||||
|
||||
def _get_worker_loop():
|
||||
"""Return a persistent event loop for the current worker thread.
|
||||
|
||||
Each worker thread (e.g., delegate_task's ThreadPoolExecutor threads)
|
||||
gets its own long-lived loop stored in thread-local storage. This
|
||||
prevents the "Event loop is closed" errors that occurred when
|
||||
asyncio.run() was used per-call: asyncio.run() creates a loop, runs
|
||||
the coroutine, then *closes* the loop — but cached httpx/AsyncOpenAI
|
||||
clients remain bound to that now-dead loop and raise RuntimeError
|
||||
during garbage collection or subsequent use.
|
||||
|
||||
By keeping the loop alive for the thread's lifetime, cached clients
|
||||
stay valid and their cleanup runs on a live loop.
|
||||
"""
|
||||
loop = getattr(_worker_thread_local, 'loop', None)
|
||||
if loop is None or loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
_worker_thread_local.loop = loop
|
||||
return loop
|
||||
|
||||
|
||||
def _run_async(coro):
|
||||
"""Run an async coroutine from a sync context.
|
||||
|
||||
@@ -44,6 +87,15 @@ def _run_async(coro):
|
||||
disposable thread so asyncio.run() can create its own loop without
|
||||
conflicting.
|
||||
|
||||
For the common CLI path (no running loop), we use a persistent event
|
||||
loop so that cached async clients (httpx / AsyncOpenAI) remain bound
|
||||
to a live loop and don't trigger "Event loop is closed" on GC.
|
||||
|
||||
When called from a worker thread (parallel tool execution), we use a
|
||||
per-thread persistent loop to avoid both contention with the main
|
||||
thread's shared loop AND the "Event loop is closed" errors caused by
|
||||
asyncio.run()'s create-and-destroy lifecycle.
|
||||
|
||||
This is the single source of truth for sync->async bridging in tool
|
||||
handlers. The RL paths (agent_loop.py, tool_context.py) also provide
|
||||
outer thread-pool wrapping as defense-in-depth, but each handler is
|
||||
@@ -55,11 +107,23 @@ def _run_async(coro):
|
||||
loop = None
|
||||
|
||||
if loop and loop.is_running():
|
||||
# Inside an async context (gateway, RL env) — run in a fresh thread.
|
||||
import concurrent.futures
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
future = pool.submit(asyncio.run, coro)
|
||||
return future.result(timeout=300)
|
||||
return asyncio.run(coro)
|
||||
|
||||
# If we're on a worker thread (e.g., parallel tool execution in
|
||||
# delegate_task), use a per-thread persistent loop. This avoids
|
||||
# contention with the main thread's shared loop while keeping cached
|
||||
# httpx/AsyncOpenAI clients bound to a live loop for the thread's
|
||||
# lifetime — preventing "Event loop is closed" on GC cleanup.
|
||||
if threading.current_thread() is not threading.main_thread():
|
||||
worker_loop = _get_worker_loop()
|
||||
return worker_loop.run_until_complete(coro)
|
||||
|
||||
tool_loop = _get_tool_loop()
|
||||
return tool_loop.run_until_complete(coro)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -242,18 +306,45 @@ def get_tool_definitions(
|
||||
# Ask the registry for schemas (only returns tools whose check_fn passes)
|
||||
filtered_tools = registry.get_definitions(tools_to_include, quiet=quiet_mode)
|
||||
|
||||
# The set of tool names that actually passed check_fn filtering.
|
||||
# Use this (not tools_to_include) for any downstream schema that references
|
||||
# other tools by name — otherwise the model sees tools mentioned in
|
||||
# descriptions that don't actually exist, and hallucinates calls to them.
|
||||
available_tool_names = {t["function"]["name"] for t in filtered_tools}
|
||||
|
||||
# Rebuild execute_code schema to only list sandbox tools that are actually
|
||||
# enabled. Without this, the model sees "web_search is available in
|
||||
# execute_code" even when the user disabled the web toolset (#560-discord).
|
||||
if "execute_code" in tools_to_include:
|
||||
# available. Without this, the model sees "web_search is available in
|
||||
# execute_code" even when the API key isn't configured or the toolset is
|
||||
# disabled (#560-discord).
|
||||
if "execute_code" in available_tool_names:
|
||||
from tools.code_execution_tool import SANDBOX_ALLOWED_TOOLS, build_execute_code_schema
|
||||
sandbox_enabled = SANDBOX_ALLOWED_TOOLS & tools_to_include
|
||||
sandbox_enabled = SANDBOX_ALLOWED_TOOLS & available_tool_names
|
||||
dynamic_schema = build_execute_code_schema(sandbox_enabled)
|
||||
for i, td in enumerate(filtered_tools):
|
||||
if td.get("function", {}).get("name") == "execute_code":
|
||||
filtered_tools[i] = {"type": "function", "function": dynamic_schema}
|
||||
break
|
||||
|
||||
# Strip web tool cross-references from browser_navigate description when
|
||||
# web_search / web_extract are not available. The static schema says
|
||||
# "prefer web_search or web_extract" which causes the model to hallucinate
|
||||
# those tools when they're missing.
|
||||
if "browser_navigate" in available_tool_names:
|
||||
web_tools_available = {"web_search", "web_extract"} & available_tool_names
|
||||
if not web_tools_available:
|
||||
for i, td in enumerate(filtered_tools):
|
||||
if td.get("function", {}).get("name") == "browser_navigate":
|
||||
desc = td["function"].get("description", "")
|
||||
desc = desc.replace(
|
||||
" For simple information retrieval, prefer web_search or web_extract (faster, cheaper).",
|
||||
"",
|
||||
)
|
||||
filtered_tools[i] = {
|
||||
"type": "function",
|
||||
"function": {**td["function"], "description": desc},
|
||||
}
|
||||
break
|
||||
|
||||
if not quiet_mode:
|
||||
if filtered_tools:
|
||||
tool_names = [t["function"]["name"] for t in filtered_tools]
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
# MCP
|
||||
|
||||
Skills for building, testing, and deploying MCP (Model Context Protocol) servers.
|
||||
@@ -0,0 +1,299 @@
|
||||
---
|
||||
name: fastmcp
|
||||
description: Build, test, inspect, install, and deploy MCP servers with FastMCP in Python. Use when creating a new MCP server, wrapping an API or database as MCP tools, exposing resources or prompts, or preparing a FastMCP server for Claude Code, Cursor, or HTTP deployment.
|
||||
version: 1.0.0
|
||||
author: Hermes Agent
|
||||
license: MIT
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [MCP, FastMCP, Python, Tools, Resources, Prompts, Deployment]
|
||||
homepage: https://gofastmcp.com
|
||||
related_skills: [native-mcp, mcporter]
|
||||
prerequisites:
|
||||
commands: [python3]
|
||||
---
|
||||
|
||||
# FastMCP
|
||||
|
||||
Build MCP servers in Python with FastMCP, validate them locally, install them into MCP clients, and deploy them as HTTP endpoints.
|
||||
|
||||
## When to Use
|
||||
|
||||
Use this skill when the task is to:
|
||||
|
||||
- create a new MCP server in Python
|
||||
- wrap an API, database, CLI, or file-processing workflow as MCP tools
|
||||
- expose resources or prompts in addition to tools
|
||||
- smoke-test a server with the FastMCP CLI before wiring it into Hermes or another client
|
||||
- install a server into Claude Code, Claude Desktop, Cursor, or a similar MCP client
|
||||
- prepare a FastMCP server repo for HTTP deployment
|
||||
|
||||
Use `native-mcp` when the server already exists and only needs to be connected to Hermes. Use `mcporter` when the goal is ad-hoc CLI access to an existing MCP server instead of building one.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Install FastMCP in the working environment first:
|
||||
|
||||
```bash
|
||||
pip install fastmcp
|
||||
fastmcp version
|
||||
```
|
||||
|
||||
For the API template, install `httpx` if it is not already present:
|
||||
|
||||
```bash
|
||||
pip install httpx
|
||||
```
|
||||
|
||||
## Included Files
|
||||
|
||||
### Templates
|
||||
|
||||
- `templates/api_wrapper.py` - REST API wrapper with auth header support
|
||||
- `templates/database_server.py` - read-only SQLite query server
|
||||
- `templates/file_processor.py` - text-file inspection and search server
|
||||
|
||||
### Scripts
|
||||
|
||||
- `scripts/scaffold_fastmcp.py` - copy a starter template and replace the server name placeholder
|
||||
|
||||
### References
|
||||
|
||||
- `references/fastmcp-cli.md` - FastMCP CLI workflow, installation targets, and deployment checks
|
||||
|
||||
## Workflow
|
||||
|
||||
### 1. Pick the Smallest Viable Server Shape
|
||||
|
||||
Choose the narrowest useful surface area first:
|
||||
|
||||
- API wrapper: start with 1-3 high-value endpoints, not the whole API
|
||||
- database server: expose read-only introspection and a constrained query path
|
||||
- file processor: expose deterministic operations with explicit path arguments
|
||||
- prompts/resources: add only when the client needs reusable prompt templates or discoverable documents
|
||||
|
||||
Prefer a thin server with good names, docstrings, and schemas over a large server with vague tools.
|
||||
|
||||
### 2. Scaffold from a Template
|
||||
|
||||
Copy a template directly or use the scaffold helper:
|
||||
|
||||
```bash
|
||||
python ~/.hermes/skills/mcp/fastmcp/scripts/scaffold_fastmcp.py \
|
||||
--template api_wrapper \
|
||||
--name "Acme API" \
|
||||
--output ./acme_server.py
|
||||
```
|
||||
|
||||
Available templates:
|
||||
|
||||
```bash
|
||||
python ~/.hermes/skills/mcp/fastmcp/scripts/scaffold_fastmcp.py --list
|
||||
```
|
||||
|
||||
If copying manually, replace `__SERVER_NAME__` with a real server name.
|
||||
|
||||
### 3. Implement Tools First
|
||||
|
||||
Start with `@mcp.tool` functions before adding resources or prompts.
|
||||
|
||||
Rules for tool design:
|
||||
|
||||
- Give every tool a concrete verb-based name
|
||||
- Write docstrings as user-facing tool descriptions
|
||||
- Keep parameters explicit and typed
|
||||
- Return structured JSON-safe data where possible
|
||||
- Validate unsafe inputs early
|
||||
- Prefer read-only behavior by default for first versions
|
||||
|
||||
Good tool examples:
|
||||
|
||||
- `get_customer`
|
||||
- `search_tickets`
|
||||
- `describe_table`
|
||||
- `summarize_text_file`
|
||||
|
||||
Weak tool examples:
|
||||
|
||||
- `run`
|
||||
- `process`
|
||||
- `do_thing`
|
||||
|
||||
### 4. Add Resources and Prompts Only When They Help
|
||||
|
||||
Add `@mcp.resource` when the client benefits from fetching stable read-only content such as schemas, policy docs, or generated reports.
|
||||
|
||||
Add `@mcp.prompt` when the server should provide a reusable prompt template for a known workflow.
|
||||
|
||||
Do not turn every document into a prompt. Prefer:
|
||||
|
||||
- tools for actions
|
||||
- resources for data/document retrieval
|
||||
- prompts for reusable LLM instructions
|
||||
|
||||
### 5. Test the Server Before Integrating It Anywhere
|
||||
|
||||
Use the FastMCP CLI for local validation:
|
||||
|
||||
```bash
|
||||
fastmcp inspect acme_server.py:mcp
|
||||
fastmcp list acme_server.py --json
|
||||
fastmcp call acme_server.py search_resources query=router limit=5 --json
|
||||
```
|
||||
|
||||
For fast iterative debugging, run the server locally:
|
||||
|
||||
```bash
|
||||
fastmcp run acme_server.py:mcp
|
||||
```
|
||||
|
||||
To test HTTP transport locally:
|
||||
|
||||
```bash
|
||||
fastmcp run acme_server.py:mcp --transport http --host 127.0.0.1 --port 8000
|
||||
fastmcp list http://127.0.0.1:8000/mcp --json
|
||||
fastmcp call http://127.0.0.1:8000/mcp search_resources query=router --json
|
||||
```
|
||||
|
||||
Always run at least one real `fastmcp call` against each new tool before claiming the server works.
|
||||
|
||||
### 6. Install into a Client When Local Validation Passes
|
||||
|
||||
FastMCP can register the server with supported MCP clients:
|
||||
|
||||
```bash
|
||||
fastmcp install claude-code acme_server.py
|
||||
fastmcp install claude-desktop acme_server.py
|
||||
fastmcp install cursor acme_server.py -e .
|
||||
```
|
||||
|
||||
Use `fastmcp discover` to inspect named MCP servers already configured on the machine.
|
||||
|
||||
When the goal is Hermes integration, either:
|
||||
|
||||
- configure the server in `~/.hermes/config.yaml` using the `native-mcp` skill, or
|
||||
- keep using FastMCP CLI commands during development until the interface stabilizes
|
||||
|
||||
### 7. Deploy After the Local Contract Is Stable
|
||||
|
||||
For managed hosting, Prefect Horizon is the path FastMCP documents most directly. Before deployment:
|
||||
|
||||
```bash
|
||||
fastmcp inspect acme_server.py:mcp
|
||||
```
|
||||
|
||||
Make sure the repo contains:
|
||||
|
||||
- a Python file with the FastMCP server object
|
||||
- `requirements.txt` or `pyproject.toml`
|
||||
- any environment-variable documentation needed for deployment
|
||||
|
||||
For generic HTTP hosting, validate the HTTP transport locally first, then deploy on any Python-compatible platform that can expose the server port.
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### API Wrapper Pattern
|
||||
|
||||
Use when exposing a REST or HTTP API as MCP tools.
|
||||
|
||||
Recommended first slice:
|
||||
|
||||
- one read path
|
||||
- one list/search path
|
||||
- optional health check
|
||||
|
||||
Implementation notes:
|
||||
|
||||
- keep auth in environment variables, not hardcoded
|
||||
- centralize request logic in one helper
|
||||
- surface API errors with concise context
|
||||
- normalize inconsistent upstream payloads before returning them
|
||||
|
||||
Start from `templates/api_wrapper.py`.
|
||||
|
||||
### Database Pattern
|
||||
|
||||
Use when exposing safe query and inspection capabilities.
|
||||
|
||||
Recommended first slice:
|
||||
|
||||
- `list_tables`
|
||||
- `describe_table`
|
||||
- one constrained read query tool
|
||||
|
||||
Implementation notes:
|
||||
|
||||
- default to read-only DB access
|
||||
- reject non-`SELECT` SQL in early versions
|
||||
- limit row counts
|
||||
- return rows plus column names
|
||||
|
||||
Start from `templates/database_server.py`.
|
||||
|
||||
### File Processor Pattern
|
||||
|
||||
Use when the server needs to inspect or transform files on demand.
|
||||
|
||||
Recommended first slice:
|
||||
|
||||
- summarize file contents
|
||||
- search within files
|
||||
- extract deterministic metadata
|
||||
|
||||
Implementation notes:
|
||||
|
||||
- accept explicit file paths
|
||||
- check for missing files and encoding failures
|
||||
- cap previews and result counts
|
||||
- avoid shelling out unless a specific external tool is required
|
||||
|
||||
Start from `templates/file_processor.py`.
|
||||
|
||||
## Quality Bar
|
||||
|
||||
Before handing off a FastMCP server, verify all of the following:
|
||||
|
||||
- server imports cleanly
|
||||
- `fastmcp inspect <file.py:mcp>` succeeds
|
||||
- `fastmcp list <server spec> --json` succeeds
|
||||
- every new tool has at least one real `fastmcp call`
|
||||
- environment variables are documented
|
||||
- the tool surface is small enough to understand without guesswork
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### FastMCP command missing
|
||||
|
||||
Install the package in the active environment:
|
||||
|
||||
```bash
|
||||
pip install fastmcp
|
||||
fastmcp version
|
||||
```
|
||||
|
||||
### `fastmcp inspect` fails
|
||||
|
||||
Check that:
|
||||
|
||||
- the file imports without side effects that crash
|
||||
- the FastMCP instance is named correctly in `<file.py:object>`
|
||||
- optional dependencies from the template are installed
|
||||
|
||||
### Tool works in Python but not through CLI
|
||||
|
||||
Run:
|
||||
|
||||
```bash
|
||||
fastmcp list server.py --json
|
||||
fastmcp call server.py your_tool_name --json
|
||||
```
|
||||
|
||||
This usually exposes naming mismatches, missing required arguments, or non-serializable return values.
|
||||
|
||||
### Hermes cannot see the deployed server
|
||||
|
||||
The server-building part may be correct while the Hermes config is not. Load the `native-mcp` skill and configure the server in `~/.hermes/config.yaml`, then restart Hermes.
|
||||
|
||||
## References
|
||||
|
||||
For CLI details, install targets, and deployment checks, read `references/fastmcp-cli.md`.
|
||||
@@ -0,0 +1,110 @@
|
||||
# FastMCP CLI Reference
|
||||
|
||||
Use this file when the task needs exact FastMCP CLI workflows rather than the higher-level guidance in `SKILL.md`.
|
||||
|
||||
## Install and Verify
|
||||
|
||||
```bash
|
||||
pip install fastmcp
|
||||
fastmcp version
|
||||
```
|
||||
|
||||
FastMCP documents `pip install fastmcp` and `fastmcp version` as the baseline installation and verification path.
|
||||
|
||||
## Run a Server
|
||||
|
||||
Run a server object from a Python file:
|
||||
|
||||
```bash
|
||||
fastmcp run server.py:mcp
|
||||
```
|
||||
|
||||
Run the same server over HTTP:
|
||||
|
||||
```bash
|
||||
fastmcp run server.py:mcp --transport http --host 127.0.0.1 --port 8000
|
||||
```
|
||||
|
||||
## Inspect a Server
|
||||
|
||||
Inspect what FastMCP will expose:
|
||||
|
||||
```bash
|
||||
fastmcp inspect server.py:mcp
|
||||
```
|
||||
|
||||
This is also the check FastMCP recommends before deploying to Prefect Horizon.
|
||||
|
||||
## List and Call Tools
|
||||
|
||||
List tools from a Python file:
|
||||
|
||||
```bash
|
||||
fastmcp list server.py --json
|
||||
```
|
||||
|
||||
List tools from an HTTP endpoint:
|
||||
|
||||
```bash
|
||||
fastmcp list http://127.0.0.1:8000/mcp --json
|
||||
```
|
||||
|
||||
Call a tool with key-value arguments:
|
||||
|
||||
```bash
|
||||
fastmcp call server.py search_resources query=router limit=5 --json
|
||||
```
|
||||
|
||||
Call a tool with a full JSON input payload:
|
||||
|
||||
```bash
|
||||
fastmcp call server.py create_item '{"name": "Widget", "tags": ["sale"]}' --json
|
||||
```
|
||||
|
||||
## Discover Named MCP Servers
|
||||
|
||||
Find named servers already configured in local MCP-aware tools:
|
||||
|
||||
```bash
|
||||
fastmcp discover
|
||||
```
|
||||
|
||||
FastMCP documents name-based resolution for Claude Desktop, Claude Code, Cursor, Gemini, Goose, and `./mcp.json`.
|
||||
|
||||
## Install into MCP Clients
|
||||
|
||||
Register a server with common clients:
|
||||
|
||||
```bash
|
||||
fastmcp install claude-code server.py
|
||||
fastmcp install claude-desktop server.py
|
||||
fastmcp install cursor server.py -e .
|
||||
```
|
||||
|
||||
FastMCP notes that client installs run in isolated environments, so declare dependencies explicitly when needed with flags such as `--with`, `--env-file`, or editable installs.
|
||||
|
||||
## Deployment Checks
|
||||
|
||||
### Prefect Horizon
|
||||
|
||||
Before pushing to Horizon:
|
||||
|
||||
```bash
|
||||
fastmcp inspect server.py:mcp
|
||||
```
|
||||
|
||||
FastMCP’s Horizon docs expect:
|
||||
|
||||
- a GitHub repo
|
||||
- a Python file containing the FastMCP server object
|
||||
- dependencies declared in `requirements.txt` or `pyproject.toml`
|
||||
- an entrypoint like `main.py:mcp`
|
||||
|
||||
### Generic HTTP Hosting
|
||||
|
||||
Before shipping to any other host:
|
||||
|
||||
1. Start the server locally with HTTP transport.
|
||||
2. Verify `fastmcp list` against the local `/mcp` URL.
|
||||
3. Verify at least one `fastmcp call`.
|
||||
4. Document required environment variables.
|
||||
@@ -0,0 +1,56 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Copy a FastMCP starter template into a working file."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
SCRIPT_DIR = Path(__file__).resolve().parent
|
||||
SKILL_DIR = SCRIPT_DIR.parent
|
||||
TEMPLATE_DIR = SKILL_DIR / "templates"
|
||||
PLACEHOLDER = "__SERVER_NAME__"
|
||||
|
||||
|
||||
def list_templates() -> list[str]:
|
||||
return sorted(path.stem for path in TEMPLATE_DIR.glob("*.py"))
|
||||
|
||||
|
||||
def render_template(template_name: str, server_name: str) -> str:
|
||||
template_path = TEMPLATE_DIR / f"{template_name}.py"
|
||||
if not template_path.exists():
|
||||
available = ", ".join(list_templates())
|
||||
raise SystemExit(f"Unknown template '{template_name}'. Available: {available}")
|
||||
return template_path.read_text(encoding="utf-8").replace(PLACEHOLDER, server_name)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--template", help="Template name without .py suffix")
|
||||
parser.add_argument("--name", help="FastMCP server display name")
|
||||
parser.add_argument("--output", help="Destination Python file path")
|
||||
parser.add_argument("--force", action="store_true", help="Overwrite an existing output file")
|
||||
parser.add_argument("--list", action="store_true", help="List available templates and exit")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.list:
|
||||
for name in list_templates():
|
||||
print(name)
|
||||
return 0
|
||||
|
||||
if not args.template or not args.name or not args.output:
|
||||
parser.error("--template, --name, and --output are required unless --list is used")
|
||||
|
||||
output_path = Path(args.output).expanduser()
|
||||
if output_path.exists() and not args.force:
|
||||
raise SystemExit(f"Refusing to overwrite existing file: {output_path}")
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_path.write_text(render_template(args.template, args.name), encoding="utf-8")
|
||||
print(f"Wrote {output_path}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -0,0 +1,54 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from fastmcp import FastMCP
|
||||
|
||||
|
||||
mcp = FastMCP("__SERVER_NAME__")
|
||||
|
||||
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.example.com")
|
||||
API_TOKEN = os.getenv("API_TOKEN")
|
||||
REQUEST_TIMEOUT = float(os.getenv("API_TIMEOUT_SECONDS", "20"))
|
||||
|
||||
|
||||
def _headers() -> dict[str, str]:
|
||||
headers = {"Accept": "application/json"}
|
||||
if API_TOKEN:
|
||||
headers["Authorization"] = f"Bearer {API_TOKEN}"
|
||||
return headers
|
||||
|
||||
|
||||
def _request(method: str, path: str, *, params: dict[str, Any] | None = None) -> Any:
|
||||
url = f"{API_BASE_URL.rstrip('/')}/{path.lstrip('/')}"
|
||||
with httpx.Client(timeout=REQUEST_TIMEOUT, headers=_headers()) as client:
|
||||
response = client.request(method, url, params=params)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
@mcp.tool
|
||||
def health_check() -> dict[str, Any]:
|
||||
"""Check whether the upstream API is reachable."""
|
||||
payload = _request("GET", "/health")
|
||||
return {"base_url": API_BASE_URL, "result": payload}
|
||||
|
||||
|
||||
@mcp.tool
|
||||
def get_resource(resource_id: str) -> dict[str, Any]:
|
||||
"""Fetch one resource by ID from the upstream API."""
|
||||
payload = _request("GET", f"/resources/{resource_id}")
|
||||
return {"resource_id": resource_id, "data": payload}
|
||||
|
||||
|
||||
@mcp.tool
|
||||
def search_resources(query: str, limit: int = 10) -> dict[str, Any]:
|
||||
"""Search upstream resources by query string."""
|
||||
payload = _request("GET", "/resources", params={"q": query, "limit": limit})
|
||||
return {"query": query, "limit": limit, "results": payload}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mcp.run()
|
||||
@@ -0,0 +1,77 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
import sqlite3
|
||||
from typing import Any
|
||||
|
||||
from fastmcp import FastMCP
|
||||
|
||||
|
||||
mcp = FastMCP("__SERVER_NAME__")
|
||||
|
||||
DATABASE_PATH = os.getenv("SQLITE_PATH", "./app.db")
|
||||
MAX_ROWS = int(os.getenv("SQLITE_MAX_ROWS", "200"))
|
||||
TABLE_NAME_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
||||
|
||||
|
||||
def _connect() -> sqlite3.Connection:
|
||||
return sqlite3.connect(f"file:{DATABASE_PATH}?mode=ro", uri=True)
|
||||
|
||||
|
||||
def _reject_mutation(sql: str) -> None:
|
||||
normalized = sql.strip().lower()
|
||||
if not normalized.startswith("select"):
|
||||
raise ValueError("Only SELECT queries are allowed")
|
||||
|
||||
|
||||
def _validate_table_name(table_name: str) -> str:
|
||||
if not TABLE_NAME_RE.fullmatch(table_name):
|
||||
raise ValueError("Invalid table name")
|
||||
return table_name
|
||||
|
||||
|
||||
@mcp.tool
|
||||
def list_tables() -> list[str]:
|
||||
"""List user-defined SQLite tables."""
|
||||
with _connect() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name"
|
||||
).fetchall()
|
||||
return [row[0] for row in rows]
|
||||
|
||||
|
||||
@mcp.tool
|
||||
def describe_table(table_name: str) -> list[dict[str, Any]]:
|
||||
"""Describe columns for a SQLite table."""
|
||||
safe_table_name = _validate_table_name(table_name)
|
||||
with _connect() as conn:
|
||||
rows = conn.execute(f"PRAGMA table_info({safe_table_name})").fetchall()
|
||||
return [
|
||||
{
|
||||
"cid": row[0],
|
||||
"name": row[1],
|
||||
"type": row[2],
|
||||
"notnull": bool(row[3]),
|
||||
"default": row[4],
|
||||
"pk": bool(row[5]),
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
|
||||
@mcp.tool
|
||||
def query(sql: str, limit: int = 50) -> dict[str, Any]:
|
||||
"""Run a read-only SELECT query and return rows plus column names."""
|
||||
_reject_mutation(sql)
|
||||
safe_limit = max(0, min(limit, MAX_ROWS))
|
||||
wrapped_sql = f"SELECT * FROM ({sql.strip().rstrip(';')}) LIMIT {safe_limit}"
|
||||
with _connect() as conn:
|
||||
cursor = conn.execute(wrapped_sql)
|
||||
columns = [column[0] for column in cursor.description or []]
|
||||
rows = [dict(zip(columns, row)) for row in cursor.fetchall()]
|
||||
return {"limit": safe_limit, "columns": columns, "rows": rows}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mcp.run()
|
||||
@@ -0,0 +1,55 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from fastmcp import FastMCP
|
||||
|
||||
|
||||
mcp = FastMCP("__SERVER_NAME__")
|
||||
|
||||
|
||||
def _read_text(path: str) -> str:
|
||||
file_path = Path(path).expanduser()
|
||||
try:
|
||||
return file_path.read_text(encoding="utf-8")
|
||||
except FileNotFoundError as exc:
|
||||
raise ValueError(f"File not found: {file_path}") from exc
|
||||
except UnicodeDecodeError as exc:
|
||||
raise ValueError(f"File is not valid UTF-8 text: {file_path}") from exc
|
||||
|
||||
|
||||
@mcp.tool
|
||||
def summarize_text_file(path: str, preview_chars: int = 1200) -> dict[str, int | str]:
|
||||
"""Return basic metadata and a preview for a UTF-8 text file."""
|
||||
file_path = Path(path).expanduser()
|
||||
text = _read_text(path)
|
||||
return {
|
||||
"path": str(file_path),
|
||||
"characters": len(text),
|
||||
"lines": len(text.splitlines()),
|
||||
"preview": text[:preview_chars],
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool
|
||||
def search_text_file(path: str, needle: str, max_matches: int = 20) -> dict[str, Any]:
|
||||
"""Find matching lines in a UTF-8 text file."""
|
||||
file_path = Path(path).expanduser()
|
||||
matches: list[dict[str, Any]] = []
|
||||
for line_number, line in enumerate(_read_text(path).splitlines(), start=1):
|
||||
if needle.lower() in line.lower():
|
||||
matches.append({"line_number": line_number, "line": line})
|
||||
if len(matches) >= max_matches:
|
||||
break
|
||||
return {"path": str(file_path), "needle": needle, "matches": matches}
|
||||
|
||||
|
||||
@mcp.resource("file://{path}")
|
||||
def read_file_resource(path: str) -> str:
|
||||
"""Expose a text file as a resource."""
|
||||
return _read_text(path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mcp.run()
|
||||
+1
-1
@@ -92,7 +92,7 @@ hermes-agent = "run_agent:main"
|
||||
hermes-acp = "acp_adapter.entry:main"
|
||||
|
||||
[tool.setuptools]
|
||||
py-modules = ["run_agent", "model_tools", "toolsets", "batch_runner", "trajectory_compressor", "toolset_distributions", "cli", "hermes_constants", "hermes_state", "hermes_time", "mini_swe_runner", "rl_cli", "utils"]
|
||||
py-modules = ["run_agent", "model_tools", "toolsets", "batch_runner", "trajectory_compressor", "toolset_distributions", "cli", "hermes_constants", "hermes_state", "hermes_time", "mini_swe_runner", "minisweagent_path", "rl_cli", "utils"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["agent", "tools", "tools.*", "hermes_cli", "gateway", "gateway.*", "cron", "honcho_integration", "acp_adapter"]
|
||||
|
||||
+253
-33
@@ -400,6 +400,7 @@ class AIAgent:
|
||||
clarify_callback: callable = None,
|
||||
step_callback: callable = None,
|
||||
stream_delta_callback: callable = None,
|
||||
status_callback: callable = None,
|
||||
max_tokens: int = None,
|
||||
reasoning_config: Dict[str, Any] = None,
|
||||
prefill_messages: List[Dict[str, Any]] = None,
|
||||
@@ -493,9 +494,20 @@ class AIAgent:
|
||||
elif self.provider == "anthropic" or (provider_name is None and "api.anthropic.com" in self._base_url_lower):
|
||||
self.api_mode = "anthropic_messages"
|
||||
self.provider = "anthropic"
|
||||
elif self._base_url_lower.rstrip("/").endswith("/anthropic"):
|
||||
# Third-party Anthropic-compatible endpoints (e.g. MiniMax, DashScope)
|
||||
# use a URL convention ending in /anthropic. Auto-detect these so the
|
||||
# Anthropic Messages API adapter is used instead of chat completions.
|
||||
self.api_mode = "anthropic_messages"
|
||||
else:
|
||||
self.api_mode = "chat_completions"
|
||||
|
||||
# Direct OpenAI sessions use the Responses API path. GPT-5.x tool
|
||||
# calls with reasoning are rejected on /v1/chat/completions, and
|
||||
# Hermes is a tool-using client by default.
|
||||
if self.api_mode == "chat_completions" and self._is_direct_openai_url():
|
||||
self.api_mode = "codex_responses"
|
||||
|
||||
# Pre-warm OpenRouter model metadata cache in a background thread.
|
||||
# fetch_model_metadata() is cached for 1 hour; this avoids a blocking
|
||||
# HTTP request on the first API response when pricing is estimated.
|
||||
@@ -511,8 +523,13 @@ class AIAgent:
|
||||
self.clarify_callback = clarify_callback
|
||||
self.step_callback = step_callback
|
||||
self.stream_delta_callback = stream_delta_callback
|
||||
self.status_callback = status_callback
|
||||
self._last_reported_tool = None # Track for "new tool" mode
|
||||
|
||||
# Tool execution state — allows _vprint during tool execution
|
||||
# even when stream consumers are registered (no tokens streaming then)
|
||||
self._executing_tools = False
|
||||
|
||||
# Interrupt mechanism for breaking out of tool loops
|
||||
self._interrupt_requested = False
|
||||
self._interrupt_message = None # Optional message that triggered interrupt
|
||||
@@ -556,6 +573,12 @@ class AIAgent:
|
||||
self._budget_warning_threshold = 0.9 # 90% — urgent, respond now
|
||||
self._budget_pressure_enabled = True
|
||||
|
||||
# Context pressure warnings: notify the USER (not the LLM) as context
|
||||
# fills up. Purely informational — displayed in CLI output and sent via
|
||||
# status_callback for gateway platforms. Does NOT inject into messages.
|
||||
self._context_50_warned = False
|
||||
self._context_70_warned = False
|
||||
|
||||
# Persistent error log -- always writes WARNING+ to ~/.hermes/logs/errors.log
|
||||
# so tool failures, API errors, etc. are inspectable after the fact.
|
||||
# In gateway mode, each incoming message creates a new AIAgent instance,
|
||||
@@ -964,6 +987,39 @@ class AIAgent:
|
||||
compression_threshold = float(_compression_cfg.get("threshold", 0.50))
|
||||
compression_enabled = str(_compression_cfg.get("enabled", True)).lower() in ("true", "1", "yes")
|
||||
compression_summary_model = _compression_cfg.get("summary_model") or None
|
||||
|
||||
# Read explicit context_length override from model config
|
||||
_model_cfg = _agent_cfg.get("model", {})
|
||||
if isinstance(_model_cfg, dict):
|
||||
_config_context_length = _model_cfg.get("context_length")
|
||||
else:
|
||||
_config_context_length = None
|
||||
if _config_context_length is not None:
|
||||
try:
|
||||
_config_context_length = int(_config_context_length)
|
||||
except (TypeError, ValueError):
|
||||
_config_context_length = None
|
||||
|
||||
# Check custom_providers per-model context_length
|
||||
if _config_context_length is None:
|
||||
_custom_providers = _agent_cfg.get("custom_providers")
|
||||
if isinstance(_custom_providers, list):
|
||||
for _cp_entry in _custom_providers:
|
||||
if not isinstance(_cp_entry, dict):
|
||||
continue
|
||||
_cp_url = (_cp_entry.get("base_url") or "").rstrip("/")
|
||||
if _cp_url and _cp_url == self.base_url.rstrip("/"):
|
||||
_cp_models = _cp_entry.get("models", {})
|
||||
if isinstance(_cp_models, dict):
|
||||
_cp_model_cfg = _cp_models.get(self.model, {})
|
||||
if isinstance(_cp_model_cfg, dict):
|
||||
_cp_ctx = _cp_model_cfg.get("context_length")
|
||||
if _cp_ctx is not None:
|
||||
try:
|
||||
_config_context_length = int(_cp_ctx)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
break
|
||||
|
||||
self.context_compressor = ContextCompressor(
|
||||
model=self.model,
|
||||
@@ -975,6 +1031,8 @@ class AIAgent:
|
||||
quiet_mode=self.quiet_mode,
|
||||
base_url=self.base_url,
|
||||
api_key=getattr(self, "api_key", ""),
|
||||
config_context_length=_config_context_length,
|
||||
provider=self.provider,
|
||||
)
|
||||
self.compression_enabled = compression_enabled
|
||||
self._user_turn_count = 0
|
||||
@@ -998,6 +1056,46 @@ class AIAgent:
|
||||
print(f"📊 Context limit: {self.context_compressor.context_length:,} tokens (compress at {int(compression_threshold*100)}% = {self.context_compressor.threshold_tokens:,})")
|
||||
else:
|
||||
print(f"📊 Context limit: {self.context_compressor.context_length:,} tokens (auto-compression disabled)")
|
||||
|
||||
def reset_session_state(self):
|
||||
"""Reset all session-scoped token counters to 0 for a fresh session.
|
||||
|
||||
This method encapsulates the reset logic for all session-level metrics
|
||||
including:
|
||||
- Token usage counters (input, output, total, prompt, completion)
|
||||
- Cache read/write tokens
|
||||
- API call count
|
||||
- Reasoning tokens
|
||||
- Estimated cost tracking
|
||||
- Context compressor internal counters
|
||||
|
||||
The method safely handles optional attributes (e.g., context compressor)
|
||||
using ``hasattr`` checks.
|
||||
|
||||
This keeps the counter reset logic DRY and maintainable in one place
|
||||
rather than scattering it across multiple methods.
|
||||
"""
|
||||
# Token usage counters
|
||||
self.session_total_tokens = 0
|
||||
self.session_input_tokens = 0
|
||||
self.session_output_tokens = 0
|
||||
self.session_prompt_tokens = 0
|
||||
self.session_completion_tokens = 0
|
||||
self.session_cache_read_tokens = 0
|
||||
self.session_cache_write_tokens = 0
|
||||
self.session_reasoning_tokens = 0
|
||||
self.session_api_calls = 0
|
||||
self.session_estimated_cost_usd = 0.0
|
||||
self.session_cost_status = "unknown"
|
||||
self.session_cost_source = "none"
|
||||
|
||||
# Context compressor internal counters (if present)
|
||||
if hasattr(self, "context_compressor") and self.context_compressor:
|
||||
self.context_compressor.last_prompt_tokens = 0
|
||||
self.context_compressor.last_completion_tokens = 0
|
||||
self.context_compressor.last_total_tokens = 0
|
||||
self.context_compressor.compression_count = 0
|
||||
self.context_compressor._context_probed = False
|
||||
|
||||
@staticmethod
|
||||
def _safe_print(*args, **kwargs):
|
||||
@@ -1013,15 +1111,24 @@ class AIAgent:
|
||||
pass
|
||||
|
||||
def _vprint(self, *args, force: bool = False, **kwargs):
|
||||
"""Verbose print — suppressed when streaming TTS is active.
|
||||
"""Verbose print — suppressed when actively streaming tokens.
|
||||
|
||||
Pass ``force=True`` for error/warning messages that should always be
|
||||
shown even during streaming playback (TTS or display).
|
||||
|
||||
During tool execution (``_executing_tools`` is True), printing is
|
||||
allowed even with stream consumers registered because no tokens
|
||||
are being streamed at that point.
|
||||
"""
|
||||
if not force and self._has_stream_consumers():
|
||||
if not force and self._has_stream_consumers() and not self._executing_tools:
|
||||
return
|
||||
self._safe_print(*args, **kwargs)
|
||||
|
||||
def _is_direct_openai_url(self, base_url: str = None) -> bool:
|
||||
"""Return True when a base URL targets OpenAI's native API."""
|
||||
url = (base_url or self._base_url_lower).lower()
|
||||
return "api.openai.com" in url and "openrouter" not in url
|
||||
|
||||
def _max_tokens_param(self, value: int) -> dict:
|
||||
"""Return the correct max tokens kwarg for the current provider.
|
||||
|
||||
@@ -1029,41 +1136,44 @@ class AIAgent:
|
||||
'max_completion_tokens'. OpenRouter, local models, and older
|
||||
OpenAI models use 'max_tokens'.
|
||||
"""
|
||||
_is_direct_openai = (
|
||||
"api.openai.com" in self._base_url_lower
|
||||
and "openrouter" not in self._base_url_lower
|
||||
)
|
||||
if _is_direct_openai:
|
||||
if self._is_direct_openai_url():
|
||||
return {"max_completion_tokens": value}
|
||||
return {"max_tokens": value}
|
||||
|
||||
def _has_content_after_think_block(self, content: str) -> bool:
|
||||
"""
|
||||
Check if content has actual text after any <think></think> blocks.
|
||||
|
||||
Check if content has actual text after any reasoning/thinking blocks.
|
||||
|
||||
This detects cases where the model only outputs reasoning but no actual
|
||||
response, which indicates an incomplete generation that should be retried.
|
||||
|
||||
Must stay in sync with _strip_think_blocks() tag variants.
|
||||
|
||||
Args:
|
||||
content: The assistant message content to check
|
||||
|
||||
|
||||
Returns:
|
||||
True if there's meaningful content after think blocks, False otherwise
|
||||
"""
|
||||
if not content:
|
||||
return False
|
||||
|
||||
# Remove all <think>...</think> blocks (including nested ones, non-greedy)
|
||||
cleaned = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL)
|
||||
|
||||
|
||||
# Remove all reasoning tag variants (must match _strip_think_blocks)
|
||||
cleaned = self._strip_think_blocks(content)
|
||||
|
||||
# Check if there's any non-whitespace content remaining
|
||||
return bool(cleaned.strip())
|
||||
|
||||
def _strip_think_blocks(self, content: str) -> str:
|
||||
"""Remove <think>...</think> blocks from content, returning only visible text."""
|
||||
"""Remove reasoning/thinking blocks from content, returning only visible text."""
|
||||
if not content:
|
||||
return ""
|
||||
return re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL)
|
||||
# Strip all reasoning tag variants: <think>, <thinking>, <THINKING>,
|
||||
# <reasoning>, <REASONING_SCRATCHPAD>
|
||||
content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL)
|
||||
content = re.sub(r'<thinking>.*?</thinking>', '', content, flags=re.DOTALL | re.IGNORECASE)
|
||||
content = re.sub(r'<reasoning>.*?</reasoning>', '', content, flags=re.DOTALL)
|
||||
content = re.sub(r'<REASONING_SCRATCHPAD>.*?</REASONING_SCRATCHPAD>', '', content, flags=re.DOTALL)
|
||||
return content
|
||||
|
||||
def _looks_like_codex_intermediate_ack(
|
||||
self,
|
||||
@@ -2338,13 +2448,22 @@ class AIAgent:
|
||||
# Replay encrypted reasoning items from previous turns
|
||||
# so the API can maintain coherent reasoning chains.
|
||||
codex_reasoning = msg.get("codex_reasoning_items")
|
||||
has_codex_reasoning = False
|
||||
if isinstance(codex_reasoning, list):
|
||||
for ri in codex_reasoning:
|
||||
if isinstance(ri, dict) and ri.get("encrypted_content"):
|
||||
items.append(ri)
|
||||
has_codex_reasoning = True
|
||||
|
||||
if content_text.strip():
|
||||
items.append({"role": "assistant", "content": content_text})
|
||||
elif has_codex_reasoning:
|
||||
# The Responses API requires a following item after each
|
||||
# reasoning item (otherwise: missing_following_item error).
|
||||
# When the assistant produced only reasoning with no visible
|
||||
# content, emit an empty assistant message as the required
|
||||
# following item.
|
||||
items.append({"role": "assistant", "content": ""})
|
||||
|
||||
tool_calls = msg.get("tool_calls")
|
||||
if isinstance(tool_calls, list):
|
||||
@@ -2786,6 +2905,14 @@ class AIAgent:
|
||||
finish_reason = "tool_calls"
|
||||
elif has_incomplete_items or (saw_commentary_phase and not saw_final_answer_phase):
|
||||
finish_reason = "incomplete"
|
||||
elif reasoning_items_raw and not final_text:
|
||||
# Response contains only reasoning (encrypted thinking state) with
|
||||
# no visible content or tool calls. The model is still thinking and
|
||||
# needs another turn to produce the actual answer. Marking this as
|
||||
# "stop" would send it into the empty-content retry loop which burns
|
||||
# 3 retries then fails — treat it as incomplete instead so the Codex
|
||||
# continuation path handles it correctly.
|
||||
finish_reason = "incomplete"
|
||||
else:
|
||||
finish_reason = "stop"
|
||||
return assistant_message, finish_reason
|
||||
@@ -3472,13 +3599,15 @@ class AIAgent:
|
||||
fb_provider)
|
||||
return False
|
||||
|
||||
# Determine api_mode from provider
|
||||
# Determine api_mode from provider / base URL
|
||||
fb_api_mode = "chat_completions"
|
||||
fb_base_url = str(fb_client.base_url)
|
||||
if fb_provider == "openai-codex":
|
||||
fb_api_mode = "codex_responses"
|
||||
elif fb_provider == "anthropic":
|
||||
elif fb_provider == "anthropic" or fb_base_url.rstrip("/").lower().endswith("/anthropic"):
|
||||
fb_api_mode = "anthropic_messages"
|
||||
fb_base_url = str(fb_client.base_url)
|
||||
elif self._is_direct_openai_url(fb_base_url):
|
||||
fb_api_mode = "codex_responses"
|
||||
|
||||
old_model = self.model
|
||||
self.model = fb_model
|
||||
@@ -4265,6 +4394,10 @@ class AIAgent:
|
||||
except Exception as e:
|
||||
logger.debug("Session DB compression split failed: %s", e)
|
||||
|
||||
# Reset context pressure warnings — usage drops after compaction
|
||||
self._context_50_warned = False
|
||||
self._context_70_warned = False
|
||||
|
||||
return compressed, new_system_prompt
|
||||
|
||||
def _execute_tool_calls(self, assistant_message, messages: list, effective_task_id: str, api_call_count: int = 0) -> None:
|
||||
@@ -4276,14 +4409,19 @@ class AIAgent:
|
||||
"""
|
||||
tool_calls = assistant_message.tool_calls
|
||||
|
||||
if not _should_parallelize_tool_batch(tool_calls):
|
||||
return self._execute_tool_calls_sequential(
|
||||
# Allow _vprint during tool execution even with stream consumers
|
||||
self._executing_tools = True
|
||||
try:
|
||||
if not _should_parallelize_tool_batch(tool_calls):
|
||||
return self._execute_tool_calls_sequential(
|
||||
assistant_message, messages, effective_task_id, api_call_count
|
||||
)
|
||||
|
||||
return self._execute_tool_calls_concurrent(
|
||||
assistant_message, messages, effective_task_id, api_call_count
|
||||
)
|
||||
|
||||
return self._execute_tool_calls_concurrent(
|
||||
assistant_message, messages, effective_task_id, api_call_count
|
||||
)
|
||||
finally:
|
||||
self._executing_tools = False
|
||||
|
||||
def _invoke_tool(self, function_name: str, function_args: dict, effective_task_id: str) -> str:
|
||||
"""Invoke a single tool and return the result string. No display logic.
|
||||
@@ -4840,6 +4978,45 @@ class AIAgent:
|
||||
)
|
||||
return None
|
||||
|
||||
def _emit_context_pressure(self, compaction_progress: float, compressor) -> None:
|
||||
"""Notify the user that context is approaching the compaction threshold.
|
||||
|
||||
Args:
|
||||
compaction_progress: How close to compaction (0.0–1.0, where 1.0 = fires).
|
||||
compressor: The ContextCompressor instance (for threshold/context info).
|
||||
|
||||
Purely user-facing — does NOT modify the message stream.
|
||||
For CLI: prints a formatted line with a progress bar.
|
||||
For gateway: fires status_callback so the platform can send a chat message.
|
||||
"""
|
||||
from agent.display import format_context_pressure, format_context_pressure_gateway
|
||||
|
||||
threshold_pct = compressor.threshold_tokens / compressor.context_length if compressor.context_length else 0.5
|
||||
|
||||
# CLI output — always shown (these are user-facing status notifications,
|
||||
# not verbose debug output, so they bypass quiet_mode).
|
||||
# Gateway users also get the callback below.
|
||||
if self.platform in (None, "cli"):
|
||||
line = format_context_pressure(
|
||||
compaction_progress=compaction_progress,
|
||||
threshold_tokens=compressor.threshold_tokens,
|
||||
threshold_percent=threshold_pct,
|
||||
compression_enabled=self.compression_enabled,
|
||||
)
|
||||
self._safe_print(line)
|
||||
|
||||
# Gateway / external consumers
|
||||
if self.status_callback:
|
||||
try:
|
||||
msg = format_context_pressure_gateway(
|
||||
compaction_progress=compaction_progress,
|
||||
threshold_percent=threshold_pct,
|
||||
compression_enabled=self.compression_enabled,
|
||||
)
|
||||
self.status_callback("context_pressure", msg)
|
||||
except Exception:
|
||||
logger.debug("status_callback error in context pressure", exc_info=True)
|
||||
|
||||
def _handle_max_iterations(self, messages: list, api_call_count: int) -> str:
|
||||
"""Request a summary when max iterations are reached. Returns the final response text."""
|
||||
print(f"⚠️ Reached maximum iterations ({self.max_iterations}). Requesting summary...")
|
||||
@@ -5340,14 +5517,17 @@ class AIAgent:
|
||||
self._vprint(f"\n{self.log_prefix}🔄 Making API call #{api_call_count}/{self.max_iterations}...")
|
||||
self._vprint(f"{self.log_prefix} 📊 Request size: {len(api_messages)} messages, ~{approx_tokens:,} tokens (~{total_chars:,} chars)")
|
||||
self._vprint(f"{self.log_prefix} 🔧 Available tools: {len(self.tools) if self.tools else 0}")
|
||||
elif not self._has_stream_consumers():
|
||||
# Animated thinking spinner in quiet mode (skip during streaming)
|
||||
else:
|
||||
# Animated thinking spinner in quiet mode
|
||||
face = random.choice(KawaiiSpinner.KAWAII_THINKING)
|
||||
verb = random.choice(KawaiiSpinner.THINKING_VERBS)
|
||||
if self.thinking_callback:
|
||||
# CLI TUI mode: use prompt_toolkit widget instead of raw spinner
|
||||
# (works in both streaming and non-streaming modes)
|
||||
self.thinking_callback(f"{face} {verb}...")
|
||||
else:
|
||||
elif not self._has_stream_consumers():
|
||||
# Raw KawaiiSpinner only when no streaming consumers
|
||||
# (would conflict with streamed token output)
|
||||
spinner_type = random.choice(['brain', 'sparkle', 'pulse', 'moon', 'star'])
|
||||
thinking_spinner = KawaiiSpinner(f"{face} {verb}...", spinner_type=spinner_type)
|
||||
thinking_spinner.start()
|
||||
@@ -6196,15 +6376,24 @@ class AIAgent:
|
||||
interim_msg = self._build_assistant_message(assistant_message, finish_reason)
|
||||
interim_has_content = bool((interim_msg.get("content") or "").strip())
|
||||
interim_has_reasoning = bool(interim_msg.get("reasoning", "").strip()) if isinstance(interim_msg.get("reasoning"), str) else False
|
||||
interim_has_codex_reasoning = bool(interim_msg.get("codex_reasoning_items"))
|
||||
|
||||
if interim_has_content or interim_has_reasoning:
|
||||
if interim_has_content or interim_has_reasoning or interim_has_codex_reasoning:
|
||||
last_msg = messages[-1] if messages else None
|
||||
# Duplicate detection: two consecutive incomplete assistant
|
||||
# messages with identical content AND reasoning are collapsed.
|
||||
# For reasoning-only messages (codex_reasoning_items differ but
|
||||
# visible content/reasoning are both empty), we also compare
|
||||
# the encrypted items to avoid silently dropping new state.
|
||||
last_codex_items = last_msg.get("codex_reasoning_items") if isinstance(last_msg, dict) else None
|
||||
interim_codex_items = interim_msg.get("codex_reasoning_items")
|
||||
duplicate_interim = (
|
||||
isinstance(last_msg, dict)
|
||||
and last_msg.get("role") == "assistant"
|
||||
and last_msg.get("finish_reason") == "incomplete"
|
||||
and (last_msg.get("content") or "") == (interim_msg.get("content") or "")
|
||||
and (last_msg.get("reasoning") or "") == (interim_msg.get("reasoning") or "")
|
||||
and last_codex_items == interim_codex_items
|
||||
)
|
||||
if not duplicate_interim:
|
||||
messages.append(interim_msg)
|
||||
@@ -6403,6 +6592,23 @@ class AIAgent:
|
||||
+ _compressor.last_completion_tokens
|
||||
+ _new_chars // 3 # conservative: JSON-heavy tool results ≈ 3 chars/token
|
||||
)
|
||||
|
||||
# ── Context pressure warnings (user-facing only) ──────────
|
||||
# Notify the user (NOT the LLM) as context approaches the
|
||||
# compaction threshold. Thresholds are relative to where
|
||||
# compaction fires, not the raw context window.
|
||||
# Does not inject into messages — just prints to CLI output
|
||||
# and fires status_callback for gateway platforms.
|
||||
if _compressor.threshold_tokens > 0:
|
||||
_compaction_progress = _estimated_next_prompt / _compressor.threshold_tokens
|
||||
if _compaction_progress >= 0.85 and not self._context_70_warned:
|
||||
self._context_70_warned = True
|
||||
self._context_50_warned = True # skip first tier if we jumped past it
|
||||
self._emit_context_pressure(_compaction_progress, _compressor)
|
||||
elif _compaction_progress >= 0.60 and not self._context_50_warned:
|
||||
self._context_50_warned = True
|
||||
self._emit_context_pressure(_compaction_progress, _compressor)
|
||||
|
||||
if self.compression_enabled and _compressor.should_compress(_estimated_next_prompt):
|
||||
messages, active_system_prompt = self._compress_context(
|
||||
messages, system_message,
|
||||
@@ -6488,7 +6694,21 @@ class AIAgent:
|
||||
self._response_was_previewed = True
|
||||
break
|
||||
|
||||
# No fallback -- append the empty message as-is
|
||||
# No fallback -- if reasoning_text exists, the model put its
|
||||
# entire response inside <think> tags; use that as the content.
|
||||
if reasoning_text:
|
||||
self._vprint(f"{self.log_prefix}Using reasoning as response content (model wrapped entire response in think tags).", force=True)
|
||||
final_response = reasoning_text
|
||||
empty_msg = {
|
||||
"role": "assistant",
|
||||
"content": final_response,
|
||||
"reasoning": reasoning_text,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
messages.append(empty_msg)
|
||||
break
|
||||
|
||||
# Truly empty -- no reasoning and no content
|
||||
empty_msg = {
|
||||
"role": "assistant",
|
||||
"content": final_response,
|
||||
@@ -6496,10 +6716,10 @@ class AIAgent:
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
messages.append(empty_msg)
|
||||
|
||||
|
||||
self._cleanup_task_resources(effective_task_id)
|
||||
self._persist_session(messages, conversation_history)
|
||||
|
||||
|
||||
return {
|
||||
"final_response": final_response or None,
|
||||
"messages": messages,
|
||||
|
||||
@@ -18,12 +18,13 @@
|
||||
* node bridge.js --port 3000 --session ~/.hermes/whatsapp/session
|
||||
*/
|
||||
|
||||
import { makeWASocket, useMultiFileAuthState, DisconnectReason, fetchLatestBaileysVersion } from '@whiskeysockets/baileys';
|
||||
import { makeWASocket, useMultiFileAuthState, DisconnectReason, fetchLatestBaileysVersion, downloadMediaMessage } from '@whiskeysockets/baileys';
|
||||
import express from 'express';
|
||||
import { Boom } from '@hapi/boom';
|
||||
import pino from 'pino';
|
||||
import path from 'path';
|
||||
import { mkdirSync, readFileSync, existsSync } from 'fs';
|
||||
import { mkdirSync, readFileSync, writeFileSync, existsSync, readdirSync } from 'fs';
|
||||
import { randomBytes } from 'crypto';
|
||||
import qrcode from 'qrcode-terminal';
|
||||
|
||||
// Parse CLI args
|
||||
@@ -41,6 +42,7 @@ const WHATSAPP_DEBUG =
|
||||
|
||||
const PORT = parseInt(getArg('port', '3000'), 10);
|
||||
const SESSION_DIR = getArg('session', path.join(process.env.HOME || '~', '.hermes', 'whatsapp', 'session'));
|
||||
const IMAGE_CACHE_DIR = path.join(process.env.HOME || '~', '.hermes', 'image_cache');
|
||||
const PAIR_ONLY = args.includes('--pair-only');
|
||||
const WHATSAPP_MODE = getArg('mode', process.env.WHATSAPP_MODE || 'self-chat'); // "bot" or "self-chat"
|
||||
const ALLOWED_USERS = (process.env.WHATSAPP_ALLOWED_USERS || '').split(',').map(s => s.trim()).filter(Boolean);
|
||||
@@ -55,6 +57,22 @@ function formatOutgoingMessage(message) {
|
||||
|
||||
mkdirSync(SESSION_DIR, { recursive: true });
|
||||
|
||||
// Build LID → phone reverse map from session files (lid-mapping-{phone}.json)
|
||||
function buildLidMap() {
|
||||
const map = {};
|
||||
try {
|
||||
for (const f of readdirSync(SESSION_DIR)) {
|
||||
const m = f.match(/^lid-mapping-(\d+)\.json$/);
|
||||
if (!m) continue;
|
||||
const phone = m[1];
|
||||
const lid = JSON.parse(readFileSync(path.join(SESSION_DIR, f), 'utf8'));
|
||||
if (lid) map[String(lid)] = phone;
|
||||
}
|
||||
} catch {}
|
||||
return map;
|
||||
}
|
||||
let lidToPhone = buildLidMap();
|
||||
|
||||
const logger = pino({ level: 'warn' });
|
||||
|
||||
// Message queue for polling
|
||||
@@ -80,9 +98,16 @@ async function startSocket() {
|
||||
browser: ['Hermes Agent', 'Chrome', '120.0'],
|
||||
syncFullHistory: false,
|
||||
markOnlineOnConnect: false,
|
||||
// Required for Baileys 7.x: without this, incoming messages that need
|
||||
// E2EE session re-establishment are silently dropped (msg.message === null)
|
||||
getMessage: async (key) => {
|
||||
// We don't maintain a message store, so return a placeholder.
|
||||
// This is enough for Baileys to complete the retry handshake.
|
||||
return { conversation: '' };
|
||||
},
|
||||
});
|
||||
|
||||
sock.ev.on('creds.update', saveCreds);
|
||||
sock.ev.on('creds.update', () => { saveCreds(); lidToPhone = buildLidMap(); });
|
||||
|
||||
sock.ev.on('connection.update', (update) => {
|
||||
const { connection, lastDisconnect, qr } = update;
|
||||
@@ -120,7 +145,7 @@ async function startSocket() {
|
||||
}
|
||||
});
|
||||
|
||||
sock.ev.on('messages.upsert', ({ messages, type }) => {
|
||||
sock.ev.on('messages.upsert', async ({ messages, type }) => {
|
||||
// In self-chat mode, your own messages commonly arrive as 'append' rather
|
||||
// than 'notify'. Accept both and filter agent echo-backs below.
|
||||
if (type !== 'notify' && type !== 'append') return;
|
||||
@@ -163,9 +188,10 @@ async function startSocket() {
|
||||
if (!isSelfChat) continue;
|
||||
}
|
||||
|
||||
// Check allowlist for messages from others
|
||||
if (!msg.key.fromMe && ALLOWED_USERS.length > 0 && !ALLOWED_USERS.includes(senderNumber)) {
|
||||
continue;
|
||||
// Check allowlist for messages from others (resolve LID → phone if needed)
|
||||
if (!msg.key.fromMe && ALLOWED_USERS.length > 0) {
|
||||
const resolvedNumber = lidToPhone[senderNumber] || senderNumber;
|
||||
if (!ALLOWED_USERS.includes(resolvedNumber)) continue;
|
||||
}
|
||||
|
||||
// Extract message body
|
||||
@@ -182,6 +208,18 @@ async function startSocket() {
|
||||
body = msg.message.imageMessage.caption || '';
|
||||
hasMedia = true;
|
||||
mediaType = 'image';
|
||||
try {
|
||||
const buf = await downloadMediaMessage(msg, 'buffer', {}, { logger, reuploadRequest: sock.updateMediaMessage });
|
||||
const mime = msg.message.imageMessage.mimetype || 'image/jpeg';
|
||||
const extMap = { 'image/jpeg': '.jpg', 'image/png': '.png', 'image/webp': '.webp', 'image/gif': '.gif' };
|
||||
const ext = extMap[mime] || '.jpg';
|
||||
mkdirSync(IMAGE_CACHE_DIR, { recursive: true });
|
||||
const filePath = path.join(IMAGE_CACHE_DIR, `img_${randomBytes(6).toString('hex')}${ext}`);
|
||||
writeFileSync(filePath, buf);
|
||||
mediaUrls.push(filePath);
|
||||
} catch (err) {
|
||||
console.error('[bridge] Failed to download image:', err.message);
|
||||
}
|
||||
} else if (msg.message.videoMessage) {
|
||||
body = msg.message.videoMessage.caption || '';
|
||||
hasMedia = true;
|
||||
@@ -195,6 +233,11 @@ async function startSocket() {
|
||||
mediaType = 'document';
|
||||
}
|
||||
|
||||
// For media without caption, use a placeholder so the API message is never empty
|
||||
if (hasMedia && !body) {
|
||||
body = `[${mediaType} received]`;
|
||||
}
|
||||
|
||||
// Ignore Hermes' own reply messages in self-chat mode to avoid loops.
|
||||
if (msg.key.fromMe && ((REPLY_PREFIX && body.startsWith(REPLY_PREFIX)) || recentlySentIds.has(msg.key.id))) {
|
||||
if (WHATSAPP_DEBUG) {
|
||||
@@ -433,7 +476,7 @@ if (PAIR_ONLY) {
|
||||
console.log();
|
||||
startSocket();
|
||||
} else {
|
||||
app.listen(PORT, () => {
|
||||
app.listen(PORT, '127.0.0.1', () => {
|
||||
console.log(`🌉 WhatsApp bridge listening on port ${PORT} (mode: ${WHATSAPP_MODE})`);
|
||||
console.log(`📁 Session stored in: ${SESSION_DIR}`);
|
||||
if (ALLOWED_USERS.length > 0) {
|
||||
|
||||
@@ -16,7 +16,7 @@ Use this skill when a user asks about configuring Hermes, enabling features, set
|
||||
- API keys: `~/.hermes/.env`
|
||||
- Skills: `~/.hermes/skills/`
|
||||
- Hermes install: `~/.hermes/hermes-agent/`
|
||||
- Venv: `~/.hermes/hermes-agent/.venv/` (or `venv/`)
|
||||
- Venv: `~/.hermes/hermes-agent/venv/`
|
||||
|
||||
## CLI Overview
|
||||
|
||||
@@ -98,7 +98,7 @@ The interactive setup wizard walks through:
|
||||
Run it from terminal:
|
||||
```bash
|
||||
cd ~/.hermes/hermes-agent
|
||||
source .venv/bin/activate
|
||||
source venv/bin/activate
|
||||
python -m hermes_cli.main setup
|
||||
```
|
||||
|
||||
@@ -140,7 +140,7 @@ Voice messages from Telegram/Discord/WhatsApp/Slack/Signal are auto-transcribed
|
||||
|
||||
```bash
|
||||
cd ~/.hermes/hermes-agent
|
||||
source .venv/bin/activate # or: source venv/bin/activate
|
||||
source venv/bin/activate
|
||||
pip install faster-whisper
|
||||
```
|
||||
|
||||
@@ -189,7 +189,7 @@ Hermes can reply with voice when users send voice messages.
|
||||
|
||||
```bash
|
||||
cd ~/.hermes/hermes-agent
|
||||
source .venv/bin/activate
|
||||
source venv/bin/activate
|
||||
python -m hermes_cli.main tools
|
||||
```
|
||||
|
||||
@@ -217,7 +217,7 @@ Use `/reset` in the chat to start a fresh session with the new toolset. Tool cha
|
||||
Some tools need extra packages:
|
||||
|
||||
```bash
|
||||
cd ~/.hermes/hermes-agent && source .venv/bin/activate
|
||||
cd ~/.hermes/hermes-agent && source venv/bin/activate
|
||||
|
||||
pip install faster-whisper # Local STT (voice transcription)
|
||||
pip install browserbase # Browser automation
|
||||
|
||||
@@ -12,7 +12,7 @@ training server.
|
||||
|
||||
```bash
|
||||
cd ~/.hermes/hermes-agent
|
||||
source .venv/bin/activate
|
||||
source venv/bin/activate
|
||||
|
||||
python environments/your_env.py process \
|
||||
--env.total_steps 1 \
|
||||
|
||||
+172
-1
@@ -1,15 +1,21 @@
|
||||
"""Tests for acp_adapter.session — SessionManager and SessionState."""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from acp_adapter.session import SessionManager, SessionState
|
||||
from hermes_state import SessionDB
|
||||
|
||||
|
||||
def _mock_agent():
|
||||
return MagicMock(name="MockAIAgent")
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def manager():
|
||||
"""SessionManager with a mock agent factory (avoids needing API keys)."""
|
||||
return SessionManager(agent_factory=lambda: MagicMock(name="MockAIAgent"))
|
||||
return SessionManager(agent_factory=_mock_agent)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -110,3 +116,168 @@ class TestListAndCleanup:
|
||||
assert manager.get_session(state.session_id) is None
|
||||
# Removing again returns False
|
||||
assert manager.remove_session(state.session_id) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# persistence — sessions survive process restarts (via SessionDB)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPersistence:
|
||||
"""Verify that sessions are persisted to SessionDB and can be restored."""
|
||||
|
||||
def test_create_session_writes_to_db(self, manager):
|
||||
state = manager.create_session(cwd="/project")
|
||||
db = manager._get_db()
|
||||
assert db is not None
|
||||
row = db.get_session(state.session_id)
|
||||
assert row is not None
|
||||
assert row["source"] == "acp"
|
||||
# cwd stored in model_config JSON
|
||||
mc = json.loads(row["model_config"])
|
||||
assert mc["cwd"] == "/project"
|
||||
|
||||
def test_get_session_restores_from_db(self, manager):
|
||||
"""Simulate process restart: create session, drop from memory, get again."""
|
||||
state = manager.create_session(cwd="/work")
|
||||
state.history.append({"role": "user", "content": "hello"})
|
||||
state.history.append({"role": "assistant", "content": "hi there"})
|
||||
manager.save_session(state.session_id)
|
||||
|
||||
sid = state.session_id
|
||||
|
||||
# Drop from in-memory store (simulates process restart).
|
||||
with manager._lock:
|
||||
del manager._sessions[sid]
|
||||
|
||||
# get_session should transparently restore from DB.
|
||||
restored = manager.get_session(sid)
|
||||
assert restored is not None
|
||||
assert restored.session_id == sid
|
||||
assert restored.cwd == "/work"
|
||||
assert len(restored.history) == 2
|
||||
assert restored.history[0]["content"] == "hello"
|
||||
assert restored.history[1]["content"] == "hi there"
|
||||
# Agent should have been recreated.
|
||||
assert restored.agent is not None
|
||||
|
||||
def test_save_session_updates_db(self, manager):
|
||||
state = manager.create_session()
|
||||
state.history.append({"role": "user", "content": "test"})
|
||||
manager.save_session(state.session_id)
|
||||
|
||||
db = manager._get_db()
|
||||
messages = db.get_messages_as_conversation(state.session_id)
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"] == "test"
|
||||
|
||||
def test_remove_session_deletes_from_db(self, manager):
|
||||
state = manager.create_session()
|
||||
db = manager._get_db()
|
||||
assert db.get_session(state.session_id) is not None
|
||||
manager.remove_session(state.session_id)
|
||||
assert db.get_session(state.session_id) is None
|
||||
|
||||
def test_cleanup_removes_all_from_db(self, manager):
|
||||
s1 = manager.create_session()
|
||||
s2 = manager.create_session()
|
||||
db = manager._get_db()
|
||||
assert db.get_session(s1.session_id) is not None
|
||||
assert db.get_session(s2.session_id) is not None
|
||||
manager.cleanup()
|
||||
assert db.get_session(s1.session_id) is None
|
||||
assert db.get_session(s2.session_id) is None
|
||||
|
||||
def test_list_sessions_includes_db_only(self, manager):
|
||||
"""Sessions only in DB (not in memory) appear in list_sessions."""
|
||||
state = manager.create_session(cwd="/db-only")
|
||||
sid = state.session_id
|
||||
|
||||
# Drop from memory.
|
||||
with manager._lock:
|
||||
del manager._sessions[sid]
|
||||
|
||||
listing = manager.list_sessions()
|
||||
ids = {s["session_id"] for s in listing}
|
||||
assert sid in ids
|
||||
|
||||
def test_fork_restores_source_from_db(self, manager):
|
||||
"""Forking a session that is only in DB should work."""
|
||||
original = manager.create_session()
|
||||
original.history.append({"role": "user", "content": "context"})
|
||||
manager.save_session(original.session_id)
|
||||
|
||||
# Drop original from memory.
|
||||
with manager._lock:
|
||||
del manager._sessions[original.session_id]
|
||||
|
||||
forked = manager.fork_session(original.session_id, cwd="/fork")
|
||||
assert forked is not None
|
||||
assert len(forked.history) == 1
|
||||
assert forked.history[0]["content"] == "context"
|
||||
assert forked.session_id != original.session_id
|
||||
|
||||
def test_update_cwd_restores_from_db(self, manager):
|
||||
state = manager.create_session(cwd="/old")
|
||||
sid = state.session_id
|
||||
|
||||
with manager._lock:
|
||||
del manager._sessions[sid]
|
||||
|
||||
updated = manager.update_cwd(sid, "/new")
|
||||
assert updated is not None
|
||||
assert updated.cwd == "/new"
|
||||
|
||||
# Should also be persisted in DB.
|
||||
db = manager._get_db()
|
||||
row = db.get_session(sid)
|
||||
mc = json.loads(row["model_config"])
|
||||
assert mc["cwd"] == "/new"
|
||||
|
||||
def test_only_restores_acp_sessions(self, manager):
|
||||
"""get_session should not restore non-ACP sessions from DB."""
|
||||
db = manager._get_db()
|
||||
# Manually create a CLI session in the DB.
|
||||
db.create_session(session_id="cli-session-123", source="cli", model="test")
|
||||
# Should not be found via ACP SessionManager.
|
||||
assert manager.get_session("cli-session-123") is None
|
||||
|
||||
def test_sessions_searchable_via_fts(self, manager):
|
||||
"""ACP sessions stored in SessionDB are searchable via FTS5."""
|
||||
state = manager.create_session()
|
||||
state.history.append({"role": "user", "content": "how do I configure nginx"})
|
||||
state.history.append({"role": "assistant", "content": "Here is the nginx config..."})
|
||||
manager.save_session(state.session_id)
|
||||
|
||||
db = manager._get_db()
|
||||
results = db.search_messages("nginx")
|
||||
assert len(results) > 0
|
||||
session_ids = {r["session_id"] for r in results}
|
||||
assert state.session_id in session_ids
|
||||
|
||||
def test_tool_calls_persisted(self, manager):
|
||||
"""Messages with tool_calls should round-trip through the DB."""
|
||||
state = manager.create_session()
|
||||
state.history.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{"id": "tc_1", "type": "function",
|
||||
"function": {"name": "terminal", "arguments": "{}"}}],
|
||||
})
|
||||
state.history.append({
|
||||
"role": "tool",
|
||||
"content": "output here",
|
||||
"tool_call_id": "tc_1",
|
||||
"name": "terminal",
|
||||
})
|
||||
manager.save_session(state.session_id)
|
||||
|
||||
# Drop from memory, restore from DB.
|
||||
with manager._lock:
|
||||
del manager._sessions[state.session_id]
|
||||
|
||||
restored = manager.get_session(state.session_id)
|
||||
assert restored is not None
|
||||
assert len(restored.history) == 2
|
||||
assert restored.history[0].get("tool_calls") is not None
|
||||
assert restored.history[1].get("tool_call_id") == "tc_1"
|
||||
|
||||
@@ -22,6 +22,7 @@ from unittest.mock import patch, MagicMock
|
||||
from agent.model_metadata import (
|
||||
CONTEXT_PROBE_TIERS,
|
||||
DEFAULT_CONTEXT_LENGTHS,
|
||||
_strip_provider_prefix,
|
||||
estimate_tokens_rough,
|
||||
estimate_messages_tokens_rough,
|
||||
get_model_context_length,
|
||||
@@ -105,9 +106,14 @@ class TestEstimateMessagesTokensRough:
|
||||
# =========================================================================
|
||||
|
||||
class TestDefaultContextLengths:
|
||||
def test_claude_models_200k(self):
|
||||
def test_claude_models_context_lengths(self):
|
||||
for key, value in DEFAULT_CONTEXT_LENGTHS.items():
|
||||
if "claude" in key:
|
||||
if "claude" not in key:
|
||||
continue
|
||||
# Claude 4.6 models have 1M context
|
||||
if "4.6" in key or "4-6" in key:
|
||||
assert value == 1000000, f"{key} should be 1000000"
|
||||
else:
|
||||
assert value == 200000, f"{key} should be 200000"
|
||||
|
||||
def test_gpt4_models_128k_or_1m(self):
|
||||
@@ -218,6 +224,122 @@ class TestGetModelContextLength:
|
||||
|
||||
assert result == CONTEXT_PROBE_TIERS[0]
|
||||
|
||||
@patch("agent.model_metadata.fetch_model_metadata")
|
||||
@patch("agent.model_metadata.fetch_endpoint_model_metadata")
|
||||
def test_custom_endpoint_single_model_fallback(self, mock_endpoint_fetch, mock_fetch):
|
||||
"""Single-model servers: use the only model even if name doesn't match."""
|
||||
mock_fetch.return_value = {}
|
||||
mock_endpoint_fetch.return_value = {
|
||||
"Qwen3.5-9B-Q4_K_M.gguf": {"context_length": 131072}
|
||||
}
|
||||
|
||||
result = get_model_context_length(
|
||||
"qwen3.5:9b",
|
||||
base_url="http://myserver.example.com:8080/v1",
|
||||
api_key="test-key",
|
||||
)
|
||||
|
||||
assert result == 131072
|
||||
|
||||
@patch("agent.model_metadata.fetch_model_metadata")
|
||||
@patch("agent.model_metadata.fetch_endpoint_model_metadata")
|
||||
def test_custom_endpoint_fuzzy_substring_match(self, mock_endpoint_fetch, mock_fetch):
|
||||
"""Fuzzy match: configured model name is substring of endpoint model."""
|
||||
mock_fetch.return_value = {}
|
||||
mock_endpoint_fetch.return_value = {
|
||||
"org/llama-3.3-70b-instruct-fp8": {"context_length": 131072},
|
||||
"org/qwen-2.5-72b": {"context_length": 32768},
|
||||
}
|
||||
|
||||
result = get_model_context_length(
|
||||
"llama-3.3-70b-instruct",
|
||||
base_url="http://myserver.example.com:8080/v1",
|
||||
api_key="test-key",
|
||||
)
|
||||
|
||||
assert result == 131072
|
||||
|
||||
@patch("agent.model_metadata.fetch_model_metadata")
|
||||
def test_config_context_length_overrides_all(self, mock_fetch):
|
||||
"""Explicit config_context_length takes priority over everything."""
|
||||
mock_fetch.return_value = {
|
||||
"test/model": {"context_length": 200000}
|
||||
}
|
||||
|
||||
result = get_model_context_length(
|
||||
"test/model",
|
||||
config_context_length=65536,
|
||||
)
|
||||
|
||||
assert result == 65536
|
||||
|
||||
@patch("agent.model_metadata.fetch_model_metadata")
|
||||
def test_config_context_length_zero_is_ignored(self, mock_fetch):
|
||||
"""config_context_length=0 should be treated as unset."""
|
||||
mock_fetch.return_value = {}
|
||||
|
||||
result = get_model_context_length(
|
||||
"anthropic/claude-sonnet-4",
|
||||
config_context_length=0,
|
||||
)
|
||||
|
||||
assert result == 200000
|
||||
|
||||
@patch("agent.model_metadata.fetch_model_metadata")
|
||||
def test_config_context_length_none_is_ignored(self, mock_fetch):
|
||||
"""config_context_length=None should be treated as unset."""
|
||||
mock_fetch.return_value = {}
|
||||
|
||||
result = get_model_context_length(
|
||||
"anthropic/claude-sonnet-4",
|
||||
config_context_length=None,
|
||||
)
|
||||
|
||||
assert result == 200000
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# _strip_provider_prefix — Ollama model:tag vs provider:model
|
||||
# =========================================================================
|
||||
|
||||
class TestStripProviderPrefix:
|
||||
def test_known_provider_prefix_is_stripped(self):
|
||||
assert _strip_provider_prefix("local:my-model") == "my-model"
|
||||
assert _strip_provider_prefix("openrouter:anthropic/claude-sonnet-4") == "anthropic/claude-sonnet-4"
|
||||
assert _strip_provider_prefix("anthropic:claude-sonnet-4") == "claude-sonnet-4"
|
||||
|
||||
def test_ollama_model_tag_preserved(self):
|
||||
"""Ollama model:tag format must NOT be stripped."""
|
||||
assert _strip_provider_prefix("qwen3.5:27b") == "qwen3.5:27b"
|
||||
assert _strip_provider_prefix("llama3.3:70b") == "llama3.3:70b"
|
||||
assert _strip_provider_prefix("gemma2:9b") == "gemma2:9b"
|
||||
assert _strip_provider_prefix("codellama:13b-instruct-q4_0") == "codellama:13b-instruct-q4_0"
|
||||
|
||||
def test_http_urls_preserved(self):
|
||||
assert _strip_provider_prefix("http://example.com") == "http://example.com"
|
||||
assert _strip_provider_prefix("https://example.com") == "https://example.com"
|
||||
|
||||
def test_no_colon_returns_unchanged(self):
|
||||
assert _strip_provider_prefix("gpt-4o") == "gpt-4o"
|
||||
assert _strip_provider_prefix("anthropic/claude-sonnet-4") == "anthropic/claude-sonnet-4"
|
||||
|
||||
@patch("agent.model_metadata.fetch_model_metadata")
|
||||
def test_ollama_model_tag_not_mangled_in_context_lookup(self, mock_fetch):
|
||||
"""Ensure 'qwen3.5:27b' is NOT reduced to '27b' during context length lookup.
|
||||
|
||||
We mock a custom endpoint that knows 'qwen3.5:27b' — the full name
|
||||
must reach the endpoint metadata lookup intact.
|
||||
"""
|
||||
mock_fetch.return_value = {}
|
||||
with patch("agent.model_metadata.fetch_endpoint_model_metadata") as mock_ep, \
|
||||
patch("agent.model_metadata._is_custom_endpoint", return_value=True):
|
||||
mock_ep.return_value = {"qwen3.5:27b": {"context_length": 32768}}
|
||||
result = get_model_context_length(
|
||||
"qwen3.5:27b",
|
||||
base_url="http://localhost:11434/v1",
|
||||
)
|
||||
assert result == 32768
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# fetch_model_metadata — caching, TTL, slugs, failures
|
||||
@@ -350,35 +472,35 @@ class TestContextProbeTiers:
|
||||
for i in range(len(CONTEXT_PROBE_TIERS) - 1):
|
||||
assert CONTEXT_PROBE_TIERS[i] > CONTEXT_PROBE_TIERS[i + 1]
|
||||
|
||||
def test_first_tier_is_2m(self):
|
||||
assert CONTEXT_PROBE_TIERS[0] == 2_000_000
|
||||
def test_first_tier_is_128k(self):
|
||||
assert CONTEXT_PROBE_TIERS[0] == 128_000
|
||||
|
||||
def test_last_tier_is_32k(self):
|
||||
assert CONTEXT_PROBE_TIERS[-1] == 32_000
|
||||
def test_last_tier_is_8k(self):
|
||||
assert CONTEXT_PROBE_TIERS[-1] == 8_000
|
||||
|
||||
|
||||
class TestGetNextProbeTier:
|
||||
def test_from_2m(self):
|
||||
assert get_next_probe_tier(2_000_000) == 1_000_000
|
||||
|
||||
def test_from_1m(self):
|
||||
assert get_next_probe_tier(1_000_000) == 512_000
|
||||
|
||||
def test_from_128k(self):
|
||||
assert get_next_probe_tier(128_000) == 64_000
|
||||
|
||||
def test_from_32k_returns_none(self):
|
||||
assert get_next_probe_tier(32_000) is None
|
||||
def test_from_64k(self):
|
||||
assert get_next_probe_tier(64_000) == 32_000
|
||||
|
||||
def test_from_32k(self):
|
||||
assert get_next_probe_tier(32_000) == 16_000
|
||||
|
||||
def test_from_8k_returns_none(self):
|
||||
assert get_next_probe_tier(8_000) is None
|
||||
|
||||
def test_from_below_min_returns_none(self):
|
||||
assert get_next_probe_tier(16_000) is None
|
||||
assert get_next_probe_tier(4_000) is None
|
||||
|
||||
def test_from_arbitrary_value(self):
|
||||
assert get_next_probe_tier(300_000) == 200_000
|
||||
assert get_next_probe_tier(100_000) == 64_000
|
||||
|
||||
def test_above_max_tier(self):
|
||||
"""Value above 2M should return 2M."""
|
||||
assert get_next_probe_tier(5_000_000) == 2_000_000
|
||||
"""Value above 128K should return 128K."""
|
||||
assert get_next_probe_tier(500_000) == 128_000
|
||||
|
||||
def test_zero_returns_none(self):
|
||||
assert get_next_probe_tier(0) is None
|
||||
|
||||
@@ -0,0 +1,197 @@
|
||||
"""Tests for agent.models_dev — models.dev registry integration."""
|
||||
import json
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
from agent.models_dev import (
|
||||
PROVIDER_TO_MODELS_DEV,
|
||||
_extract_context,
|
||||
fetch_models_dev,
|
||||
lookup_models_dev_context,
|
||||
)
|
||||
|
||||
|
||||
SAMPLE_REGISTRY = {
|
||||
"anthropic": {
|
||||
"id": "anthropic",
|
||||
"name": "Anthropic",
|
||||
"models": {
|
||||
"claude-opus-4-6": {
|
||||
"id": "claude-opus-4-6",
|
||||
"limit": {"context": 1000000, "output": 128000},
|
||||
},
|
||||
"claude-sonnet-4-6": {
|
||||
"id": "claude-sonnet-4-6",
|
||||
"limit": {"context": 1000000, "output": 64000},
|
||||
},
|
||||
"claude-sonnet-4-0": {
|
||||
"id": "claude-sonnet-4-0",
|
||||
"limit": {"context": 200000, "output": 64000},
|
||||
},
|
||||
},
|
||||
},
|
||||
"github-copilot": {
|
||||
"id": "github-copilot",
|
||||
"name": "GitHub Copilot",
|
||||
"models": {
|
||||
"claude-opus-4.6": {
|
||||
"id": "claude-opus-4.6",
|
||||
"limit": {"context": 128000, "output": 32000},
|
||||
},
|
||||
},
|
||||
},
|
||||
"kilo": {
|
||||
"id": "kilo",
|
||||
"name": "Kilo Gateway",
|
||||
"models": {
|
||||
"anthropic/claude-sonnet-4.6": {
|
||||
"id": "anthropic/claude-sonnet-4.6",
|
||||
"limit": {"context": 1000000, "output": 128000},
|
||||
},
|
||||
},
|
||||
},
|
||||
"deepseek": {
|
||||
"id": "deepseek",
|
||||
"name": "DeepSeek",
|
||||
"models": {
|
||||
"deepseek-chat": {
|
||||
"id": "deepseek-chat",
|
||||
"limit": {"context": 128000, "output": 8192},
|
||||
},
|
||||
},
|
||||
},
|
||||
"audio-only": {
|
||||
"id": "audio-only",
|
||||
"models": {
|
||||
"tts-model": {
|
||||
"id": "tts-model",
|
||||
"limit": {"context": 0, "output": 0},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestProviderMapping:
|
||||
def test_all_mapped_providers_are_strings(self):
|
||||
for hermes_id, mdev_id in PROVIDER_TO_MODELS_DEV.items():
|
||||
assert isinstance(hermes_id, str)
|
||||
assert isinstance(mdev_id, str)
|
||||
|
||||
def test_known_providers_mapped(self):
|
||||
assert PROVIDER_TO_MODELS_DEV["anthropic"] == "anthropic"
|
||||
assert PROVIDER_TO_MODELS_DEV["copilot"] == "github-copilot"
|
||||
assert PROVIDER_TO_MODELS_DEV["kilocode"] == "kilo"
|
||||
assert PROVIDER_TO_MODELS_DEV["ai-gateway"] == "vercel"
|
||||
|
||||
def test_unmapped_provider_not_in_dict(self):
|
||||
assert "nous" not in PROVIDER_TO_MODELS_DEV
|
||||
assert "openai-codex" not in PROVIDER_TO_MODELS_DEV
|
||||
|
||||
|
||||
class TestExtractContext:
|
||||
def test_valid_entry(self):
|
||||
assert _extract_context({"limit": {"context": 128000}}) == 128000
|
||||
|
||||
def test_zero_context_returns_none(self):
|
||||
assert _extract_context({"limit": {"context": 0}}) is None
|
||||
|
||||
def test_missing_limit_returns_none(self):
|
||||
assert _extract_context({"id": "test"}) is None
|
||||
|
||||
def test_missing_context_returns_none(self):
|
||||
assert _extract_context({"limit": {"output": 8192}}) is None
|
||||
|
||||
def test_non_dict_returns_none(self):
|
||||
assert _extract_context("not a dict") is None
|
||||
|
||||
def test_float_context_coerced_to_int(self):
|
||||
assert _extract_context({"limit": {"context": 131072.0}}) == 131072
|
||||
|
||||
|
||||
class TestLookupModelsDevContext:
|
||||
@patch("agent.models_dev.fetch_models_dev")
|
||||
def test_exact_match(self, mock_fetch):
|
||||
mock_fetch.return_value = SAMPLE_REGISTRY
|
||||
assert lookup_models_dev_context("anthropic", "claude-opus-4-6") == 1000000
|
||||
|
||||
@patch("agent.models_dev.fetch_models_dev")
|
||||
def test_case_insensitive_match(self, mock_fetch):
|
||||
mock_fetch.return_value = SAMPLE_REGISTRY
|
||||
assert lookup_models_dev_context("anthropic", "Claude-Opus-4-6") == 1000000
|
||||
|
||||
@patch("agent.models_dev.fetch_models_dev")
|
||||
def test_provider_not_mapped(self, mock_fetch):
|
||||
mock_fetch.return_value = SAMPLE_REGISTRY
|
||||
assert lookup_models_dev_context("nous", "some-model") is None
|
||||
|
||||
@patch("agent.models_dev.fetch_models_dev")
|
||||
def test_model_not_found(self, mock_fetch):
|
||||
mock_fetch.return_value = SAMPLE_REGISTRY
|
||||
assert lookup_models_dev_context("anthropic", "nonexistent-model") is None
|
||||
|
||||
@patch("agent.models_dev.fetch_models_dev")
|
||||
def test_provider_aware_context(self, mock_fetch):
|
||||
"""Same model, different context per provider."""
|
||||
mock_fetch.return_value = SAMPLE_REGISTRY
|
||||
# Anthropic direct: 1M
|
||||
assert lookup_models_dev_context("anthropic", "claude-opus-4-6") == 1000000
|
||||
# GitHub Copilot: only 128K for same model
|
||||
assert lookup_models_dev_context("copilot", "claude-opus-4.6") == 128000
|
||||
|
||||
@patch("agent.models_dev.fetch_models_dev")
|
||||
def test_zero_context_filtered(self, mock_fetch):
|
||||
mock_fetch.return_value = SAMPLE_REGISTRY
|
||||
# audio-only is not a mapped provider, but test the filtering directly
|
||||
data = SAMPLE_REGISTRY["audio-only"]["models"]["tts-model"]
|
||||
assert _extract_context(data) is None
|
||||
|
||||
@patch("agent.models_dev.fetch_models_dev")
|
||||
def test_empty_registry(self, mock_fetch):
|
||||
mock_fetch.return_value = {}
|
||||
assert lookup_models_dev_context("anthropic", "claude-opus-4-6") is None
|
||||
|
||||
|
||||
class TestFetchModelsDev:
|
||||
@patch("agent.models_dev.requests.get")
|
||||
def test_fetch_success(self, mock_get):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = SAMPLE_REGISTRY
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
# Clear caches
|
||||
import agent.models_dev as md
|
||||
md._models_dev_cache = {}
|
||||
md._models_dev_cache_time = 0
|
||||
|
||||
with patch.object(md, "_save_disk_cache"):
|
||||
result = fetch_models_dev(force_refresh=True)
|
||||
|
||||
assert "anthropic" in result
|
||||
assert len(result) == len(SAMPLE_REGISTRY)
|
||||
|
||||
@patch("agent.models_dev.requests.get")
|
||||
def test_fetch_failure_returns_stale_cache(self, mock_get):
|
||||
mock_get.side_effect = Exception("network error")
|
||||
|
||||
import agent.models_dev as md
|
||||
md._models_dev_cache = SAMPLE_REGISTRY
|
||||
md._models_dev_cache_time = 0 # expired
|
||||
|
||||
with patch.object(md, "_load_disk_cache", return_value=SAMPLE_REGISTRY):
|
||||
result = fetch_models_dev(force_refresh=True)
|
||||
|
||||
assert "anthropic" in result
|
||||
|
||||
@patch("agent.models_dev.requests.get")
|
||||
def test_in_memory_cache_used(self, mock_get):
|
||||
import agent.models_dev as md
|
||||
import time
|
||||
md._models_dev_cache = SAMPLE_REGISTRY
|
||||
md._models_dev_cache_time = time.time() # fresh
|
||||
|
||||
result = fetch_models_dev()
|
||||
mock_get.assert_not_called()
|
||||
assert result == SAMPLE_REGISTRY
|
||||
@@ -7,7 +7,7 @@ from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from cron.scheduler import _resolve_origin, _resolve_delivery_target, _deliver_result, run_job, SILENT_MARKER
|
||||
from cron.scheduler import _resolve_origin, _resolve_delivery_target, _deliver_result, run_job, SILENT_MARKER, _build_job_prompt
|
||||
|
||||
|
||||
class TestResolveOrigin:
|
||||
@@ -532,14 +532,53 @@ class TestBuildJobPromptSilentHint:
|
||||
"""Verify _build_job_prompt always injects [SILENT] guidance."""
|
||||
|
||||
def test_hint_always_present(self):
|
||||
from cron.scheduler import _build_job_prompt
|
||||
job = {"prompt": "Check for updates"}
|
||||
result = _build_job_prompt(job)
|
||||
assert "[SILENT]" in result
|
||||
assert "Check for updates" in result
|
||||
|
||||
def test_hint_present_even_without_prompt(self):
|
||||
from cron.scheduler import _build_job_prompt
|
||||
job = {"prompt": ""}
|
||||
result = _build_job_prompt(job)
|
||||
assert "[SILENT]" in result
|
||||
|
||||
|
||||
class TestBuildJobPromptMissingSkill:
|
||||
"""Verify that a missing skill logs a warning and does not crash the job."""
|
||||
|
||||
def _missing_skill_view(self, name: str) -> str:
|
||||
return json.dumps({"success": False, "error": f"Skill '{name}' not found."})
|
||||
|
||||
def test_missing_skill_does_not_raise(self):
|
||||
"""Job should run even when a referenced skill is not installed."""
|
||||
with patch("tools.skills_tool.skill_view", side_effect=self._missing_skill_view):
|
||||
result = _build_job_prompt({"skills": ["ghost-skill"], "prompt": "do something"})
|
||||
# prompt is preserved even though skill was skipped
|
||||
assert "do something" in result
|
||||
|
||||
def test_missing_skill_injects_user_notice_into_prompt(self):
|
||||
"""A system notice about the missing skill is injected into the prompt."""
|
||||
with patch("tools.skills_tool.skill_view", side_effect=self._missing_skill_view):
|
||||
result = _build_job_prompt({"skills": ["ghost-skill"], "prompt": "do something"})
|
||||
assert "ghost-skill" in result
|
||||
assert "not found" in result.lower() or "skipped" in result.lower()
|
||||
|
||||
def test_missing_skill_logs_warning(self, caplog):
|
||||
"""A warning is logged when a skill cannot be found."""
|
||||
with caplog.at_level(logging.WARNING, logger="cron.scheduler"):
|
||||
with patch("tools.skills_tool.skill_view", side_effect=self._missing_skill_view):
|
||||
_build_job_prompt({"name": "My Job", "skills": ["ghost-skill"], "prompt": "do something"})
|
||||
assert any("ghost-skill" in record.message for record in caplog.records)
|
||||
|
||||
def test_valid_skill_loaded_alongside_missing(self):
|
||||
"""A valid skill is still loaded when another skill in the list is missing."""
|
||||
|
||||
def _mixed_skill_view(name: str) -> str:
|
||||
if name == "real-skill":
|
||||
return json.dumps({"success": True, "content": "Real skill content."})
|
||||
return json.dumps({"success": False, "error": f"Skill '{name}' not found."})
|
||||
|
||||
with patch("tools.skills_tool.skill_view", side_effect=_mixed_skill_view):
|
||||
result = _build_job_prompt({"skills": ["ghost-skill", "real-skill"], "prompt": "go"})
|
||||
assert "Real skill content." in result
|
||||
assert "go" in result
|
||||
|
||||
@@ -0,0 +1,240 @@
|
||||
"""Tests for /approve and /deny gateway commands.
|
||||
|
||||
Verifies that dangerous command approvals require explicit /approve or /deny
|
||||
slash commands, not bare "yes"/"no" text matching.
|
||||
"""
|
||||
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.session import SessionEntry, SessionSource, build_session_key
|
||||
|
||||
|
||||
def _make_source() -> SessionSource:
|
||||
return SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
user_id="u1",
|
||||
chat_id="c1",
|
||||
user_name="tester",
|
||||
chat_type="dm",
|
||||
)
|
||||
|
||||
|
||||
def _make_event(text: str) -> MessageEvent:
|
||||
return MessageEvent(
|
||||
text=text,
|
||||
source=_make_source(),
|
||||
message_id="m1",
|
||||
)
|
||||
|
||||
|
||||
def _make_runner():
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(
|
||||
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
|
||||
)
|
||||
adapter = MagicMock()
|
||||
adapter.send = AsyncMock()
|
||||
runner.adapters = {Platform.TELEGRAM: adapter}
|
||||
runner._voice_mode = {}
|
||||
runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False)
|
||||
runner.session_store = MagicMock()
|
||||
runner._running_agents = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._session_db = None
|
||||
runner._reasoning_config = None
|
||||
runner._provider_routing = {}
|
||||
runner._fallback_model = None
|
||||
runner._show_reasoning = False
|
||||
runner._is_user_authorized = lambda _source: True
|
||||
runner._set_session_env = lambda _context: None
|
||||
return runner
|
||||
|
||||
|
||||
def _make_pending_approval(command="sudo rm -rf /tmp/test", pattern_key="sudo"):
|
||||
return {
|
||||
"command": command,
|
||||
"pattern_key": pattern_key,
|
||||
"pattern_keys": [pattern_key],
|
||||
"description": "sudo command",
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# /approve command
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestApproveCommand:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approve_executes_pending_command(self):
|
||||
"""Basic /approve executes the pending command."""
|
||||
runner = _make_runner()
|
||||
source = _make_source()
|
||||
session_key = runner._session_key_for_source(source)
|
||||
runner._pending_approvals[session_key] = _make_pending_approval()
|
||||
|
||||
event = _make_event("/approve")
|
||||
with patch("tools.terminal_tool.terminal_tool", return_value="done") as mock_term:
|
||||
result = await runner._handle_approve_command(event)
|
||||
|
||||
assert "✅ Command approved and executed" in result
|
||||
mock_term.assert_called_once_with(command="sudo rm -rf /tmp/test", force=True)
|
||||
assert session_key not in runner._pending_approvals
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approve_session_remembers_pattern(self):
|
||||
"""/approve session approves the pattern for the session."""
|
||||
runner = _make_runner()
|
||||
source = _make_source()
|
||||
session_key = runner._session_key_for_source(source)
|
||||
runner._pending_approvals[session_key] = _make_pending_approval()
|
||||
|
||||
event = _make_event("/approve session")
|
||||
with (
|
||||
patch("tools.terminal_tool.terminal_tool", return_value="done"),
|
||||
patch("tools.approval.approve_session") as mock_session,
|
||||
):
|
||||
result = await runner._handle_approve_command(event)
|
||||
|
||||
assert "pattern approved for this session" in result
|
||||
mock_session.assert_called_once_with(session_key, "sudo")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approve_always_approves_permanently(self):
|
||||
"""/approve always approves the pattern permanently."""
|
||||
runner = _make_runner()
|
||||
source = _make_source()
|
||||
session_key = runner._session_key_for_source(source)
|
||||
runner._pending_approvals[session_key] = _make_pending_approval()
|
||||
|
||||
event = _make_event("/approve always")
|
||||
with (
|
||||
patch("tools.terminal_tool.terminal_tool", return_value="done"),
|
||||
patch("tools.approval.approve_permanent") as mock_perm,
|
||||
):
|
||||
result = await runner._handle_approve_command(event)
|
||||
|
||||
assert "pattern approved permanently" in result
|
||||
mock_perm.assert_called_once_with("sudo")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approve_no_pending(self):
|
||||
"""/approve with no pending approval returns helpful message."""
|
||||
runner = _make_runner()
|
||||
event = _make_event("/approve")
|
||||
result = await runner._handle_approve_command(event)
|
||||
assert "No pending command" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approve_expired(self):
|
||||
"""/approve on a timed-out approval rejects it."""
|
||||
runner = _make_runner()
|
||||
source = _make_source()
|
||||
session_key = runner._session_key_for_source(source)
|
||||
approval = _make_pending_approval()
|
||||
approval["timestamp"] = time.time() - 600 # 10 minutes ago
|
||||
runner._pending_approvals[session_key] = approval
|
||||
|
||||
event = _make_event("/approve")
|
||||
result = await runner._handle_approve_command(event)
|
||||
|
||||
assert "expired" in result
|
||||
assert session_key not in runner._pending_approvals
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# /deny command
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDenyCommand:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deny_clears_pending(self):
|
||||
"""/deny clears the pending approval."""
|
||||
runner = _make_runner()
|
||||
source = _make_source()
|
||||
session_key = runner._session_key_for_source(source)
|
||||
runner._pending_approvals[session_key] = _make_pending_approval()
|
||||
|
||||
event = _make_event("/deny")
|
||||
result = await runner._handle_deny_command(event)
|
||||
|
||||
assert "❌ Command denied" in result
|
||||
assert session_key not in runner._pending_approvals
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deny_no_pending(self):
|
||||
"""/deny with no pending approval returns helpful message."""
|
||||
runner = _make_runner()
|
||||
event = _make_event("/deny")
|
||||
result = await runner._handle_deny_command(event)
|
||||
assert "No pending command" in result
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Bare "yes" must NOT trigger approval
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBareTextNoLongerApproves:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_yes_does_not_execute_pending_command(self):
|
||||
"""Saying 'yes' in normal conversation must not execute a pending command.
|
||||
|
||||
This is the core bug from issue #1888: bare text matching against
|
||||
'yes'/'no' could intercept unrelated user messages.
|
||||
"""
|
||||
runner = _make_runner()
|
||||
source = _make_source()
|
||||
session_key = runner._session_key_for_source(source)
|
||||
runner._pending_approvals[session_key] = _make_pending_approval()
|
||||
|
||||
# Simulate the user saying "yes" as a normal message.
|
||||
# The old code would have executed the pending command.
|
||||
# Now it should fall through to normal processing (agent handles it).
|
||||
event = _make_event("yes")
|
||||
|
||||
# The approval should still be pending — "yes" is not /approve
|
||||
# We can't easily run _handle_message end-to-end, but we CAN verify
|
||||
# the old text-matching block no longer exists by confirming the
|
||||
# approval is untouched after the command dispatch section.
|
||||
# The key assertion is that _pending_approvals is NOT consumed.
|
||||
assert session_key in runner._pending_approvals
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Approval hint appended to response
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestApprovalHint:
|
||||
|
||||
def test_approval_hint_appended_to_response(self):
|
||||
"""When a pending approval is collected, structured instructions
|
||||
should be appended to the agent response."""
|
||||
# This tests the approval collection logic at the end of _handle_message.
|
||||
# We verify the hint format directly.
|
||||
cmd = "sudo rm -rf /tmp/dangerous"
|
||||
cmd_preview = cmd
|
||||
hint = (
|
||||
f"\n\n⚠️ **Dangerous command requires approval:**\n"
|
||||
f"```\n{cmd_preview}\n```\n"
|
||||
f"Reply `/approve` to execute, `/approve session` to approve this pattern "
|
||||
f"for the session, or `/deny` to cancel."
|
||||
)
|
||||
assert "/approve" in hint
|
||||
assert "/deny" in hint
|
||||
assert cmd in hint
|
||||
@@ -0,0 +1,267 @@
|
||||
"""Tests for the session race guard that prevents concurrent agent runs.
|
||||
|
||||
The sentinel-based guard ensures that when _handle_message passes the
|
||||
"is an agent already running?" check and proceeds to the slow async
|
||||
setup path (vision enrichment, STT, hooks, session hygiene), a second
|
||||
message for the same session is correctly recognized as "already running"
|
||||
and routed through the interrupt/queue path instead of spawning a
|
||||
duplicate agent.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent, MessageType
|
||||
from gateway.run import GatewayRunner, _AGENT_PENDING_SENTINEL
|
||||
from gateway.session import SessionSource, build_session_key
|
||||
|
||||
|
||||
class _FakeAdapter:
|
||||
"""Minimal adapter stub for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self._pending_messages = {}
|
||||
|
||||
async def send(self, chat_id, text, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def _make_runner():
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(
|
||||
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
|
||||
)
|
||||
runner.adapters = {Platform.TELEGRAM: _FakeAdapter()}
|
||||
runner._running_agents = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._voice_mode = {}
|
||||
runner._is_user_authorized = lambda _source: True
|
||||
return runner
|
||||
|
||||
|
||||
def _make_event(text="hello", chat_id="12345"):
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM, chat_id=chat_id, chat_type="dm"
|
||||
)
|
||||
return MessageEvent(text=text, message_type=MessageType.TEXT, source=source)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Test 1: Sentinel is placed before _handle_message_with_agent runs
|
||||
# ------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_sentinel_placed_before_agent_setup():
|
||||
"""After passing the 'not running' guard, the sentinel must be
|
||||
written into _running_agents *before* any await, so that a
|
||||
concurrent message sees the session as occupied."""
|
||||
runner = _make_runner()
|
||||
event = _make_event()
|
||||
session_key = build_session_key(event.source)
|
||||
|
||||
# Patch _handle_message_with_agent to capture state at entry
|
||||
sentinel_was_set = False
|
||||
|
||||
async def mock_inner(self_inner, ev, src, qk):
|
||||
nonlocal sentinel_was_set
|
||||
sentinel_was_set = runner._running_agents.get(qk) is _AGENT_PENDING_SENTINEL
|
||||
return "ok"
|
||||
|
||||
with patch.object(GatewayRunner, "_handle_message_with_agent", mock_inner):
|
||||
await runner._handle_message(event)
|
||||
|
||||
assert sentinel_was_set, (
|
||||
"Sentinel must be in _running_agents when _handle_message_with_agent starts"
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Test 2: Sentinel is cleaned up after _handle_message_with_agent
|
||||
# ------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_sentinel_cleaned_up_after_handler_returns():
|
||||
"""If _handle_message_with_agent returns normally, the sentinel
|
||||
must be removed so the session is not permanently locked."""
|
||||
runner = _make_runner()
|
||||
event = _make_event()
|
||||
session_key = build_session_key(event.source)
|
||||
|
||||
async def mock_inner(self_inner, ev, src, qk):
|
||||
return "ok"
|
||||
|
||||
with patch.object(GatewayRunner, "_handle_message_with_agent", mock_inner):
|
||||
await runner._handle_message(event)
|
||||
|
||||
assert session_key not in runner._running_agents, (
|
||||
"Sentinel must be removed after handler completes"
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Test 3: Sentinel cleaned up on exception
|
||||
# ------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_sentinel_cleaned_up_on_exception():
|
||||
"""If _handle_message_with_agent raises, the sentinel must still
|
||||
be cleaned up so the session is not permanently locked."""
|
||||
runner = _make_runner()
|
||||
event = _make_event()
|
||||
session_key = build_session_key(event.source)
|
||||
|
||||
async def mock_inner(self_inner, ev, src, qk):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
with patch.object(GatewayRunner, "_handle_message_with_agent", mock_inner):
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
await runner._handle_message(event)
|
||||
|
||||
assert session_key not in runner._running_agents, (
|
||||
"Sentinel must be removed even if handler raises"
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Test 4: Second message during sentinel sees "already running"
|
||||
# ------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_second_message_during_sentinel_queued_not_duplicate():
|
||||
"""While the sentinel is set (agent setup in progress), a second
|
||||
message for the same session must hit the 'already running' branch
|
||||
and be queued — not start a second agent."""
|
||||
runner = _make_runner()
|
||||
event1 = _make_event(text="first message")
|
||||
event2 = _make_event(text="second message")
|
||||
session_key = build_session_key(event1.source)
|
||||
|
||||
barrier = asyncio.Event()
|
||||
|
||||
async def slow_inner(self_inner, ev, src, qk):
|
||||
# Simulate slow setup — wait until test tells us to proceed
|
||||
await barrier.wait()
|
||||
return "ok"
|
||||
|
||||
with patch.object(GatewayRunner, "_handle_message_with_agent", slow_inner):
|
||||
# Start first message (will block at barrier)
|
||||
task1 = asyncio.create_task(runner._handle_message(event1))
|
||||
# Yield so task1 enters slow_inner and sentinel is set
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# Verify sentinel is set
|
||||
assert runner._running_agents.get(session_key) is _AGENT_PENDING_SENTINEL
|
||||
|
||||
# Second message should see "already running" and be queued
|
||||
result2 = await runner._handle_message(event2)
|
||||
assert result2 is None, "Second message should return None (queued)"
|
||||
|
||||
# The second message should have been queued in adapter pending
|
||||
adapter = runner.adapters[Platform.TELEGRAM]
|
||||
assert session_key in adapter._pending_messages, (
|
||||
"Second message should be queued as pending"
|
||||
)
|
||||
assert adapter._pending_messages[session_key] is event2
|
||||
|
||||
# Let first message complete
|
||||
barrier.set()
|
||||
await task1
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Test 5: Sentinel not placed for command messages
|
||||
# ------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_command_messages_do_not_leave_sentinel():
|
||||
"""Slash commands (/help, /status, etc.) return early from
|
||||
_handle_message. They must NOT leave a sentinel behind."""
|
||||
runner = _make_runner()
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM, chat_id="12345", chat_type="dm"
|
||||
)
|
||||
event = MessageEvent(
|
||||
text="/help", message_type=MessageType.TEXT, source=source
|
||||
)
|
||||
session_key = build_session_key(source)
|
||||
|
||||
# Mock the help handler to avoid needing full runner setup
|
||||
runner._handle_help_command = AsyncMock(return_value="Help text")
|
||||
# Need hooks for command emission
|
||||
runner.hooks = MagicMock()
|
||||
runner.hooks.emit = AsyncMock()
|
||||
|
||||
await runner._handle_message(event)
|
||||
|
||||
assert session_key not in runner._running_agents, (
|
||||
"Command handlers must not leave sentinel in _running_agents"
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Test 6: /stop during sentinel returns helpful message
|
||||
# ------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_during_sentinel_returns_message():
|
||||
"""If /stop arrives while the sentinel is set (agent still starting),
|
||||
it should return a helpful message instead of crashing or queuing."""
|
||||
runner = _make_runner()
|
||||
event1 = _make_event(text="hello")
|
||||
session_key = build_session_key(event1.source)
|
||||
|
||||
barrier = asyncio.Event()
|
||||
|
||||
async def slow_inner(self_inner, ev, src, qk):
|
||||
await barrier.wait()
|
||||
return "ok"
|
||||
|
||||
with patch.object(GatewayRunner, "_handle_message_with_agent", slow_inner):
|
||||
task1 = asyncio.create_task(runner._handle_message(event1))
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# Sentinel should be set
|
||||
assert runner._running_agents.get(session_key) is _AGENT_PENDING_SENTINEL
|
||||
|
||||
# Send /stop — should get a message, not crash
|
||||
stop_event = _make_event(text="/stop")
|
||||
result = await runner._handle_message(stop_event)
|
||||
assert result is not None, "/stop during sentinel should return a message"
|
||||
assert "starting up" in result.lower()
|
||||
|
||||
# Should NOT be queued as pending
|
||||
adapter = runner.adapters[Platform.TELEGRAM]
|
||||
assert session_key not in adapter._pending_messages
|
||||
|
||||
barrier.set()
|
||||
await task1
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Test 7: Shutdown skips sentinel entries
|
||||
# ------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown_skips_sentinel():
|
||||
"""During gateway shutdown, sentinel entries in _running_agents
|
||||
should be skipped without raising AttributeError."""
|
||||
runner = _make_runner()
|
||||
session_key = "telegram:dm:99999"
|
||||
|
||||
# Simulate a sentinel in _running_agents
|
||||
runner._running_agents[session_key] = _AGENT_PENDING_SENTINEL
|
||||
|
||||
# Also add a real agent mock to verify it still gets interrupted
|
||||
real_agent = MagicMock()
|
||||
runner._running_agents["telegram:dm:88888"] = real_agent
|
||||
|
||||
runner.adapters = {} # No adapters to disconnect
|
||||
runner._running = True
|
||||
runner._shutdown_event = asyncio.Event()
|
||||
runner._exit_reason = None
|
||||
runner._shutdown_all_gateway_honcho = lambda: None
|
||||
|
||||
with patch("gateway.status.remove_pid_file"), \
|
||||
patch("gateway.status.write_runtime_status"):
|
||||
await runner.stop()
|
||||
|
||||
# Real agent should have been interrupted
|
||||
real_agent.interrupt.assert_called_once()
|
||||
# Should not have raised on the sentinel
|
||||
@@ -0,0 +1,619 @@
|
||||
"""Unit tests for the generic webhook platform adapter.
|
||||
|
||||
Covers:
|
||||
- HMAC signature validation (GitHub, GitLab, generic)
|
||||
- Prompt rendering with dot-notation template variables
|
||||
- Event type filtering
|
||||
- HTTP handler behaviour (404, 202, health)
|
||||
- Idempotency cache (duplicate delivery IDs)
|
||||
- Rate limiting (fixed-window, per route)
|
||||
- Body size limits
|
||||
- INSECURE_NO_AUTH bypass
|
||||
- Session isolation for concurrent webhooks
|
||||
- Delivery info cleanup after send()
|
||||
- connect / disconnect lifecycle
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent, MessageType, SendResult
|
||||
from gateway.platforms.webhook import (
|
||||
WebhookAdapter,
|
||||
_INSECURE_NO_AUTH,
|
||||
check_webhook_requirements,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_config(
|
||||
routes=None,
|
||||
secret="",
|
||||
rate_limit=30,
|
||||
max_body_bytes=1_048_576,
|
||||
host="0.0.0.0",
|
||||
port=0, # let OS pick a free port in tests
|
||||
):
|
||||
"""Build a PlatformConfig suitable for WebhookAdapter."""
|
||||
extra = {
|
||||
"host": host,
|
||||
"port": port,
|
||||
"routes": routes or {},
|
||||
"rate_limit": rate_limit,
|
||||
"max_body_bytes": max_body_bytes,
|
||||
}
|
||||
if secret:
|
||||
extra["secret"] = secret
|
||||
return PlatformConfig(enabled=True, extra=extra)
|
||||
|
||||
|
||||
def _make_adapter(routes=None, **kwargs):
|
||||
"""Create a WebhookAdapter with sensible defaults for testing."""
|
||||
config = _make_config(routes=routes, **kwargs)
|
||||
return WebhookAdapter(config)
|
||||
|
||||
|
||||
def _create_app(adapter: WebhookAdapter) -> web.Application:
|
||||
"""Build the aiohttp Application from the adapter (without starting a full server)."""
|
||||
app = web.Application()
|
||||
app.router.add_get("/health", adapter._handle_health)
|
||||
app.router.add_post("/webhooks/{route_name}", adapter._handle_webhook)
|
||||
return app
|
||||
|
||||
|
||||
def _mock_request(headers=None, body=b"", content_length=None, match_info=None):
|
||||
"""Build a lightweight mock aiohttp request for non-HTTP tests."""
|
||||
req = MagicMock()
|
||||
req.headers = headers or {}
|
||||
req.content_length = content_length if content_length is not None else len(body)
|
||||
req.match_info = match_info or {}
|
||||
req.method = "POST"
|
||||
|
||||
async def _read():
|
||||
return body
|
||||
|
||||
req.read = _read
|
||||
return req
|
||||
|
||||
|
||||
def _github_signature(body: bytes, secret: str) -> str:
|
||||
"""Compute X-Hub-Signature-256 for *body* using *secret*."""
|
||||
return "sha256=" + hmac.new(
|
||||
secret.encode(), body, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
|
||||
def _generic_signature(body: bytes, secret: str) -> str:
|
||||
"""Compute X-Webhook-Signature (plain HMAC-SHA256 hex) for *body*."""
|
||||
return hmac.new(secret.encode(), body, hashlib.sha256).hexdigest()
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Signature validation
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestValidateSignature:
|
||||
"""Tests for WebhookAdapter._validate_signature."""
|
||||
|
||||
def test_validate_github_signature_valid(self):
|
||||
"""Valid X-Hub-Signature-256 is accepted."""
|
||||
adapter = _make_adapter()
|
||||
body = b'{"action": "opened"}'
|
||||
secret = "webhook-secret-42"
|
||||
sig = _github_signature(body, secret)
|
||||
req = _mock_request(headers={"X-Hub-Signature-256": sig})
|
||||
assert adapter._validate_signature(req, body, secret) is True
|
||||
|
||||
def test_validate_github_signature_invalid(self):
|
||||
"""Wrong X-Hub-Signature-256 is rejected."""
|
||||
adapter = _make_adapter()
|
||||
body = b'{"action": "opened"}'
|
||||
secret = "webhook-secret-42"
|
||||
req = _mock_request(headers={"X-Hub-Signature-256": "sha256=deadbeef"})
|
||||
assert adapter._validate_signature(req, body, secret) is False
|
||||
|
||||
def test_validate_gitlab_token(self):
|
||||
"""GitLab plain-token match via X-Gitlab-Token."""
|
||||
adapter = _make_adapter()
|
||||
secret = "gl-token-value"
|
||||
req = _mock_request(headers={"X-Gitlab-Token": secret})
|
||||
assert adapter._validate_signature(req, b"{}", secret) is True
|
||||
|
||||
def test_validate_gitlab_token_wrong(self):
|
||||
"""Wrong X-Gitlab-Token is rejected."""
|
||||
adapter = _make_adapter()
|
||||
req = _mock_request(headers={"X-Gitlab-Token": "wrong"})
|
||||
assert adapter._validate_signature(req, b"{}", "correct") is False
|
||||
|
||||
def test_validate_no_signature_with_secret_rejects(self):
|
||||
"""Secret configured but no recognised signature header → reject."""
|
||||
adapter = _make_adapter()
|
||||
req = _mock_request(headers={}) # no sig headers at all
|
||||
assert adapter._validate_signature(req, b"{}", "my-secret") is False
|
||||
|
||||
def test_validate_no_secret_allows_all(self):
|
||||
"""When the secret is empty/falsy, the validator is never even called
|
||||
by the handler (secret check is 'if secret and secret != _INSECURE...').
|
||||
Verify that an empty secret isn't accidentally passed to the validator."""
|
||||
# This tests the semantics: empty secret means skip validation entirely.
|
||||
# The handler code does: if secret and secret != _INSECURE_NO_AUTH: validate
|
||||
# So with an empty secret, _validate_signature is never reached.
|
||||
# We just verify the code path is correct by constructing an adapter
|
||||
# with no secret and confirming the route config resolves to "".
|
||||
adapter = _make_adapter(
|
||||
routes={"test": {"prompt": "hello"}},
|
||||
secret="",
|
||||
)
|
||||
# The route has no secret, global secret is empty
|
||||
route_secret = adapter._routes["test"].get("secret", adapter._global_secret)
|
||||
assert not route_secret # empty → validation is skipped in handler
|
||||
|
||||
def test_validate_generic_signature_valid(self):
|
||||
"""Valid X-Webhook-Signature (generic HMAC-SHA256 hex) is accepted."""
|
||||
adapter = _make_adapter()
|
||||
body = b'{"event": "push"}'
|
||||
secret = "generic-secret"
|
||||
sig = _generic_signature(body, secret)
|
||||
req = _mock_request(headers={"X-Webhook-Signature": sig})
|
||||
assert adapter._validate_signature(req, body, secret) is True
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Prompt rendering
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestRenderPrompt:
|
||||
"""Tests for WebhookAdapter._render_prompt."""
|
||||
|
||||
def test_render_prompt_dot_notation(self):
|
||||
"""Dot-notation {pull_request.title} resolves nested keys."""
|
||||
adapter = _make_adapter()
|
||||
payload = {"pull_request": {"title": "Fix bug", "number": 42}}
|
||||
result = adapter._render_prompt(
|
||||
"PR #{pull_request.number}: {pull_request.title}",
|
||||
payload,
|
||||
"pull_request",
|
||||
"github",
|
||||
)
|
||||
assert result == "PR #42: Fix bug"
|
||||
|
||||
def test_render_prompt_missing_key_preserved(self):
|
||||
"""{nonexistent} is left as-is when key doesn't exist in payload."""
|
||||
adapter = _make_adapter()
|
||||
result = adapter._render_prompt(
|
||||
"Hello {nonexistent}!",
|
||||
{"action": "opened"},
|
||||
"push",
|
||||
"test",
|
||||
)
|
||||
assert "{nonexistent}" in result
|
||||
|
||||
def test_render_prompt_no_template_dumps_json(self):
|
||||
"""Empty template → JSON dump fallback with event/route context."""
|
||||
adapter = _make_adapter()
|
||||
payload = {"key": "value"}
|
||||
result = adapter._render_prompt("", payload, "push", "my-route")
|
||||
assert "push" in result
|
||||
assert "my-route" in result
|
||||
assert "key" in result
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Delivery extra rendering
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestRenderDeliveryExtra:
|
||||
def test_render_delivery_extra_templates(self):
|
||||
"""String values in deliver_extra are rendered with payload data."""
|
||||
adapter = _make_adapter()
|
||||
extra = {"repo": "{repository.full_name}", "pr_number": "{number}", "static": 42}
|
||||
payload = {"repository": {"full_name": "org/repo"}, "number": 7}
|
||||
result = adapter._render_delivery_extra(extra, payload)
|
||||
assert result["repo"] == "org/repo"
|
||||
assert result["pr_number"] == "7"
|
||||
assert result["static"] == 42 # non-string left as-is
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Event filtering
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestEventFilter:
|
||||
"""Tests for event type filtering in _handle_webhook."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_filter_accepts_matching(self):
|
||||
"""Matching event type passes through."""
|
||||
routes = {
|
||||
"gh": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
"events": ["pull_request"],
|
||||
"prompt": "PR: {action}",
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes=routes)
|
||||
# Stub handle_message to avoid running the agent
|
||||
adapter.handle_message = AsyncMock()
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post(
|
||||
"/webhooks/gh",
|
||||
json={"action": "opened"},
|
||||
headers={"X-GitHub-Event": "pull_request"},
|
||||
)
|
||||
assert resp.status == 202
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_filter_rejects_non_matching(self):
|
||||
"""Non-matching event type returns 200 with status=ignored."""
|
||||
routes = {
|
||||
"gh": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
"events": ["pull_request"],
|
||||
"prompt": "test",
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes=routes)
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post(
|
||||
"/webhooks/gh",
|
||||
json={"action": "opened"},
|
||||
headers={"X-GitHub-Event": "push"},
|
||||
)
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["status"] == "ignored"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_filter_empty_allows_all(self):
|
||||
"""No events list → accept any event type."""
|
||||
routes = {
|
||||
"all": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
"prompt": "got it",
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes=routes)
|
||||
adapter.handle_message = AsyncMock()
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post(
|
||||
"/webhooks/all",
|
||||
json={"action": "any"},
|
||||
headers={"X-GitHub-Event": "whatever"},
|
||||
)
|
||||
assert resp.status == 202
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# HTTP handling
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestHTTPHandling:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_route_returns_404(self):
|
||||
"""POST to an unknown route returns 404."""
|
||||
adapter = _make_adapter(routes={"real": {"secret": _INSECURE_NO_AUTH, "prompt": "x"}})
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post("/webhooks/nonexistent", json={"a": 1})
|
||||
assert resp.status == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_handler_returns_202(self):
|
||||
"""Valid request returns 202 Accepted."""
|
||||
routes = {"test": {"secret": _INSECURE_NO_AUTH, "prompt": "hi"}}
|
||||
adapter = _make_adapter(routes=routes)
|
||||
adapter.handle_message = AsyncMock()
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post("/webhooks/test", json={"data": "value"})
|
||||
assert resp.status == 202
|
||||
data = await resp.json()
|
||||
assert data["status"] == "accepted"
|
||||
assert data["route"] == "test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_endpoint(self):
|
||||
"""GET /health returns 200 with status=ok."""
|
||||
adapter = _make_adapter()
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.get("/health")
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["platform"] == "webhook"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_starts_server(self):
|
||||
"""connect() starts the HTTP listener and marks adapter as connected."""
|
||||
routes = {"r1": {"secret": _INSECURE_NO_AUTH, "prompt": "x"}}
|
||||
adapter = _make_adapter(routes=routes, port=0)
|
||||
# Use port 0 — the OS picks a free port, but aiohttp requires a real bind.
|
||||
# We just test that the method completes and marks connected.
|
||||
# Need to mock TCPSite to avoid actual binding.
|
||||
with patch("gateway.platforms.webhook.web.AppRunner") as MockRunner, \
|
||||
patch("gateway.platforms.webhook.web.TCPSite") as MockSite:
|
||||
mock_runner_inst = AsyncMock()
|
||||
MockRunner.return_value = mock_runner_inst
|
||||
mock_site_inst = AsyncMock()
|
||||
MockSite.return_value = mock_site_inst
|
||||
|
||||
result = await adapter.connect()
|
||||
assert result is True
|
||||
assert adapter.is_connected
|
||||
mock_runner_inst.setup.assert_awaited_once()
|
||||
mock_site_inst.start.assert_awaited_once()
|
||||
|
||||
await adapter.disconnect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_cleans_up(self):
|
||||
"""disconnect() stops the server and marks adapter disconnected."""
|
||||
adapter = _make_adapter()
|
||||
# Simulate a runner that was previously set up
|
||||
mock_runner = AsyncMock()
|
||||
adapter._runner = mock_runner
|
||||
adapter._running = True
|
||||
|
||||
await adapter.disconnect()
|
||||
mock_runner.cleanup.assert_awaited_once()
|
||||
assert adapter._runner is None
|
||||
assert not adapter.is_connected
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Idempotency
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestIdempotency:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_duplicate_delivery_id_returns_200(self):
|
||||
"""Second request with same delivery ID returns 200 duplicate."""
|
||||
routes = {"idem": {"secret": _INSECURE_NO_AUTH, "prompt": "test"}}
|
||||
adapter = _make_adapter(routes=routes)
|
||||
adapter.handle_message = AsyncMock()
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
headers = {"X-GitHub-Delivery": "delivery-123"}
|
||||
resp1 = await cli.post("/webhooks/idem", json={"a": 1}, headers=headers)
|
||||
assert resp1.status == 202
|
||||
|
||||
resp2 = await cli.post("/webhooks/idem", json={"a": 1}, headers=headers)
|
||||
assert resp2.status == 200
|
||||
data = await resp2.json()
|
||||
assert data["status"] == "duplicate"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_delivery_id_allows_reprocess(self):
|
||||
"""After TTL expires, the same delivery ID is accepted again."""
|
||||
routes = {"idem": {"secret": _INSECURE_NO_AUTH, "prompt": "test"}}
|
||||
adapter = _make_adapter(routes=routes)
|
||||
adapter._idempotency_ttl = 1 # 1 second TTL for test speed
|
||||
adapter.handle_message = AsyncMock()
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
headers = {"X-GitHub-Delivery": "delivery-456"}
|
||||
|
||||
resp1 = await cli.post("/webhooks/idem", json={"x": 1}, headers=headers)
|
||||
assert resp1.status == 202
|
||||
|
||||
# Backdate the cache entry so it appears expired
|
||||
adapter._seen_deliveries["delivery-456"] = time.time() - 3700
|
||||
|
||||
resp2 = await cli.post("/webhooks/idem", json={"x": 1}, headers=headers)
|
||||
assert resp2.status == 202 # re-accepted
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Rate limiting
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestRateLimiting:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_rejects_excess(self):
|
||||
"""Exceeding the rate limit returns 429."""
|
||||
routes = {"limited": {"secret": _INSECURE_NO_AUTH, "prompt": "test"}}
|
||||
adapter = _make_adapter(routes=routes, rate_limit=2)
|
||||
adapter.handle_message = AsyncMock()
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
# Two requests within limit
|
||||
for i in range(2):
|
||||
resp = await cli.post(
|
||||
"/webhooks/limited",
|
||||
json={"n": i},
|
||||
headers={"X-GitHub-Delivery": f"d-{i}"},
|
||||
)
|
||||
assert resp.status == 202, f"Request {i} should be accepted"
|
||||
|
||||
# Third request should be rate-limited
|
||||
resp = await cli.post(
|
||||
"/webhooks/limited",
|
||||
json={"n": 99},
|
||||
headers={"X-GitHub-Delivery": "d-99"},
|
||||
)
|
||||
assert resp.status == 429
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_window_resets(self):
|
||||
"""After the 60-second window passes, requests are allowed again."""
|
||||
routes = {"limited": {"secret": _INSECURE_NO_AUTH, "prompt": "test"}}
|
||||
adapter = _make_adapter(routes=routes, rate_limit=1)
|
||||
adapter.handle_message = AsyncMock()
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post(
|
||||
"/webhooks/limited",
|
||||
json={"n": 1},
|
||||
headers={"X-GitHub-Delivery": "d-a"},
|
||||
)
|
||||
assert resp.status == 202
|
||||
|
||||
# Backdate all rate-limit timestamps to > 60 seconds ago
|
||||
adapter._rate_counts["limited"] = [time.time() - 120]
|
||||
|
||||
resp = await cli.post(
|
||||
"/webhooks/limited",
|
||||
json={"n": 2},
|
||||
headers={"X-GitHub-Delivery": "d-b"},
|
||||
)
|
||||
assert resp.status == 202 # allowed again
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Body size limit
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestBodySize:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oversized_payload_rejected(self):
|
||||
"""Content-Length > max_body_bytes returns 413."""
|
||||
routes = {"big": {"secret": _INSECURE_NO_AUTH, "prompt": "test"}}
|
||||
adapter = _make_adapter(routes=routes, max_body_bytes=100)
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
large_payload = {"data": "x" * 200}
|
||||
resp = await cli.post(
|
||||
"/webhooks/big",
|
||||
json=large_payload,
|
||||
headers={"Content-Length": "999999"},
|
||||
)
|
||||
assert resp.status == 413
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# INSECURE_NO_AUTH
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestInsecureNoAuth:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insecure_no_auth_skips_validation(self):
|
||||
"""Setting secret to _INSECURE_NO_AUTH bypasses signature check."""
|
||||
routes = {"open": {"secret": _INSECURE_NO_AUTH, "prompt": "hello"}}
|
||||
adapter = _make_adapter(routes=routes)
|
||||
adapter.handle_message = AsyncMock()
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
# No signature header at all — should still be accepted
|
||||
resp = await cli.post("/webhooks/open", json={"test": True})
|
||||
assert resp.status == 202
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Session isolation
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestSessionIsolation:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_webhooks_get_independent_sessions(self):
|
||||
"""Two events on the same route produce different session keys."""
|
||||
routes = {"ci": {"secret": _INSECURE_NO_AUTH, "prompt": "build"}}
|
||||
adapter = _make_adapter(routes=routes)
|
||||
|
||||
captured_events = []
|
||||
|
||||
async def _capture(event):
|
||||
captured_events.append(event)
|
||||
|
||||
adapter.handle_message = _capture
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp1 = await cli.post(
|
||||
"/webhooks/ci",
|
||||
json={"ref": "main"},
|
||||
headers={"X-GitHub-Delivery": "aaa-111"},
|
||||
)
|
||||
assert resp1.status == 202
|
||||
|
||||
resp2 = await cli.post(
|
||||
"/webhooks/ci",
|
||||
json={"ref": "dev"},
|
||||
headers={"X-GitHub-Delivery": "bbb-222"},
|
||||
)
|
||||
assert resp2.status == 202
|
||||
|
||||
# Wait for the async tasks to be created
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert len(captured_events) == 2
|
||||
ids = {ev.source.chat_id for ev in captured_events}
|
||||
assert len(ids) == 2, "Each delivery must have a unique session chat_id"
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Delivery info cleanup
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestDeliveryCleanup:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delivery_info_cleaned_after_send(self):
|
||||
"""send() pops delivery_info so the entry doesn't leak memory."""
|
||||
adapter = _make_adapter()
|
||||
chat_id = "webhook:test:d-xyz"
|
||||
adapter._delivery_info[chat_id] = {
|
||||
"deliver": "log",
|
||||
"deliver_extra": {},
|
||||
"payload": {"x": 1},
|
||||
}
|
||||
|
||||
result = await adapter.send(chat_id, "Agent response here")
|
||||
assert result.success is True
|
||||
assert chat_id not in adapter._delivery_info
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# check_webhook_requirements
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestCheckRequirements:
|
||||
def test_returns_true_when_aiohttp_available(self):
|
||||
assert check_webhook_requirements() is True
|
||||
|
||||
@patch("gateway.platforms.webhook.AIOHTTP_AVAILABLE", False)
|
||||
def test_returns_false_without_aiohttp(self):
|
||||
assert check_webhook_requirements() is False
|
||||
@@ -0,0 +1,337 @@
|
||||
"""Integration tests for the generic webhook platform adapter.
|
||||
|
||||
These tests exercise end-to-end flows through the webhook adapter:
|
||||
1. GitHub PR webhook → agent MessageEvent created
|
||||
2. Skills config injects skill content into the prompt
|
||||
3. Cross-platform delivery routes to a mock Telegram adapter
|
||||
4. GitHub comment delivery invokes ``gh`` CLI (mocked subprocess)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
|
||||
from gateway.config import (
|
||||
GatewayConfig,
|
||||
HomeChannel,
|
||||
Platform,
|
||||
PlatformConfig,
|
||||
)
|
||||
from gateway.platforms.base import MessageEvent, MessageType, SendResult
|
||||
from gateway.platforms.webhook import WebhookAdapter, _INSECURE_NO_AUTH
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_adapter(routes, **extra_kw) -> WebhookAdapter:
|
||||
"""Create a WebhookAdapter with the given routes."""
|
||||
extra = {"host": "0.0.0.0", "port": 0, "routes": routes}
|
||||
extra.update(extra_kw)
|
||||
config = PlatformConfig(enabled=True, extra=extra)
|
||||
return WebhookAdapter(config)
|
||||
|
||||
|
||||
def _create_app(adapter: WebhookAdapter) -> web.Application:
|
||||
"""Build the aiohttp Application from the adapter."""
|
||||
app = web.Application()
|
||||
app.router.add_get("/health", adapter._handle_health)
|
||||
app.router.add_post("/webhooks/{route_name}", adapter._handle_webhook)
|
||||
return app
|
||||
|
||||
|
||||
def _github_signature(body: bytes, secret: str) -> str:
|
||||
"""Compute X-Hub-Signature-256 for *body* using *secret*."""
|
||||
return "sha256=" + hmac.new(
|
||||
secret.encode(), body, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
|
||||
# A realistic GitHub pull_request event payload (trimmed)
|
||||
GITHUB_PR_PAYLOAD = {
|
||||
"action": "opened",
|
||||
"number": 42,
|
||||
"pull_request": {
|
||||
"title": "Add webhook adapter",
|
||||
"body": "This PR adds a generic webhook platform adapter.",
|
||||
"html_url": "https://github.com/org/repo/pull/42",
|
||||
"user": {"login": "contributor"},
|
||||
"head": {"ref": "feature/webhooks"},
|
||||
"base": {"ref": "main"},
|
||||
},
|
||||
"repository": {
|
||||
"full_name": "org/repo",
|
||||
"html_url": "https://github.com/org/repo",
|
||||
},
|
||||
"sender": {"login": "contributor"},
|
||||
}
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Test 1: GitHub PR webhook triggers agent
|
||||
# ===================================================================
|
||||
|
||||
class TestGitHubPRWebhook:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_github_pr_webhook_triggers_agent(self):
|
||||
"""POST with a realistic GitHub PR payload should:
|
||||
1. Return 202 Accepted
|
||||
2. Call handle_message with a MessageEvent
|
||||
3. The event text contains the rendered prompt
|
||||
4. The event source has chat_type 'webhook'
|
||||
"""
|
||||
secret = "gh-webhook-test-secret"
|
||||
routes = {
|
||||
"github-pr": {
|
||||
"secret": secret,
|
||||
"events": ["pull_request"],
|
||||
"prompt": (
|
||||
"Review PR #{number} by {sender.login}: "
|
||||
"{pull_request.title}\n\n{pull_request.body}"
|
||||
),
|
||||
"deliver": "log",
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes)
|
||||
|
||||
captured_events: list[MessageEvent] = []
|
||||
|
||||
async def _capture(event: MessageEvent):
|
||||
captured_events.append(event)
|
||||
|
||||
adapter.handle_message = _capture
|
||||
|
||||
app = _create_app(adapter)
|
||||
body = json.dumps(GITHUB_PR_PAYLOAD).encode()
|
||||
sig = _github_signature(body, secret)
|
||||
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post(
|
||||
"/webhooks/github-pr",
|
||||
data=body,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"X-GitHub-Event": "pull_request",
|
||||
"X-Hub-Signature-256": sig,
|
||||
"X-GitHub-Delivery": "gh-delivery-001",
|
||||
},
|
||||
)
|
||||
assert resp.status == 202
|
||||
data = await resp.json()
|
||||
assert data["status"] == "accepted"
|
||||
assert data["route"] == "github-pr"
|
||||
assert data["event"] == "pull_request"
|
||||
assert data["delivery_id"] == "gh-delivery-001"
|
||||
|
||||
# Let the asyncio.create_task fire
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert len(captured_events) == 1
|
||||
event = captured_events[0]
|
||||
assert "Review PR #42 by contributor" in event.text
|
||||
assert "Add webhook adapter" in event.text
|
||||
assert event.source.chat_type == "webhook"
|
||||
assert event.source.platform == Platform.WEBHOOK
|
||||
assert "github-pr" in event.source.chat_id
|
||||
assert event.message_id == "gh-delivery-001"
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Test 2: Skills injected into prompt
|
||||
# ===================================================================
|
||||
|
||||
class TestSkillsInjection:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skills_injected_into_prompt(self):
|
||||
"""When a route has skills: [code-review], the adapter should
|
||||
call build_skill_invocation_message() and use its output as the
|
||||
prompt instead of the raw template render."""
|
||||
routes = {
|
||||
"pr-review": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
"events": ["pull_request"],
|
||||
"prompt": "Review this PR: {pull_request.title}",
|
||||
"skills": ["code-review"],
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes)
|
||||
|
||||
captured_events: list[MessageEvent] = []
|
||||
|
||||
async def _capture(event: MessageEvent):
|
||||
captured_events.append(event)
|
||||
|
||||
adapter.handle_message = _capture
|
||||
|
||||
skill_content = (
|
||||
"You are a code reviewer. Review the following:\n"
|
||||
"Review this PR: Add webhook adapter"
|
||||
)
|
||||
|
||||
# The imports are lazy (inside the handler), so patch the source module
|
||||
with patch(
|
||||
"agent.skill_commands.build_skill_invocation_message",
|
||||
return_value=skill_content,
|
||||
) as mock_build, patch(
|
||||
"agent.skill_commands.get_skill_commands",
|
||||
return_value={"/code-review": {"name": "code-review"}},
|
||||
):
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post(
|
||||
"/webhooks/pr-review",
|
||||
json=GITHUB_PR_PAYLOAD,
|
||||
headers={
|
||||
"X-GitHub-Event": "pull_request",
|
||||
"X-GitHub-Delivery": "skill-test-001",
|
||||
},
|
||||
)
|
||||
assert resp.status == 202
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert len(captured_events) == 1
|
||||
event = captured_events[0]
|
||||
# The prompt should be the skill content, not the raw template
|
||||
assert "You are a code reviewer" in event.text
|
||||
mock_build.assert_called_once()
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Test 3: Cross-platform delivery (webhook → Telegram)
|
||||
# ===================================================================
|
||||
|
||||
class TestCrossPlatformDelivery:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cross_platform_delivery(self):
|
||||
"""When deliver='telegram', the response is routed to the
|
||||
Telegram adapter via gateway_runner.adapters."""
|
||||
routes = {
|
||||
"alerts": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
"prompt": "Alert: {message}",
|
||||
"deliver": "telegram",
|
||||
"deliver_extra": {"chat_id": "12345"},
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes)
|
||||
adapter.handle_message = AsyncMock()
|
||||
|
||||
# Set up a mock gateway runner with a mock Telegram adapter
|
||||
mock_tg_adapter = AsyncMock()
|
||||
mock_tg_adapter.send = AsyncMock(return_value=SendResult(success=True))
|
||||
|
||||
mock_runner = MagicMock()
|
||||
mock_runner.adapters = {Platform.TELEGRAM: mock_tg_adapter}
|
||||
mock_runner.config = GatewayConfig(
|
||||
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="fake")}
|
||||
)
|
||||
adapter.gateway_runner = mock_runner
|
||||
|
||||
# First, simulate a webhook POST to set up delivery_info
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post(
|
||||
"/webhooks/alerts",
|
||||
json={"message": "Server is on fire!"},
|
||||
headers={"X-GitHub-Delivery": "alert-001"},
|
||||
)
|
||||
assert resp.status == 202
|
||||
|
||||
# The adapter should have stored delivery info
|
||||
chat_id = "webhook:alerts:alert-001"
|
||||
assert chat_id in adapter._delivery_info
|
||||
|
||||
# Now call send() as if the agent has finished
|
||||
result = await adapter.send(chat_id, "I've acknowledged the alert.")
|
||||
|
||||
assert result.success is True
|
||||
mock_tg_adapter.send.assert_awaited_once_with(
|
||||
"12345", "I've acknowledged the alert."
|
||||
)
|
||||
# Delivery info should be cleaned up
|
||||
assert chat_id not in adapter._delivery_info
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Test 4: GitHub comment delivery via gh CLI
|
||||
# ===================================================================
|
||||
|
||||
class TestGitHubCommentDelivery:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_github_comment_delivery(self):
|
||||
"""When deliver='github_comment', the adapter invokes
|
||||
``gh pr comment`` via subprocess.run (mocked)."""
|
||||
routes = {
|
||||
"pr-bot": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
"prompt": "Review: {pull_request.title}",
|
||||
"deliver": "github_comment",
|
||||
"deliver_extra": {
|
||||
"repo": "{repository.full_name}",
|
||||
"pr_number": "{number}",
|
||||
},
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes)
|
||||
adapter.handle_message = AsyncMock()
|
||||
|
||||
# POST a webhook to set up delivery info
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post(
|
||||
"/webhooks/pr-bot",
|
||||
json=GITHUB_PR_PAYLOAD,
|
||||
headers={
|
||||
"X-GitHub-Event": "pull_request",
|
||||
"X-GitHub-Delivery": "gh-comment-001",
|
||||
},
|
||||
)
|
||||
assert resp.status == 202
|
||||
|
||||
chat_id = "webhook:pr-bot:gh-comment-001"
|
||||
assert chat_id in adapter._delivery_info
|
||||
|
||||
# Verify deliver_extra was rendered with payload data
|
||||
delivery = adapter._delivery_info[chat_id]
|
||||
assert delivery["deliver_extra"]["repo"] == "org/repo"
|
||||
assert delivery["deliver_extra"]["pr_number"] == "42"
|
||||
|
||||
# Mock subprocess.run and call send()
|
||||
mock_result = MagicMock()
|
||||
mock_result.returncode = 0
|
||||
mock_result.stdout = "Comment posted"
|
||||
mock_result.stderr = ""
|
||||
|
||||
with patch(
|
||||
"gateway.platforms.webhook.subprocess.run",
|
||||
return_value=mock_result,
|
||||
) as mock_run:
|
||||
result = await adapter.send(
|
||||
chat_id, "LGTM! The code looks great."
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
mock_run.assert_called_once_with(
|
||||
[
|
||||
"gh", "pr", "comment", "42",
|
||||
"--repo", "org/repo",
|
||||
"--body", "LGTM! The code looks great.",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
)
|
||||
# Delivery info cleaned up
|
||||
assert chat_id not in adapter._delivery_info
|
||||
@@ -97,30 +97,32 @@ def test_custom_setup_clears_active_oauth_provider(tmp_path, monkeypatch):
|
||||
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
|
||||
|
||||
prompt_values = iter(
|
||||
[
|
||||
"https://custom.example/v1",
|
||||
"custom-api-key",
|
||||
"custom/model",
|
||||
]
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.setup.prompt",
|
||||
lambda *args, **kwargs: next(prompt_values),
|
||||
)
|
||||
# _model_flow_custom uses builtins.input (URL, key, model, context_length)
|
||||
input_values = iter([
|
||||
"https://custom.example/v1",
|
||||
"custom-api-key",
|
||||
"custom/model",
|
||||
"", # context_length (blank = auto-detect)
|
||||
])
|
||||
monkeypatch.setattr("builtins.input", lambda _prompt="": next(input_values))
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", lambda *args, **kwargs: False)
|
||||
monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: [])
|
||||
monkeypatch.setattr("hermes_cli.main._save_custom_provider", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.models.probe_api_models",
|
||||
lambda api_key, base_url: {"models": ["m"], "probed_url": base_url + "/models"},
|
||||
)
|
||||
|
||||
setup_model_provider(config)
|
||||
save_config(config)
|
||||
|
||||
reloaded = load_config()
|
||||
|
||||
# Core assertion: switching to custom endpoint clears OAuth provider
|
||||
assert get_active_provider() is None
|
||||
assert isinstance(reloaded["model"], dict)
|
||||
assert reloaded["model"]["provider"] == "custom"
|
||||
assert reloaded["model"]["base_url"] == "https://custom.example/v1"
|
||||
assert reloaded["model"]["default"] == "custom/model"
|
||||
|
||||
# _model_flow_custom writes config via its own load/save cycle
|
||||
reloaded = load_config()
|
||||
if isinstance(reloaded.get("model"), dict):
|
||||
assert reloaded["model"].get("provider") == "custom"
|
||||
assert reloaded["model"].get("default") == "custom/model"
|
||||
|
||||
|
||||
def test_codex_setup_uses_runtime_access_token_for_live_model_list(tmp_path, monkeypatch):
|
||||
|
||||
@@ -99,21 +99,21 @@ def test_setup_custom_endpoint_saves_working_v1_base_url(tmp_path, monkeypatch):
|
||||
return tts_idx
|
||||
raise AssertionError(f"Unexpected prompt_choice call: {question}")
|
||||
|
||||
def fake_prompt(message, current=None, **kwargs):
|
||||
if "API base URL" in message:
|
||||
return "http://localhost:8000"
|
||||
if "API key" in message:
|
||||
return "local-key"
|
||||
if "Model name" in message:
|
||||
return "llm"
|
||||
return ""
|
||||
# _model_flow_custom uses builtins.input (URL, key, model, context_length)
|
||||
input_values = iter([
|
||||
"http://localhost:8000",
|
||||
"local-key",
|
||||
"llm",
|
||||
"", # context_length (blank = auto-detect)
|
||||
])
|
||||
monkeypatch.setattr("builtins.input", lambda _prompt="": next(input_values))
|
||||
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt", fake_prompt)
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", lambda *args, **kwargs: False)
|
||||
monkeypatch.setattr("hermes_cli.auth.get_active_provider", lambda: None)
|
||||
monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: [])
|
||||
monkeypatch.setattr("agent.auxiliary_client.get_available_vision_backends", lambda: [])
|
||||
monkeypatch.setattr("hermes_cli.main._save_custom_provider", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.models.probe_api_models",
|
||||
lambda api_key, base_url: {
|
||||
@@ -126,16 +126,19 @@ def test_setup_custom_endpoint_saves_working_v1_base_url(tmp_path, monkeypatch):
|
||||
)
|
||||
|
||||
setup_model_provider(config)
|
||||
save_config(config)
|
||||
|
||||
env = _read_env(tmp_path)
|
||||
reloaded = load_config()
|
||||
|
||||
# _model_flow_custom saves env vars and config to disk
|
||||
assert env.get("OPENAI_BASE_URL") == "http://localhost:8000/v1"
|
||||
assert env.get("OPENAI_API_KEY") == "local-key"
|
||||
assert reloaded["model"]["provider"] == "custom"
|
||||
assert reloaded["model"]["base_url"] == "http://localhost:8000/v1"
|
||||
assert reloaded["model"]["default"] == "llm"
|
||||
|
||||
# The model config is saved as a dict by _model_flow_custom
|
||||
reloaded = load_config()
|
||||
model_cfg = reloaded.get("model", {})
|
||||
if isinstance(model_cfg, dict):
|
||||
assert model_cfg.get("provider") == "custom"
|
||||
assert model_cfg.get("default") == "llm"
|
||||
|
||||
|
||||
def test_setup_keep_current_config_provider_uses_provider_specific_model_menu(tmp_path, monkeypatch):
|
||||
|
||||
@@ -60,6 +60,21 @@ class TestFromEnv:
|
||||
config = HonchoClientConfig.from_env(workspace_id="custom")
|
||||
assert config.workspace_id == "custom"
|
||||
|
||||
def test_reads_base_url_from_env(self):
|
||||
with patch.dict(os.environ, {"HONCHO_BASE_URL": "http://localhost:8000"}, clear=False):
|
||||
config = HonchoClientConfig.from_env()
|
||||
assert config.base_url == "http://localhost:8000"
|
||||
assert config.enabled is True
|
||||
|
||||
def test_enabled_without_api_key_when_base_url_set(self):
|
||||
"""base_url alone (no API key) is sufficient to enable a local instance."""
|
||||
with patch.dict(os.environ, {"HONCHO_BASE_URL": "http://localhost:8000"}, clear=False):
|
||||
os.environ.pop("HONCHO_API_KEY", None)
|
||||
config = HonchoClientConfig.from_env()
|
||||
assert config.api_key is None
|
||||
assert config.base_url == "http://localhost:8000"
|
||||
assert config.enabled is True
|
||||
|
||||
|
||||
class TestFromGlobalConfig:
|
||||
def test_missing_config_falls_back_to_env(self, tmp_path):
|
||||
@@ -188,6 +203,36 @@ class TestFromGlobalConfig:
|
||||
config = HonchoClientConfig.from_global_config(config_path=config_file)
|
||||
assert config.api_key == "env-key"
|
||||
|
||||
def test_base_url_env_fallback(self, tmp_path):
|
||||
"""HONCHO_BASE_URL env var is used when no baseUrl in config JSON."""
|
||||
config_file = tmp_path / "config.json"
|
||||
config_file.write_text(json.dumps({"workspace": "local"}))
|
||||
|
||||
with patch.dict(os.environ, {"HONCHO_BASE_URL": "http://localhost:8000"}, clear=False):
|
||||
config = HonchoClientConfig.from_global_config(config_path=config_file)
|
||||
assert config.base_url == "http://localhost:8000"
|
||||
assert config.enabled is True
|
||||
|
||||
def test_base_url_from_config_root(self, tmp_path):
|
||||
"""baseUrl in config root is read and takes precedence over env var."""
|
||||
config_file = tmp_path / "config.json"
|
||||
config_file.write_text(json.dumps({"baseUrl": "http://config-host:9000"}))
|
||||
|
||||
with patch.dict(os.environ, {"HONCHO_BASE_URL": "http://localhost:8000"}, clear=False):
|
||||
config = HonchoClientConfig.from_global_config(config_path=config_file)
|
||||
assert config.base_url == "http://config-host:9000"
|
||||
|
||||
def test_base_url_not_read_from_host_block(self, tmp_path):
|
||||
"""baseUrl is a root-level connection setting, not overridable per-host (consistent with apiKey)."""
|
||||
config_file = tmp_path / "config.json"
|
||||
config_file.write_text(json.dumps({
|
||||
"baseUrl": "http://root:9000",
|
||||
"hosts": {"hermes": {"baseUrl": "http://host-block:9001"}},
|
||||
}))
|
||||
|
||||
config = HonchoClientConfig.from_global_config(config_path=config_file)
|
||||
assert config.base_url == "http://root:9000"
|
||||
|
||||
|
||||
class TestResolveSessionName:
|
||||
def test_manual_override(self):
|
||||
|
||||
@@ -578,21 +578,39 @@ class TestConvertMessages:
|
||||
|
||||
def test_converts_tool_results(self):
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{"id": "tc_1", "function": {"name": "test_tool", "arguments": "{}"}},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "tc_1", "content": "result data"},
|
||||
]
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
assert result[0]["role"] == "user"
|
||||
assert result[0]["content"][0]["type"] == "tool_result"
|
||||
assert result[0]["content"][0]["tool_use_id"] == "tc_1"
|
||||
# tool result is in the second message (user role)
|
||||
user_msg = [m for m in result if m["role"] == "user"][0]
|
||||
assert user_msg["content"][0]["type"] == "tool_result"
|
||||
assert user_msg["content"][0]["tool_use_id"] == "tc_1"
|
||||
|
||||
def test_merges_consecutive_tool_results(self):
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{"id": "tc_1", "function": {"name": "tool_a", "arguments": "{}"}},
|
||||
{"id": "tc_2", "function": {"name": "tool_b", "arguments": "{}"}},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "tc_1", "content": "result 1"},
|
||||
{"role": "tool", "tool_call_id": "tc_2", "content": "result 2"},
|
||||
]
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
assert len(result) == 1
|
||||
assert len(result[0]["content"]) == 2
|
||||
# assistant + merged user (with 2 tool_results)
|
||||
user_msgs = [m for m in result if m["role"] == "user"]
|
||||
assert len(user_msgs) == 1
|
||||
assert len(user_msgs[0]["content"]) == 2
|
||||
|
||||
def test_strips_orphaned_tool_use(self):
|
||||
messages = [
|
||||
@@ -610,6 +628,51 @@ class TestConvertMessages:
|
||||
assistant_blocks = result[0]["content"]
|
||||
assert all(b.get("type") != "tool_use" for b in assistant_blocks)
|
||||
|
||||
def test_strips_orphaned_tool_result(self):
|
||||
"""tool_result with no matching tool_use should be stripped.
|
||||
|
||||
This happens when context compression removes the assistant message
|
||||
containing the tool_use but leaves the subsequent tool_result intact.
|
||||
Anthropic rejects orphaned tool_results with a 400.
|
||||
"""
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"},
|
||||
# The assistant tool_use message was removed by compression,
|
||||
# but the tool_result survived:
|
||||
{"role": "tool", "tool_call_id": "tc_gone", "content": "stale result"},
|
||||
{"role": "user", "content": "Thanks"},
|
||||
]
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
# tc_gone has no matching tool_use — its tool_result should be stripped
|
||||
for m in result:
|
||||
if m["role"] == "user" and isinstance(m["content"], list):
|
||||
assert all(
|
||||
b.get("type") != "tool_result"
|
||||
for b in m["content"]
|
||||
), "Orphaned tool_result should have been stripped"
|
||||
|
||||
def test_strips_orphaned_tool_result_preserves_valid(self):
|
||||
"""Orphaned tool_results are stripped while valid ones survive."""
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{"id": "tc_valid", "function": {"name": "search", "arguments": "{}"}},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "tc_valid", "content": "good result"},
|
||||
{"role": "tool", "tool_call_id": "tc_orphan", "content": "stale result"},
|
||||
]
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
user_msg = [m for m in result if m["role"] == "user"][0]
|
||||
tool_results = [
|
||||
b for b in user_msg["content"] if b.get("type") == "tool_result"
|
||||
]
|
||||
assert len(tool_results) == 1
|
||||
assert tool_results[0]["tool_use_id"] == "tc_valid"
|
||||
|
||||
def test_system_with_cache_control(self):
|
||||
messages = [
|
||||
{
|
||||
@@ -641,11 +704,19 @@ class TestConvertMessages:
|
||||
def test_tool_cache_control_is_preserved_on_tool_result_block(self):
|
||||
messages = apply_anthropic_cache_control([
|
||||
{"role": "system", "content": "System prompt"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{"id": "tc_1", "function": {"name": "test_tool", "arguments": "{}"}},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "tc_1", "content": "result"},
|
||||
])
|
||||
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
tool_block = result[0]["content"][0]
|
||||
user_msg = [m for m in result if m["role"] == "user"][0]
|
||||
tool_block = user_msg["content"][0]
|
||||
|
||||
assert tool_block["type"] == "tool_result"
|
||||
assert tool_block["tool_use_id"] == "tc_1"
|
||||
|
||||
@@ -92,8 +92,8 @@ class TestProviderRegistry:
|
||||
assert PROVIDER_REGISTRY["copilot-acp"].inference_base_url == "acp://copilot"
|
||||
assert PROVIDER_REGISTRY["zai"].inference_base_url == "https://api.z.ai/api/paas/v4"
|
||||
assert PROVIDER_REGISTRY["kimi-coding"].inference_base_url == "https://api.moonshot.ai/v1"
|
||||
assert PROVIDER_REGISTRY["minimax"].inference_base_url == "https://api.minimax.io/v1"
|
||||
assert PROVIDER_REGISTRY["minimax-cn"].inference_base_url == "https://api.minimaxi.com/v1"
|
||||
assert PROVIDER_REGISTRY["minimax"].inference_base_url == "https://api.minimax.io/anthropic"
|
||||
assert PROVIDER_REGISTRY["minimax-cn"].inference_base_url == "https://api.minimaxi.com/anthropic"
|
||||
assert PROVIDER_REGISTRY["ai-gateway"].inference_base_url == "https://ai-gateway.vercel.sh/v1"
|
||||
assert PROVIDER_REGISTRY["kilocode"].inference_base_url == "https://api.kilo.ai/api/gateway"
|
||||
|
||||
@@ -399,14 +399,14 @@ class TestResolveApiKeyProviderCredentials:
|
||||
creds = resolve_api_key_provider_credentials("minimax")
|
||||
assert creds["provider"] == "minimax"
|
||||
assert creds["api_key"] == "mm-secret-key"
|
||||
assert creds["base_url"] == "https://api.minimax.io/v1"
|
||||
assert creds["base_url"] == "https://api.minimax.io/anthropic"
|
||||
|
||||
def test_resolve_minimax_cn_with_key(self, monkeypatch):
|
||||
monkeypatch.setenv("MINIMAX_CN_API_KEY", "mmcn-secret-key")
|
||||
creds = resolve_api_key_provider_credentials("minimax-cn")
|
||||
assert creds["provider"] == "minimax-cn"
|
||||
assert creds["api_key"] == "mmcn-secret-key"
|
||||
assert creds["base_url"] == "https://api.minimaxi.com/v1"
|
||||
assert creds["base_url"] == "https://api.minimaxi.com/anthropic"
|
||||
|
||||
def test_resolve_ai_gateway_with_key(self, monkeypatch):
|
||||
monkeypatch.setenv("AI_GATEWAY_API_KEY", "gw-secret-key")
|
||||
|
||||
@@ -42,6 +42,7 @@ def _make_cli(env_overrides=None, config_overrides=None, **kwargs):
|
||||
"prompt_toolkit.key_binding": MagicMock(),
|
||||
"prompt_toolkit.completion": MagicMock(),
|
||||
"prompt_toolkit.formatted_text": MagicMock(),
|
||||
"prompt_toolkit.auto_suggest": MagicMock(),
|
||||
}
|
||||
with patch.dict(sys.modules, prompt_toolkit_stubs), \
|
||||
patch.dict("os.environ", clean_env, clear=False):
|
||||
|
||||
@@ -12,6 +12,17 @@ from hermes_state import SessionDB
|
||||
from tools.todo_tool import TodoStore
|
||||
|
||||
|
||||
class _FakeCompressor:
|
||||
"""Minimal stand-in for ContextCompressor."""
|
||||
|
||||
def __init__(self):
|
||||
self.last_prompt_tokens = 500
|
||||
self.last_completion_tokens = 200
|
||||
self.last_total_tokens = 700
|
||||
self.compression_count = 3
|
||||
self._context_probed = True
|
||||
|
||||
|
||||
class _FakeAgent:
|
||||
def __init__(self, session_id: str, session_start):
|
||||
self.session_id = session_id
|
||||
@@ -25,6 +36,42 @@ class _FakeAgent:
|
||||
self.flush_memories = MagicMock()
|
||||
self._invalidate_system_prompt = MagicMock()
|
||||
|
||||
# Token counters (non-zero to verify reset)
|
||||
self.session_total_tokens = 1000
|
||||
self.session_input_tokens = 600
|
||||
self.session_output_tokens = 400
|
||||
self.session_prompt_tokens = 550
|
||||
self.session_completion_tokens = 350
|
||||
self.session_cache_read_tokens = 100
|
||||
self.session_cache_write_tokens = 50
|
||||
self.session_reasoning_tokens = 80
|
||||
self.session_api_calls = 5
|
||||
self.session_estimated_cost_usd = 0.42
|
||||
self.session_cost_status = "estimated"
|
||||
self.session_cost_source = "openrouter"
|
||||
self.context_compressor = _FakeCompressor()
|
||||
|
||||
def reset_session_state(self):
|
||||
"""Mirror the real AIAgent.reset_session_state()."""
|
||||
self.session_total_tokens = 0
|
||||
self.session_input_tokens = 0
|
||||
self.session_output_tokens = 0
|
||||
self.session_prompt_tokens = 0
|
||||
self.session_completion_tokens = 0
|
||||
self.session_cache_read_tokens = 0
|
||||
self.session_cache_write_tokens = 0
|
||||
self.session_reasoning_tokens = 0
|
||||
self.session_api_calls = 0
|
||||
self.session_estimated_cost_usd = 0.0
|
||||
self.session_cost_status = "unknown"
|
||||
self.session_cost_source = "none"
|
||||
if hasattr(self, "context_compressor") and self.context_compressor:
|
||||
self.context_compressor.last_prompt_tokens = 0
|
||||
self.context_compressor.last_completion_tokens = 0
|
||||
self.context_compressor.last_total_tokens = 0
|
||||
self.context_compressor.compression_count = 0
|
||||
self.context_compressor._context_probed = False
|
||||
|
||||
|
||||
def _make_cli(env_overrides=None, config_overrides=None, **kwargs):
|
||||
"""Create a HermesCLI instance with minimal mocking."""
|
||||
@@ -58,6 +105,7 @@ def _make_cli(env_overrides=None, config_overrides=None, **kwargs):
|
||||
"prompt_toolkit.key_binding": MagicMock(),
|
||||
"prompt_toolkit.completion": MagicMock(),
|
||||
"prompt_toolkit.formatted_text": MagicMock(),
|
||||
"prompt_toolkit.auto_suggest": MagicMock(),
|
||||
}
|
||||
with patch.dict(sys.modules, prompt_toolkit_stubs), patch.dict(
|
||||
"os.environ", clean_env, clear=False
|
||||
@@ -137,3 +185,38 @@ def test_clear_command_starts_new_session_before_redrawing(tmp_path):
|
||||
cli.console.clear.assert_called_once()
|
||||
cli.show_banner.assert_called_once()
|
||||
assert cli.conversation_history == []
|
||||
|
||||
|
||||
def test_new_session_resets_token_counters(tmp_path):
|
||||
"""Regression test for #2099: /new must zero all token counters."""
|
||||
cli = _prepare_cli_with_active_session(tmp_path)
|
||||
|
||||
# Verify counters are non-zero before reset
|
||||
agent = cli.agent
|
||||
assert agent.session_total_tokens > 0
|
||||
assert agent.session_api_calls > 0
|
||||
assert agent.context_compressor.compression_count > 0
|
||||
|
||||
cli.process_command("/new")
|
||||
|
||||
# All agent token counters must be zero
|
||||
assert agent.session_total_tokens == 0
|
||||
assert agent.session_input_tokens == 0
|
||||
assert agent.session_output_tokens == 0
|
||||
assert agent.session_prompt_tokens == 0
|
||||
assert agent.session_completion_tokens == 0
|
||||
assert agent.session_cache_read_tokens == 0
|
||||
assert agent.session_cache_write_tokens == 0
|
||||
assert agent.session_reasoning_tokens == 0
|
||||
assert agent.session_api_calls == 0
|
||||
assert agent.session_estimated_cost_usd == 0.0
|
||||
assert agent.session_cost_status == "unknown"
|
||||
assert agent.session_cost_source == "none"
|
||||
|
||||
# Context compressor counters must also be zero
|
||||
comp = agent.context_compressor
|
||||
assert comp.last_prompt_tokens == 0
|
||||
assert comp.last_completion_tokens == 0
|
||||
assert comp.last_total_tokens == 0
|
||||
assert comp.compression_count == 0
|
||||
assert comp._context_probed is False
|
||||
|
||||
@@ -459,7 +459,7 @@ def test_model_flow_custom_saves_verified_v1_base_url(monkeypatch, capsys):
|
||||
)
|
||||
monkeypatch.setattr("hermes_cli.config.save_config", lambda cfg: None)
|
||||
|
||||
answers = iter(["http://localhost:8000", "local-key", "llm"])
|
||||
answers = iter(["http://localhost:8000", "local-key", "llm", ""])
|
||||
monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers))
|
||||
|
||||
hermes_main._model_flow_custom({})
|
||||
|
||||
@@ -0,0 +1,249 @@
|
||||
"""Tests for context pressure warnings (user-facing, not injected into messages).
|
||||
|
||||
Covers:
|
||||
- Display formatting (CLI and gateway variants)
|
||||
- Flag tracking and threshold logic on AIAgent
|
||||
- Flag reset after compression
|
||||
- status_callback invocation
|
||||
"""
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.display import format_context_pressure, format_context_pressure_gateway
|
||||
from run_agent import AIAgent
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Display formatting tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFormatContextPressure:
|
||||
"""CLI context pressure display (agent/display.py).
|
||||
|
||||
The bar shows progress toward the compaction threshold, not the
|
||||
raw context window. 60% = 60% of the way to compaction.
|
||||
"""
|
||||
|
||||
def test_60_percent_uses_info_icon(self):
|
||||
line = format_context_pressure(0.60, 100_000, 0.50)
|
||||
assert "◐" in line
|
||||
assert "60% to compaction" in line
|
||||
|
||||
def test_85_percent_uses_warning_icon(self):
|
||||
line = format_context_pressure(0.85, 100_000, 0.50)
|
||||
assert "⚠" in line
|
||||
assert "85% to compaction" in line
|
||||
|
||||
def test_bar_length_scales_with_progress(self):
|
||||
line_60 = format_context_pressure(0.60, 100_000, 0.50)
|
||||
line_85 = format_context_pressure(0.85, 100_000, 0.50)
|
||||
assert line_85.count("▰") > line_60.count("▰")
|
||||
|
||||
def test_shows_threshold_tokens(self):
|
||||
line = format_context_pressure(0.60, 100_000, 0.50)
|
||||
assert "100k" in line
|
||||
|
||||
def test_small_threshold(self):
|
||||
line = format_context_pressure(0.60, 500, 0.50)
|
||||
assert "500" in line
|
||||
|
||||
def test_shows_threshold_percent(self):
|
||||
line = format_context_pressure(0.85, 100_000, 0.50)
|
||||
assert "50%" in line # threshold percent shown
|
||||
|
||||
def test_imminent_hint_at_85(self):
|
||||
line = format_context_pressure(0.85, 100_000, 0.50)
|
||||
assert "compaction imminent" in line
|
||||
|
||||
def test_approaching_hint_below_85(self):
|
||||
line = format_context_pressure(0.60, 100_000, 0.80)
|
||||
assert "approaching compaction" in line
|
||||
|
||||
def test_no_compaction_when_disabled(self):
|
||||
line = format_context_pressure(0.85, 100_000, 0.50, compression_enabled=False)
|
||||
assert "no auto-compaction" in line
|
||||
|
||||
def test_returns_string(self):
|
||||
result = format_context_pressure(0.65, 128_000, 0.50)
|
||||
assert isinstance(result, str)
|
||||
|
||||
def test_over_100_percent_capped(self):
|
||||
"""Progress > 1.0 should not break the bar."""
|
||||
line = format_context_pressure(1.05, 100_000, 0.50)
|
||||
assert "▰" in line
|
||||
assert line.count("▰") == 20
|
||||
|
||||
|
||||
class TestFormatContextPressureGateway:
|
||||
"""Gateway (plain text) context pressure display."""
|
||||
|
||||
def test_60_percent_informational(self):
|
||||
msg = format_context_pressure_gateway(0.60, 0.50)
|
||||
assert "60% to compaction" in msg
|
||||
assert "50%" in msg # threshold shown
|
||||
|
||||
def test_85_percent_warning(self):
|
||||
msg = format_context_pressure_gateway(0.85, 0.50)
|
||||
assert "85% to compaction" in msg
|
||||
assert "imminent" in msg
|
||||
|
||||
def test_no_compaction_warning(self):
|
||||
msg = format_context_pressure_gateway(0.85, 0.50, compression_enabled=False)
|
||||
assert "disabled" in msg
|
||||
|
||||
def test_no_ansi_codes(self):
|
||||
msg = format_context_pressure_gateway(0.85, 0.50)
|
||||
assert "\033[" not in msg
|
||||
|
||||
def test_has_progress_bar(self):
|
||||
msg = format_context_pressure_gateway(0.85, 0.50)
|
||||
assert "▰" in msg
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AIAgent context pressure flag tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_tool_defs(*names):
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": n,
|
||||
"description": f"{n} tool",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
for n in names
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def agent():
|
||||
"""Minimal AIAgent with mocked internals."""
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
a = AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
a.client = MagicMock()
|
||||
return a
|
||||
|
||||
|
||||
class TestContextPressureFlags:
|
||||
"""Context pressure warning flag tracking on AIAgent."""
|
||||
|
||||
def test_flags_initialized_false(self, agent):
|
||||
assert agent._context_50_warned is False
|
||||
assert agent._context_70_warned is False
|
||||
|
||||
def test_emit_calls_status_callback(self, agent):
|
||||
"""status_callback should be invoked with event type and message."""
|
||||
cb = MagicMock()
|
||||
agent.status_callback = cb
|
||||
|
||||
compressor = MagicMock()
|
||||
compressor.context_length = 200_000
|
||||
compressor.threshold_tokens = 100_000 # 50%
|
||||
|
||||
agent._emit_context_pressure(0.85, compressor)
|
||||
|
||||
cb.assert_called_once()
|
||||
args = cb.call_args[0]
|
||||
assert args[0] == "context_pressure"
|
||||
assert "85% to compaction" in args[1]
|
||||
|
||||
def test_emit_no_callback_no_crash(self, agent):
|
||||
"""No status_callback set — should not crash."""
|
||||
agent.status_callback = None
|
||||
|
||||
compressor = MagicMock()
|
||||
compressor.context_length = 200_000
|
||||
compressor.threshold_tokens = 100_000
|
||||
|
||||
# Should not raise
|
||||
agent._emit_context_pressure(0.60, compressor)
|
||||
|
||||
def test_emit_prints_for_cli_platform(self, agent, capsys):
|
||||
"""CLI platform should always print context pressure, even in quiet_mode."""
|
||||
agent.quiet_mode = True
|
||||
agent.platform = "cli"
|
||||
agent.status_callback = None
|
||||
|
||||
compressor = MagicMock()
|
||||
compressor.context_length = 200_000
|
||||
compressor.threshold_tokens = 100_000
|
||||
|
||||
agent._emit_context_pressure(0.85, compressor)
|
||||
captured = capsys.readouterr()
|
||||
assert "▰" in captured.out
|
||||
assert "to compaction" in captured.out
|
||||
|
||||
def test_emit_skips_print_for_gateway_platform(self, agent, capsys):
|
||||
"""Gateway platforms get the callback, not CLI print."""
|
||||
agent.platform = "telegram"
|
||||
agent.status_callback = None
|
||||
|
||||
compressor = MagicMock()
|
||||
compressor.context_length = 200_000
|
||||
compressor.threshold_tokens = 100_000
|
||||
|
||||
agent._emit_context_pressure(0.85, compressor)
|
||||
captured = capsys.readouterr()
|
||||
assert "▰" not in captured.out
|
||||
|
||||
def test_flags_reset_on_compression(self, agent):
|
||||
"""After _compress_context, context pressure flags should reset."""
|
||||
agent._context_50_warned = True
|
||||
agent._context_70_warned = True
|
||||
agent.compression_enabled = True
|
||||
|
||||
# Mock the compressor's compress method to return minimal valid output
|
||||
agent.context_compressor = MagicMock()
|
||||
agent.context_compressor.compress.return_value = [
|
||||
{"role": "user", "content": "Summary of conversation so far."}
|
||||
]
|
||||
agent.context_compressor.context_length = 200_000
|
||||
agent.context_compressor.threshold_tokens = 100_000
|
||||
|
||||
# Mock _todo_store
|
||||
agent._todo_store = MagicMock()
|
||||
agent._todo_store.format_for_injection.return_value = None
|
||||
|
||||
# Mock _build_system_prompt
|
||||
agent._build_system_prompt = MagicMock(return_value="system prompt")
|
||||
agent._cached_system_prompt = "old system prompt"
|
||||
agent._session_db = None
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi there"},
|
||||
]
|
||||
agent._compress_context(messages, "system prompt")
|
||||
|
||||
assert agent._context_50_warned is False
|
||||
assert agent._context_70_warned is False
|
||||
|
||||
def test_emit_callback_error_handled(self, agent):
|
||||
"""If status_callback raises, it should be caught gracefully."""
|
||||
cb = MagicMock(side_effect=RuntimeError("callback boom"))
|
||||
agent.status_callback = cb
|
||||
|
||||
compressor = MagicMock()
|
||||
compressor.context_length = 200_000
|
||||
compressor.threshold_tokens = 100_000
|
||||
|
||||
# Should not raise
|
||||
agent._emit_context_pressure(0.85, compressor)
|
||||
@@ -0,0 +1,493 @@
|
||||
"""Tests for _query_local_context_length and the local server fallback in
|
||||
get_model_context_length.
|
||||
|
||||
All tests use synthetic inputs — no filesystem or live server required.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _query_local_context_length — unit tests with mocked httpx
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestQueryLocalContextLengthOllama:
|
||||
"""_query_local_context_length with server_type == 'ollama'."""
|
||||
|
||||
def _make_resp(self, status_code, body):
|
||||
resp = MagicMock()
|
||||
resp.status_code = status_code
|
||||
resp.json.return_value = body
|
||||
return resp
|
||||
|
||||
def test_ollama_model_info_context_length(self):
|
||||
"""Reads context length from model_info dict in /api/show response."""
|
||||
from agent.model_metadata import _query_local_context_length
|
||||
|
||||
show_resp = self._make_resp(200, {
|
||||
"model_info": {"llama.context_length": 131072}
|
||||
})
|
||||
models_resp = self._make_resp(404, {})
|
||||
|
||||
client_mock = MagicMock()
|
||||
client_mock.__enter__ = lambda s: client_mock
|
||||
client_mock.__exit__ = MagicMock(return_value=False)
|
||||
client_mock.post.return_value = show_resp
|
||||
client_mock.get.return_value = models_resp
|
||||
|
||||
with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"), \
|
||||
patch("httpx.Client", return_value=client_mock):
|
||||
result = _query_local_context_length("omnicoder-9b", "http://localhost:11434/v1")
|
||||
|
||||
assert result == 131072
|
||||
|
||||
def test_ollama_parameters_num_ctx(self):
|
||||
"""Falls back to num_ctx in parameters string when model_info lacks context_length."""
|
||||
from agent.model_metadata import _query_local_context_length
|
||||
|
||||
show_resp = self._make_resp(200, {
|
||||
"model_info": {},
|
||||
"parameters": "num_ctx 32768\ntemperature 0.7\n"
|
||||
})
|
||||
models_resp = self._make_resp(404, {})
|
||||
|
||||
client_mock = MagicMock()
|
||||
client_mock.__enter__ = lambda s: client_mock
|
||||
client_mock.__exit__ = MagicMock(return_value=False)
|
||||
client_mock.post.return_value = show_resp
|
||||
client_mock.get.return_value = models_resp
|
||||
|
||||
with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"), \
|
||||
patch("httpx.Client", return_value=client_mock):
|
||||
result = _query_local_context_length("some-model", "http://localhost:11434/v1")
|
||||
|
||||
assert result == 32768
|
||||
|
||||
def test_ollama_show_404_falls_through(self):
|
||||
"""When /api/show returns 404, falls through to /v1/models/{model}."""
|
||||
from agent.model_metadata import _query_local_context_length
|
||||
|
||||
show_resp = self._make_resp(404, {})
|
||||
model_detail_resp = self._make_resp(200, {"max_model_len": 65536})
|
||||
|
||||
client_mock = MagicMock()
|
||||
client_mock.__enter__ = lambda s: client_mock
|
||||
client_mock.__exit__ = MagicMock(return_value=False)
|
||||
client_mock.post.return_value = show_resp
|
||||
client_mock.get.return_value = model_detail_resp
|
||||
|
||||
with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"), \
|
||||
patch("httpx.Client", return_value=client_mock):
|
||||
result = _query_local_context_length("some-model", "http://localhost:11434/v1")
|
||||
|
||||
assert result == 65536
|
||||
|
||||
|
||||
class TestQueryLocalContextLengthVllm:
|
||||
"""_query_local_context_length with vLLM-style /v1/models/{model} response."""
|
||||
|
||||
def _make_resp(self, status_code, body):
|
||||
resp = MagicMock()
|
||||
resp.status_code = status_code
|
||||
resp.json.return_value = body
|
||||
return resp
|
||||
|
||||
def test_vllm_max_model_len(self):
|
||||
"""Reads max_model_len from /v1/models/{model} response."""
|
||||
from agent.model_metadata import _query_local_context_length
|
||||
|
||||
detail_resp = self._make_resp(200, {"id": "omnicoder-9b", "max_model_len": 100000})
|
||||
list_resp = self._make_resp(404, {})
|
||||
|
||||
client_mock = MagicMock()
|
||||
client_mock.__enter__ = lambda s: client_mock
|
||||
client_mock.__exit__ = MagicMock(return_value=False)
|
||||
client_mock.post.return_value = self._make_resp(404, {})
|
||||
client_mock.get.return_value = detail_resp
|
||||
|
||||
with patch("agent.model_metadata.detect_local_server_type", return_value="vllm"), \
|
||||
patch("httpx.Client", return_value=client_mock):
|
||||
result = _query_local_context_length("omnicoder-9b", "http://localhost:8000/v1")
|
||||
|
||||
assert result == 100000
|
||||
|
||||
def test_vllm_context_length_key(self):
|
||||
"""Reads context_length from /v1/models/{model} response."""
|
||||
from agent.model_metadata import _query_local_context_length
|
||||
|
||||
detail_resp = self._make_resp(200, {"id": "some-model", "context_length": 32768})
|
||||
|
||||
client_mock = MagicMock()
|
||||
client_mock.__enter__ = lambda s: client_mock
|
||||
client_mock.__exit__ = MagicMock(return_value=False)
|
||||
client_mock.post.return_value = self._make_resp(404, {})
|
||||
client_mock.get.return_value = detail_resp
|
||||
|
||||
with patch("agent.model_metadata.detect_local_server_type", return_value="vllm"), \
|
||||
patch("httpx.Client", return_value=client_mock):
|
||||
result = _query_local_context_length("some-model", "http://localhost:8000/v1")
|
||||
|
||||
assert result == 32768
|
||||
|
||||
|
||||
class TestQueryLocalContextLengthModelsList:
|
||||
"""_query_local_context_length: falls back to /v1/models list."""
|
||||
|
||||
def _make_resp(self, status_code, body):
|
||||
resp = MagicMock()
|
||||
resp.status_code = status_code
|
||||
resp.json.return_value = body
|
||||
return resp
|
||||
|
||||
def test_models_list_max_model_len(self):
|
||||
"""Finds context length for model in /v1/models list."""
|
||||
from agent.model_metadata import _query_local_context_length
|
||||
|
||||
detail_resp = self._make_resp(404, {})
|
||||
list_resp = self._make_resp(200, {
|
||||
"data": [
|
||||
{"id": "other-model", "max_model_len": 4096},
|
||||
{"id": "omnicoder-9b", "max_model_len": 131072},
|
||||
]
|
||||
})
|
||||
|
||||
call_count = [0]
|
||||
def side_effect(url, **kwargs):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return detail_resp # /v1/models/omnicoder-9b
|
||||
return list_resp # /v1/models
|
||||
|
||||
client_mock = MagicMock()
|
||||
client_mock.__enter__ = lambda s: client_mock
|
||||
client_mock.__exit__ = MagicMock(return_value=False)
|
||||
client_mock.post.return_value = self._make_resp(404, {})
|
||||
client_mock.get.side_effect = side_effect
|
||||
|
||||
with patch("agent.model_metadata.detect_local_server_type", return_value=None), \
|
||||
patch("httpx.Client", return_value=client_mock):
|
||||
result = _query_local_context_length("omnicoder-9b", "http://localhost:1234")
|
||||
|
||||
assert result == 131072
|
||||
|
||||
def test_models_list_model_not_found_returns_none(self):
|
||||
"""Returns None when model is not in the /v1/models list."""
|
||||
from agent.model_metadata import _query_local_context_length
|
||||
|
||||
detail_resp = self._make_resp(404, {})
|
||||
list_resp = self._make_resp(200, {
|
||||
"data": [{"id": "other-model", "max_model_len": 4096}]
|
||||
})
|
||||
|
||||
call_count = [0]
|
||||
def side_effect(url, **kwargs):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return detail_resp
|
||||
return list_resp
|
||||
|
||||
client_mock = MagicMock()
|
||||
client_mock.__enter__ = lambda s: client_mock
|
||||
client_mock.__exit__ = MagicMock(return_value=False)
|
||||
client_mock.post.return_value = self._make_resp(404, {})
|
||||
client_mock.get.side_effect = side_effect
|
||||
|
||||
with patch("agent.model_metadata.detect_local_server_type", return_value=None), \
|
||||
patch("httpx.Client", return_value=client_mock):
|
||||
result = _query_local_context_length("omnicoder-9b", "http://localhost:1234")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestQueryLocalContextLengthLmStudio:
|
||||
"""_query_local_context_length with LM Studio native /api/v1/models response."""
|
||||
|
||||
def _make_resp(self, status_code, body):
|
||||
resp = MagicMock()
|
||||
resp.status_code = status_code
|
||||
resp.json.return_value = body
|
||||
return resp
|
||||
|
||||
def _make_client(self, native_resp, detail_resp, list_resp):
|
||||
"""Build a mock httpx.Client with sequenced GET responses."""
|
||||
client_mock = MagicMock()
|
||||
client_mock.__enter__ = lambda s: client_mock
|
||||
client_mock.__exit__ = MagicMock(return_value=False)
|
||||
client_mock.post.return_value = self._make_resp(404, {})
|
||||
|
||||
responses = [native_resp, detail_resp, list_resp]
|
||||
call_idx = [0]
|
||||
|
||||
def get_side_effect(url, **kwargs):
|
||||
idx = call_idx[0]
|
||||
call_idx[0] += 1
|
||||
if idx < len(responses):
|
||||
return responses[idx]
|
||||
return self._make_resp(404, {})
|
||||
|
||||
client_mock.get.side_effect = get_side_effect
|
||||
return client_mock
|
||||
|
||||
def test_lmstudio_exact_key_match(self):
|
||||
"""Reads max_context_length when key matches exactly."""
|
||||
from agent.model_metadata import _query_local_context_length
|
||||
|
||||
native_resp = self._make_resp(200, {
|
||||
"models": [
|
||||
{"key": "nvidia/nvidia-nemotron-super-49b-v1", "id": "nvidia/nvidia-nemotron-super-49b-v1",
|
||||
"max_context_length": 131072},
|
||||
]
|
||||
})
|
||||
client_mock = self._make_client(
|
||||
native_resp,
|
||||
self._make_resp(404, {}),
|
||||
self._make_resp(404, {}),
|
||||
)
|
||||
|
||||
with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \
|
||||
patch("httpx.Client", return_value=client_mock):
|
||||
result = _query_local_context_length(
|
||||
"nvidia/nvidia-nemotron-super-49b-v1", "http://192.168.1.22:1234/v1"
|
||||
)
|
||||
|
||||
assert result == 131072
|
||||
|
||||
def test_lmstudio_slug_only_matches_key_with_publisher_prefix(self):
|
||||
"""Fuzzy match: bare model slug matches key that includes publisher prefix.
|
||||
|
||||
When the user configures the model as "local:nvidia-nemotron-super-49b-v1"
|
||||
(slug only, no publisher), but LM Studio's native API stores it as
|
||||
"nvidia/nvidia-nemotron-super-49b-v1", the lookup must still succeed.
|
||||
"""
|
||||
from agent.model_metadata import _query_local_context_length
|
||||
|
||||
native_resp = self._make_resp(200, {
|
||||
"models": [
|
||||
{"key": "nvidia/nvidia-nemotron-super-49b-v1",
|
||||
"id": "nvidia/nvidia-nemotron-super-49b-v1",
|
||||
"max_context_length": 131072},
|
||||
]
|
||||
})
|
||||
client_mock = self._make_client(
|
||||
native_resp,
|
||||
self._make_resp(404, {}),
|
||||
self._make_resp(404, {}),
|
||||
)
|
||||
|
||||
with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \
|
||||
patch("httpx.Client", return_value=client_mock):
|
||||
# Model passed in is just the slug after stripping "local:" prefix
|
||||
result = _query_local_context_length(
|
||||
"nvidia-nemotron-super-49b-v1", "http://192.168.1.22:1234/v1"
|
||||
)
|
||||
|
||||
assert result == 131072
|
||||
|
||||
def test_lmstudio_v1_models_list_slug_fuzzy_match(self):
|
||||
"""Fuzzy match also works for /v1/models list when exact match fails.
|
||||
|
||||
LM Studio's OpenAI-compat /v1/models returns id like
|
||||
"nvidia/nvidia-nemotron-super-49b-v1" — must match bare slug.
|
||||
"""
|
||||
from agent.model_metadata import _query_local_context_length
|
||||
|
||||
# native /api/v1/models: no match
|
||||
native_resp = self._make_resp(404, {})
|
||||
# /v1/models/{model}: no match
|
||||
detail_resp = self._make_resp(404, {})
|
||||
# /v1/models list: model found with publisher prefix, includes context_length
|
||||
list_resp = self._make_resp(200, {
|
||||
"data": [
|
||||
{"id": "nvidia/nvidia-nemotron-super-49b-v1", "context_length": 131072},
|
||||
]
|
||||
})
|
||||
client_mock = self._make_client(native_resp, detail_resp, list_resp)
|
||||
|
||||
with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \
|
||||
patch("httpx.Client", return_value=client_mock):
|
||||
result = _query_local_context_length(
|
||||
"nvidia-nemotron-super-49b-v1", "http://192.168.1.22:1234/v1"
|
||||
)
|
||||
|
||||
assert result == 131072
|
||||
|
||||
def test_lmstudio_loaded_instances_context_length(self):
|
||||
"""Reads active context_length from loaded_instances when max_context_length absent."""
|
||||
from agent.model_metadata import _query_local_context_length
|
||||
|
||||
native_resp = self._make_resp(200, {
|
||||
"models": [
|
||||
{
|
||||
"key": "nvidia/nvidia-nemotron-super-49b-v1",
|
||||
"id": "nvidia/nvidia-nemotron-super-49b-v1",
|
||||
"loaded_instances": [
|
||||
{"config": {"context_length": 65536}},
|
||||
],
|
||||
},
|
||||
]
|
||||
})
|
||||
client_mock = self._make_client(
|
||||
native_resp,
|
||||
self._make_resp(404, {}),
|
||||
self._make_resp(404, {}),
|
||||
)
|
||||
|
||||
with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \
|
||||
patch("httpx.Client", return_value=client_mock):
|
||||
result = _query_local_context_length(
|
||||
"nvidia-nemotron-super-49b-v1", "http://192.168.1.22:1234/v1"
|
||||
)
|
||||
|
||||
assert result == 65536
|
||||
|
||||
def test_lmstudio_loaded_instance_beats_max_context_length(self):
|
||||
"""loaded_instances context_length takes priority over max_context_length.
|
||||
|
||||
LM Studio may show max_context_length=1_048_576 (theoretical model max)
|
||||
while the actual loaded context is 122_651 (runtime setting). The loaded
|
||||
value is the real constraint and must be preferred.
|
||||
"""
|
||||
from agent.model_metadata import _query_local_context_length
|
||||
|
||||
native_resp = self._make_resp(200, {
|
||||
"models": [
|
||||
{
|
||||
"key": "nvidia/nvidia-nemotron-3-nano-4b",
|
||||
"id": "nvidia/nvidia-nemotron-3-nano-4b",
|
||||
"max_context_length": 1_048_576,
|
||||
"loaded_instances": [
|
||||
{"config": {"context_length": 122_651}},
|
||||
],
|
||||
},
|
||||
]
|
||||
})
|
||||
client_mock = self._make_client(
|
||||
native_resp,
|
||||
self._make_resp(404, {}),
|
||||
self._make_resp(404, {}),
|
||||
)
|
||||
|
||||
with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \
|
||||
patch("httpx.Client", return_value=client_mock):
|
||||
result = _query_local_context_length(
|
||||
"nvidia-nemotron-3-nano-4b", "http://192.168.1.22:1234/v1"
|
||||
)
|
||||
|
||||
assert result == 122_651, (
|
||||
f"Expected loaded instance context (122651) but got {result}. "
|
||||
"max_context_length (1048576) must not win over loaded_instances."
|
||||
)
|
||||
|
||||
|
||||
class TestQueryLocalContextLengthNetworkError:
|
||||
"""_query_local_context_length handles network failures gracefully."""
|
||||
|
||||
def test_connection_error_returns_none(self):
|
||||
"""Returns None when the server is unreachable."""
|
||||
from agent.model_metadata import _query_local_context_length
|
||||
|
||||
client_mock = MagicMock()
|
||||
client_mock.__enter__ = lambda s: client_mock
|
||||
client_mock.__exit__ = MagicMock(return_value=False)
|
||||
client_mock.post.side_effect = Exception("Connection refused")
|
||||
client_mock.get.side_effect = Exception("Connection refused")
|
||||
|
||||
with patch("agent.model_metadata.detect_local_server_type", return_value=None), \
|
||||
patch("httpx.Client", return_value=client_mock):
|
||||
result = _query_local_context_length("omnicoder-9b", "http://localhost:11434/v1")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_model_context_length — integration-style tests with mocked helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetModelContextLengthLocalFallback:
|
||||
"""get_model_context_length uses local server query before falling back to 2M."""
|
||||
|
||||
def test_local_endpoint_unknown_model_queries_server(self):
|
||||
"""Unknown model on local endpoint gets ctx from server, not 2M default."""
|
||||
from agent.model_metadata import get_model_context_length
|
||||
|
||||
with patch("agent.model_metadata.get_cached_context_length", return_value=None), \
|
||||
patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \
|
||||
patch("agent.model_metadata.fetch_model_metadata", return_value={}), \
|
||||
patch("agent.model_metadata.is_local_endpoint", return_value=True), \
|
||||
patch("agent.model_metadata._query_local_context_length", return_value=131072), \
|
||||
patch("agent.model_metadata.save_context_length") as mock_save:
|
||||
result = get_model_context_length("omnicoder-9b", "http://localhost:11434/v1")
|
||||
|
||||
assert result == 131072
|
||||
|
||||
def test_local_endpoint_unknown_model_result_is_cached(self):
|
||||
"""Context length returned from local server is persisted to cache."""
|
||||
from agent.model_metadata import get_model_context_length
|
||||
|
||||
with patch("agent.model_metadata.get_cached_context_length", return_value=None), \
|
||||
patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \
|
||||
patch("agent.model_metadata.fetch_model_metadata", return_value={}), \
|
||||
patch("agent.model_metadata.is_local_endpoint", return_value=True), \
|
||||
patch("agent.model_metadata._query_local_context_length", return_value=131072), \
|
||||
patch("agent.model_metadata.save_context_length") as mock_save:
|
||||
get_model_context_length("omnicoder-9b", "http://localhost:11434/v1")
|
||||
|
||||
mock_save.assert_called_once_with("omnicoder-9b", "http://localhost:11434/v1", 131072)
|
||||
|
||||
def test_local_endpoint_server_returns_none_falls_back_to_2m(self):
|
||||
"""When local server returns None, still falls back to 2M probe tier."""
|
||||
from agent.model_metadata import get_model_context_length, CONTEXT_PROBE_TIERS
|
||||
|
||||
with patch("agent.model_metadata.get_cached_context_length", return_value=None), \
|
||||
patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \
|
||||
patch("agent.model_metadata.fetch_model_metadata", return_value={}), \
|
||||
patch("agent.model_metadata.is_local_endpoint", return_value=True), \
|
||||
patch("agent.model_metadata._query_local_context_length", return_value=None):
|
||||
result = get_model_context_length("omnicoder-9b", "http://localhost:11434/v1")
|
||||
|
||||
assert result == CONTEXT_PROBE_TIERS[0]
|
||||
|
||||
def test_non_local_endpoint_does_not_query_local_server(self):
|
||||
"""For non-local endpoints, _query_local_context_length is not called."""
|
||||
from agent.model_metadata import get_model_context_length, CONTEXT_PROBE_TIERS
|
||||
|
||||
with patch("agent.model_metadata.get_cached_context_length", return_value=None), \
|
||||
patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \
|
||||
patch("agent.model_metadata.fetch_model_metadata", return_value={}), \
|
||||
patch("agent.model_metadata.is_local_endpoint", return_value=False), \
|
||||
patch("agent.model_metadata._query_local_context_length") as mock_query:
|
||||
result = get_model_context_length(
|
||||
"unknown-model", "https://some-cloud-api.example.com/v1"
|
||||
)
|
||||
|
||||
mock_query.assert_not_called()
|
||||
|
||||
def test_cached_result_skips_local_query(self):
|
||||
"""Cached context length is returned without querying the local server."""
|
||||
from agent.model_metadata import get_model_context_length
|
||||
|
||||
with patch("agent.model_metadata.get_cached_context_length", return_value=65536), \
|
||||
patch("agent.model_metadata._query_local_context_length") as mock_query:
|
||||
result = get_model_context_length("omnicoder-9b", "http://localhost:11434/v1")
|
||||
|
||||
assert result == 65536
|
||||
mock_query.assert_not_called()
|
||||
|
||||
def test_no_base_url_does_not_query_local_server(self):
|
||||
"""When base_url is empty, local server is not queried."""
|
||||
from agent.model_metadata import get_model_context_length
|
||||
|
||||
with patch("agent.model_metadata.get_cached_context_length", return_value=None), \
|
||||
patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \
|
||||
patch("agent.model_metadata.fetch_model_metadata", return_value={}), \
|
||||
patch("agent.model_metadata._query_local_context_length") as mock_query:
|
||||
result = get_model_context_length("unknown-xyz-model", "")
|
||||
|
||||
mock_query.assert_not_called()
|
||||
@@ -0,0 +1,307 @@
|
||||
"""Regression tests for the _run_async() event-loop lifecycle.
|
||||
|
||||
These tests verify the fix for GitHub issue #2104:
|
||||
"Event loop is closed" after vision_analyze used as first call in session.
|
||||
|
||||
Root cause: asyncio.run() creates and *closes* a fresh event loop on every
|
||||
call. Cached httpx/AsyncOpenAI clients that were bound to the now-dead loop
|
||||
would crash with RuntimeError("Event loop is closed") when garbage-collected.
|
||||
|
||||
The fix replaces asyncio.run() with a persistent event loop in _run_async().
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import threading
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _get_current_loop():
|
||||
"""Return the running event loop from inside a coroutine."""
|
||||
return asyncio.get_event_loop()
|
||||
|
||||
|
||||
async def _create_and_return_transport():
|
||||
"""Simulate an async client creating a transport on the current loop.
|
||||
|
||||
Returns a simple asyncio.Future bound to the running loop so we can
|
||||
later check whether the loop is still alive.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
fut = loop.create_future()
|
||||
fut.set_result("ok")
|
||||
return loop, fut
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRunAsyncLoopLifecycle:
|
||||
"""Verify _run_async() keeps the event loop alive after returning."""
|
||||
|
||||
def test_loop_not_closed_after_run_async(self):
|
||||
"""The loop used by _run_async must still be open after the call."""
|
||||
from model_tools import _run_async
|
||||
|
||||
loop = _run_async(_get_current_loop())
|
||||
|
||||
assert not loop.is_closed(), (
|
||||
"_run_async() closed the event loop — cached async clients will "
|
||||
"crash with 'Event loop is closed' on GC (issue #2104)"
|
||||
)
|
||||
|
||||
def test_same_loop_reused_across_calls(self):
|
||||
"""Consecutive _run_async calls should reuse the same loop."""
|
||||
from model_tools import _run_async
|
||||
|
||||
loop1 = _run_async(_get_current_loop())
|
||||
loop2 = _run_async(_get_current_loop())
|
||||
|
||||
assert loop1 is loop2, (
|
||||
"_run_async() created a new loop on the second call — cached "
|
||||
"async clients from the first call would be orphaned"
|
||||
)
|
||||
|
||||
def test_cached_transport_survives_between_calls(self):
|
||||
"""A transport/future created in call 1 must be valid in call 2."""
|
||||
from model_tools import _run_async
|
||||
|
||||
loop, fut = _run_async(_create_and_return_transport())
|
||||
|
||||
assert not loop.is_closed()
|
||||
assert fut.result() == "ok"
|
||||
|
||||
loop2 = _run_async(_get_current_loop())
|
||||
assert loop2 is loop, "Loop changed between calls"
|
||||
assert not loop.is_closed(), "Loop closed before second call"
|
||||
|
||||
|
||||
class TestRunAsyncWorkerThread:
|
||||
"""Verify worker threads get persistent per-thread loops (delegate_task fix)."""
|
||||
|
||||
def test_worker_thread_loop_not_closed(self):
|
||||
"""A worker thread's loop must stay open after _run_async returns,
|
||||
so cached httpx/AsyncOpenAI clients don't crash on GC."""
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from model_tools import _run_async
|
||||
|
||||
def _run_on_worker():
|
||||
loop = _run_async(_get_current_loop())
|
||||
still_open = not loop.is_closed()
|
||||
return loop, still_open
|
||||
|
||||
with ThreadPoolExecutor(max_workers=1) as pool:
|
||||
loop, still_open = pool.submit(_run_on_worker).result()
|
||||
|
||||
assert still_open, (
|
||||
"Worker thread's event loop was closed after _run_async — "
|
||||
"cached async clients will crash with 'Event loop is closed'"
|
||||
)
|
||||
|
||||
def test_worker_thread_reuses_loop_across_calls(self):
|
||||
"""Multiple _run_async calls on the same worker thread should
|
||||
reuse the same persistent loop (not create-and-destroy each time)."""
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from model_tools import _run_async
|
||||
|
||||
def _run_twice_on_worker():
|
||||
loop1 = _run_async(_get_current_loop())
|
||||
loop2 = _run_async(_get_current_loop())
|
||||
return loop1, loop2
|
||||
|
||||
with ThreadPoolExecutor(max_workers=1) as pool:
|
||||
loop1, loop2 = pool.submit(_run_twice_on_worker).result()
|
||||
|
||||
assert loop1 is loop2, (
|
||||
"Worker thread created different loops for consecutive calls — "
|
||||
"cached clients from the first call would be orphaned"
|
||||
)
|
||||
assert not loop1.is_closed()
|
||||
|
||||
def test_parallel_workers_get_separate_loops(self):
|
||||
"""Different worker threads must get their own loops to avoid
|
||||
contention (the original reason for the worker-thread branch)."""
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from model_tools import _run_async
|
||||
|
||||
barrier = threading.Barrier(3, timeout=5)
|
||||
|
||||
def _get_loop_id():
|
||||
# Use a barrier to force all 3 threads to be alive simultaneously,
|
||||
# ensuring the ThreadPoolExecutor actually uses 3 distinct threads.
|
||||
loop = _run_async(_get_current_loop())
|
||||
barrier.wait()
|
||||
return id(loop), not loop.is_closed(), threading.current_thread().ident
|
||||
|
||||
with ThreadPoolExecutor(max_workers=3) as pool:
|
||||
futures = [pool.submit(_get_loop_id) for _ in range(3)]
|
||||
results = [f.result() for f in as_completed(futures)]
|
||||
|
||||
loop_ids = {r[0] for r in results}
|
||||
thread_ids = {r[2] for r in results}
|
||||
all_open = all(r[1] for r in results)
|
||||
|
||||
assert all_open, "At least one worker thread's loop was closed"
|
||||
# The barrier guarantees 3 distinct threads were used
|
||||
assert len(thread_ids) == 3, f"Expected 3 threads, got {len(thread_ids)}"
|
||||
# Each thread should have its own loop
|
||||
assert len(loop_ids) == 3, (
|
||||
f"Expected 3 distinct loops for 3 parallel workers, "
|
||||
f"got {len(loop_ids)} — workers may be contending on a shared loop"
|
||||
)
|
||||
|
||||
def test_worker_loop_separate_from_main_loop(self):
|
||||
"""Worker thread loops must be different from the main thread's
|
||||
persistent loop to avoid cross-thread contention."""
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from model_tools import _run_async, _get_tool_loop
|
||||
|
||||
main_loop = _get_tool_loop()
|
||||
|
||||
def _get_worker_loop_id():
|
||||
loop = _run_async(_get_current_loop())
|
||||
return id(loop)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=1) as pool:
|
||||
worker_loop_id = pool.submit(_get_worker_loop_id).result()
|
||||
|
||||
assert worker_loop_id != id(main_loop), (
|
||||
"Worker thread used the main thread's loop — this would cause "
|
||||
"cross-thread contention on the event loop"
|
||||
)
|
||||
|
||||
|
||||
class TestRunAsyncWithRunningLoop:
|
||||
"""When a loop is already running, _run_async falls back to a thread."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_from_async_context(self):
|
||||
"""_run_async should still work when called from inside an
|
||||
already-running event loop (gateway / Atropos path)."""
|
||||
from model_tools import _run_async
|
||||
|
||||
async def _simple():
|
||||
return 42
|
||||
|
||||
result = await asyncio.get_event_loop().run_in_executor(
|
||||
None, _run_async, _simple()
|
||||
)
|
||||
assert result == 42
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: full vision_analyze dispatch chain
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _mock_vision_response():
|
||||
"""Build a fake LLM response matching async_call_llm's return shape."""
|
||||
message = SimpleNamespace(content="A cat sitting on a chair.")
|
||||
choice = SimpleNamespace(index=0, message=message, finish_reason="stop")
|
||||
return SimpleNamespace(choices=[choice], model="test/vision", usage=None)
|
||||
|
||||
|
||||
class TestVisionDispatchLoopSafety:
|
||||
"""Simulate the full registry.dispatch('vision_analyze') chain and
|
||||
verify the event loop stays alive afterwards — the exact scenario
|
||||
from issue #2104."""
|
||||
|
||||
def test_vision_dispatch_keeps_loop_alive(self, tmp_path):
|
||||
"""After dispatching vision_analyze via the registry, the event
|
||||
loop must remain open so cached async clients don't crash on GC."""
|
||||
from model_tools import _run_async, _get_tool_loop
|
||||
from tools.registry import registry
|
||||
|
||||
fake_response = _mock_vision_response()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"tools.vision_tools.async_call_llm",
|
||||
new_callable=AsyncMock,
|
||||
return_value=fake_response,
|
||||
),
|
||||
patch(
|
||||
"tools.vision_tools._download_image",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=lambda url, dest, **kw: _write_fake_image(dest),
|
||||
),
|
||||
patch(
|
||||
"tools.vision_tools._validate_image_url",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"tools.vision_tools._image_to_base64_data_url",
|
||||
return_value="data:image/jpeg;base64,abc",
|
||||
),
|
||||
):
|
||||
result_json = registry.dispatch(
|
||||
"vision_analyze",
|
||||
{"image_url": "https://example.com/cat.png", "question": "What is this?"},
|
||||
)
|
||||
|
||||
result = json.loads(result_json)
|
||||
assert result.get("success") is True, f"dispatch failed: {result}"
|
||||
assert "cat" in result.get("analysis", "").lower()
|
||||
|
||||
loop = _get_tool_loop()
|
||||
assert not loop.is_closed(), (
|
||||
"Event loop closed after vision_analyze dispatch — cached async "
|
||||
"clients will crash with 'Event loop is closed' (issue #2104)"
|
||||
)
|
||||
|
||||
def test_two_consecutive_vision_dispatches(self, tmp_path):
|
||||
"""Two back-to-back vision_analyze dispatches must both succeed
|
||||
and share the same loop (simulates 'first call fails, second
|
||||
works' from the issue report)."""
|
||||
from model_tools import _get_tool_loop
|
||||
from tools.registry import registry
|
||||
|
||||
fake_response = _mock_vision_response()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"tools.vision_tools.async_call_llm",
|
||||
new_callable=AsyncMock,
|
||||
return_value=fake_response,
|
||||
),
|
||||
patch(
|
||||
"tools.vision_tools._download_image",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=lambda url, dest, **kw: _write_fake_image(dest),
|
||||
),
|
||||
patch(
|
||||
"tools.vision_tools._validate_image_url",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"tools.vision_tools._image_to_base64_data_url",
|
||||
return_value="data:image/jpeg;base64,abc",
|
||||
),
|
||||
):
|
||||
args = {"image_url": "https://example.com/cat.png", "question": "Describe"}
|
||||
|
||||
r1 = json.loads(registry.dispatch("vision_analyze", args))
|
||||
loop_after_first = _get_tool_loop()
|
||||
|
||||
r2 = json.loads(registry.dispatch("vision_analyze", args))
|
||||
loop_after_second = _get_tool_loop()
|
||||
|
||||
assert r1.get("success") is True
|
||||
assert r2.get("success") is True
|
||||
assert loop_after_first is loop_after_second, "Loop changed between dispatches"
|
||||
assert not loop_after_second.is_closed()
|
||||
|
||||
|
||||
def _write_fake_image(dest):
|
||||
"""Write minimal bytes so vision_analyze_tool thinks download succeeded."""
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
dest.write_bytes(b"\xff\xd8\xff" + b"\x00" * 16)
|
||||
return dest
|
||||
@@ -830,3 +830,212 @@ def test_dump_api_request_debug_uses_chat_completions_url(monkeypatch, tmp_path)
|
||||
|
||||
payload = json.loads(dump_file.read_text())
|
||||
assert payload["request"]["url"] == "http://127.0.0.1:9208/v1/chat/completions"
|
||||
|
||||
|
||||
# --- Reasoning-only response tests (fix for empty content retry loop) ---
|
||||
|
||||
|
||||
def _codex_reasoning_only_response(*, encrypted_content="enc_abc123", summary_text="Thinking..."):
|
||||
"""Codex response containing only reasoning items — no message text, no tool calls."""
|
||||
return SimpleNamespace(
|
||||
output=[
|
||||
SimpleNamespace(
|
||||
type="reasoning",
|
||||
id="rs_001",
|
||||
encrypted_content=encrypted_content,
|
||||
summary=[SimpleNamespace(type="summary_text", text=summary_text)],
|
||||
status="completed",
|
||||
)
|
||||
],
|
||||
usage=SimpleNamespace(input_tokens=50, output_tokens=100, total_tokens=150),
|
||||
status="completed",
|
||||
model="gpt-5-codex",
|
||||
)
|
||||
|
||||
|
||||
def test_normalize_codex_response_marks_reasoning_only_as_incomplete(monkeypatch):
|
||||
"""A response with only reasoning items and no content should be 'incomplete', not 'stop'.
|
||||
|
||||
Without this fix, reasoning-only responses get finish_reason='stop' which
|
||||
sends them into the empty-content retry loop (3 retries then failure).
|
||||
"""
|
||||
agent = _build_agent(monkeypatch)
|
||||
assistant_message, finish_reason = agent._normalize_codex_response(
|
||||
_codex_reasoning_only_response()
|
||||
)
|
||||
|
||||
assert finish_reason == "incomplete"
|
||||
assert assistant_message.content == ""
|
||||
assert assistant_message.codex_reasoning_items is not None
|
||||
assert len(assistant_message.codex_reasoning_items) == 1
|
||||
assert assistant_message.codex_reasoning_items[0]["encrypted_content"] == "enc_abc123"
|
||||
|
||||
|
||||
def test_normalize_codex_response_reasoning_with_content_is_stop(monkeypatch):
|
||||
"""If a response has both reasoning and message content, it should still be 'stop'."""
|
||||
agent = _build_agent(monkeypatch)
|
||||
response = SimpleNamespace(
|
||||
output=[
|
||||
SimpleNamespace(
|
||||
type="reasoning",
|
||||
id="rs_001",
|
||||
encrypted_content="enc_xyz",
|
||||
summary=[SimpleNamespace(type="summary_text", text="Thinking...")],
|
||||
status="completed",
|
||||
),
|
||||
SimpleNamespace(
|
||||
type="message",
|
||||
content=[SimpleNamespace(type="output_text", text="Here is the answer.")],
|
||||
status="completed",
|
||||
),
|
||||
],
|
||||
usage=SimpleNamespace(input_tokens=50, output_tokens=100, total_tokens=150),
|
||||
status="completed",
|
||||
model="gpt-5-codex",
|
||||
)
|
||||
assistant_message, finish_reason = agent._normalize_codex_response(response)
|
||||
|
||||
assert finish_reason == "stop"
|
||||
assert "Here is the answer" in assistant_message.content
|
||||
|
||||
|
||||
def test_run_conversation_codex_continues_after_reasoning_only_response(monkeypatch):
|
||||
"""End-to-end: reasoning-only → final message should succeed, not hit retry loop."""
|
||||
agent = _build_agent(monkeypatch)
|
||||
responses = [
|
||||
_codex_reasoning_only_response(),
|
||||
_codex_message_response("The final answer is 42."),
|
||||
]
|
||||
monkeypatch.setattr(agent, "_interruptible_api_call", lambda api_kwargs: responses.pop(0))
|
||||
|
||||
result = agent.run_conversation("what is the answer?")
|
||||
|
||||
assert result["completed"] is True
|
||||
assert result["final_response"] == "The final answer is 42."
|
||||
# The reasoning-only turn should be in messages as an incomplete interim
|
||||
assert any(
|
||||
msg.get("role") == "assistant"
|
||||
and msg.get("finish_reason") == "incomplete"
|
||||
and msg.get("codex_reasoning_items") is not None
|
||||
for msg in result["messages"]
|
||||
)
|
||||
|
||||
|
||||
def test_run_conversation_codex_preserves_encrypted_reasoning_in_interim(monkeypatch):
|
||||
"""Encrypted codex_reasoning_items must be preserved in interim messages
|
||||
even when there is no visible reasoning text or content."""
|
||||
agent = _build_agent(monkeypatch)
|
||||
# Response with encrypted reasoning but no human-readable summary
|
||||
reasoning_response = SimpleNamespace(
|
||||
output=[
|
||||
SimpleNamespace(
|
||||
type="reasoning",
|
||||
id="rs_002",
|
||||
encrypted_content="enc_opaque_blob",
|
||||
summary=[],
|
||||
status="completed",
|
||||
)
|
||||
],
|
||||
usage=SimpleNamespace(input_tokens=50, output_tokens=100, total_tokens=150),
|
||||
status="completed",
|
||||
model="gpt-5-codex",
|
||||
)
|
||||
responses = [
|
||||
reasoning_response,
|
||||
_codex_message_response("Done thinking."),
|
||||
]
|
||||
monkeypatch.setattr(agent, "_interruptible_api_call", lambda api_kwargs: responses.pop(0))
|
||||
|
||||
result = agent.run_conversation("think hard")
|
||||
|
||||
assert result["completed"] is True
|
||||
assert result["final_response"] == "Done thinking."
|
||||
# The interim message must have codex_reasoning_items preserved
|
||||
interim_msgs = [
|
||||
msg for msg in result["messages"]
|
||||
if msg.get("role") == "assistant"
|
||||
and msg.get("finish_reason") == "incomplete"
|
||||
]
|
||||
assert len(interim_msgs) >= 1
|
||||
assert interim_msgs[0].get("codex_reasoning_items") is not None
|
||||
assert interim_msgs[0]["codex_reasoning_items"][0]["encrypted_content"] == "enc_opaque_blob"
|
||||
|
||||
|
||||
def test_chat_messages_to_responses_input_reasoning_only_has_following_item(monkeypatch):
|
||||
"""When converting a reasoning-only interim message to Responses API input,
|
||||
the reasoning items must be followed by an assistant message (even if empty)
|
||||
to satisfy the API's 'required following item' constraint."""
|
||||
agent = _build_agent(monkeypatch)
|
||||
messages = [
|
||||
{"role": "user", "content": "think hard"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning": None,
|
||||
"finish_reason": "incomplete",
|
||||
"codex_reasoning_items": [
|
||||
{"type": "reasoning", "id": "rs_001", "encrypted_content": "enc_abc", "summary": []},
|
||||
],
|
||||
},
|
||||
]
|
||||
items = agent._chat_messages_to_responses_input(messages)
|
||||
|
||||
# Find the reasoning item
|
||||
reasoning_indices = [i for i, it in enumerate(items) if it.get("type") == "reasoning"]
|
||||
assert len(reasoning_indices) == 1
|
||||
ri_idx = reasoning_indices[0]
|
||||
|
||||
# There must be a following item after the reasoning
|
||||
assert ri_idx < len(items) - 1, "Reasoning item must not be the last item (missing_following_item)"
|
||||
following = items[ri_idx + 1]
|
||||
assert following.get("role") == "assistant"
|
||||
|
||||
|
||||
def test_duplicate_detection_distinguishes_different_codex_reasoning(monkeypatch):
|
||||
"""Two consecutive reasoning-only responses with different encrypted content
|
||||
must NOT be treated as duplicates."""
|
||||
agent = _build_agent(monkeypatch)
|
||||
responses = [
|
||||
# First reasoning-only response
|
||||
SimpleNamespace(
|
||||
output=[
|
||||
SimpleNamespace(
|
||||
type="reasoning", id="rs_001",
|
||||
encrypted_content="enc_first", summary=[], status="completed",
|
||||
)
|
||||
],
|
||||
usage=SimpleNamespace(input_tokens=50, output_tokens=100, total_tokens=150),
|
||||
status="completed", model="gpt-5-codex",
|
||||
),
|
||||
# Second reasoning-only response (different encrypted content)
|
||||
SimpleNamespace(
|
||||
output=[
|
||||
SimpleNamespace(
|
||||
type="reasoning", id="rs_002",
|
||||
encrypted_content="enc_second", summary=[], status="completed",
|
||||
)
|
||||
],
|
||||
usage=SimpleNamespace(input_tokens=50, output_tokens=100, total_tokens=150),
|
||||
status="completed", model="gpt-5-codex",
|
||||
),
|
||||
_codex_message_response("Final answer after thinking."),
|
||||
]
|
||||
monkeypatch.setattr(agent, "_interruptible_api_call", lambda api_kwargs: responses.pop(0))
|
||||
|
||||
result = agent.run_conversation("think very hard")
|
||||
|
||||
assert result["completed"] is True
|
||||
assert result["final_response"] == "Final answer after thinking."
|
||||
# Both reasoning-only interim messages should be in history (not collapsed)
|
||||
interim_msgs = [
|
||||
msg for msg in result["messages"]
|
||||
if msg.get("role") == "assistant"
|
||||
and msg.get("finish_reason") == "incomplete"
|
||||
]
|
||||
assert len(interim_msgs) == 2
|
||||
encrypted_contents = [
|
||||
msg["codex_reasoning_items"][0]["encrypted_content"]
|
||||
for msg in interim_msgs
|
||||
]
|
||||
assert "enc_first" in encrypted_contents
|
||||
assert "enc_second" in encrypted_contents
|
||||
|
||||
@@ -438,10 +438,116 @@ def test_named_custom_provider_without_api_mode_defaults(monkeypatch):
|
||||
lambda p: {
|
||||
"name": "my-server",
|
||||
"base_url": "http://localhost:8000/v1",
|
||||
"api_key": "sk-test",
|
||||
"api_key": "***",
|
||||
},
|
||||
)
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="my-server")
|
||||
|
||||
assert resolved["api_mode"] == "chat_completions"
|
||||
|
||||
|
||||
def test_anthropic_messages_in_valid_api_modes():
|
||||
"""anthropic_messages should be accepted by _parse_api_mode."""
|
||||
assert rp._parse_api_mode("anthropic_messages") == "anthropic_messages"
|
||||
|
||||
|
||||
def test_api_key_provider_anthropic_url_auto_detection(monkeypatch):
|
||||
"""API-key providers with /anthropic base URL should auto-detect anthropic_messages mode."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "minimax")
|
||||
monkeypatch.setattr(rp, "_get_model_config", lambda: {})
|
||||
monkeypatch.setenv("MINIMAX_API_KEY", "test-minimax-key")
|
||||
monkeypatch.setenv("MINIMAX_BASE_URL", "https://api.minimax.io/anthropic")
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="minimax")
|
||||
|
||||
assert resolved["provider"] == "minimax"
|
||||
assert resolved["api_mode"] == "anthropic_messages"
|
||||
assert resolved["base_url"] == "https://api.minimax.io/anthropic"
|
||||
|
||||
|
||||
def test_api_key_provider_explicit_api_mode_config(monkeypatch):
|
||||
"""API-key providers should respect api_mode from model config."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "minimax")
|
||||
monkeypatch.setattr(rp, "_get_model_config", lambda: {"api_mode": "anthropic_messages"})
|
||||
monkeypatch.setenv("MINIMAX_API_KEY", "test-minimax-key")
|
||||
monkeypatch.delenv("MINIMAX_BASE_URL", raising=False)
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="minimax")
|
||||
|
||||
assert resolved["provider"] == "minimax"
|
||||
assert resolved["api_mode"] == "anthropic_messages"
|
||||
|
||||
|
||||
def test_minimax_default_url_uses_anthropic_messages(monkeypatch):
|
||||
"""MiniMax with default /anthropic URL should auto-detect anthropic_messages mode."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "minimax")
|
||||
monkeypatch.setattr(rp, "_get_model_config", lambda: {})
|
||||
monkeypatch.setenv("MINIMAX_API_KEY", "test-minimax-key")
|
||||
monkeypatch.delenv("MINIMAX_BASE_URL", raising=False)
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="minimax")
|
||||
|
||||
assert resolved["provider"] == "minimax"
|
||||
assert resolved["api_mode"] == "anthropic_messages"
|
||||
assert resolved["base_url"] == "https://api.minimax.io/anthropic"
|
||||
|
||||
|
||||
def test_minimax_stale_v1_url_auto_corrected(monkeypatch):
|
||||
"""MiniMax with stale /v1 base URL should be auto-corrected to /anthropic."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "minimax")
|
||||
monkeypatch.setattr(rp, "_get_model_config", lambda: {})
|
||||
monkeypatch.setenv("MINIMAX_API_KEY", "test-minimax-key")
|
||||
monkeypatch.setenv("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="minimax")
|
||||
|
||||
assert resolved["provider"] == "minimax"
|
||||
assert resolved["api_mode"] == "anthropic_messages"
|
||||
assert resolved["base_url"] == "https://api.minimax.io/anthropic"
|
||||
|
||||
|
||||
def test_minimax_cn_stale_v1_url_auto_corrected(monkeypatch):
|
||||
"""MiniMax-CN with stale /v1 base URL should be auto-corrected to /anthropic."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "minimax-cn")
|
||||
monkeypatch.setattr(rp, "_get_model_config", lambda: {})
|
||||
monkeypatch.setenv("MINIMAX_CN_API_KEY", "test-minimax-cn-key")
|
||||
monkeypatch.setenv("MINIMAX_CN_BASE_URL", "https://api.minimaxi.com/v1")
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="minimax-cn")
|
||||
|
||||
assert resolved["provider"] == "minimax-cn"
|
||||
assert resolved["api_mode"] == "anthropic_messages"
|
||||
assert resolved["base_url"] == "https://api.minimaxi.com/anthropic"
|
||||
|
||||
|
||||
def test_minimax_explicit_api_mode_respected(monkeypatch):
|
||||
"""Explicit api_mode config should override MiniMax auto-detection."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "minimax")
|
||||
monkeypatch.setattr(rp, "_get_model_config", lambda: {"api_mode": "chat_completions"})
|
||||
monkeypatch.setenv("MINIMAX_API_KEY", "test-minimax-key")
|
||||
monkeypatch.delenv("MINIMAX_BASE_URL", raising=False)
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="minimax")
|
||||
|
||||
assert resolved["provider"] == "minimax"
|
||||
assert resolved["api_mode"] == "chat_completions"
|
||||
|
||||
|
||||
def test_named_custom_provider_anthropic_api_mode(monkeypatch):
|
||||
"""Custom providers should accept api_mode: anthropic_messages."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "my-anthropic-proxy")
|
||||
monkeypatch.setattr(
|
||||
rp, "_get_named_custom_provider",
|
||||
lambda p: {
|
||||
"name": "my-anthropic-proxy",
|
||||
"base_url": "https://proxy.example.com/anthropic",
|
||||
"api_key": "test-key",
|
||||
"api_mode": "anthropic_messages",
|
||||
},
|
||||
)
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="my-anthropic-proxy")
|
||||
|
||||
assert resolved["api_mode"] == "anthropic_messages"
|
||||
assert resolved["base_url"] == "https://proxy.example.com/anthropic"
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
"""Tests that verify SQL injection mitigations in insights and state modules."""
|
||||
|
||||
import re
|
||||
|
||||
from agent.insights import InsightsEngine
|
||||
|
||||
|
||||
def test_session_cols_no_injection_chars():
|
||||
"""_SESSION_COLS must not contain SQL injection vectors."""
|
||||
cols = InsightsEngine._SESSION_COLS
|
||||
assert ";" not in cols
|
||||
assert "--" not in cols
|
||||
assert "'" not in cols
|
||||
assert "DROP" not in cols.upper()
|
||||
|
||||
|
||||
def test_get_sessions_all_query_is_parameterized():
|
||||
"""_GET_SESSIONS_ALL must use a ? placeholder for the cutoff value."""
|
||||
query = InsightsEngine._GET_SESSIONS_ALL
|
||||
assert "?" in query
|
||||
assert "started_at >= ?" in query
|
||||
# Must not embed any runtime-variable content via brace interpolation
|
||||
assert "{" not in query
|
||||
|
||||
|
||||
def test_get_sessions_with_source_query_is_parameterized():
|
||||
"""_GET_SESSIONS_WITH_SOURCE must use ? placeholders for both parameters."""
|
||||
query = InsightsEngine._GET_SESSIONS_WITH_SOURCE
|
||||
assert query.count("?") == 2
|
||||
assert "started_at >= ?" in query
|
||||
assert "source = ?" in query
|
||||
assert "{" not in query
|
||||
|
||||
|
||||
def test_session_col_names_are_safe_identifiers():
|
||||
"""Every column name listed in _SESSION_COLS must be a simple identifier."""
|
||||
cols = InsightsEngine._SESSION_COLS
|
||||
identifiers = [c.strip() for c in cols.split(",")]
|
||||
safe_identifier = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
|
||||
for col in identifiers:
|
||||
assert safe_identifier.match(col), (
|
||||
f"Column name {col!r} is not a safe SQL identifier"
|
||||
)
|
||||
@@ -0,0 +1,47 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
|
||||
HOST = "example-host"
|
||||
PORT = 9223
|
||||
WS_URL = f"ws://{HOST}:{PORT}/devtools/browser/abc123"
|
||||
HTTP_URL = f"http://{HOST}:{PORT}"
|
||||
VERSION_URL = f"{HTTP_URL}/json/version"
|
||||
|
||||
|
||||
class TestResolveCdpOverride:
|
||||
def test_keeps_full_devtools_websocket_url(self):
|
||||
from tools.browser_tool import _resolve_cdp_override
|
||||
|
||||
assert _resolve_cdp_override(WS_URL) == WS_URL
|
||||
|
||||
def test_resolves_http_discovery_endpoint_to_websocket(self):
|
||||
from tools.browser_tool import _resolve_cdp_override
|
||||
|
||||
response = Mock()
|
||||
response.raise_for_status.return_value = None
|
||||
response.json.return_value = {"webSocketDebuggerUrl": WS_URL}
|
||||
|
||||
with patch("tools.browser_tool.requests.get", return_value=response) as mock_get:
|
||||
resolved = _resolve_cdp_override(HTTP_URL)
|
||||
|
||||
assert resolved == WS_URL
|
||||
mock_get.assert_called_once_with(VERSION_URL, timeout=10)
|
||||
|
||||
def test_resolves_bare_ws_hostport_to_discovery_websocket(self):
|
||||
from tools.browser_tool import _resolve_cdp_override
|
||||
|
||||
response = Mock()
|
||||
response.raise_for_status.return_value = None
|
||||
response.json.return_value = {"webSocketDebuggerUrl": WS_URL}
|
||||
|
||||
with patch("tools.browser_tool.requests.get", return_value=response) as mock_get:
|
||||
resolved = _resolve_cdp_override(f"ws://{HOST}:{PORT}")
|
||||
|
||||
assert resolved == WS_URL
|
||||
mock_get.assert_called_once_with(VERSION_URL, timeout=10)
|
||||
|
||||
def test_falls_back_to_raw_url_when_discovery_fails(self):
|
||||
from tools.browser_tool import _resolve_cdp_override
|
||||
|
||||
with patch("tools.browser_tool.requests.get", side_effect=RuntimeError("boom")):
|
||||
assert _resolve_cdp_override(HTTP_URL) == HTTP_URL
|
||||
@@ -64,7 +64,8 @@ def make_env(daytona_sdk, monkeypatch):
|
||||
|
||||
def _factory(
|
||||
sandbox=None,
|
||||
find_one_side_effect=None,
|
||||
get_side_effect=None,
|
||||
list_return=None,
|
||||
home_dir="/root",
|
||||
persistent=True,
|
||||
**kwargs,
|
||||
@@ -76,11 +77,17 @@ def make_env(daytona_sdk, monkeypatch):
|
||||
mock_client = MagicMock()
|
||||
mock_client.create.return_value = sandbox
|
||||
|
||||
if find_one_side_effect is not None:
|
||||
mock_client.find_one.side_effect = find_one_side_effect
|
||||
if get_side_effect is not None:
|
||||
mock_client.get.side_effect = get_side_effect
|
||||
else:
|
||||
# Default: no existing sandbox found
|
||||
mock_client.find_one.side_effect = daytona_sdk.DaytonaError("not found")
|
||||
# Default: no existing sandbox found via get()
|
||||
mock_client.get.side_effect = daytona_sdk.DaytonaError("not found")
|
||||
|
||||
# Default: no legacy sandbox found via list()
|
||||
if list_return is not None:
|
||||
mock_client.list.return_value = list_return
|
||||
else:
|
||||
mock_client.list.return_value = SimpleNamespace(items=[])
|
||||
|
||||
daytona_sdk.Daytona = MagicMock(return_value=mock_client)
|
||||
|
||||
@@ -131,24 +138,46 @@ class TestCwdResolution:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPersistence:
|
||||
def test_persistent_resumes_existing_sandbox(self, make_env):
|
||||
def test_persistent_resumes_via_get(self, make_env):
|
||||
existing = _make_sandbox(sandbox_id="sb-existing")
|
||||
existing.process.exec.return_value = _make_exec_response(result="/root")
|
||||
env = make_env(find_one_side_effect=lambda **kw: existing, persistent=True)
|
||||
env = make_env(get_side_effect=lambda name: existing, persistent=True,
|
||||
task_id="mytask")
|
||||
existing.start.assert_called_once()
|
||||
# Should NOT have called create since find_one succeeded
|
||||
env._mock_client.get.assert_called_once_with("hermes-mytask")
|
||||
env._mock_client.create.assert_not_called()
|
||||
|
||||
def test_persistent_resumes_legacy_via_list(self, make_env, daytona_sdk):
|
||||
legacy = _make_sandbox(sandbox_id="sb-legacy")
|
||||
legacy.process.exec.return_value = _make_exec_response(result="/root")
|
||||
env = make_env(
|
||||
get_side_effect=daytona_sdk.DaytonaError("not found"),
|
||||
list_return=SimpleNamespace(items=[legacy]),
|
||||
persistent=True,
|
||||
task_id="mytask",
|
||||
)
|
||||
legacy.start.assert_called_once()
|
||||
env._mock_client.list.assert_called_once_with(
|
||||
labels={"hermes_task_id": "mytask"}, page=1, limit=1)
|
||||
env._mock_client.create.assert_not_called()
|
||||
|
||||
def test_persistent_creates_new_when_none_found(self, make_env, daytona_sdk):
|
||||
env = make_env(
|
||||
find_one_side_effect=daytona_sdk.DaytonaError("not found"),
|
||||
get_side_effect=daytona_sdk.DaytonaError("not found"),
|
||||
persistent=True,
|
||||
task_id="mytask",
|
||||
)
|
||||
env._mock_client.create.assert_called_once()
|
||||
# Verify the name and labels were passed to CreateSandboxFromImageParams
|
||||
# by checking get() was called with the right sandbox name
|
||||
env._mock_client.get.assert_called_with("hermes-mytask")
|
||||
env._mock_client.list.assert_called_with(
|
||||
labels={"hermes_task_id": "mytask"}, page=1, limit=1)
|
||||
|
||||
def test_non_persistent_skips_find_one(self, make_env):
|
||||
def test_non_persistent_skips_lookup(self, make_env):
|
||||
env = make_env(persistent=False)
|
||||
env._mock_client.find_one.assert_not_called()
|
||||
env._mock_client.get.assert_not_called()
|
||||
env._mock_client.list.assert_not_called()
|
||||
env._mock_client.create.assert_called_once()
|
||||
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ from tools.delegate_tool import (
|
||||
MAX_DEPTH,
|
||||
check_delegate_requirements,
|
||||
delegate_task,
|
||||
_build_child_agent,
|
||||
_build_child_system_prompt,
|
||||
_strip_blocked_tools,
|
||||
_resolve_delegation_credentials,
|
||||
@@ -291,6 +292,58 @@ class TestToolNamePreservation(unittest.TestCase):
|
||||
|
||||
self.assertEqual(model_tools._last_resolved_tool_names, original_tools)
|
||||
|
||||
def test_build_child_agent_does_not_raise_name_error(self):
|
||||
"""Regression: _build_child_agent must not reference _saved_tool_names.
|
||||
|
||||
The bug introduced by the e7844e9c merge conflict: line 235 inside
|
||||
_build_child_agent read `list(_saved_tool_names)` where that variable
|
||||
is only defined later in _run_single_child. Calling _build_child_agent
|
||||
standalone (without _run_single_child's scope) must never raise NameError.
|
||||
"""
|
||||
parent = _make_mock_parent(depth=0)
|
||||
|
||||
with patch("run_agent.AIAgent"):
|
||||
try:
|
||||
_build_child_agent(
|
||||
task_index=0,
|
||||
goal="regression check",
|
||||
context=None,
|
||||
toolsets=None,
|
||||
model=None,
|
||||
max_iterations=10,
|
||||
parent_agent=parent,
|
||||
)
|
||||
except NameError as exc:
|
||||
self.fail(
|
||||
f"_build_child_agent raised NameError — "
|
||||
f"_saved_tool_names leaked back into wrong scope: {exc}"
|
||||
)
|
||||
|
||||
def test_saved_tool_names_set_on_child_before_run(self):
|
||||
"""_run_single_child must set _delegate_saved_tool_names on the child
|
||||
from model_tools._last_resolved_tool_names before run_conversation."""
|
||||
import model_tools
|
||||
|
||||
parent = _make_mock_parent(depth=0)
|
||||
expected_tools = ["read_file", "web_search", "execute_code"]
|
||||
model_tools._last_resolved_tool_names = list(expected_tools)
|
||||
|
||||
captured = {}
|
||||
|
||||
with patch("run_agent.AIAgent") as MockAgent:
|
||||
mock_child = MagicMock()
|
||||
|
||||
def capture_and_return(user_message):
|
||||
captured["saved"] = list(mock_child._delegate_saved_tool_names)
|
||||
return {"final_response": "ok", "completed": True, "api_calls": 1}
|
||||
|
||||
mock_child.run_conversation.side_effect = capture_and_return
|
||||
MockAgent.return_value = mock_child
|
||||
|
||||
delegate_task(goal="capture test", parent_agent=parent)
|
||||
|
||||
self.assertEqual(captured["saved"], expected_tools)
|
||||
|
||||
|
||||
class TestDelegateObservability(unittest.TestCase):
|
||||
"""Tests for enriched metadata returned by _run_single_child."""
|
||||
|
||||
@@ -106,6 +106,18 @@ class TestSchemaConversion:
|
||||
assert schema["parameters"]["type"] == "object"
|
||||
assert schema["parameters"]["properties"] == {}
|
||||
|
||||
def test_object_schema_without_properties_gets_normalized(self):
|
||||
from tools.mcp_tool import _convert_mcp_schema
|
||||
|
||||
mcp_tool = _make_mcp_tool(
|
||||
name="ask",
|
||||
description="Ask Crawl4AI",
|
||||
input_schema={"type": "object"},
|
||||
)
|
||||
schema = _convert_mcp_schema("crawl4ai", mcp_tool)
|
||||
|
||||
assert schema["parameters"] == {"type": "object", "properties": {}}
|
||||
|
||||
def test_tool_name_prefix_format(self):
|
||||
from tools.mcp_tool import _convert_mcp_schema
|
||||
|
||||
@@ -1893,6 +1905,33 @@ class TestSamplingCallbackText:
|
||||
messages = call_args.kwargs["messages"]
|
||||
assert messages[0] == {"role": "system", "content": "Be helpful"}
|
||||
|
||||
def test_server_tools_with_object_schema_are_normalized(self):
|
||||
"""Server-provided tools should gain empty properties for object schemas."""
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_response()
|
||||
server_tool = SimpleNamespace(
|
||||
name="ask",
|
||||
description="Ask Crawl4AI",
|
||||
inputSchema={"type": "object"},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.call_llm",
|
||||
return_value=fake_client.chat.completions.create.return_value,
|
||||
) as mock_call:
|
||||
params = _make_sampling_params(tools=[server_tool])
|
||||
asyncio.run(self.handler(None, params))
|
||||
|
||||
tools = mock_call.call_args.kwargs["tools"]
|
||||
assert tools == [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ask",
|
||||
"description": "Ask Crawl4AI",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
}]
|
||||
|
||||
def test_length_stop_reason(self):
|
||||
"""finish_reason='length' maps to stopReason='maxTokens'."""
|
||||
fake_client = MagicMock()
|
||||
|
||||
+51
-2
@@ -106,14 +106,63 @@ def _get_extraction_model() -> Optional[str]:
|
||||
return os.getenv("AUXILIARY_WEB_EXTRACT_MODEL", "").strip() or None
|
||||
|
||||
|
||||
def _resolve_cdp_override(cdp_url: str) -> str:
|
||||
"""Normalize a user-supplied CDP endpoint into a concrete connectable URL.
|
||||
|
||||
Accepts:
|
||||
- full websocket endpoints: ws://host:port/devtools/browser/...
|
||||
- HTTP discovery endpoints: http://host:port or http://host:port/json/version
|
||||
- bare websocket host:port values like ws://host:port
|
||||
|
||||
For discovery-style endpoints we fetch /json/version and return the
|
||||
webSocketDebuggerUrl so downstream tools always receive a concrete browser
|
||||
websocket instead of an ambiguous host:port URL.
|
||||
"""
|
||||
raw = (cdp_url or "").strip()
|
||||
if not raw:
|
||||
return ""
|
||||
|
||||
lowered = raw.lower()
|
||||
if "/devtools/browser/" in lowered:
|
||||
return raw
|
||||
|
||||
discovery_url = raw
|
||||
if lowered.startswith("ws://") or lowered.startswith("wss://"):
|
||||
if raw.count(":") == 2 and raw.rstrip("/").rsplit(":", 1)[-1].isdigit() and "/" not in raw.split(":", 2)[-1]:
|
||||
discovery_url = ("http://" if lowered.startswith("ws://") else "https://") + raw.split("://", 1)[1]
|
||||
else:
|
||||
return raw
|
||||
|
||||
if discovery_url.lower().endswith("/json/version"):
|
||||
version_url = discovery_url
|
||||
else:
|
||||
version_url = discovery_url.rstrip("/") + "/json/version"
|
||||
|
||||
try:
|
||||
response = requests.get(version_url, timeout=10)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to resolve CDP endpoint %s via %s: %s", raw, version_url, exc)
|
||||
return raw
|
||||
|
||||
ws_url = str(payload.get("webSocketDebuggerUrl") or "").strip()
|
||||
if ws_url:
|
||||
logger.info("Resolved CDP endpoint %s -> %s", raw, ws_url)
|
||||
return ws_url
|
||||
|
||||
logger.warning("CDP discovery at %s did not return webSocketDebuggerUrl; using raw endpoint", version_url)
|
||||
return raw
|
||||
|
||||
|
||||
def _get_cdp_override() -> str:
|
||||
"""Return a user-supplied CDP URL override, or empty string.
|
||||
"""Return a normalized user-supplied CDP URL override, or empty string.
|
||||
|
||||
When ``BROWSER_CDP_URL`` is set (e.g. via ``/browser connect``), we skip
|
||||
both Browserbase and the local headless launcher and connect directly to
|
||||
the supplied Chrome DevTools Protocol endpoint.
|
||||
"""
|
||||
return os.environ.get("BROWSER_CDP_URL", "").strip()
|
||||
return _resolve_cdp_override(os.environ.get("BROWSER_CDP_URL", ""))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
|
||||
@@ -336,11 +336,9 @@ Jobs run in a fresh session with no current-chat context, so prompts must be sel
|
||||
If skill or skills are provided on create, the future cron run loads those skills in order, then follows the prompt as the task instruction.
|
||||
On update, passing skills=[] clears attached skills.
|
||||
|
||||
NOTE: The agent's final response is auto-delivered to the target — do NOT use
|
||||
send_message in the prompt for that same destination. Same-target send_message
|
||||
calls are skipped to avoid duplicate cron deliveries. Put the primary
|
||||
user-facing content in the final response, and use send_message only for
|
||||
additional or different targets.
|
||||
NOTE: The agent's final response is auto-delivered to the target. Put the primary
|
||||
user-facing content in the final response. Cron jobs run autonomously with no user
|
||||
present — they cannot ask questions or request clarification.
|
||||
|
||||
Important safety rule: cron-run sessions should not recursively schedule more cron jobs.""",
|
||||
"parameters": {
|
||||
@@ -372,7 +370,7 @@ Important safety rule: cron-run sessions should not recursively schedule more cr
|
||||
},
|
||||
"deliver": {
|
||||
"type": "string",
|
||||
"description": "Delivery target: origin, local, telegram, discord, signal, sms, or platform:chat_id"
|
||||
"description": "Delivery target: origin, local, telegram, discord, slack, whatsapp, signal, matrix, mattermost, homeassistant, dingtalk, email, sms, or platform:chat_id"
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
|
||||
+28
-17
@@ -232,8 +232,6 @@ def _build_child_agent(
|
||||
tool_progress_callback=child_progress_cb,
|
||||
iteration_budget=shared_budget,
|
||||
)
|
||||
child._delegate_saved_tool_names = list(_saved_tool_names)
|
||||
|
||||
# Set delegation depth so children can't spawn grandchildren
|
||||
child._delegate_depth = getattr(parent_agent, '_delegate_depth', 0) + 1
|
||||
|
||||
@@ -264,12 +262,11 @@ def _run_single_child(
|
||||
# Get the progress callback from the child agent
|
||||
child_progress_cb = getattr(child, 'tool_progress_callback', None)
|
||||
|
||||
# Save the parent's resolved tool names before the child agent can
|
||||
# overwrite the process-global via get_tool_definitions().
|
||||
# This must be in _run_single_child (not _build_child_agent) so the
|
||||
# save/restore happens in the same scope as the try/finally.
|
||||
# Restore parent tool names using the value saved before child construction
|
||||
# mutated the global. This is the correct parent toolset, not the child's.
|
||||
import model_tools
|
||||
_saved_tool_names = list(model_tools._last_resolved_tool_names)
|
||||
_saved_tool_names = getattr(child, "_delegate_saved_tool_names",
|
||||
list(model_tools._last_resolved_tool_names))
|
||||
|
||||
try:
|
||||
result = child.run_conversation(user_message=goal)
|
||||
@@ -466,18 +463,32 @@ def delegate_task(
|
||||
# Track goal labels for progress display (truncated for readability)
|
||||
task_labels = [t["goal"][:40] for t in task_list]
|
||||
|
||||
# Save parent tool names BEFORE any child construction mutates the global.
|
||||
# _build_child_agent() calls AIAgent() which calls get_tool_definitions(),
|
||||
# which overwrites model_tools._last_resolved_tool_names with child's toolset.
|
||||
import model_tools as _model_tools
|
||||
_parent_tool_names = list(_model_tools._last_resolved_tool_names)
|
||||
|
||||
# Build all child agents on the main thread (thread-safe construction)
|
||||
# Wrapped in try/finally so the global is always restored even if a
|
||||
# child build raises (otherwise _last_resolved_tool_names stays corrupted).
|
||||
children = []
|
||||
for i, t in enumerate(task_list):
|
||||
child = _build_child_agent(
|
||||
task_index=i, goal=t["goal"], context=t.get("context"),
|
||||
toolsets=t.get("toolsets") or toolsets, model=creds["model"],
|
||||
max_iterations=effective_max_iter, parent_agent=parent_agent,
|
||||
override_provider=creds["provider"], override_base_url=creds["base_url"],
|
||||
override_api_key=creds["api_key"],
|
||||
override_api_mode=creds["api_mode"],
|
||||
)
|
||||
children.append((i, t, child))
|
||||
try:
|
||||
for i, t in enumerate(task_list):
|
||||
child = _build_child_agent(
|
||||
task_index=i, goal=t["goal"], context=t.get("context"),
|
||||
toolsets=t.get("toolsets") or toolsets, model=creds["model"],
|
||||
max_iterations=effective_max_iter, parent_agent=parent_agent,
|
||||
override_provider=creds["provider"], override_base_url=creds["base_url"],
|
||||
override_api_key=creds["api_key"],
|
||||
override_api_mode=creds["api_mode"],
|
||||
)
|
||||
# Override with correct parent tool names (before child construction mutated global)
|
||||
child._delegate_saved_tool_names = _parent_tool_names
|
||||
children.append((i, t, child))
|
||||
finally:
|
||||
# Authoritative restore: reset global to parent's tool names after all children built
|
||||
_model_tools._last_resolved_tool_names = _parent_tool_names
|
||||
|
||||
if n_tasks == 1:
|
||||
# Single task -- run directly (no thread pool overhead)
|
||||
|
||||
@@ -68,11 +68,13 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
resources = Resources(cpu=cpu, memory=memory_gib, disk=disk_gib)
|
||||
|
||||
labels = {"hermes_task_id": task_id}
|
||||
sandbox_name = f"hermes-{task_id}"
|
||||
|
||||
# Try to resume an existing stopped sandbox for this task
|
||||
# Try to resume an existing sandbox for this task
|
||||
if self._persistent:
|
||||
# 1. Try name-based lookup (new path)
|
||||
try:
|
||||
self._sandbox = self._daytona.find_one(labels=labels)
|
||||
self._sandbox = self._daytona.get(sandbox_name)
|
||||
self._sandbox.start()
|
||||
logger.info("Daytona: resumed sandbox %s for task %s",
|
||||
self._sandbox.id, task_id)
|
||||
@@ -83,11 +85,26 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
task_id, e)
|
||||
self._sandbox = None
|
||||
|
||||
# 2. Legacy fallback: find sandbox created before the naming migration
|
||||
if self._sandbox is None:
|
||||
try:
|
||||
page = self._daytona.list(labels=labels, page=1, limit=1)
|
||||
if page.items:
|
||||
self._sandbox = page.items[0]
|
||||
self._sandbox.start()
|
||||
logger.info("Daytona: resumed legacy sandbox %s for task %s",
|
||||
self._sandbox.id, task_id)
|
||||
except Exception as e:
|
||||
logger.debug("Daytona: no legacy sandbox found for task %s: %s",
|
||||
task_id, e)
|
||||
self._sandbox = None
|
||||
|
||||
# Create a fresh sandbox if we don't have one
|
||||
if self._sandbox is None:
|
||||
self._sandbox = self._daytona.create(
|
||||
CreateSandboxFromImageParams(
|
||||
image=image,
|
||||
name=sandbox_name,
|
||||
labels=labels,
|
||||
auto_stop_interval=0,
|
||||
resources=resources,
|
||||
|
||||
+15
-5
@@ -605,7 +605,9 @@ class SamplingHandler:
|
||||
"function": {
|
||||
"name": getattr(t, "name", ""),
|
||||
"description": getattr(t, "description", "") or "",
|
||||
"parameters": getattr(t, "inputSchema", {}) or {},
|
||||
"parameters": _normalize_mcp_input_schema(
|
||||
getattr(t, "inputSchema", None)
|
||||
),
|
||||
},
|
||||
}
|
||||
for t in server_tools
|
||||
@@ -1213,6 +1215,17 @@ def _make_check_fn(server_name: str):
|
||||
# Discovery & registration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _normalize_mcp_input_schema(schema: dict | None) -> dict:
|
||||
"""Normalize MCP input schemas for LLM tool-calling compatibility."""
|
||||
if not schema:
|
||||
return {"type": "object", "properties": {}}
|
||||
|
||||
if schema.get("type") == "object" and "properties" not in schema:
|
||||
return {**schema, "properties": {}}
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def _convert_mcp_schema(server_name: str, mcp_tool) -> dict:
|
||||
"""Convert an MCP tool listing to the Hermes registry schema format.
|
||||
|
||||
@@ -1231,10 +1244,7 @@ def _convert_mcp_schema(server_name: str, mcp_tool) -> dict:
|
||||
return {
|
||||
"name": prefixed_name,
|
||||
"description": mcp_tool.description or f"MCP tool {mcp_tool.name} from {server_name}",
|
||||
"parameters": mcp_tool.inputSchema if mcp_tool.inputSchema else {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
},
|
||||
"parameters": _normalize_mcp_input_schema(mcp_tool.inputSchema),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -124,6 +124,10 @@ def _handle_send(args):
|
||||
"slack": Platform.SLACK,
|
||||
"whatsapp": Platform.WHATSAPP,
|
||||
"signal": Platform.SIGNAL,
|
||||
"matrix": Platform.MATRIX,
|
||||
"mattermost": Platform.MATTERMOST,
|
||||
"homeassistant": Platform.HOMEASSISTANT,
|
||||
"dingtalk": Platform.DINGTALK,
|
||||
"email": Platform.EMAIL,
|
||||
"sms": Platform.SMS,
|
||||
}
|
||||
|
||||
+2
-1
@@ -239,6 +239,7 @@ def _generate_openai_tts(text: str, output_path: str, tts_config: Dict[str, Any]
|
||||
oai_config = tts_config.get("openai", {})
|
||||
model = oai_config.get("model", DEFAULT_OPENAI_MODEL)
|
||||
voice = oai_config.get("voice", DEFAULT_OPENAI_VOICE)
|
||||
base_url = oai_config.get("base_url", "https://api.openai.com/v1")
|
||||
|
||||
# Determine response format from extension
|
||||
if output_path.endswith(".ogg"):
|
||||
@@ -247,7 +248,7 @@ def _generate_openai_tts(text: str, output_path: str, tts_config: Dict[str, Any]
|
||||
response_format = "mp3"
|
||||
|
||||
OpenAIClient = _import_openai_client()
|
||||
client = OpenAIClient(api_key=api_key, base_url="https://api.openai.com/v1")
|
||||
client = OpenAIClient(api_key=api_key, base_url=base_url)
|
||||
response = client.audio.speech.create(
|
||||
model=model,
|
||||
voice=voice,
|
||||
|
||||
@@ -305,14 +305,14 @@ For docs-only examples, the exact file set may differ. The point is to cover:
|
||||
Run tests with xdist disabled:
|
||||
|
||||
```bash
|
||||
source .venv/bin/activate
|
||||
source venv/bin/activate
|
||||
python -m pytest tests/test_runtime_provider_resolution.py tests/test_cli_provider_resolution.py tests/test_cli_model_command.py tests/test_setup_model_selection.py -n0 -q
|
||||
```
|
||||
|
||||
For deeper changes, run the full suite before pushing:
|
||||
|
||||
```bash
|
||||
source .venv/bin/activate
|
||||
source venv/bin/activate
|
||||
python -m pytest tests/ -n0 -q
|
||||
```
|
||||
|
||||
@@ -321,14 +321,14 @@ python -m pytest tests/ -n0 -q
|
||||
After tests, run a real smoke test.
|
||||
|
||||
```bash
|
||||
source .venv/bin/activate
|
||||
source venv/bin/activate
|
||||
python -m hermes_cli.main chat -q "Say hello" --provider your-provider --model your-model
|
||||
```
|
||||
|
||||
Also test the interactive flows if you changed menus:
|
||||
|
||||
```bash
|
||||
source .venv/bin/activate
|
||||
source venv/bin/activate
|
||||
python -m hermes_cli.main model
|
||||
python -m hermes_cli.main setup
|
||||
```
|
||||
|
||||
@@ -51,11 +51,13 @@ hermes setup # Or configure everything at once
|
||||
| **MiniMax China** | China-region MiniMax endpoint | Set `MINIMAX_CN_API_KEY` |
|
||||
| **Alibaba Cloud** | Qwen models via DashScope | Set `DASHSCOPE_API_KEY` |
|
||||
| **Kilo Code** | KiloCode-hosted models | Set `KILOCODE_API_KEY` |
|
||||
| **OpenCode Zen** | Pay-as-you-go access to curated models | Set `OPENCODE_ZEN_API_KEY` |
|
||||
| **OpenCode Go** | $10/month subscription for open models | Set `OPENCODE_GO_API_KEY` |
|
||||
| **Vercel AI Gateway** | Vercel AI Gateway routing | Set `AI_GATEWAY_API_KEY` |
|
||||
| **Custom Endpoint** | VLLM, SGLang, or any OpenAI-compatible API | Set base URL + API key |
|
||||
| **Custom Endpoint** | VLLM, SGLang, Ollama, or any OpenAI-compatible API | Set base URL + API key |
|
||||
|
||||
:::tip
|
||||
You can switch providers at any time with `hermes model` — no code changes, no lock-in.
|
||||
You can switch providers at any time with `hermes model` — no code changes, no lock-in. When configuring a custom endpoint, Hermes will prompt for the context window size and auto-detect it when possible. See [Context Length Detection](../user-guide/configuration.md#context-length-detection) for details.
|
||||
:::
|
||||
|
||||
## 3. Start Chatting
|
||||
|
||||
@@ -66,7 +66,7 @@ Common options:
|
||||
| `-q`, `--query "..."` | One-shot, non-interactive prompt. |
|
||||
| `-m`, `--model <model>` | Override the model for this run. |
|
||||
| `-t`, `--toolsets <csv>` | Enable a comma-separated set of toolsets. |
|
||||
| `--provider <provider>` | Force a provider: `auto`, `openrouter`, `nous`, `openai-codex`, `copilot`, `copilot-acp`, `anthropic`, `zai`, `kimi-coding`, `minimax`, `minimax-cn`. |
|
||||
| `--provider <provider>` | Force a provider: `auto`, `openrouter`, `nous`, `openai-codex`, `copilot`, `copilot-acp`, `anthropic`, `zai`, `kimi-coding`, `minimax`, `minimax-cn`, `opencode-zen`, `opencode-go`, `ai-gateway`, `kilocode`, `alibaba`. |
|
||||
| `-v`, `--verbose` | Verbose output. |
|
||||
| `-Q`, `--quiet` | Programmatic mode: suppress banner/spinner/tool previews. |
|
||||
| `--resume <session>` / `--continue [name]` | Resume a session directly from `chat`. |
|
||||
|
||||
@@ -41,6 +41,12 @@ All variables go in `~/.hermes/.env`. You can also set them with `hermes config
|
||||
| `ANTHROPIC_TOKEN` | Manual or legacy Anthropic OAuth/setup-token override |
|
||||
| `DASHSCOPE_API_KEY` | Alibaba Cloud DashScope API key for Qwen models ([modelstudio.console.alibabacloud.com](https://modelstudio.console.alibabacloud.com/)) |
|
||||
| `DASHSCOPE_BASE_URL` | Custom DashScope base URL (default: international endpoint) |
|
||||
| `DEEPSEEK_API_KEY` | DeepSeek API key for direct DeepSeek access ([platform.deepseek.com](https://platform.deepseek.com/api_keys)) |
|
||||
| `DEEPSEEK_BASE_URL` | Custom DeepSeek API base URL |
|
||||
| `OPENCODE_ZEN_API_KEY` | OpenCode Zen API key — pay-as-you-go access to curated models ([opencode.ai](https://opencode.ai/auth)) |
|
||||
| `OPENCODE_ZEN_BASE_URL` | Override OpenCode Zen base URL |
|
||||
| `OPENCODE_GO_API_KEY` | OpenCode Go API key — $10/month subscription for open models ([opencode.ai](https://opencode.ai/auth)) |
|
||||
| `OPENCODE_GO_BASE_URL` | Override OpenCode Go base URL |
|
||||
| `CLAUDE_CODE_OAUTH_TOKEN` | Explicit Claude Code token override if you export one manually |
|
||||
| `HERMES_MODEL` | Preferred model name (checked before `LLM_MODEL`, used by gateway) |
|
||||
| `LLM_MODEL` | Default model name (fallback when not set in config.yaml) |
|
||||
@@ -71,6 +77,7 @@ For native Anthropic auth, Hermes prefers Claude Code's own credential files whe
|
||||
| `PARALLEL_API_KEY` | AI-native web search ([parallel.ai](https://parallel.ai/)) |
|
||||
| `FIRECRAWL_API_KEY` | Web scraping ([firecrawl.dev](https://firecrawl.dev/)) |
|
||||
| `FIRECRAWL_API_URL` | Custom Firecrawl API endpoint for self-hosted instances (optional) |
|
||||
| `TAVILY_API_KEY` | Tavily API key for AI-native web search, extract, and crawl ([app.tavily.com](https://app.tavily.com/home)) |
|
||||
| `BROWSERBASE_API_KEY` | Browser automation ([browserbase.com](https://browserbase.com/)) |
|
||||
| `BROWSERBASE_PROJECT_ID` | Browserbase project ID |
|
||||
| `BROWSER_USE_API_KEY` | Browser Use cloud browser API key ([browser-use.com](https://browser-use.com/)) |
|
||||
@@ -83,7 +90,9 @@ For native Anthropic auth, Hermes prefers Claude Code's own credential files whe
|
||||
| `GROQ_BASE_URL` | Override the Groq OpenAI-compatible STT endpoint |
|
||||
| `STT_OPENAI_MODEL` | Override the OpenAI STT model (default: `whisper-1`) |
|
||||
| `STT_OPENAI_BASE_URL` | Override the OpenAI-compatible STT endpoint |
|
||||
| `GITHUB_TOKEN` | GitHub token for Skills Hub (higher API rate limits, skill publish) |
|
||||
| `HONCHO_API_KEY` | Cross-session user modeling ([honcho.dev](https://honcho.dev/)) |
|
||||
| `HONCHO_BASE_URL` | Base URL for self-hosted Honcho instances (default: Honcho cloud). No API key required for local instances |
|
||||
| `TINKER_API_KEY` | RL training ([tinker-console.thinkingmachines.ai](https://tinker-console.thinkingmachines.ai/)) |
|
||||
| `WANDB_API_KEY` | RL training metrics ([wandb.ai](https://wandb.ai/)) |
|
||||
| `DAYTONA_API_KEY` | Daytona cloud sandboxes ([daytona.io](https://daytona.io/)) |
|
||||
@@ -199,6 +208,9 @@ For native Anthropic auth, Hermes prefers Claude Code's own credential files whe
|
||||
| `MATRIX_ENCRYPTION` | Enable end-to-end encryption (`true`/`false`, default: `false`) |
|
||||
| `HASS_TOKEN` | Home Assistant Long-Lived Access Token (enables HA platform + tools) |
|
||||
| `HASS_URL` | Home Assistant URL (default: `http://homeassistant.local:8123`) |
|
||||
| `WEBHOOK_ENABLED` | Enable the webhook platform adapter (`true`/`false`) |
|
||||
| `WEBHOOK_PORT` | HTTP server port for receiving webhooks (default: `8644`) |
|
||||
| `WEBHOOK_SECRET` | Global HMAC secret for webhook signature validation (used as fallback when routes don't specify their own) |
|
||||
| `API_SERVER_ENABLED` | Enable the OpenAI-compatible API server (`true`/`false`). Runs alongside other platforms. |
|
||||
| `API_SERVER_KEY` | Bearer token for API server authentication. If empty, all requests are allowed (local-only use). |
|
||||
| `API_SERVER_PORT` | Port for the API server (default: `8642`) |
|
||||
@@ -211,7 +223,7 @@ For native Anthropic auth, Hermes prefers Claude Code's own credential files whe
|
||||
|
||||
| Variable | Description |
|
||||
|----------|-------------|
|
||||
| `HERMES_MAX_ITERATIONS` | Max tool-calling iterations per conversation (default: 60) |
|
||||
| `HERMES_MAX_ITERATIONS` | Max tool-calling iterations per conversation (default: 90) |
|
||||
| `HERMES_TOOL_PROGRESS` | Deprecated compatibility variable for tool progress display. Prefer `display.tool_progress` in `config.yaml`. |
|
||||
| `HERMES_TOOL_PROGRESS_MODE` | Deprecated compatibility variable for tool progress mode. Prefer `display.tool_progress` in `config.yaml`. |
|
||||
| `HERMES_HUMAN_DELAY_MODE` | Response pacing: `off`/`natural`/`custom` |
|
||||
@@ -221,6 +233,7 @@ For native Anthropic auth, Hermes prefers Claude Code's own credential files whe
|
||||
| `HERMES_API_TIMEOUT` | LLM API call timeout in seconds (default: `900`) |
|
||||
| `HERMES_EXEC_ASK` | Enable execution approval prompts in gateway mode (`true`/`false`) |
|
||||
| `HERMES_BACKGROUND_NOTIFICATIONS` | Background process notification mode in gateway: `all` (default), `result`, `error`, `off` |
|
||||
| `HERMES_EPHEMERAL_SYSTEM_PROMPT` | Ephemeral system prompt injected at API-call time (never persisted to sessions) |
|
||||
|
||||
## Session Settings
|
||||
|
||||
|
||||
@@ -42,18 +42,25 @@ API calls go **only to the LLM provider you configure** (e.g., OpenRouter, your
|
||||
|
||||
### Can I use it offline / with local models?
|
||||
|
||||
Yes. Point Hermes at any local OpenAI-compatible server:
|
||||
Yes. Run `hermes model`, select **Custom endpoint**, and enter your server's URL:
|
||||
|
||||
```bash
|
||||
hermes config set OPENAI_BASE_URL http://localhost:11434/v1 # Ollama
|
||||
hermes config set OPENAI_API_KEY ollama # Any non-empty value
|
||||
hermes config set HERMES_MODEL llama3.1
|
||||
hermes model
|
||||
# Select: Custom endpoint (enter URL manually)
|
||||
# API base URL: http://localhost:11434/v1
|
||||
# API key: ollama
|
||||
# Model name: qwen3.5:27b
|
||||
# Context length: 32768 ← set this to match your server's actual context window
|
||||
```
|
||||
|
||||
You can also save the endpoint interactively with `hermes model`. Hermes persists that custom endpoint in `config.yaml`, and auxiliary tasks configured with provider `main` follow the same saved endpoint.
|
||||
Hermes persists the endpoint in `config.yaml` and prompts for the context window size so compression triggers at the right time. If you leave context length blank, Hermes auto-detects it from the server's `/models` endpoint or [models.dev](https://models.dev).
|
||||
|
||||
This works with Ollama, vLLM, llama.cpp server, SGLang, LocalAI, and others. See the [Configuration guide](../user-guide/configuration.md) for details.
|
||||
|
||||
:::tip Ollama users
|
||||
If you set a custom `num_ctx` in Ollama (e.g., `ollama run --num_ctx 16384`), make sure to set the matching context length in Hermes — Ollama's `/api/show` reports the model's *maximum* context, not the effective `num_ctx` you configured.
|
||||
:::
|
||||
|
||||
### How much does it cost?
|
||||
|
||||
Hermes Agent itself is **free and open-source** (MIT license). You pay only for the LLM API usage from your chosen provider. Local models are completely free to run.
|
||||
@@ -200,7 +207,7 @@ hermes chat --model openrouter/meta-llama/llama-3.1-70b-instruct
|
||||
|
||||
#### Context length exceeded
|
||||
|
||||
**Cause:** The conversation has grown too long for the model's context window.
|
||||
**Cause:** The conversation has grown too long for the model's context window, or Hermes detected the wrong context length for your model.
|
||||
|
||||
**Solution:**
|
||||
```bash
|
||||
@@ -214,6 +221,35 @@ hermes chat
|
||||
hermes chat --model openrouter/google/gemini-2.0-flash-001
|
||||
```
|
||||
|
||||
If this happens on the first long conversation, Hermes may have the wrong context length for your model. Check what it detected:
|
||||
|
||||
```bash
|
||||
# Look at the status bar — it shows the detected context length
|
||||
/context
|
||||
```
|
||||
|
||||
To fix context detection, set it explicitly:
|
||||
|
||||
```yaml
|
||||
# In ~/.hermes/config.yaml
|
||||
model:
|
||||
default: your-model-name
|
||||
context_length: 131072 # your model's actual context window
|
||||
```
|
||||
|
||||
Or for custom endpoints, add it per-model:
|
||||
|
||||
```yaml
|
||||
custom_providers:
|
||||
- name: "My Server"
|
||||
base_url: "http://localhost:11434/v1"
|
||||
models:
|
||||
qwen3.5:27b:
|
||||
context_length: 32768
|
||||
```
|
||||
|
||||
See [Context Length Detection](../user-guide/configuration.md#context-length-detection) for how auto-detection works and all override options.
|
||||
|
||||
---
|
||||
|
||||
### Terminal Issues
|
||||
|
||||
@@ -21,9 +21,8 @@ Type `/` in the CLI to open the autocomplete menu. Built-in commands are case-in
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/new` | Start a new conversation (reset history) |
|
||||
| `/reset` | Reset conversation only (keep screen) |
|
||||
| `/clear` | Clear screen and reset conversation (fresh start) |
|
||||
| `/new` (alias: `/reset`) | Start a new session (fresh session ID + history) |
|
||||
| `/clear` | Clear screen and start a new session |
|
||||
| `/history` | Show conversation history |
|
||||
| `/save` | Save the current conversation |
|
||||
| `/retry` | Retry the last message (resend to agent) |
|
||||
@@ -31,6 +30,8 @@ Type `/` in the CLI to open the autocomplete menu. Built-in commands are case-in
|
||||
| `/title` | Set a title for the current session (usage: /title My Session Name) |
|
||||
| `/compress` | Manually compress conversation context (flush memories + summarize) |
|
||||
| `/rollback` | List or restore filesystem checkpoints (usage: /rollback [number]) |
|
||||
| `/stop` | Kill all running background processes |
|
||||
| `/statusbar` (alias: `/sb`) | Toggle the context/model status bar on or off |
|
||||
| `/background <prompt>` | Run a prompt in a separate background session. The agent processes your prompt independently — your current session stays free for other work. Results appear as a panel when the task finishes. See [CLI Background Sessions](/docs/user-guide/cli#background-sessions). |
|
||||
| `/plan [request]` | Load the bundled `plan` skill to write a markdown plan instead of executing the work. Plans are saved under `.hermes/plans/` relative to the active workspace/backend working directory. |
|
||||
|
||||
@@ -58,6 +59,7 @@ Type `/` in the CLI to open the autocomplete menu. Built-in commands are case-in
|
||||
| `/skills` | Search, install, inspect, or manage skills from online registries |
|
||||
| `/cron` | Manage scheduled tasks (list, add/create, edit, pause, resume, run, remove) |
|
||||
| `/reload-mcp` | Reload MCP servers from config.yaml |
|
||||
| `/plugins` | List installed plugins and their status |
|
||||
|
||||
### Info
|
||||
|
||||
@@ -95,7 +97,7 @@ The messaging gateway supports the following built-in commands inside Telegram,
|
||||
| `/new` | Start a new conversation. |
|
||||
| `/reset` | Reset conversation history. |
|
||||
| `/status` | Show session info. |
|
||||
| `/stop` | Interrupt the running agent without queuing a follow-up prompt. |
|
||||
| `/stop` | Kill all running background processes and interrupt the running agent. |
|
||||
| `/model [provider:model]` | Show or change the model, including provider switches. |
|
||||
| `/provider` | Show provider availability and auth status. |
|
||||
| `/personality [name]` | Set a personality overlay for the session. |
|
||||
@@ -113,13 +115,15 @@ The messaging gateway supports the following built-in commands inside Telegram,
|
||||
| `/background <prompt>` | Run a prompt in a separate background session. Results are delivered back to the same chat when the task finishes. See [Messaging Background Sessions](/docs/user-guide/messaging/#background-sessions). |
|
||||
| `/plan [request]` | Load the bundled `plan` skill to write a markdown plan instead of executing the work. Plans are saved under `.hermes/plans/` relative to the active workspace/backend working directory. |
|
||||
| `/reload-mcp` | Reload MCP servers from config. |
|
||||
| `/approve` | Approve and execute a pending dangerous command (terminal commands flagged for review). |
|
||||
| `/deny` | Reject a pending dangerous command. |
|
||||
| `/update` | Update Hermes Agent to the latest version. |
|
||||
| `/help` | Show messaging help. |
|
||||
| `/<skill-name>` | Invoke any installed skill by name. |
|
||||
|
||||
## Notes
|
||||
|
||||
- `/skin`, `/tools`, `/toolsets`, `/browser`, `/config`, `/prompt`, `/cron`, `/skills`, `/platforms`, `/paste`, and `/verbose` are **CLI-only** commands.
|
||||
- `/status`, `/stop`, `/sethome`, `/resume`, and `/update` are **messaging-only** commands.
|
||||
- `/skin`, `/tools`, `/toolsets`, `/browser`, `/config`, `/prompt`, `/cron`, `/skills`, `/platforms`, `/paste`, `/verbose`, `/statusbar`, and `/plugins` are **CLI-only** commands.
|
||||
- `/status`, `/sethome`, `/update`, `/approve`, and `/deny` are **messaging-only** commands.
|
||||
- `/background`, `/voice`, `/reload-mcp`, and `/rollback` work in **both** the CLI and the messaging gateway.
|
||||
- `/voice join`, `/voice channel`, and `/voice leave` are only meaningful on Discord.
|
||||
|
||||
@@ -141,6 +141,19 @@ This page documents the built-in Hermes tool registry as it exists in code. Avai
|
||||
|------|-------------|----------------------|
|
||||
| `todo` | Manage your task list for the current session. Use for complex tasks with 3+ steps or when the user provides multiple tasks. Call with no parameters to read the current list. Writing: - Provide 'todos' array to create/update items - merge=… | — |
|
||||
|
||||
## `vision` toolset
|
||||
|
||||
| Tool | Description | Requires environment |
|
||||
|------|-------------|----------------------|
|
||||
| `vision_analyze` | Analyze images using AI vision. Provides a comprehensive description and answers a specific question about the image content. | — |
|
||||
|
||||
## `web` toolset
|
||||
|
||||
| Tool | Description | Requires environment |
|
||||
|------|-------------|----------------------|
|
||||
| `web_search` | Search the web for information on any topic. Returns up to 5 relevant results with titles, URLs, and descriptions. | PARALLEL_API_KEY or FIRECRAWL_API_KEY or TAVILY_API_KEY |
|
||||
| `web_extract` | Extract content from web page URLs. Returns page content in markdown format. Also works with PDF URLs — pass the PDF link directly and it converts to markdown text. Pages under 5000 chars return full markdown; larger pages are LLM-summarized. | PARALLEL_API_KEY or FIRECRAWL_API_KEY or TAVILY_API_KEY |
|
||||
|
||||
## `tts` toolset
|
||||
|
||||
| Tool | Description | Requires environment |
|
||||
|
||||
@@ -10,26 +10,29 @@ Toolsets are named bundles of tools that you can enable with `hermes chat --tool
|
||||
|
||||
| Toolset | Kind | Resolves to |
|
||||
|---------|------|-------------|
|
||||
| `browser` | core | `browser_back`, `browser_click`, `browser_close`, `browser_get_images`, `browser_navigate`, `browser_press`, `browser_scroll`, `browser_snapshot`, `browser_type`, `browser_vision`, `web_search` |
|
||||
| `browser` | core | `browser_back`, `browser_click`, `browser_close`, `browser_console`, `browser_get_images`, `browser_navigate`, `browser_press`, `browser_scroll`, `browser_snapshot`, `browser_type`, `browser_vision`, `web_search` |
|
||||
| `clarify` | core | `clarify` |
|
||||
| `code_execution` | core | `execute_code` |
|
||||
| `cronjob` | core | `cronjob` |
|
||||
| `debugging` | composite | `patch`, `process`, `read_file`, `search_files`, `terminal`, `web_extract`, `web_search`, `write_file` |
|
||||
| `delegation` | core | `delegate_task` |
|
||||
| `file` | core | `patch`, `read_file`, `search_files`, `write_file` |
|
||||
| `hermes-cli` | platform | `browser_back`, `browser_click`, `browser_close`, `browser_get_images`, `browser_navigate`, `browser_press`, `browser_scroll`, `browser_snapshot`, `browser_type`, `browser_vision`, `clarify`, `delegate_task`, `execute_code`, `ha_call_service`, `ha_get_state`, `ha_list_entities`, `ha_list_services`, `honcho_conclude`, `honcho_context`, `honcho_profile`, `honcho_search`, `image_generate`, `cronjob`, `memory`, `mixture_of_agents`, `patch`, `process`, `read_file`, `search_files`, `send_message`, `session_search`, `skill_manage`, `skill_view`, `skills_list`, `terminal`, `text_to_speech`, `todo`, `vision_analyze`, `web_extract`, `web_search`, `write_file` |
|
||||
| `hermes-discord` | platform | `browser_back`, `browser_click`, `browser_close`, `browser_get_images`, `browser_navigate`, `browser_press`, `browser_scroll`, `browser_snapshot`, `browser_type`, `browser_vision`, `clarify`, `delegate_task`, `execute_code`, `ha_call_service`, `ha_get_state`, `ha_list_entities`, `ha_list_services`, `honcho_conclude`, `honcho_context`, `honcho_profile`, `honcho_search`, `image_generate`, `cronjob`, `memory`, `mixture_of_agents`, `patch`, `process`, `read_file`, `search_files`, `send_message`, `session_search`, `skill_manage`, `skill_view`, `skills_list`, `terminal`, `text_to_speech`, `todo`, `vision_analyze`, `web_extract`, `web_search`, `write_file` |
|
||||
| `hermes-email` | platform | `browser_back`, `browser_click`, `browser_close`, `browser_get_images`, `browser_navigate`, `browser_press`, `browser_scroll`, `browser_snapshot`, `browser_type`, `browser_vision`, `clarify`, `delegate_task`, `execute_code`, `ha_call_service`, `ha_get_state`, `ha_list_entities`, `ha_list_services`, `honcho_conclude`, `honcho_context`, `honcho_profile`, `honcho_search`, `image_generate`, `cronjob`, `memory`, `mixture_of_agents`, `patch`, `process`, `read_file`, `search_files`, `send_message`, `session_search`, `skill_manage`, `skill_view`, `skills_list`, `terminal`, `text_to_speech`, `todo`, `vision_analyze`, `web_extract`, `web_search`, `write_file` |
|
||||
| `hermes-gateway` | platform | `browser_back`, `browser_click`, `browser_close`, `browser_get_images`, `browser_navigate`, `browser_press`, `browser_scroll`, `browser_snapshot`, `browser_type`, `browser_vision`, `clarify`, `delegate_task`, `execute_code`, `ha_call_service`, `ha_get_state`, `ha_list_entities`, `ha_list_services`, `honcho_conclude`, `honcho_context`, `honcho_profile`, `honcho_search`, `image_generate`, `cronjob`, `memory`, `mixture_of_agents`, `patch`, `process`, `read_file`, `search_files`, `send_message`, `session_search`, `skill_manage`, `skill_view`, `skills_list`, `terminal`, `text_to_speech`, `todo`, `vision_analyze`, `web_extract`, `web_search`, `write_file` |
|
||||
| `hermes-homeassistant` | platform | `browser_back`, `browser_click`, `browser_close`, `browser_get_images`, `browser_navigate`, `browser_press`, `browser_scroll`, `browser_snapshot`, `browser_type`, `browser_vision`, `clarify`, `delegate_task`, `execute_code`, `ha_call_service`, `ha_get_state`, `ha_list_entities`, `ha_list_services`, `honcho_conclude`, `honcho_context`, `honcho_profile`, `honcho_search`, `image_generate`, `cronjob`, `memory`, `mixture_of_agents`, `patch`, `process`, `read_file`, `search_files`, `send_message`, `session_search`, `skill_manage`, `skill_view`, `skills_list`, `terminal`, `text_to_speech`, `todo`, `vision_analyze`, `web_extract`, `web_search`, `write_file` |
|
||||
| `hermes-signal` | platform | `browser_back`, `browser_click`, `browser_close`, `browser_get_images`, `browser_navigate`, `browser_press`, `browser_scroll`, `browser_snapshot`, `browser_type`, `browser_vision`, `clarify`, `delegate_task`, `execute_code`, `ha_call_service`, `ha_get_state`, `ha_list_entities`, `ha_list_services`, `honcho_conclude`, `honcho_context`, `honcho_profile`, `honcho_search`, `image_generate`, `cronjob`, `memory`, `mixture_of_agents`, `patch`, `process`, `read_file`, `search_files`, `send_message`, `session_search`, `skill_manage`, `skill_view`, `skills_list`, `terminal`, `text_to_speech`, `todo`, `vision_analyze`, `web_extract`, `web_search`, `write_file` |
|
||||
| `hermes-slack` | platform | `browser_back`, `browser_click`, `browser_close`, `browser_get_images`, `browser_navigate`, `browser_press`, `browser_scroll`, `browser_snapshot`, `browser_type`, `browser_vision`, `clarify`, `delegate_task`, `execute_code`, `ha_call_service`, `ha_get_state`, `ha_list_entities`, `ha_list_services`, `honcho_conclude`, `honcho_context`, `honcho_profile`, `honcho_search`, `image_generate`, `cronjob`, `memory`, `mixture_of_agents`, `patch`, `process`, `read_file`, `search_files`, `send_message`, `session_search`, `skill_manage`, `skill_view`, `skills_list`, `terminal`, `text_to_speech`, `todo`, `vision_analyze`, `web_extract`, `web_search`, `write_file` |
|
||||
| `hermes-telegram` | platform | `browser_back`, `browser_click`, `browser_close`, `browser_get_images`, `browser_navigate`, `browser_press`, `browser_scroll`, `browser_snapshot`, `browser_type`, `browser_vision`, `clarify`, `delegate_task`, `execute_code`, `ha_call_service`, `ha_get_state`, `ha_list_entities`, `ha_list_services`, `honcho_conclude`, `honcho_context`, `honcho_profile`, `honcho_search`, `image_generate`, `cronjob`, `memory`, `mixture_of_agents`, `patch`, `process`, `read_file`, `search_files`, `send_message`, `session_search`, `skill_manage`, `skill_view`, `skills_list`, `terminal`, `text_to_speech`, `todo`, `vision_analyze`, `web_extract`, `web_search`, `write_file` |
|
||||
| `hermes-whatsapp` | platform | `browser_back`, `browser_click`, `browser_close`, `browser_get_images`, `browser_navigate`, `browser_press`, `browser_scroll`, `browser_snapshot`, `browser_type`, `browser_vision`, `clarify`, `delegate_task`, `execute_code`, `ha_call_service`, `ha_get_state`, `ha_list_entities`, `ha_list_services`, `honcho_conclude`, `honcho_context`, `honcho_profile`, `honcho_search`, `image_generate`, `cronjob`, `memory`, `mixture_of_agents`, `patch`, `process`, `read_file`, `search_files`, `send_message`, `session_search`, `skill_manage`, `skill_view`, `skills_list`, `terminal`, `text_to_speech`, `todo`, `vision_analyze`, `web_extract`, `web_search`, `write_file` |
|
||||
| `hermes-acp` | platform | `browser_back`, `browser_click`, `browser_close`, `browser_console`, `browser_get_images`, `browser_navigate`, `browser_press`, `browser_scroll`, `browser_snapshot`, `browser_type`, `browser_vision`, `delegate_task`, `execute_code`, `memory`, `patch`, `process`, `read_file`, `search_files`, `session_search`, `skill_manage`, `skill_view`, `skills_list`, `terminal`, `todo`, `vision_analyze`, `web_extract`, `web_search`, `write_file` |
|
||||
| `hermes-cli` | platform | `browser_back`, `browser_click`, `browser_close`, `browser_console`, `browser_get_images`, `browser_navigate`, `browser_press`, `browser_scroll`, `browser_snapshot`, `browser_type`, `browser_vision`, `clarify`, `cronjob`, `delegate_task`, `execute_code`, `ha_call_service`, `ha_get_state`, `ha_list_entities`, `ha_list_services`, `honcho_conclude`, `honcho_context`, `honcho_profile`, `honcho_search`, `image_generate`, `memory`, `mixture_of_agents`, `patch`, `process`, `read_file`, `search_files`, `send_message`, `session_search`, `skill_manage`, `skill_view`, `skills_list`, `terminal`, `text_to_speech`, `todo`, `vision_analyze`, `web_extract`, `web_search`, `write_file` |
|
||||
| `hermes-discord` | platform | _(same as hermes-cli)_ |
|
||||
| `hermes-email` | platform | _(same as hermes-cli)_ |
|
||||
| `hermes-gateway` | composite | Union of all messaging platform toolsets |
|
||||
| `hermes-homeassistant` | platform | _(same as hermes-cli)_ |
|
||||
| `hermes-signal` | platform | _(same as hermes-cli)_ |
|
||||
| `hermes-slack` | platform | _(same as hermes-cli)_ |
|
||||
| `hermes-sms` | platform | _(same as hermes-cli)_ |
|
||||
| `hermes-telegram` | platform | _(same as hermes-cli)_ |
|
||||
| `hermes-whatsapp` | platform | _(same as hermes-cli)_ |
|
||||
| `homeassistant` | core | `ha_call_service`, `ha_get_state`, `ha_list_entities`, `ha_list_services` |
|
||||
| `honcho` | core | `honcho_conclude`, `honcho_context`, `honcho_profile`, `honcho_search` |
|
||||
| `image_gen` | core | `image_generate` |
|
||||
| `memory` | core | `memory` |
|
||||
| `messaging` | core | `send_message` |
|
||||
| `moa` | core | `mixture_of_agents` |
|
||||
| `rl` | core | `rl_check_status`, `rl_edit_config`, `rl_get_current_config`, `rl_get_results`, `rl_list_environments`, `rl_list_runs`, `rl_select_environment`, `rl_start_training`, `rl_stop_training`, `rl_test_inference` |
|
||||
| `safe` | composite | `image_generate`, `mixture_of_agents`, `vision_analyze`, `web_extract`, `web_search` |
|
||||
|
||||
@@ -94,7 +94,7 @@ When resuming a previous session (`hermes -c` or `hermes --resume <id>`), a "Pre
|
||||
| `Ctrl+B` | Start/stop voice recording when voice mode is enabled (`voice.record_key`, default: `ctrl+b`) |
|
||||
| `Ctrl+C` | Interrupt agent (double-press within 2s to force exit) |
|
||||
| `Ctrl+D` | Exit |
|
||||
| `Tab` | Autocomplete slash commands |
|
||||
| `Tab` | Accept auto-suggestion (ghost text) or autocomplete slash commands |
|
||||
|
||||
## Slash Commands
|
||||
|
||||
|
||||
@@ -74,7 +74,8 @@ You need at least one way to connect to an LLM. Use `hermes model` to switch pro
|
||||
| **MiniMax China** | `MINIMAX_CN_API_KEY` in `~/.hermes/.env` (provider: `minimax-cn`) |
|
||||
| **Alibaba Cloud** | `DASHSCOPE_API_KEY` in `~/.hermes/.env` (provider: `alibaba`, aliases: `dashscope`, `qwen`) |
|
||||
| **Kilo Code** | `KILOCODE_API_KEY` in `~/.hermes/.env` (provider: `kilocode`) |
|
||||
| **Alibaba Cloud** | `DASHSCOPE_API_KEY` in `~/.hermes/.env` (provider: `alibaba`) |
|
||||
| **OpenCode Zen** | `OPENCODE_ZEN_API_KEY` in `~/.hermes/.env` (provider: `opencode-zen`) |
|
||||
| **OpenCode Go** | `OPENCODE_GO_API_KEY` in `~/.hermes/.env` (provider: `opencode-go`) |
|
||||
| **Custom Endpoint** | `hermes model` (saved in `config.yaml`) or `OPENAI_BASE_URL` + `OPENAI_API_KEY` in `~/.hermes/.env` |
|
||||
|
||||
:::info Codex Note
|
||||
@@ -413,6 +414,54 @@ LLM_MODEL=meta-llama/Llama-3.1-70B-Instruct-Turbo
|
||||
|
||||
---
|
||||
|
||||
### Context Length Detection
|
||||
|
||||
Hermes uses a multi-source resolution chain to detect the correct context window for your model and provider:
|
||||
|
||||
1. **Config override** — `model.context_length` in config.yaml (highest priority)
|
||||
2. **Custom provider per-model** — `custom_providers[].models.<id>.context_length`
|
||||
3. **Persistent cache** — previously discovered values (survives restarts)
|
||||
4. **Endpoint `/models`** — queries your server's API (local/custom endpoints)
|
||||
5. **Anthropic `/v1/models`** — queries Anthropic's API for `max_input_tokens` (API-key users only)
|
||||
6. **OpenRouter API** — live model metadata from OpenRouter
|
||||
7. **Nous Portal** — suffix-matches Nous model IDs against OpenRouter metadata
|
||||
8. **[models.dev](https://models.dev)** — community-maintained registry with provider-specific context lengths for 3800+ models across 100+ providers
|
||||
9. **Fallback defaults** — broad model family patterns (128K default)
|
||||
|
||||
For most setups this works out of the box. The system is provider-aware — the same model can have different context limits depending on who serves it (e.g., `claude-opus-4.6` is 1M on Anthropic direct but 128K on GitHub Copilot).
|
||||
|
||||
To set the context length explicitly, add `context_length` to your model config:
|
||||
|
||||
```yaml
|
||||
model:
|
||||
default: "qwen3.5:9b"
|
||||
base_url: "http://localhost:8080/v1"
|
||||
context_length: 131072 # tokens
|
||||
```
|
||||
|
||||
For custom endpoints, you can also set context length per model:
|
||||
|
||||
```yaml
|
||||
custom_providers:
|
||||
- name: "My Local LLM"
|
||||
base_url: "http://localhost:11434/v1"
|
||||
models:
|
||||
qwen3.5:27b:
|
||||
context_length: 32768
|
||||
deepseek-r1:70b:
|
||||
context_length: 65536
|
||||
```
|
||||
|
||||
`hermes model` will prompt for context length when configuring a custom endpoint. Leave it blank for auto-detection.
|
||||
|
||||
:::tip When to set this manually
|
||||
- You're using Ollama with a custom `num_ctx` that's lower than the model's maximum
|
||||
- You want to limit context below the model's maximum (e.g., 8k on a 128k model to save VRAM)
|
||||
- You're running behind a proxy that doesn't expose `/v1/models`
|
||||
:::
|
||||
|
||||
---
|
||||
|
||||
### Choosing the Right Setup
|
||||
|
||||
| Use Case | Recommended |
|
||||
@@ -805,6 +854,31 @@ agent:
|
||||
|
||||
Budget pressure is enabled by default. The agent sees warnings naturally as part of tool results, encouraging it to consolidate its work and deliver a response before running out of iterations.
|
||||
|
||||
## Context Pressure Warnings
|
||||
|
||||
Separate from iteration budget pressure, context pressure tracks how close the conversation is to the **compaction threshold** — the point where context compression fires to summarize older messages. This helps both you and the agent understand when the conversation is getting long.
|
||||
|
||||
| Progress | Level | What happens |
|
||||
|----------|-------|-------------|
|
||||
| **≥ 60%** to threshold | Info | CLI shows a cyan progress bar; gateway sends an informational notice |
|
||||
| **≥ 85%** to threshold | Warning | CLI shows a bold yellow bar; gateway warns compaction is imminent |
|
||||
|
||||
In the CLI, context pressure appears as a progress bar in the tool output feed:
|
||||
|
||||
```
|
||||
◐ context ████████████░░░░░░░░ 62% to compaction 48k threshold (50%) · approaching compaction
|
||||
```
|
||||
|
||||
On messaging platforms, a plain-text notification is sent:
|
||||
|
||||
```
|
||||
◐ Context: ████████████░░░░░░░░ 62% to compaction (threshold: 50% of window).
|
||||
```
|
||||
|
||||
If auto-compression is disabled, the warning tells you context may be truncated instead.
|
||||
|
||||
Context pressure is automatic — no configuration needed. It fires purely as a user-facing notification and does not modify the message stream or inject anything into the model's context.
|
||||
|
||||
## Auxiliary Models
|
||||
|
||||
Hermes uses lightweight "auxiliary" models for side tasks like image analysis, web page summarization, and browser screenshot analysis. By default, these use **Gemini Flash** via auto-detection — you don't need to configure anything.
|
||||
@@ -993,6 +1067,7 @@ tts:
|
||||
openai:
|
||||
model: "gpt-4o-mini-tts"
|
||||
voice: "alloy" # alloy, echo, fable, onyx, nova, shimmer
|
||||
base_url: "https://api.openai.com/v1" # Override for OpenAI-compatible TTS endpoints
|
||||
neutts:
|
||||
ref_audio: ''
|
||||
ref_text: ''
|
||||
@@ -1016,6 +1091,7 @@ display:
|
||||
show_reasoning: false # Show model reasoning/thinking above each response (toggle with /reasoning show|hide)
|
||||
streaming: false # Stream tokens to terminal as they arrive (real-time output)
|
||||
background_process_notifications: all # all | result | error | off (gateway only)
|
||||
show_cost: false # Show estimated $ cost in the CLI status bar
|
||||
```
|
||||
|
||||
### Theme mode
|
||||
|
||||
@@ -42,6 +42,7 @@ tts:
|
||||
openai:
|
||||
model: "gpt-4o-mini-tts"
|
||||
voice: "alloy" # alloy, echo, fable, onyx, nova, shimmer
|
||||
base_url: "https://api.openai.com/v1" # Override for OpenAI-compatible TTS endpoints
|
||||
neutts:
|
||||
ref_audio: ''
|
||||
ref_text: ''
|
||||
|
||||
@@ -404,6 +404,7 @@ tts:
|
||||
openai:
|
||||
model: "gpt-4o-mini-tts"
|
||||
voice: "alloy" # alloy, echo, fable, onyx, nova, shimmer
|
||||
base_url: "https://api.openai.com/v1" # optional: override for self-hosted or OpenAI-compatible endpoints
|
||||
neutts:
|
||||
ref_audio: ''
|
||||
ref_text: ''
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
---
|
||||
sidebar_position: 1
|
||||
title: "Messaging Gateway"
|
||||
description: "Chat with Hermes from Telegram, Discord, Slack, WhatsApp, Signal, SMS, Email, Home Assistant, Mattermost, Matrix, DingTalk, or any OpenAI-compatible frontend via the API server — architecture and setup overview"
|
||||
description: "Chat with Hermes from Telegram, Discord, Slack, WhatsApp, Signal, SMS, Email, Home Assistant, Mattermost, Matrix, DingTalk, Webhooks, or any OpenAI-compatible frontend via the API server — architecture and setup overview"
|
||||
---
|
||||
|
||||
# Messaging Gateway
|
||||
@@ -28,6 +28,7 @@ flowchart TB
|
||||
mx[Matrix]
|
||||
dt[DingTalk]
|
||||
api["API Server<br/>(OpenAI-compatible)"]
|
||||
wh[Webhooks]
|
||||
end
|
||||
|
||||
store["Session store<br/>per chat"]
|
||||
@@ -47,6 +48,7 @@ flowchart TB
|
||||
mx --> store
|
||||
dt --> store
|
||||
api --> store
|
||||
wh --> store
|
||||
store --> agent
|
||||
cron --> store
|
||||
```
|
||||
@@ -88,6 +90,8 @@ hermes gateway status --system # Linux only: inspect the system service
|
||||
| `/undo` | Remove the last exchange |
|
||||
| `/status` | Show session info |
|
||||
| `/stop` | Stop the running agent |
|
||||
| `/approve` | Approve a pending dangerous command |
|
||||
| `/deny` | Reject a pending dangerous command |
|
||||
| `/sethome` | Set this chat as the home channel |
|
||||
| `/compress` | Manually compress conversation context |
|
||||
| `/title [name]` | Set or show the session title |
|
||||
@@ -309,6 +313,7 @@ Each platform has its own toolset:
|
||||
| Matrix | `hermes-matrix` | Full tools including terminal |
|
||||
| DingTalk | `hermes-dingtalk` | Full tools including terminal |
|
||||
| API Server | `hermes` (default) | Full tools including terminal |
|
||||
| Webhooks | `hermes-webhook` | Full tools including terminal |
|
||||
|
||||
## Next Steps
|
||||
|
||||
@@ -324,3 +329,4 @@ Each platform has its own toolset:
|
||||
- [Matrix Setup](matrix.md)
|
||||
- [DingTalk Setup](dingtalk.md)
|
||||
- [Open WebUI + API Server](open-webui.md)
|
||||
- [Webhooks](webhooks.md)
|
||||
|
||||
@@ -177,6 +177,19 @@ All phone numbers are automatically redacted in logs:
|
||||
- `+15551234567` → `+155****4567`
|
||||
- This applies to both Hermes gateway logs and the global redaction system
|
||||
|
||||
### Note to Self (Single-Number Setup)
|
||||
|
||||
If you run signal-cli as a **linked secondary device** on your own phone number (rather than a separate bot number), you can interact with Hermes through Signal's "Note to Self" feature.
|
||||
|
||||
Just send a message to yourself from your phone — signal-cli picks it up and Hermes responds in the same conversation.
|
||||
|
||||
**How it works:**
|
||||
- "Note to Self" messages arrive as `syncMessage.sentMessage` envelopes
|
||||
- The adapter detects when these are addressed to the bot's own account and processes them as regular inbound messages
|
||||
- Echo-back protection (sent-timestamp tracking) prevents infinite loops — the bot's own replies are filtered out automatically
|
||||
|
||||
**No extra configuration needed.** This works automatically as long as `SIGNAL_ACCOUNT` matches your phone number.
|
||||
|
||||
### Health Monitoring
|
||||
|
||||
The adapter monitors the SSE connection and automatically reconnects if:
|
||||
|
||||
@@ -0,0 +1,310 @@
|
||||
---
|
||||
sidebar_position: 13
|
||||
title: "Webhooks"
|
||||
description: "Receive events from GitHub, GitLab, and other services to trigger Hermes agent runs"
|
||||
---
|
||||
|
||||
# Webhooks
|
||||
|
||||
Receive events from external services (GitHub, GitLab, JIRA, Stripe, etc.) and trigger Hermes agent runs automatically. The webhook adapter runs an HTTP server that accepts POST requests, validates HMAC signatures, transforms payloads into agent prompts, and routes responses back to the source or to another configured platform.
|
||||
|
||||
The agent processes the event and can respond by posting comments on PRs, sending messages to Telegram/Discord, or logging the result.
|
||||
|
||||
---
|
||||
|
||||
## Quick Start
|
||||
|
||||
1. Enable via `hermes setup gateway` or environment variables
|
||||
2. Define webhook routes in `config.yaml`
|
||||
3. Point your service at `http://your-server:8644/webhooks/<route-name>`
|
||||
|
||||
---
|
||||
|
||||
## Setup
|
||||
|
||||
There are two ways to enable the webhook adapter.
|
||||
|
||||
### Via setup wizard
|
||||
|
||||
```bash
|
||||
hermes setup gateway
|
||||
```
|
||||
|
||||
Follow the prompts to enable webhooks, set the port, and set a global HMAC secret.
|
||||
|
||||
### Via environment variables
|
||||
|
||||
Add to `~/.hermes/.env`:
|
||||
|
||||
```bash
|
||||
WEBHOOK_ENABLED=true
|
||||
WEBHOOK_PORT=8644 # default
|
||||
WEBHOOK_SECRET=your-global-secret
|
||||
```
|
||||
|
||||
### Verify the server
|
||||
|
||||
Once the gateway is running:
|
||||
|
||||
```bash
|
||||
curl http://localhost:8644/health
|
||||
```
|
||||
|
||||
Expected response:
|
||||
|
||||
```json
|
||||
{"status": "ok", "platform": "webhook"}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Configuring Routes {#configuring-routes}
|
||||
|
||||
Routes define how different webhook sources are handled. Each route is a named entry under `platforms.webhook.extra.routes` in your `config.yaml`.
|
||||
|
||||
### Route properties
|
||||
|
||||
| Property | Required | Description |
|
||||
|----------|----------|-------------|
|
||||
| `events` | No | List of event types to accept (e.g. `["pull_request"]`). If empty, all events are accepted. Event type is read from `X-GitHub-Event`, `X-GitLab-Event`, or `event_type` in the payload. |
|
||||
| `secret` | **Yes** | HMAC secret for signature validation. Falls back to the global `secret` if not set on the route. Set to `"INSECURE_NO_AUTH"` for testing only (skips validation). |
|
||||
| `prompt` | No | Template string with dot-notation payload access (e.g. `{pull_request.title}`). If omitted, the full JSON payload is dumped into the prompt. |
|
||||
| `skills` | No | List of skill names to load for the agent run. |
|
||||
| `deliver` | No | Where to send the response: `github_comment`, `telegram`, `discord`, `slack`, `signal`, `sms`, or `log` (default). |
|
||||
| `deliver_extra` | No | Additional delivery config — keys depend on `deliver` type (e.g. `repo`, `pr_number`, `chat_id`). Values support the same `{dot.notation}` templates as `prompt`. |
|
||||
|
||||
### Full example
|
||||
|
||||
```yaml
|
||||
platforms:
|
||||
webhook:
|
||||
enabled: true
|
||||
extra:
|
||||
port: 8644
|
||||
secret: "global-fallback-secret"
|
||||
routes:
|
||||
github-pr:
|
||||
events: ["pull_request"]
|
||||
secret: "github-webhook-secret"
|
||||
prompt: |
|
||||
Review this pull request:
|
||||
Repository: {repository.full_name}
|
||||
PR #{number}: {pull_request.title}
|
||||
Author: {pull_request.user.login}
|
||||
URL: {pull_request.html_url}
|
||||
Diff URL: {pull_request.diff_url}
|
||||
Action: {action}
|
||||
skills: ["github-code-review"]
|
||||
deliver: "github_comment"
|
||||
deliver_extra:
|
||||
repo: "{repository.full_name}"
|
||||
pr_number: "{number}"
|
||||
deploy-notify:
|
||||
events: ["push"]
|
||||
secret: "deploy-secret"
|
||||
prompt: "New push to {repository.full_name} branch {ref}: {head_commit.message}"
|
||||
deliver: "telegram"
|
||||
```
|
||||
|
||||
### Prompt Templates
|
||||
|
||||
Prompts use dot-notation to access nested fields in the webhook payload:
|
||||
|
||||
- `{pull_request.title}` resolves to `payload["pull_request"]["title"]`
|
||||
- `{repository.full_name}` resolves to `payload["repository"]["full_name"]`
|
||||
- Missing keys are left as the literal `{key}` string (no error)
|
||||
- Nested dicts and lists are JSON-serialized and truncated at 2000 characters
|
||||
|
||||
If no `prompt` template is configured for a route, the entire payload is dumped as indented JSON (truncated at 4000 characters).
|
||||
|
||||
The same dot-notation templates work in `deliver_extra` values.
|
||||
|
||||
---
|
||||
|
||||
## GitHub PR Review (Step by Step) {#github-pr-review}
|
||||
|
||||
This walkthrough sets up automatic code review on every pull request.
|
||||
|
||||
### 1. Create the webhook in GitHub
|
||||
|
||||
1. Go to your repository → **Settings** → **Webhooks** → **Add webhook**
|
||||
2. Set **Payload URL** to `http://your-server:8644/webhooks/github-pr`
|
||||
3. Set **Content type** to `application/json`
|
||||
4. Set **Secret** to match your route config (e.g. `github-webhook-secret`)
|
||||
5. Under **Which events?**, select **Let me select individual events** and check **Pull requests**
|
||||
6. Click **Add webhook**
|
||||
|
||||
### 2. Add the route config
|
||||
|
||||
Add the `github-pr` route to your `~/.hermes/config.yaml` as shown in the example above.
|
||||
|
||||
### 3. Ensure `gh` CLI is authenticated
|
||||
|
||||
The `github_comment` delivery type uses the GitHub CLI to post comments:
|
||||
|
||||
```bash
|
||||
gh auth login
|
||||
```
|
||||
|
||||
### 4. Test it
|
||||
|
||||
Open a pull request on the repository. The webhook fires, Hermes processes the event, and posts a review comment on the PR.
|
||||
|
||||
---
|
||||
|
||||
## GitLab Webhook Setup {#gitlab-webhook-setup}
|
||||
|
||||
GitLab webhooks work similarly but use a different authentication mechanism. GitLab sends the secret as a plain `X-Gitlab-Token` header (exact string match, not HMAC).
|
||||
|
||||
### 1. Create the webhook in GitLab
|
||||
|
||||
1. Go to your project → **Settings** → **Webhooks**
|
||||
2. Set the **URL** to `http://your-server:8644/webhooks/gitlab-mr`
|
||||
3. Enter your **Secret token**
|
||||
4. Select **Merge request events** (and any other events you want)
|
||||
5. Click **Add webhook**
|
||||
|
||||
### 2. Add the route config
|
||||
|
||||
```yaml
|
||||
platforms:
|
||||
webhook:
|
||||
enabled: true
|
||||
extra:
|
||||
routes:
|
||||
gitlab-mr:
|
||||
events: ["merge_request"]
|
||||
secret: "your-gitlab-secret-token"
|
||||
prompt: |
|
||||
Review this merge request:
|
||||
Project: {project.path_with_namespace}
|
||||
MR !{object_attributes.iid}: {object_attributes.title}
|
||||
Author: {object_attributes.last_commit.author.name}
|
||||
URL: {object_attributes.url}
|
||||
Action: {object_attributes.action}
|
||||
deliver: "log"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Delivery Options {#delivery-options}
|
||||
|
||||
The `deliver` field controls where the agent's response goes after processing the webhook event.
|
||||
|
||||
| Deliver Type | Description |
|
||||
|-------------|-------------|
|
||||
| `log` | Logs the response to the gateway log output. This is the default and is useful for testing. |
|
||||
| `github_comment` | Posts the response as a PR/issue comment via the `gh` CLI. Requires `deliver_extra.repo` and `deliver_extra.pr_number`. The `gh` CLI must be installed and authenticated on the gateway host (`gh auth login`). |
|
||||
| `telegram` | Routes the response to Telegram. Uses the home channel, or specify `chat_id` in `deliver_extra`. |
|
||||
| `discord` | Routes the response to Discord. Uses the home channel, or specify `chat_id` in `deliver_extra`. |
|
||||
| `slack` | Routes the response to Slack. Uses the home channel, or specify `chat_id` in `deliver_extra`. |
|
||||
| `signal` | Routes the response to Signal. Uses the home channel, or specify `chat_id` in `deliver_extra`. |
|
||||
| `sms` | Routes the response to SMS via Twilio. Uses the home channel, or specify `chat_id` in `deliver_extra`. |
|
||||
|
||||
For cross-platform delivery (telegram, discord, slack, signal, sms), the target platform must also be enabled and connected in the gateway. If no `chat_id` is provided in `deliver_extra`, the response is sent to that platform's configured home channel.
|
||||
|
||||
---
|
||||
|
||||
## Security {#security}
|
||||
|
||||
The webhook adapter includes multiple layers of security:
|
||||
|
||||
### HMAC signature validation
|
||||
|
||||
The adapter validates incoming webhook signatures using the appropriate method for each source:
|
||||
|
||||
- **GitHub**: `X-Hub-Signature-256` header — HMAC-SHA256 hex digest prefixed with `sha256=`
|
||||
- **GitLab**: `X-Gitlab-Token` header — plain secret string match
|
||||
- **Generic**: `X-Webhook-Signature` header — raw HMAC-SHA256 hex digest
|
||||
|
||||
If a secret is configured but no recognized signature header is present, the request is rejected.
|
||||
|
||||
### Secret is required
|
||||
|
||||
Every route must have a secret — either set directly on the route or inherited from the global `secret`. Routes without a secret cause the adapter to fail at startup with an error. For development/testing only, you can set the secret to `"INSECURE_NO_AUTH"` to skip validation entirely.
|
||||
|
||||
### Rate limiting
|
||||
|
||||
Each route is rate-limited to **30 requests per minute** by default (fixed-window). Configure this globally:
|
||||
|
||||
```yaml
|
||||
platforms:
|
||||
webhook:
|
||||
extra:
|
||||
rate_limit: 60 # requests per minute
|
||||
```
|
||||
|
||||
Requests exceeding the limit receive a `429 Too Many Requests` response.
|
||||
|
||||
### Idempotency
|
||||
|
||||
Delivery IDs (from `X-GitHub-Delivery`, `X-Request-ID`, or a timestamp fallback) are cached for **1 hour**. Duplicate deliveries (e.g. webhook retries) are silently skipped with a `200` response, preventing duplicate agent runs.
|
||||
|
||||
### Body size limits
|
||||
|
||||
Payloads exceeding **1 MB** are rejected before the body is read. Configure this:
|
||||
|
||||
```yaml
|
||||
platforms:
|
||||
webhook:
|
||||
extra:
|
||||
max_body_bytes: 2097152 # 2 MB
|
||||
```
|
||||
|
||||
### Prompt injection risk
|
||||
|
||||
:::warning
|
||||
Webhook payloads contain attacker-controlled data — PR titles, commit messages, issue descriptions, etc. can all contain malicious instructions. Run the gateway in a sandboxed environment (Docker, VM) when exposed to the internet. Consider using the Docker or SSH terminal backend for isolation.
|
||||
:::
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting {#troubleshooting}
|
||||
|
||||
### Webhook not arriving
|
||||
|
||||
- Verify the port is exposed and accessible from the webhook source
|
||||
- Check firewall rules — port `8644` (or your configured port) must be open
|
||||
- Verify the URL path matches: `http://your-server:8644/webhooks/<route-name>`
|
||||
- Use the `/health` endpoint to confirm the server is running
|
||||
|
||||
### Signature validation failing
|
||||
|
||||
- Ensure the secret in your route config exactly matches the secret configured in the webhook source
|
||||
- For GitHub, the secret is HMAC-based — check `X-Hub-Signature-256`
|
||||
- For GitLab, the secret is a plain token match — check `X-Gitlab-Token`
|
||||
- Check gateway logs for `Invalid signature` warnings
|
||||
|
||||
### Event being ignored
|
||||
|
||||
- Check that the event type is in your route's `events` list
|
||||
- GitHub events use values like `pull_request`, `push`, `issues` (the `X-GitHub-Event` header value)
|
||||
- GitLab events use values like `merge_request`, `push` (the `X-GitLab-Event` header value)
|
||||
- If `events` is empty or not set, all events are accepted
|
||||
|
||||
### Agent not responding
|
||||
|
||||
- Run the gateway in foreground to see logs: `hermes gateway run`
|
||||
- Check that the prompt template is rendering correctly
|
||||
- Verify the delivery target is configured and connected
|
||||
|
||||
### Duplicate responses
|
||||
|
||||
- The idempotency cache should prevent this — check that the webhook source is sending a delivery ID header (`X-GitHub-Delivery` or `X-Request-ID`)
|
||||
- Delivery IDs are cached for 1 hour
|
||||
|
||||
### `gh` CLI errors (GitHub comment delivery)
|
||||
|
||||
- Run `gh auth login` on the gateway host
|
||||
- Ensure the authenticated GitHub user has write access to the repository
|
||||
- Check that `gh` is installed and on the PATH
|
||||
|
||||
---
|
||||
|
||||
## Environment Variables {#environment-variables}
|
||||
|
||||
| Variable | Description | Default |
|
||||
|----------|-------------|---------|
|
||||
| `WEBHOOK_ENABLED` | Enable the webhook platform adapter | `false` |
|
||||
| `WEBHOOK_PORT` | HTTP server port for receiving webhooks | `8644` |
|
||||
| `WEBHOOK_SECRET` | Global HMAC secret (used as fallback when routes don't specify their own) | _(none)_ |
|
||||
@@ -52,6 +52,7 @@ const sidebars: SidebarsConfig = {
|
||||
'user-guide/messaging/matrix',
|
||||
'user-guide/messaging/dingtalk',
|
||||
'user-guide/messaging/open-webui',
|
||||
'user-guide/messaging/webhooks',
|
||||
],
|
||||
},
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user