Compare commits
285 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 50c8198918 | |||
| 6fb69229ca | |||
| 2edebedc9e | |||
| f9667331e5 | |||
| 9527707f80 | |||
| cf012a05d8 | |||
| 3b69b2fd61 | |||
| 8826d9c197 | |||
| a2c9f5d0a7 | |||
| 8322b42c6c | |||
| 285bb2b915 | |||
| 54e0eb24c0 | |||
| 73bccc94c7 | |||
| 598cba62ad | |||
| 5ff65dbf68 | |||
| c20e236b71 | |||
| 994faacce8 | |||
| 8a59f8a9ed | |||
| 1c352f6b1d | |||
| 11a89cc032 | |||
| 45acd9beb5 | |||
| c5c0bb9a73 | |||
| 20f2258f34 | |||
| 607be54a24 | |||
| e5333e793c | |||
| 148459716c | |||
| 53e4a2f2c6 | |||
| 07db20c72d | |||
| 38436eb4e3 | |||
| 86fd0f846d | |||
| 4459913f40 | |||
| d7ef562a05 | |||
| 47010e0757 | |||
| 213e39463b | |||
| 2297c5f5ce | |||
| c7fece1f9d | |||
| c096a6935f | |||
| a155b4a159 | |||
| b449a0e049 | |||
| 85cdb04bd4 | |||
| 9b14b76eb3 | |||
| 2992802b35 | |||
| 04a0c3cb95 | |||
| 8444f66890 | |||
| bb85404b16 | |||
| 8ab1aa2efc | |||
| 511ed4dacc | |||
| d465fc5869 | |||
| 016ae5c334 | |||
| 304fb921bf | |||
| 64b354719f | |||
| e9b8ece103 | |||
| 3f43aec15d | |||
| aa583cb14e | |||
| 0a83187801 | |||
| 2b60478fc2 | |||
| c6fd2619f7 | |||
| d2206c69cc | |||
| 103beea7a6 | |||
| 287d3e12c7 | |||
| 6fd58e1e4a | |||
| 235e6ecc0e | |||
| 1648e41c17 | |||
| c4cdf3b861 | |||
| 02f5e3dc27 | |||
| b7d330211a | |||
| a5f4d652d3 | |||
| 6358501915 | |||
| 31e7276474 | |||
| 036dacf659 | |||
| 3207b9bda0 | |||
| eb07c05646 | |||
| f362083c64 | |||
| 3b569ff576 | |||
| bd09e42eac | |||
| cc3aa76675 | |||
| 2ff1ef6ae6 | |||
| 1229d8855c | |||
| d49126b987 | |||
| cb883f9e97 | |||
| d5b9db8b4a | |||
| 6a37802476 | |||
| d0e1388ca9 | |||
| 78a74bb097 | |||
| bedbeebbc8 | |||
| f53250b5e1 | |||
| 00591e3801 | |||
| be768db627 | |||
| 42721dbe1c | |||
| 8f553a55b2 | |||
| a82097e7a2 | |||
| 0dd5055d59 | |||
| 5b386ced71 | |||
| 0219da9626 | |||
| 1f37ef2fd1 | |||
| 5435287dec | |||
| 41d3d7afb7 | |||
| 39231f29c6 | |||
| c730ab8ad7 | |||
| c74017f405 | |||
| 40f2368875 | |||
| 319aabbb80 | |||
| 26f3a05c9c | |||
| 15096903c7 | |||
| 26859e3fcb | |||
| aedc767c66 | |||
| 23212d6b40 | |||
| 7ffefc2d6c | |||
| 2812bfe5b9 | |||
| ca30803d89 | |||
| 7f1204840d | |||
| dd2ec6bfa0 | |||
| 3746c60439 | |||
| 727f0eaf74 | |||
| 275256cdb4 | |||
| 9503896aa2 | |||
| 04e36851b7 | |||
| a8e0a1148f | |||
| 842a122964 | |||
| 2d693c865c | |||
| f3920fec0b | |||
| c6ed61430a | |||
| cb2a737bc8 | |||
| 18840bcff8 | |||
| 0478266831 | |||
| beccd1bc04 | |||
| 68ecdb6e26 | |||
| fc0623f0af | |||
| 9c71f3a6ea | |||
| c4b9750bc1 | |||
| 39b1336d1f | |||
| f81dba0da2 | |||
| 8e06db56fd | |||
| cb31732c4f | |||
| 097702c8a7 | |||
| 72aebfbb24 | |||
| c9f78d110a | |||
| baa0de7649 | |||
| 57e4b61155 | |||
| 53a024a941 | |||
| cb7b740e32 | |||
| 4b4b4d47bc | |||
| 46cef4b7fa | |||
| 9931d1d814 | |||
| cc15b55bb9 | |||
| 371166fe26 | |||
| 33c615504d | |||
| 561cea0d4a | |||
| 496bfb3c59 | |||
| 99d859ce4a | |||
| 4cbf54fb33 | |||
| 77cd5bf565 | |||
| bf54f1fb2f | |||
| 3bc661ea29 | |||
| 52c11d172a | |||
| 9804aa7443 | |||
| 7aed09e1ba | |||
| dd2b0b4775 | |||
| ea2d5754ab | |||
| 9a3a2925ed | |||
| c189d5e98b | |||
| 6bbac046a7 | |||
| bbc7316007 | |||
| 35dbb1da3f | |||
| 6d6b3b03ac | |||
| 1b573b7b21 | |||
| 7e4dd6ea02 | |||
| aeb53131f3 | |||
| 783c6b6ed6 | |||
| 4a260b51fe | |||
| ebe3270430 | |||
| 77b97b810a | |||
| 9db94e8521 | |||
| cac1b1b724 | |||
| 56524bb1d9 | |||
| 0642b6cc53 | |||
| eec1db36f7 | |||
| 713a614ea8 | |||
| a27167fb30 | |||
| a2c0597ae4 | |||
| 0fd33a98cd | |||
| ddb0871769 | |||
| e03bef684e | |||
| 4b026d6761 | |||
| 8efd3db1b4 | |||
| ef51bb0091 | |||
| 3bf0f39337 | |||
| 690d62a6d1 | |||
| 2aea75e91e | |||
| 5552e1ffe1 | |||
| 90890f8f04 | |||
| 8e0df1d532 | |||
| 29721fcc58 | |||
| a1d2a0c0fd | |||
| ec553fdb49 | |||
| 24a498eb90 | |||
| 9ccb490cf3 | |||
| 32302c37dd | |||
| 5e5e65f6d5 | |||
| acbf1794f2 | |||
| e2ea8934d4 | |||
| 7e7f78f86c | |||
| 5fb6a4418b | |||
| bf6af95ff5 | |||
| 3fd5cf6e3c | |||
| b04248f4d5 | |||
| 7803d21bcc | |||
| 8760faf991 | |||
| cab6447d58 | |||
| 57e8d44af8 | |||
| cb79018977 | |||
| 90f0aa174d | |||
| 304f1463a9 | |||
| 294c377c0c | |||
| 660379637a | |||
| bc80848e49 | |||
| 658cd2dd4c | |||
| 8c1ba639c6 | |||
| 17a9c47178 | |||
| e1df13cf20 | |||
| 4fe78d5b88 | |||
| aa5b697a9d | |||
| aca479c1ae | |||
| b85ff282bc | |||
| f805323517 | |||
| 4406b4b100 | |||
| 17ecdce936 | |||
| 7e813a30e0 | |||
| 6e24b9947e | |||
| 99fd3b518d | |||
| c5511bbc5a | |||
| b7d4ea1550 | |||
| 74241328f0 | |||
| df5874c119 | |||
| 21afb3fa3c | |||
| 31b2c12f0f | |||
| 405c1b4e84 | |||
| 5ff96551d5 | |||
| 2b4272ef5b | |||
| 670dcea8f4 | |||
| 17f13013eb | |||
| 00e1d42b9e | |||
| b2ea9b4176 | |||
| 0d7c19a42f | |||
| 8755b9dfc0 | |||
| 54bd25ff4a | |||
| b66550ed08 | |||
| c49bbbe8c2 | |||
| 9d8f9765c1 | |||
| f226e6be10 | |||
| a435c7274a | |||
| b597123489 | |||
| af0f4a52fe | |||
| b50d81f212 | |||
| a9fa054df9 | |||
| 31cb23890a | |||
| a3cfb1de86 | |||
| 371efafc46 | |||
| ebd2d83ef2 | |||
| af077b2c0d | |||
| 2d884ff12d | |||
| b397c91d4a | |||
| 9c2c9e3a3e | |||
| c3eeb03e26 | |||
| d9d0ac06b9 | |||
| 29f2610e4b | |||
| dcb97f7465 | |||
| 86308b6de4 | |||
| 2d349bbf7a | |||
| 39878aff00 | |||
| afd670a36f | |||
| e2b3b1c5e4 | |||
| 4c7d5ec778 | |||
| f116c59071 | |||
| 0f556a17f5 | |||
| ee92460763 | |||
| 2893e9df71 | |||
| 5a5d90c85a | |||
| 56a69e519b | |||
| fab4d8d470 | |||
| 1218994992 | |||
| f4bf57ff7a | |||
| bbba9ed4f2 | |||
| 2818dd8611 | |||
| 2ea5345a7b |
@@ -1 +1,5 @@
|
||||
watch_file pyproject.toml uv.lock
|
||||
watch_file ui-tui/package-lock.json ui-tui/package.json
|
||||
watch_file flake.nix flake.lock nix/devShell.nix nix/tui.nix nix/package.nix nix/python.nix
|
||||
|
||||
use flake
|
||||
|
||||
@@ -60,5 +60,6 @@ mini-swe-agent/
|
||||
|
||||
# Nix
|
||||
.direnv/
|
||||
.nix-stamps/
|
||||
result
|
||||
website/static/api/skills-index.json
|
||||
|
||||
@@ -56,6 +56,19 @@ hermes-agent/
|
||||
│ ├── run.py # Main loop, slash commands, message dispatch
|
||||
│ ├── session.py # SessionStore — conversation persistence
|
||||
│ └── platforms/ # Adapters: telegram, discord, slack, whatsapp, homeassistant, signal, qqbot
|
||||
├── ui-tui/ # Ink (React) terminal UI — `hermes --tui`
|
||||
│ ├── src/entry.tsx # TTY gate + render()
|
||||
│ ├── src/app.tsx # Main state machine and UI
|
||||
│ ├── src/gatewayClient.ts # Child process + JSON-RPC bridge
|
||||
│ ├── src/app/ # Decomposed app logic (event handler, slash handler, stores, hooks)
|
||||
│ ├── src/components/ # Ink components (branding, markdown, prompts, pickers, etc.)
|
||||
│ ├── src/hooks/ # useCompletion, useInputHistory, useQueue, useVirtualHistory
|
||||
│ └── src/lib/ # Pure helpers (history, osc52, text, rpc, messages)
|
||||
├── tui_gateway/ # Python JSON-RPC backend for the TUI
|
||||
│ ├── entry.py # stdio entrypoint
|
||||
│ ├── server.py # RPC handlers and session logic
|
||||
│ ├── render.py # Optional rich/ANSI bridge
|
||||
│ └── slash_worker.py # Persistent HermesCLI subprocess for slash commands
|
||||
├── acp_adapter/ # ACP server (VS Code / Zed / JetBrains integration)
|
||||
├── cron/ # Scheduler (jobs.py, scheduler.py)
|
||||
├── environments/ # RL training environments (Atropos)
|
||||
@@ -179,6 +192,59 @@ if canonical == "mycommand":
|
||||
|
||||
---
|
||||
|
||||
## TUI Architecture (ui-tui + tui_gateway)
|
||||
|
||||
The TUI is a full replacement for the classic (prompt_toolkit) CLI, activated via `hermes --tui` or `HERMES_TUI=1`.
|
||||
|
||||
### Process Model
|
||||
|
||||
```
|
||||
hermes --tui
|
||||
└─ Node (Ink) ──stdio JSON-RPC── Python (tui_gateway)
|
||||
│ └─ AIAgent + tools + sessions
|
||||
└─ renders transcript, composer, prompts, activity
|
||||
```
|
||||
|
||||
TypeScript owns the screen. Python owns sessions, tools, model calls, and slash command logic.
|
||||
|
||||
### Transport
|
||||
|
||||
Newline-delimited JSON-RPC over stdio. Requests from Ink, events from Python. See `tui_gateway/server.py` for the full method/event catalog.
|
||||
|
||||
### Key Surfaces
|
||||
|
||||
| Surface | Ink component | Gateway method |
|
||||
|---------|---------------|----------------|
|
||||
| Chat streaming | `app.tsx` + `messageLine.tsx` | `prompt.submit` → `message.delta/complete` |
|
||||
| Tool activity | `thinking.tsx` | `tool.start/progress/complete` |
|
||||
| Approvals | `prompts.tsx` | `approval.respond` ← `approval.request` |
|
||||
| Clarify/sudo/secret | `prompts.tsx`, `maskedPrompt.tsx` | `clarify/sudo/secret.respond` |
|
||||
| Session picker | `sessionPicker.tsx` | `session.list/resume` |
|
||||
| Slash commands | Local handler + fallthrough | `slash.exec` → `_SlashWorker`, `command.dispatch` |
|
||||
| Completions | `useCompletion` hook | `complete.slash`, `complete.path` |
|
||||
| Theming | `theme.ts` + `branding.tsx` | `gateway.ready` with skin data |
|
||||
|
||||
### Slash Command Flow
|
||||
|
||||
1. Built-in client commands (`/help`, `/quit`, `/clear`, `/resume`, `/copy`, `/paste`, etc.) handled locally in `app.tsx`
|
||||
2. Everything else → `slash.exec` (runs in persistent `_SlashWorker` subprocess) → `command.dispatch` fallback
|
||||
|
||||
### Dev Commands
|
||||
|
||||
```bash
|
||||
cd ui-tui
|
||||
npm install # first time
|
||||
npm run dev # watch mode (rebuilds hermes-ink + tsx --watch)
|
||||
npm start # production
|
||||
npm run build # full build (hermes-ink + tsc)
|
||||
npm run type-check # typecheck only (tsc --noEmit)
|
||||
npm run lint # eslint
|
||||
npm run fmt # prettier
|
||||
npm test # vitest
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Adding New Tools
|
||||
|
||||
Requires changes in **2 files**:
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
|
||||
**The self-improving AI agent built by [Nous Research](https://nousresearch.com).** It's the only agent with a built-in learning loop — it creates skills from experience, improves them during use, nudges itself to persist knowledge, searches its own past conversations, and builds a deepening model of who you are across sessions. Run it on a $5 VPS, a GPU cluster, or serverless infrastructure that costs nearly nothing when idle. It's not tied to your laptop — talk to it from Telegram while it works on a cloud VM.
|
||||
|
||||
Use any model you want — [Nous Portal](https://portal.nousresearch.com), [OpenRouter](https://openrouter.ai) (200+ models), [Xiaomi MiMo](https://platform.xiaomimimo.com), [z.ai/GLM](https://z.ai), [Kimi/Moonshot](https://platform.moonshot.ai), [MiniMax](https://www.minimax.io), [Hugging Face](https://huggingface.co), OpenAI, or your own endpoint. Switch with `hermes model` — no code changes, no lock-in.
|
||||
Use any model you want — [Nous Portal](https://portal.nousresearch.com), [OpenRouter](https://openrouter.ai) (200+ models), [NVIDIA NIM](https://build.nvidia.com) (Nemotron), [Xiaomi MiMo](https://platform.xiaomimimo.com), [z.ai/GLM](https://z.ai), [Kimi/Moonshot](https://platform.moonshot.ai), [MiniMax](https://www.minimax.io), [Hugging Face](https://huggingface.co), OpenAI, or your own endpoint. Switch with `hermes model` — no code changes, no lock-in.
|
||||
|
||||
<table>
|
||||
<tr><td><b>A real terminal interface</b></td><td>Full TUI with multiline editing, slash-command autocomplete, conversation history, interrupt-and-redirect, and streaming tool output.</td></tr>
|
||||
@@ -141,11 +141,18 @@ See `hermes claw migrate --help` for all options, or use the `openclaw-migration
|
||||
|
||||
We welcome contributions! See the [Contributing Guide](https://hermes-agent.nousresearch.com/docs/developer-guide/contributing) for development setup, code style, and PR process.
|
||||
|
||||
Quick start for contributors:
|
||||
Quick start for contributors — clone and go with `setup-hermes.sh`:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/NousResearch/hermes-agent.git
|
||||
cd hermes-agent
|
||||
./setup-hermes.sh # installs uv, creates venv, installs .[all], symlinks ~/.local/bin/hermes
|
||||
./hermes # auto-detects the venv, no need to `source` first
|
||||
```
|
||||
|
||||
Manual path (equivalent to the above):
|
||||
|
||||
```bash
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
uv venv venv --python 3.11
|
||||
source venv/bin/activate
|
||||
|
||||
+20
-1
@@ -49,6 +49,7 @@ def make_tool_progress_cb(
|
||||
session_id: str,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
tool_call_ids: Dict[str, Deque[str]],
|
||||
tool_call_meta: Dict[str, Dict[str, Any]],
|
||||
) -> Callable:
|
||||
"""Create a ``tool_progress_callback`` for AIAgent.
|
||||
|
||||
@@ -84,6 +85,16 @@ def make_tool_progress_cb(
|
||||
tool_call_ids[name] = queue
|
||||
queue.append(tc_id)
|
||||
|
||||
snapshot = None
|
||||
if name in {"write_file", "patch", "skill_manage"}:
|
||||
try:
|
||||
from agent.display import capture_local_edit_snapshot
|
||||
|
||||
snapshot = capture_local_edit_snapshot(name, args)
|
||||
except Exception:
|
||||
logger.debug("Failed to capture ACP edit snapshot for %s", name, exc_info=True)
|
||||
tool_call_meta[tc_id] = {"args": args, "snapshot": snapshot}
|
||||
|
||||
update = build_tool_start(tc_id, name, args)
|
||||
_send_update(conn, session_id, loop, update)
|
||||
|
||||
@@ -119,6 +130,7 @@ def make_step_cb(
|
||||
session_id: str,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
tool_call_ids: Dict[str, Deque[str]],
|
||||
tool_call_meta: Dict[str, Dict[str, Any]],
|
||||
) -> Callable:
|
||||
"""Create a ``step_callback`` for AIAgent.
|
||||
|
||||
@@ -132,10 +144,12 @@ def make_step_cb(
|
||||
for tool_info in prev_tools:
|
||||
tool_name = None
|
||||
result = None
|
||||
function_args = None
|
||||
|
||||
if isinstance(tool_info, dict):
|
||||
tool_name = tool_info.get("name") or tool_info.get("function_name")
|
||||
result = tool_info.get("result") or tool_info.get("output")
|
||||
function_args = tool_info.get("arguments") or tool_info.get("args")
|
||||
elif isinstance(tool_info, str):
|
||||
tool_name = tool_info
|
||||
|
||||
@@ -145,8 +159,13 @@ def make_step_cb(
|
||||
tool_call_ids[tool_name] = queue
|
||||
if tool_name and queue:
|
||||
tc_id = queue.popleft()
|
||||
meta = tool_call_meta.pop(tc_id, {})
|
||||
update = build_tool_complete(
|
||||
tc_id, tool_name, result=str(result) if result is not None else None
|
||||
tc_id,
|
||||
tool_name,
|
||||
result=str(result) if result is not None else None,
|
||||
function_args=function_args or meta.get("args"),
|
||||
snapshot=meta.get("snapshot"),
|
||||
)
|
||||
_send_update(conn, session_id, loop, update)
|
||||
if not queue:
|
||||
|
||||
+148
-30
@@ -26,6 +26,7 @@ from acp.schema import (
|
||||
McpServerHttp,
|
||||
McpServerSse,
|
||||
McpServerStdio,
|
||||
ModelInfo,
|
||||
NewSessionResponse,
|
||||
PromptResponse,
|
||||
ResumeSessionResponse,
|
||||
@@ -36,6 +37,7 @@ from acp.schema import (
|
||||
SessionCapabilities,
|
||||
SessionForkCapabilities,
|
||||
SessionListCapabilities,
|
||||
SessionModelState,
|
||||
SessionResumeCapabilities,
|
||||
SessionInfo,
|
||||
TextContentBlock,
|
||||
@@ -147,6 +149,98 @@ class HermesACPAgent(acp.Agent):
|
||||
self._conn = conn
|
||||
logger.info("ACP client connected")
|
||||
|
||||
@staticmethod
|
||||
def _encode_model_choice(provider: str | None, model: str | None) -> str:
|
||||
"""Encode a model selection so ACP clients can keep provider context."""
|
||||
raw_model = str(model or "").strip()
|
||||
if not raw_model:
|
||||
return ""
|
||||
raw_provider = str(provider or "").strip().lower()
|
||||
if not raw_provider:
|
||||
return raw_model
|
||||
return f"{raw_provider}:{raw_model}"
|
||||
|
||||
def _build_model_state(self, state: SessionState) -> SessionModelState | None:
|
||||
"""Return the ACP model selector payload for editors like Zed."""
|
||||
model = str(state.model or getattr(state.agent, "model", "") or "").strip()
|
||||
provider = getattr(state.agent, "provider", None) or detect_provider() or "openrouter"
|
||||
|
||||
try:
|
||||
from hermes_cli.models import curated_models_for_provider, normalize_provider, provider_label
|
||||
|
||||
normalized_provider = normalize_provider(provider)
|
||||
provider_name = provider_label(normalized_provider)
|
||||
available_models: list[ModelInfo] = []
|
||||
seen_ids: set[str] = set()
|
||||
|
||||
for model_id, description in curated_models_for_provider(normalized_provider):
|
||||
rendered_model = str(model_id or "").strip()
|
||||
if not rendered_model:
|
||||
continue
|
||||
choice_id = self._encode_model_choice(normalized_provider, rendered_model)
|
||||
if choice_id in seen_ids:
|
||||
continue
|
||||
desc_parts = [f"Provider: {provider_name}"]
|
||||
if description:
|
||||
desc_parts.append(str(description).strip())
|
||||
if rendered_model == model:
|
||||
desc_parts.append("current")
|
||||
available_models.append(
|
||||
ModelInfo(
|
||||
model_id=choice_id,
|
||||
name=rendered_model,
|
||||
description=" • ".join(part for part in desc_parts if part),
|
||||
)
|
||||
)
|
||||
seen_ids.add(choice_id)
|
||||
|
||||
current_model_id = self._encode_model_choice(normalized_provider, model)
|
||||
if current_model_id and current_model_id not in seen_ids:
|
||||
available_models.insert(
|
||||
0,
|
||||
ModelInfo(
|
||||
model_id=current_model_id,
|
||||
name=model,
|
||||
description=f"Provider: {provider_name} • current",
|
||||
),
|
||||
)
|
||||
|
||||
if available_models:
|
||||
return SessionModelState(
|
||||
available_models=available_models,
|
||||
current_model_id=current_model_id or available_models[0].model_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Could not build ACP model state", exc_info=True)
|
||||
|
||||
if not model:
|
||||
return None
|
||||
|
||||
fallback_choice = self._encode_model_choice(provider, model)
|
||||
return SessionModelState(
|
||||
available_models=[ModelInfo(model_id=fallback_choice, name=model)],
|
||||
current_model_id=fallback_choice,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_model_selection(raw_model: str, current_provider: str) -> tuple[str, str]:
|
||||
"""Resolve ``provider:model`` input into the provider and normalized model id."""
|
||||
target_provider = current_provider
|
||||
new_model = raw_model.strip()
|
||||
|
||||
try:
|
||||
from hermes_cli.models import detect_provider_for_model, parse_model_input
|
||||
|
||||
target_provider, new_model = parse_model_input(new_model, current_provider)
|
||||
if target_provider == current_provider:
|
||||
detected = detect_provider_for_model(new_model, current_provider)
|
||||
if detected:
|
||||
target_provider, new_model = detected
|
||||
except Exception:
|
||||
logger.debug("Provider detection failed, using model as-is", exc_info=True)
|
||||
|
||||
return target_provider, new_model
|
||||
|
||||
async def _register_session_mcp_servers(
|
||||
self,
|
||||
state: SessionState,
|
||||
@@ -273,7 +367,10 @@ class HermesACPAgent(acp.Agent):
|
||||
await self._register_session_mcp_servers(state, mcp_servers)
|
||||
logger.info("New session %s (cwd=%s)", state.session_id, cwd)
|
||||
self._schedule_available_commands_update(state.session_id)
|
||||
return NewSessionResponse(session_id=state.session_id)
|
||||
return NewSessionResponse(
|
||||
session_id=state.session_id,
|
||||
models=self._build_model_state(state),
|
||||
)
|
||||
|
||||
async def load_session(
|
||||
self,
|
||||
@@ -289,7 +386,7 @@ class HermesACPAgent(acp.Agent):
|
||||
await self._register_session_mcp_servers(state, mcp_servers)
|
||||
logger.info("Loaded session %s", session_id)
|
||||
self._schedule_available_commands_update(session_id)
|
||||
return LoadSessionResponse()
|
||||
return LoadSessionResponse(models=self._build_model_state(state))
|
||||
|
||||
async def resume_session(
|
||||
self,
|
||||
@@ -305,7 +402,7 @@ class HermesACPAgent(acp.Agent):
|
||||
await self._register_session_mcp_servers(state, mcp_servers)
|
||||
logger.info("Resumed session %s", state.session_id)
|
||||
self._schedule_available_commands_update(state.session_id)
|
||||
return ResumeSessionResponse()
|
||||
return ResumeSessionResponse(models=self._build_model_state(state))
|
||||
|
||||
async def cancel(self, session_id: str, **kwargs: Any) -> None:
|
||||
state = self.session_manager.get_session(session_id)
|
||||
@@ -340,11 +437,20 @@ class HermesACPAgent(acp.Agent):
|
||||
cwd: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ListSessionsResponse:
|
||||
infos = self.session_manager.list_sessions()
|
||||
sessions = [
|
||||
SessionInfo(session_id=s["session_id"], cwd=s["cwd"])
|
||||
for s in infos
|
||||
]
|
||||
infos = self.session_manager.list_sessions(cwd=cwd)
|
||||
sessions = []
|
||||
for s in infos:
|
||||
updated_at = s.get("updated_at")
|
||||
if updated_at is not None and not isinstance(updated_at, str):
|
||||
updated_at = str(updated_at)
|
||||
sessions.append(
|
||||
SessionInfo(
|
||||
session_id=s["session_id"],
|
||||
cwd=s["cwd"],
|
||||
title=s.get("title"),
|
||||
updated_at=updated_at,
|
||||
)
|
||||
)
|
||||
return ListSessionsResponse(sessions=sessions)
|
||||
|
||||
# ---- Prompt (core) ------------------------------------------------------
|
||||
@@ -389,12 +495,13 @@ class HermesACPAgent(acp.Agent):
|
||||
state.cancel_event.clear()
|
||||
|
||||
tool_call_ids: dict[str, Deque[str]] = defaultdict(deque)
|
||||
tool_call_meta: dict[str, dict[str, Any]] = {}
|
||||
previous_approval_cb = None
|
||||
|
||||
if conn:
|
||||
tool_progress_cb = make_tool_progress_cb(conn, session_id, loop, tool_call_ids)
|
||||
tool_progress_cb = make_tool_progress_cb(conn, session_id, loop, tool_call_ids, tool_call_meta)
|
||||
thinking_cb = make_thinking_cb(conn, session_id, loop)
|
||||
step_cb = make_step_cb(conn, session_id, loop, tool_call_ids)
|
||||
step_cb = make_step_cb(conn, session_id, loop, tool_call_ids, tool_call_meta)
|
||||
message_cb = make_message_cb(conn, session_id, loop)
|
||||
approval_cb = make_approval_callback(conn.request_permission, loop, session_id)
|
||||
else:
|
||||
@@ -449,6 +556,19 @@ class HermesACPAgent(acp.Agent):
|
||||
self.session_manager.save_session(session_id)
|
||||
|
||||
final_response = result.get("final_response", "")
|
||||
if final_response:
|
||||
try:
|
||||
from agent.title_generator import maybe_auto_title
|
||||
|
||||
maybe_auto_title(
|
||||
self.session_manager._get_db(),
|
||||
session_id,
|
||||
user_text,
|
||||
final_response,
|
||||
state.history,
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to auto-title ACP session %s", session_id, exc_info=True)
|
||||
if final_response and conn:
|
||||
update = acp.update_agent_message_text(final_response)
|
||||
await conn.session_update(session_id, update)
|
||||
@@ -556,27 +676,15 @@ class HermesACPAgent(acp.Agent):
|
||||
provider = getattr(state.agent, "provider", None) or "auto"
|
||||
return f"Current model: {model}\nProvider: {provider}"
|
||||
|
||||
new_model = args.strip()
|
||||
target_provider = None
|
||||
current_provider = getattr(state.agent, "provider", None) or "openrouter"
|
||||
|
||||
# Auto-detect provider for the requested model
|
||||
try:
|
||||
from hermes_cli.models import parse_model_input, detect_provider_for_model
|
||||
target_provider, new_model = parse_model_input(new_model, current_provider)
|
||||
if target_provider == current_provider:
|
||||
detected = detect_provider_for_model(new_model, current_provider)
|
||||
if detected:
|
||||
target_provider, new_model = detected
|
||||
except Exception:
|
||||
logger.debug("Provider detection failed, using model as-is", exc_info=True)
|
||||
target_provider, new_model = self._resolve_model_selection(args, current_provider)
|
||||
|
||||
state.model = new_model
|
||||
state.agent = self.session_manager._make_agent(
|
||||
session_id=state.session_id,
|
||||
cwd=state.cwd,
|
||||
model=new_model,
|
||||
requested_provider=target_provider or current_provider,
|
||||
requested_provider=target_provider,
|
||||
)
|
||||
self.session_manager.save_session(state.session_id)
|
||||
provider_label = getattr(state.agent, "provider", None) or target_provider or current_provider
|
||||
@@ -678,20 +786,30 @@ class HermesACPAgent(acp.Agent):
|
||||
"""Switch the model for a session (called by ACP protocol)."""
|
||||
state = self.session_manager.get_session(session_id)
|
||||
if state:
|
||||
state.model = model_id
|
||||
current_provider = getattr(state.agent, "provider", None)
|
||||
current_base_url = getattr(state.agent, "base_url", None)
|
||||
current_api_mode = getattr(state.agent, "api_mode", None)
|
||||
requested_provider, resolved_model = self._resolve_model_selection(
|
||||
model_id,
|
||||
current_provider or "openrouter",
|
||||
)
|
||||
state.model = resolved_model
|
||||
provider_changed = bool(current_provider and requested_provider != current_provider)
|
||||
current_base_url = None if provider_changed else getattr(state.agent, "base_url", None)
|
||||
current_api_mode = None if provider_changed else getattr(state.agent, "api_mode", None)
|
||||
state.agent = self.session_manager._make_agent(
|
||||
session_id=session_id,
|
||||
cwd=state.cwd,
|
||||
model=model_id,
|
||||
requested_provider=current_provider,
|
||||
model=resolved_model,
|
||||
requested_provider=requested_provider,
|
||||
base_url=current_base_url,
|
||||
api_mode=current_api_mode,
|
||||
)
|
||||
self.session_manager.save_session(session_id)
|
||||
logger.info("Session %s: model switched to %s", session_id, model_id)
|
||||
logger.info(
|
||||
"Session %s: model switched to %s via provider %s",
|
||||
session_id,
|
||||
resolved_model,
|
||||
requested_provider,
|
||||
)
|
||||
return SetSessionModelResponse()
|
||||
logger.warning("Session %s: model switch requested for missing session", session_id)
|
||||
return None
|
||||
|
||||
+127
-34
@@ -13,8 +13,12 @@ from hermes_constants import get_hermes_home
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Lock
|
||||
from typing import Any, Dict, List, Optional
|
||||
@@ -22,6 +26,64 @@ from typing import Any, Dict, List, Optional
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _normalize_cwd_for_compare(cwd: str | None) -> str:
|
||||
raw = str(cwd or ".").strip()
|
||||
if not raw:
|
||||
raw = "."
|
||||
expanded = os.path.expanduser(raw)
|
||||
|
||||
# Normalize Windows drive paths into the equivalent WSL mount form so
|
||||
# ACP history filters match the same workspace across Windows and WSL.
|
||||
match = re.match(r"^([A-Za-z]):[\\/](.*)$", expanded)
|
||||
if match:
|
||||
drive = match.group(1).lower()
|
||||
tail = match.group(2).replace("\\", "/")
|
||||
expanded = f"/mnt/{drive}/{tail}"
|
||||
elif re.match(r"^/mnt/[A-Za-z]/", expanded):
|
||||
expanded = f"/mnt/{expanded[5].lower()}/{expanded[7:]}"
|
||||
|
||||
return os.path.normpath(expanded)
|
||||
|
||||
|
||||
def _build_session_title(title: Any, preview: Any, cwd: str | None) -> str:
|
||||
explicit = str(title or "").strip()
|
||||
if explicit:
|
||||
return explicit
|
||||
preview_text = str(preview or "").strip()
|
||||
if preview_text:
|
||||
return preview_text
|
||||
leaf = os.path.basename(str(cwd or "").rstrip("/\\"))
|
||||
return leaf or "New thread"
|
||||
|
||||
|
||||
def _format_updated_at(value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value
|
||||
try:
|
||||
return datetime.fromtimestamp(float(value), tz=timezone.utc).isoformat()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _updated_at_sort_key(value: Any) -> float:
|
||||
if value is None:
|
||||
return float("-inf")
|
||||
if isinstance(value, (int, float)):
|
||||
return float(value)
|
||||
raw = str(value).strip()
|
||||
if not raw:
|
||||
return float("-inf")
|
||||
try:
|
||||
return datetime.fromisoformat(raw.replace("Z", "+00:00")).timestamp()
|
||||
except Exception:
|
||||
try:
|
||||
return float(raw)
|
||||
except Exception:
|
||||
return float("-inf")
|
||||
|
||||
|
||||
def _acp_stderr_print(*args, **kwargs) -> None:
|
||||
"""Best-effort human-readable output sink for ACP stdio sessions.
|
||||
|
||||
@@ -162,47 +224,78 @@ class SessionManager:
|
||||
logger.info("Forked ACP session %s -> %s", session_id, new_id)
|
||||
return state
|
||||
|
||||
def list_sessions(self) -> List[Dict[str, Any]]:
|
||||
def list_sessions(self, cwd: str | None = None) -> List[Dict[str, Any]]:
|
||||
"""Return lightweight info dicts for all sessions (memory + database)."""
|
||||
normalized_cwd = _normalize_cwd_for_compare(cwd) if cwd else None
|
||||
db = self._get_db()
|
||||
persisted_rows: dict[str, dict[str, Any]] = {}
|
||||
|
||||
if db is not None:
|
||||
try:
|
||||
for row in db.list_sessions_rich(source="acp", limit=1000):
|
||||
persisted_rows[str(row["id"])] = dict(row)
|
||||
except Exception:
|
||||
logger.debug("Failed to load ACP sessions from DB", exc_info=True)
|
||||
|
||||
# Collect in-memory sessions first.
|
||||
with self._lock:
|
||||
seen_ids = set(self._sessions.keys())
|
||||
results = [
|
||||
{
|
||||
"session_id": s.session_id,
|
||||
"cwd": s.cwd,
|
||||
"model": s.model,
|
||||
"history_len": len(s.history),
|
||||
}
|
||||
for s in self._sessions.values()
|
||||
]
|
||||
results = []
|
||||
for s in self._sessions.values():
|
||||
history_len = len(s.history)
|
||||
if history_len <= 0:
|
||||
continue
|
||||
if normalized_cwd and _normalize_cwd_for_compare(s.cwd) != normalized_cwd:
|
||||
continue
|
||||
persisted = persisted_rows.get(s.session_id, {})
|
||||
preview = next(
|
||||
(
|
||||
str(msg.get("content") or "").strip()
|
||||
for msg in s.history
|
||||
if msg.get("role") == "user" and str(msg.get("content") or "").strip()
|
||||
),
|
||||
persisted.get("preview") or "",
|
||||
)
|
||||
results.append(
|
||||
{
|
||||
"session_id": s.session_id,
|
||||
"cwd": s.cwd,
|
||||
"model": s.model,
|
||||
"history_len": history_len,
|
||||
"title": _build_session_title(persisted.get("title"), preview, s.cwd),
|
||||
"updated_at": _format_updated_at(
|
||||
persisted.get("last_active") or persisted.get("started_at") or time.time()
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# Merge any persisted sessions not currently in memory.
|
||||
db = self._get_db()
|
||||
if db is not None:
|
||||
try:
|
||||
rows = db.search_sessions(source="acp", limit=1000)
|
||||
for row in rows:
|
||||
sid = row["id"]
|
||||
if sid in seen_ids:
|
||||
continue
|
||||
# Extract cwd from model_config JSON.
|
||||
cwd = "."
|
||||
mc = row.get("model_config")
|
||||
if mc:
|
||||
try:
|
||||
cwd = json.loads(mc).get("cwd", ".")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
results.append({
|
||||
"session_id": sid,
|
||||
"cwd": cwd,
|
||||
"model": row.get("model") or "",
|
||||
"history_len": row.get("message_count") or 0,
|
||||
})
|
||||
except Exception:
|
||||
logger.debug("Failed to list ACP sessions from DB", exc_info=True)
|
||||
for sid, row in persisted_rows.items():
|
||||
if sid in seen_ids:
|
||||
continue
|
||||
message_count = int(row.get("message_count") or 0)
|
||||
if message_count <= 0:
|
||||
continue
|
||||
# Extract cwd from model_config JSON.
|
||||
session_cwd = "."
|
||||
mc = row.get("model_config")
|
||||
if mc:
|
||||
try:
|
||||
session_cwd = json.loads(mc).get("cwd", ".")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
if normalized_cwd and _normalize_cwd_for_compare(session_cwd) != normalized_cwd:
|
||||
continue
|
||||
results.append({
|
||||
"session_id": sid,
|
||||
"cwd": session_cwd,
|
||||
"model": row.get("model") or "",
|
||||
"history_len": message_count,
|
||||
"title": _build_session_title(row.get("title"), row.get("preview"), session_cwd),
|
||||
"updated_at": _format_updated_at(row.get("last_active") or row.get("started_at")),
|
||||
})
|
||||
|
||||
results.sort(key=lambda item: _updated_at_sort_key(item.get("updated_at")), reverse=True)
|
||||
return results
|
||||
|
||||
def update_cwd(self, session_id: str, cwd: str) -> Optional[SessionState]:
|
||||
|
||||
+174
-9
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
@@ -96,6 +97,170 @@ def build_tool_title(tool_name: str, args: Dict[str, Any]) -> str:
|
||||
return tool_name
|
||||
|
||||
|
||||
def _build_patch_mode_content(patch_text: str) -> List[Any]:
|
||||
"""Parse V4A patch mode input into ACP diff blocks when possible."""
|
||||
if not patch_text:
|
||||
return [acp.tool_content(acp.text_block(""))]
|
||||
|
||||
try:
|
||||
from tools.patch_parser import OperationType, parse_v4a_patch
|
||||
|
||||
operations, error = parse_v4a_patch(patch_text)
|
||||
if error or not operations:
|
||||
return [acp.tool_content(acp.text_block(patch_text))]
|
||||
|
||||
content: List[Any] = []
|
||||
for op in operations:
|
||||
if op.operation == OperationType.UPDATE:
|
||||
old_chunks: list[str] = []
|
||||
new_chunks: list[str] = []
|
||||
for hunk in op.hunks:
|
||||
old_lines = [line.content for line in hunk.lines if line.prefix in (" ", "-")]
|
||||
new_lines = [line.content for line in hunk.lines if line.prefix in (" ", "+")]
|
||||
if old_lines or new_lines:
|
||||
old_chunks.append("\n".join(old_lines))
|
||||
new_chunks.append("\n".join(new_lines))
|
||||
|
||||
old_text = "\n...\n".join(chunk for chunk in old_chunks if chunk)
|
||||
new_text = "\n...\n".join(chunk for chunk in new_chunks if chunk)
|
||||
if old_text or new_text:
|
||||
content.append(
|
||||
acp.tool_diff_content(
|
||||
path=op.file_path,
|
||||
old_text=old_text or None,
|
||||
new_text=new_text or "",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if op.operation == OperationType.ADD:
|
||||
added_lines = [line.content for hunk in op.hunks for line in hunk.lines if line.prefix == "+"]
|
||||
content.append(
|
||||
acp.tool_diff_content(
|
||||
path=op.file_path,
|
||||
new_text="\n".join(added_lines),
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if op.operation == OperationType.DELETE:
|
||||
content.append(
|
||||
acp.tool_diff_content(
|
||||
path=op.file_path,
|
||||
old_text=f"Delete file: {op.file_path}",
|
||||
new_text="",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if op.operation == OperationType.MOVE:
|
||||
content.append(
|
||||
acp.tool_content(acp.text_block(f"Move file: {op.file_path} -> {op.new_path}"))
|
||||
)
|
||||
|
||||
return content or [acp.tool_content(acp.text_block(patch_text))]
|
||||
except Exception:
|
||||
return [acp.tool_content(acp.text_block(patch_text))]
|
||||
|
||||
|
||||
def _strip_diff_prefix(path: str) -> str:
|
||||
raw = str(path or "").strip()
|
||||
if raw.startswith(("a/", "b/")):
|
||||
return raw[2:]
|
||||
return raw
|
||||
|
||||
|
||||
def _parse_unified_diff_content(diff_text: str) -> List[Any]:
|
||||
"""Convert unified diff text into ACP diff content blocks."""
|
||||
if not diff_text:
|
||||
return []
|
||||
|
||||
content: List[Any] = []
|
||||
current_old_path: Optional[str] = None
|
||||
current_new_path: Optional[str] = None
|
||||
old_lines: list[str] = []
|
||||
new_lines: list[str] = []
|
||||
|
||||
def _flush() -> None:
|
||||
nonlocal current_old_path, current_new_path, old_lines, new_lines
|
||||
if current_old_path is None and current_new_path is None:
|
||||
return
|
||||
path = current_new_path if current_new_path and current_new_path != "/dev/null" else current_old_path
|
||||
if not path or path == "/dev/null":
|
||||
current_old_path = None
|
||||
current_new_path = None
|
||||
old_lines = []
|
||||
new_lines = []
|
||||
return
|
||||
content.append(
|
||||
acp.tool_diff_content(
|
||||
path=_strip_diff_prefix(path),
|
||||
old_text="\n".join(old_lines) if old_lines else None,
|
||||
new_text="\n".join(new_lines),
|
||||
)
|
||||
)
|
||||
current_old_path = None
|
||||
current_new_path = None
|
||||
old_lines = []
|
||||
new_lines = []
|
||||
|
||||
for line in diff_text.splitlines():
|
||||
if line.startswith("--- "):
|
||||
_flush()
|
||||
current_old_path = line[4:].strip()
|
||||
continue
|
||||
if line.startswith("+++ "):
|
||||
current_new_path = line[4:].strip()
|
||||
continue
|
||||
if line.startswith("@@"):
|
||||
continue
|
||||
if current_old_path is None and current_new_path is None:
|
||||
continue
|
||||
if line.startswith("+"):
|
||||
new_lines.append(line[1:])
|
||||
elif line.startswith("-"):
|
||||
old_lines.append(line[1:])
|
||||
elif line.startswith(" "):
|
||||
shared = line[1:]
|
||||
old_lines.append(shared)
|
||||
new_lines.append(shared)
|
||||
|
||||
_flush()
|
||||
return content
|
||||
|
||||
|
||||
def _build_tool_complete_content(
|
||||
tool_name: str,
|
||||
result: Optional[str],
|
||||
*,
|
||||
function_args: Optional[Dict[str, Any]] = None,
|
||||
snapshot: Any = None,
|
||||
) -> List[Any]:
|
||||
"""Build structured ACP completion content, falling back to plain text."""
|
||||
display_result = result or ""
|
||||
if len(display_result) > 5000:
|
||||
display_result = display_result[:4900] + f"\n... ({len(result)} chars total, truncated)"
|
||||
|
||||
if tool_name in {"write_file", "patch", "skill_manage"}:
|
||||
try:
|
||||
from agent.display import extract_edit_diff
|
||||
|
||||
diff_text = extract_edit_diff(
|
||||
tool_name,
|
||||
result,
|
||||
function_args=function_args,
|
||||
snapshot=snapshot,
|
||||
)
|
||||
if isinstance(diff_text, str) and diff_text.strip():
|
||||
diff_content = _parse_unified_diff_content(diff_text)
|
||||
if diff_content:
|
||||
return diff_content
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return [acp.tool_content(acp.text_block(display_result))]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Build ACP content objects for tool-call events
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -119,9 +284,8 @@ def build_tool_start(
|
||||
new = arguments.get("new_string", "")
|
||||
content = [acp.tool_diff_content(path=path, new_text=new, old_text=old)]
|
||||
else:
|
||||
# Patch mode — show the patch content as text
|
||||
patch_text = arguments.get("patch", "")
|
||||
content = [acp.tool_content(acp.text_block(patch_text))]
|
||||
content = _build_patch_mode_content(patch_text)
|
||||
return acp.start_tool_call(
|
||||
tool_call_id, title, kind=kind, content=content, locations=locations,
|
||||
raw_input=arguments,
|
||||
@@ -178,16 +342,17 @@ def build_tool_complete(
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
result: Optional[str] = None,
|
||||
function_args: Optional[Dict[str, Any]] = None,
|
||||
snapshot: Any = None,
|
||||
) -> ToolCallProgress:
|
||||
"""Create a ToolCallUpdate (progress) event for a completed tool call."""
|
||||
kind = get_tool_kind(tool_name)
|
||||
|
||||
# Truncate very large results for the UI
|
||||
display_result = result or ""
|
||||
if len(display_result) > 5000:
|
||||
display_result = display_result[:4900] + f"\n... ({len(result)} chars total, truncated)"
|
||||
|
||||
content = [acp.tool_content(acp.text_block(display_result))]
|
||||
content = _build_tool_complete_content(
|
||||
tool_name,
|
||||
result,
|
||||
function_args=function_args,
|
||||
snapshot=snapshot,
|
||||
)
|
||||
return acp.update_tool_call(
|
||||
tool_call_id,
|
||||
kind=kind,
|
||||
|
||||
+78
-32
@@ -94,6 +94,17 @@ def _normalize_aux_provider(provider: Optional[str]) -> str:
|
||||
return "custom"
|
||||
return _PROVIDER_ALIASES.get(normalized, normalized)
|
||||
|
||||
|
||||
_FIXED_TEMPERATURE_MODELS: Dict[str, float] = {
|
||||
"kimi-for-coding": 0.6,
|
||||
}
|
||||
|
||||
|
||||
def _fixed_temperature_for_model(model: Optional[str]) -> Optional[float]:
|
||||
"""Return a required temperature override for models with strict contracts."""
|
||||
normalized = (model or "").strip().lower()
|
||||
return _FIXED_TEMPERATURE_MODELS.get(normalized)
|
||||
|
||||
# Default auxiliary models for direct API-key providers (cheap/fast for side tasks)
|
||||
_API_KEY_PROVIDER_AUX_MODELS: Dict[str, str] = {
|
||||
"gemini": "gemini-3-flash-preview",
|
||||
@@ -734,6 +745,15 @@ def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
from hermes_cli.models import copilot_default_headers
|
||||
|
||||
extra["default_headers"] = copilot_default_headers()
|
||||
elif "generativelanguage.googleapis.com" in base_url.lower():
|
||||
# Google's OpenAI-compatible endpoint only accepts x-goog-api-key.
|
||||
# Passing api_key= causes the SDK to inject Authorization: Bearer,
|
||||
# which Google rejects with HTTP 400 "Multiple authentication
|
||||
# credentials received". Use a placeholder for api_key and pass
|
||||
# the real key via x-goog-api-key header instead.
|
||||
# Fixes: https://github.com/NousResearch/hermes-agent/issues/7893
|
||||
extra["default_headers"] = {"x-goog-api-key": api_key}
|
||||
api_key = "not-used"
|
||||
return OpenAI(api_key=api_key, base_url=base_url, **extra), model
|
||||
|
||||
creds = resolve_api_key_provider_credentials(provider_id)
|
||||
@@ -755,6 +775,15 @@ def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
from hermes_cli.models import copilot_default_headers
|
||||
|
||||
extra["default_headers"] = copilot_default_headers()
|
||||
elif "generativelanguage.googleapis.com" in base_url.lower():
|
||||
# Google's OpenAI-compatible endpoint only accepts x-goog-api-key.
|
||||
# Passing api_key= causes the SDK to inject Authorization: Bearer,
|
||||
# which Google rejects with HTTP 400 "Multiple authentication
|
||||
# credentials received". Use a placeholder for api_key and pass
|
||||
# the real key via x-goog-api-key header instead.
|
||||
# Fixes: https://github.com/NousResearch/hermes-agent/issues/7893
|
||||
extra["default_headers"] = {"x-goog-api-key": api_key}
|
||||
api_key = "not-used"
|
||||
return OpenAI(api_key=api_key, base_url=base_url, **extra), model
|
||||
|
||||
return None, None
|
||||
@@ -1064,8 +1093,6 @@ _AUTO_PROVIDER_LABELS = {
|
||||
"_resolve_api_key_provider": "api-key",
|
||||
}
|
||||
|
||||
_AGGREGATOR_PROVIDERS = frozenset({"openrouter", "nous"})
|
||||
|
||||
_MAIN_RUNTIME_FIELDS = ("provider", "model", "base_url", "api_key", "api_mode")
|
||||
|
||||
|
||||
@@ -1196,11 +1223,15 @@ def _resolve_auto(main_runtime: Optional[Dict[str, Any]] = None) -> Tuple[Option
|
||||
"""Full auto-detection chain.
|
||||
|
||||
Priority:
|
||||
1. If the user's main provider is NOT an aggregator (OpenRouter / Nous),
|
||||
use their main provider + main model directly. This ensures users on
|
||||
Alibaba, DeepSeek, ZAI, etc. get auxiliary tasks handled by the same
|
||||
provider they already have credentials for — no OpenRouter key needed.
|
||||
2. OpenRouter → Nous → custom → Codex → API-key providers (original chain).
|
||||
1. User's main provider + main model, regardless of provider type.
|
||||
This means auxiliary tasks (compression, vision, web extraction,
|
||||
session search, etc.) use the same model the user configured for
|
||||
chat. Users on OpenRouter/Nous get their chosen chat model; users
|
||||
on DeepSeek/ZAI/Alibaba get theirs; etc. Running aux tasks on the
|
||||
user's picked model keeps behavior predictable — no surprise
|
||||
switches to a cheap fallback model for side tasks.
|
||||
2. OpenRouter → Nous → custom → Codex → API-key providers (fallback
|
||||
chain, only used when the main provider has no working client).
|
||||
"""
|
||||
global auxiliary_is_nous, _stale_base_url_warned
|
||||
auxiliary_is_nous = False # Reset — _try_nous() will set True if it wins
|
||||
@@ -1230,11 +1261,16 @@ def _resolve_auto(main_runtime: Optional[Dict[str, Any]] = None) -> Tuple[Option
|
||||
)
|
||||
_stale_base_url_warned = True
|
||||
|
||||
# ── Step 1: non-aggregator main provider → use main model directly ──
|
||||
# ── Step 1: main provider + main model → use them directly ──
|
||||
#
|
||||
# This is the primary aux backend for every user. "auto" means
|
||||
# "use my main chat model for side tasks as well" — including users
|
||||
# on aggregators (OpenRouter, Nous) who previously got routed to a
|
||||
# cheap provider-side default. Explicit per-task overrides set via
|
||||
# config.yaml (auxiliary.<task>.provider) still win over this.
|
||||
main_provider = runtime_provider or _read_main_provider()
|
||||
main_model = runtime_model or _read_main_model()
|
||||
if (main_provider and main_model
|
||||
and main_provider not in _AGGREGATOR_PROVIDERS
|
||||
and main_provider not in ("auto", "")):
|
||||
resolved_provider = main_provider
|
||||
explicit_base_url = None
|
||||
@@ -1593,6 +1629,15 @@ def resolve_provider_client(
|
||||
from hermes_cli.models import copilot_default_headers
|
||||
|
||||
headers.update(copilot_default_headers())
|
||||
elif "generativelanguage.googleapis.com" in base_url.lower():
|
||||
# Google's OpenAI-compatible endpoint only accepts x-goog-api-key.
|
||||
# Passing api_key= causes the OpenAI SDK to inject Authorization: Bearer,
|
||||
# which Google rejects with HTTP 400 "Multiple authentication credentials
|
||||
# received". Use a placeholder for api_key and pass the real key via
|
||||
# x-goog-api-key header instead.
|
||||
# Fixes: https://github.com/NousResearch/hermes-agent/issues/7893
|
||||
headers["x-goog-api-key"] = api_key
|
||||
api_key = "not-used"
|
||||
|
||||
client = OpenAI(api_key=api_key, base_url=base_url,
|
||||
**({"default_headers": headers} if headers else {}))
|
||||
@@ -1817,34 +1862,31 @@ def resolve_vision_provider_client(
|
||||
|
||||
if requested == "auto":
|
||||
# Vision auto-detection order:
|
||||
# 1. Active provider + model (user's main chat config)
|
||||
# 2. OpenRouter (known vision-capable default model)
|
||||
# 3. Nous Portal (known vision-capable default model)
|
||||
# 1. User's main provider + main model (including aggregators).
|
||||
# _PROVIDER_VISION_MODELS provides per-provider vision model
|
||||
# overrides when the provider has a dedicated multimodal model
|
||||
# that differs from the chat model (e.g. xiaomi → mimo-v2-omni,
|
||||
# zai → glm-5v-turbo).
|
||||
# 2. OpenRouter (vision-capable aggregator fallback)
|
||||
# 3. Nous Portal (vision-capable aggregator fallback)
|
||||
# 4. Stop
|
||||
main_provider = _read_main_provider()
|
||||
main_model = _read_main_model()
|
||||
if main_provider and main_provider not in ("auto", ""):
|
||||
if main_provider in _VISION_AUTO_PROVIDER_ORDER:
|
||||
# Known strict backend — use its defaults.
|
||||
sync_client, default_model = _resolve_strict_vision_backend(main_provider)
|
||||
if sync_client is not None:
|
||||
return _finalize(main_provider, sync_client, default_model)
|
||||
else:
|
||||
# Exotic provider (DeepSeek, Alibaba, Xiaomi, named custom, etc.)
|
||||
# Use provider-specific vision model if available, otherwise main model.
|
||||
vision_model = _PROVIDER_VISION_MODELS.get(main_provider, main_model)
|
||||
rpc_client, rpc_model = resolve_provider_client(
|
||||
main_provider, vision_model,
|
||||
api_mode=resolved_api_mode)
|
||||
if rpc_client is not None:
|
||||
logger.info(
|
||||
"Vision auto-detect: using active provider %s (%s)",
|
||||
main_provider, rpc_model or vision_model,
|
||||
)
|
||||
return _finalize(
|
||||
main_provider, rpc_client, rpc_model or vision_model)
|
||||
vision_model = _PROVIDER_VISION_MODELS.get(main_provider, main_model)
|
||||
rpc_client, rpc_model = resolve_provider_client(
|
||||
main_provider, vision_model,
|
||||
api_mode=resolved_api_mode)
|
||||
if rpc_client is not None:
|
||||
logger.info(
|
||||
"Vision auto-detect: using main provider %s (%s)",
|
||||
main_provider, rpc_model or vision_model,
|
||||
)
|
||||
return _finalize(
|
||||
main_provider, rpc_client, rpc_model or vision_model)
|
||||
|
||||
# Fall back through aggregators.
|
||||
# Fall back through aggregators (uses their dedicated vision model,
|
||||
# not the user's main model) when main provider has no client.
|
||||
for candidate in _VISION_AUTO_PROVIDER_ORDER:
|
||||
if candidate == main_provider:
|
||||
continue # already tried above
|
||||
@@ -2293,6 +2335,10 @@ def _build_call_kwargs(
|
||||
"timeout": timeout,
|
||||
}
|
||||
|
||||
fixed_temperature = _fixed_temperature_for_model(model)
|
||||
if fixed_temperature is not None:
|
||||
temperature = fixed_temperature
|
||||
|
||||
# Opus 4.7+ rejects any non-default temperature/top_p/top_k — silently
|
||||
# drop here so auxiliary callers that hardcode temperature (e.g. 0.3 on
|
||||
# flush_memories, 0 on structured-JSON extraction) don't 400 the moment
|
||||
|
||||
@@ -1130,6 +1130,14 @@ def _seed_from_singletons(provider: str, entries: List[PooledCredential]) -> Tup
|
||||
state = _load_provider_state(auth_store, "nous")
|
||||
if state:
|
||||
active_sources.add("device_code")
|
||||
# Prefer a user-supplied label embedded in the singleton state
|
||||
# (set by persist_nous_credentials(label=...) when the user ran
|
||||
# `hermes auth add nous --label <name>`). Fall back to the
|
||||
# auto-derived token fingerprint for logins that didn't supply one.
|
||||
custom_label = str(state.get("label") or "").strip()
|
||||
seeded_label = custom_label or label_from_token(
|
||||
state.get("access_token", ""), "device_code"
|
||||
)
|
||||
changed |= _upsert_entry(
|
||||
entries,
|
||||
provider,
|
||||
@@ -1148,7 +1156,7 @@ def _seed_from_singletons(provider: str, entries: List[PooledCredential]) -> Tup
|
||||
"agent_key": state.get("agent_key"),
|
||||
"agent_key_expires_at": state.get("agent_key_expires_at"),
|
||||
"tls": state.get("tls") if isinstance(state.get("tls"), dict) else None,
|
||||
"label": label_from_token(state.get("access_token", ""), "device_code"),
|
||||
"label": seeded_label,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -747,18 +747,149 @@ class GeminiCloudCodeClient:
|
||||
|
||||
|
||||
def _gemini_http_error(response: httpx.Response) -> CodeAssistError:
|
||||
"""Translate an httpx response into a CodeAssistError with rich metadata.
|
||||
|
||||
Parses Google's error envelope (``{"error": {"code", "message", "status",
|
||||
"details": [...]}}``) so the agent's error classifier can reason about
|
||||
the failure — ``status_code`` enables the rate_limit / auth classification
|
||||
paths, and ``response`` lets the main loop honor ``Retry-After`` just
|
||||
like it does for OpenAI SDK exceptions.
|
||||
|
||||
Also lifts a few recognizable Google conditions into human-readable
|
||||
messages so the user sees something better than a 500-char JSON dump:
|
||||
|
||||
MODEL_CAPACITY_EXHAUSTED → "Gemini model capacity exhausted for
|
||||
<model>. This is a Google-side throttle..."
|
||||
RESOURCE_EXHAUSTED w/o reason → quota-style message
|
||||
404 → "Model <name> not found at cloudcode-pa..."
|
||||
"""
|
||||
status = response.status_code
|
||||
|
||||
# Parse the body once, surviving any weird encodings.
|
||||
body_text = ""
|
||||
body_json: Dict[str, Any] = {}
|
||||
try:
|
||||
body = response.text[:500]
|
||||
body_text = response.text
|
||||
except Exception:
|
||||
body = ""
|
||||
# Let run_agent's retry logic see auth errors as rotatable via `api_key`
|
||||
body_text = ""
|
||||
if body_text:
|
||||
try:
|
||||
parsed = json.loads(body_text)
|
||||
if isinstance(parsed, dict):
|
||||
body_json = parsed
|
||||
except (ValueError, TypeError):
|
||||
body_json = {}
|
||||
|
||||
# Dig into Google's error envelope. Shape is:
|
||||
# {"error": {"code": 429, "message": "...", "status": "RESOURCE_EXHAUSTED",
|
||||
# "details": [{"@type": ".../ErrorInfo", "reason": "MODEL_CAPACITY_EXHAUSTED",
|
||||
# "metadata": {...}},
|
||||
# {"@type": ".../RetryInfo", "retryDelay": "30s"}]}}
|
||||
err_obj = body_json.get("error") if isinstance(body_json, dict) else None
|
||||
if not isinstance(err_obj, dict):
|
||||
err_obj = {}
|
||||
err_status = str(err_obj.get("status") or "").strip()
|
||||
err_message = str(err_obj.get("message") or "").strip()
|
||||
err_details_list = err_obj.get("details") if isinstance(err_obj.get("details"), list) else []
|
||||
|
||||
# Extract google.rpc.ErrorInfo reason + metadata. There may be more
|
||||
# than one ErrorInfo (rare), so we pick the first one with a reason.
|
||||
error_reason = ""
|
||||
error_metadata: Dict[str, Any] = {}
|
||||
retry_delay_seconds: Optional[float] = None
|
||||
for detail in err_details_list:
|
||||
if not isinstance(detail, dict):
|
||||
continue
|
||||
type_url = str(detail.get("@type") or "")
|
||||
if not error_reason and type_url.endswith("/google.rpc.ErrorInfo"):
|
||||
reason = detail.get("reason")
|
||||
if isinstance(reason, str) and reason:
|
||||
error_reason = reason
|
||||
md = detail.get("metadata")
|
||||
if isinstance(md, dict):
|
||||
error_metadata = md
|
||||
elif retry_delay_seconds is None and type_url.endswith("/google.rpc.RetryInfo"):
|
||||
# retryDelay is a google.protobuf.Duration string like "30s" or "1.5s".
|
||||
delay_raw = detail.get("retryDelay")
|
||||
if isinstance(delay_raw, str) and delay_raw.endswith("s"):
|
||||
try:
|
||||
retry_delay_seconds = float(delay_raw[:-1])
|
||||
except ValueError:
|
||||
pass
|
||||
elif isinstance(delay_raw, (int, float)):
|
||||
retry_delay_seconds = float(delay_raw)
|
||||
|
||||
# Fall back to the Retry-After header if the body didn't include RetryInfo.
|
||||
if retry_delay_seconds is None:
|
||||
try:
|
||||
header_val = response.headers.get("Retry-After") or response.headers.get("retry-after")
|
||||
except Exception:
|
||||
header_val = None
|
||||
if header_val:
|
||||
try:
|
||||
retry_delay_seconds = float(header_val)
|
||||
except (TypeError, ValueError):
|
||||
retry_delay_seconds = None
|
||||
|
||||
# Classify the error code. ``code_assist_rate_limited`` stays the default
|
||||
# for 429s; a more specific reason tag helps downstream callers (e.g. tests,
|
||||
# logs) without changing the rate_limit classification path.
|
||||
code = f"code_assist_http_{status}"
|
||||
if status == 401:
|
||||
code = "code_assist_unauthorized"
|
||||
elif status == 429:
|
||||
code = "code_assist_rate_limited"
|
||||
if error_reason == "MODEL_CAPACITY_EXHAUSTED":
|
||||
code = "code_assist_capacity_exhausted"
|
||||
|
||||
# Build a human-readable message. Keep the status + a raw-body tail for
|
||||
# debugging, but lead with a friendlier summary when we recognize the
|
||||
# Google signal.
|
||||
model_hint = ""
|
||||
if isinstance(error_metadata, dict):
|
||||
model_hint = str(error_metadata.get("model") or error_metadata.get("modelId") or "").strip()
|
||||
|
||||
if status == 429 and error_reason == "MODEL_CAPACITY_EXHAUSTED":
|
||||
target = model_hint or "this Gemini model"
|
||||
message = (
|
||||
f"Gemini capacity exhausted for {target} (Google-side throttle, "
|
||||
f"not a Hermes issue). Try a different Gemini model or set a "
|
||||
f"fallback_providers entry to a non-Gemini provider."
|
||||
)
|
||||
if retry_delay_seconds is not None:
|
||||
message += f" Google suggests retrying in {retry_delay_seconds:g}s."
|
||||
elif status == 429 and err_status == "RESOURCE_EXHAUSTED":
|
||||
message = (
|
||||
f"Gemini quota exhausted ({err_message or 'RESOURCE_EXHAUSTED'}). "
|
||||
f"Check /gquota for remaining daily requests."
|
||||
)
|
||||
if retry_delay_seconds is not None:
|
||||
message += f" Retry suggested in {retry_delay_seconds:g}s."
|
||||
elif status == 404:
|
||||
# Google returns 404 when a model has been retired or renamed.
|
||||
target = model_hint or (err_message or "model")
|
||||
message = (
|
||||
f"Code Assist 404: {target} is not available at "
|
||||
f"cloudcode-pa.googleapis.com. It may have been renamed or "
|
||||
f"retired. Check hermes_cli/models.py for the current list."
|
||||
)
|
||||
elif err_message:
|
||||
# Generic fallback with the parsed message.
|
||||
message = f"Code Assist HTTP {status} ({err_status or 'error'}): {err_message}"
|
||||
else:
|
||||
# Last-ditch fallback — raw body snippet.
|
||||
message = f"Code Assist returned HTTP {status}: {body_text[:500]}"
|
||||
|
||||
return CodeAssistError(
|
||||
f"Code Assist returned HTTP {status}: {body}",
|
||||
message,
|
||||
code=code,
|
||||
status_code=status,
|
||||
response=response,
|
||||
retry_after=retry_delay_seconds,
|
||||
details={
|
||||
"status": err_status,
|
||||
"reason": error_reason,
|
||||
"metadata": error_metadata,
|
||||
"message": err_message,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -68,9 +68,45 @@ _ONBOARDING_POLL_INTERVAL_SECONDS = 5.0
|
||||
|
||||
|
||||
class CodeAssistError(RuntimeError):
|
||||
def __init__(self, message: str, *, code: str = "code_assist_error") -> None:
|
||||
"""Exception raised by the Code Assist (``cloudcode-pa``) integration.
|
||||
|
||||
Carries HTTP status / response / retry-after metadata so the agent's
|
||||
``error_classifier._extract_status_code`` and the main loop's Retry-After
|
||||
handling (which walks ``error.response.headers``) pick up the right
|
||||
signals. Without these, 429s from the OAuth path look like opaque
|
||||
``RuntimeError`` and skip the rate-limit path.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
code: str = "code_assist_error",
|
||||
status_code: Optional[int] = None,
|
||||
response: Any = None,
|
||||
retry_after: Optional[float] = None,
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.code = code
|
||||
# ``status_code`` is picked up by ``agent.error_classifier._extract_status_code``
|
||||
# so a 429 from Code Assist classifies as FailoverReason.rate_limit and
|
||||
# triggers the main loop's fallback_providers chain the same way SDK
|
||||
# errors do.
|
||||
self.status_code = status_code
|
||||
# ``response`` is the underlying ``httpx.Response`` (or a shim with a
|
||||
# ``.headers`` mapping and ``.json()`` method). The main loop reads
|
||||
# ``error.response.headers["Retry-After"]`` to honor Google's retry
|
||||
# hints when the backend throttles us.
|
||||
self.response = response
|
||||
# Parsed ``Retry-After`` seconds (kept separately for convenience —
|
||||
# Google returns retry hints in both the header and the error body's
|
||||
# ``google.rpc.RetryInfo`` details, and we pick whichever we found).
|
||||
self.retry_after = retry_after
|
||||
# Parsed structured error details from the Google error envelope
|
||||
# (e.g. ``{"reason": "MODEL_CAPACITY_EXHAUSTED", "status": "RESOURCE_EXHAUSTED"}``).
|
||||
# Useful for logging and for tests that want to assert on specifics.
|
||||
self.details = details or {}
|
||||
|
||||
|
||||
class ProjectIdRequiredError(CodeAssistError):
|
||||
|
||||
@@ -38,6 +38,7 @@ _PROVIDER_PREFIXES: frozenset[str] = frozenset({
|
||||
"mimo", "xiaomi-mimo",
|
||||
"arcee-ai", "arceeai",
|
||||
"xai", "x-ai", "x.ai", "grok",
|
||||
"nvidia", "nim", "nvidia-nim", "nemotron",
|
||||
"qwen-portal",
|
||||
})
|
||||
|
||||
@@ -124,7 +125,6 @@ DEFAULT_CONTEXT_LENGTHS = {
|
||||
"gemini": 1048576,
|
||||
# Gemma (open models served via AI Studio)
|
||||
"gemma-4-31b": 256000,
|
||||
"gemma-4-26b": 256000,
|
||||
"gemma-3": 131072,
|
||||
"gemma": 8192, # fallback for older gemma models
|
||||
# DeepSeek
|
||||
@@ -158,6 +158,8 @@ DEFAULT_CONTEXT_LENGTHS = {
|
||||
"grok": 131072, # catch-all (grok-beta, unknown grok-*)
|
||||
# Kimi
|
||||
"kimi": 262144,
|
||||
# Nemotron — NVIDIA's open-weights series (128K context across all sizes)
|
||||
"nemotron": 131072,
|
||||
# Arcee
|
||||
"trinity": 262144,
|
||||
# OpenRouter
|
||||
@@ -240,6 +242,7 @@ _URL_TO_PROVIDER: Dict[str, str] = {
|
||||
"api.fireworks.ai": "fireworks",
|
||||
"opencode.ai": "opencode-go",
|
||||
"api.x.ai": "xai",
|
||||
"integrate.api.nvidia.com": "nvidia",
|
||||
"api.xiaomimimo.com": "xiaomi",
|
||||
"xiaomimimo.com": "xiaomi",
|
||||
"ollama.com": "ollama-cloud",
|
||||
|
||||
@@ -654,7 +654,7 @@ def build_skills_system_prompt(
|
||||
):
|
||||
continue
|
||||
skills_by_category.setdefault(category, []).append(
|
||||
(skill_name, entry.get("description", ""))
|
||||
(frontmatter_name, entry.get("description", ""))
|
||||
)
|
||||
category_descriptions = {
|
||||
str(k): str(v)
|
||||
@@ -679,7 +679,7 @@ def build_skills_system_prompt(
|
||||
):
|
||||
continue
|
||||
skills_by_category.setdefault(entry["category"], []).append(
|
||||
(skill_name, entry["description"])
|
||||
(entry["frontmatter_name"], entry["description"])
|
||||
)
|
||||
|
||||
# Read category-level DESCRIPTION.md files
|
||||
@@ -722,9 +722,10 @@ def build_skills_system_prompt(
|
||||
continue
|
||||
entry = _build_snapshot_entry(skill_file, ext_dir, frontmatter, desc)
|
||||
skill_name = entry["skill_name"]
|
||||
if skill_name in seen_skill_names:
|
||||
frontmatter_name = entry["frontmatter_name"]
|
||||
if frontmatter_name in seen_skill_names:
|
||||
continue
|
||||
if entry["frontmatter_name"] in disabled or skill_name in disabled:
|
||||
if frontmatter_name in disabled or skill_name in disabled:
|
||||
continue
|
||||
if not _skill_should_show(
|
||||
extract_skill_conditions(frontmatter),
|
||||
@@ -732,9 +733,9 @@ def build_skills_system_prompt(
|
||||
available_toolsets,
|
||||
):
|
||||
continue
|
||||
seen_skill_names.add(skill_name)
|
||||
seen_skill_names.add(frontmatter_name)
|
||||
skills_by_category.setdefault(entry["category"], []).append(
|
||||
(skill_name, entry["description"])
|
||||
(frontmatter_name, entry["description"])
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Error reading external skill %s: %s", skill_file, e)
|
||||
|
||||
@@ -24,6 +24,7 @@ model:
|
||||
# "minimax" - MiniMax global (requires: MINIMAX_API_KEY)
|
||||
# "minimax-cn" - MiniMax China (requires: MINIMAX_CN_API_KEY)
|
||||
# "huggingface" - Hugging Face Inference (requires: HF_TOKEN)
|
||||
# "nvidia" - NVIDIA NIM / build.nvidia.com (requires: NVIDIA_API_KEY)
|
||||
# "xiaomi" - Xiaomi MiMo (requires: XIAOMI_API_KEY)
|
||||
# "arcee" - Arcee AI Trinity models (requires: ARCEEAI_API_KEY)
|
||||
# "ollama-cloud" - Ollama Cloud (requires: OLLAMA_API_KEY — https://ollama.com/settings)
|
||||
|
||||
@@ -18,6 +18,8 @@ import os
|
||||
import shutil
|
||||
import sys
|
||||
import json
|
||||
import re
|
||||
import base64
|
||||
import atexit
|
||||
import tempfile
|
||||
import time
|
||||
@@ -78,6 +80,42 @@ _project_env = Path(__file__).parent / '.env'
|
||||
load_hermes_dotenv(hermes_home=_hermes_home, project_env=_project_env)
|
||||
|
||||
|
||||
_REASONING_TAGS = (
|
||||
"REASONING_SCRATCHPAD",
|
||||
"think",
|
||||
"reasoning",
|
||||
"THINKING",
|
||||
"thinking",
|
||||
)
|
||||
|
||||
|
||||
def _strip_reasoning_tags(text: str) -> str:
|
||||
cleaned = text
|
||||
for tag in _REASONING_TAGS:
|
||||
cleaned = re.sub(rf"<{tag}>.*?</{tag}>\s*", "", cleaned, flags=re.DOTALL)
|
||||
cleaned = re.sub(rf"<{tag}>.*$", "", cleaned, flags=re.DOTALL)
|
||||
return cleaned.strip()
|
||||
|
||||
|
||||
def _assistant_content_as_text(content: Any) -> str:
|
||||
if content is None:
|
||||
return ""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = [
|
||||
str(part.get("text", ""))
|
||||
for part in content
|
||||
if isinstance(part, dict) and part.get("type") == "text"
|
||||
]
|
||||
return "\n".join(p for p in parts if p)
|
||||
return str(content)
|
||||
|
||||
|
||||
def _assistant_copy_text(content: Any) -> str:
|
||||
return _strip_reasoning_tags(_assistant_content_as_text(content))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Configuration Loading
|
||||
# =============================================================================
|
||||
@@ -1172,6 +1210,10 @@ def _resolve_attachment_path(raw_path: str) -> Path | None:
|
||||
return None
|
||||
|
||||
expanded = os.path.expandvars(os.path.expanduser(token))
|
||||
if os.name != "nt":
|
||||
normalized = expanded.replace("\\", "/")
|
||||
if len(normalized) >= 3 and normalized[1] == ":" and normalized[2] == "/" and normalized[0].isalpha():
|
||||
expanded = f"/mnt/{normalized[0].lower()}/{normalized[3:]}"
|
||||
path = Path(expanded)
|
||||
if not path.is_absolute():
|
||||
base_dir = Path(os.getenv("TERMINAL_CWD", os.getcwd()))
|
||||
@@ -1254,10 +1296,12 @@ def _detect_file_drop(user_input: str) -> "dict | None":
|
||||
or stripped.startswith("~")
|
||||
or stripped.startswith("./")
|
||||
or stripped.startswith("../")
|
||||
or (len(stripped) >= 3 and stripped[1] == ":" and stripped[2] in ("\\", "/") and stripped[0].isalpha())
|
||||
or stripped.startswith('"/')
|
||||
or stripped.startswith('"~')
|
||||
or stripped.startswith("'/")
|
||||
or stripped.startswith("'~")
|
||||
or (len(stripped) >= 4 and stripped[0] in ("'", '"') and stripped[2] == ":" and stripped[3] in ("\\", "/") and stripped[1].isalpha())
|
||||
)
|
||||
if not starts_like_path:
|
||||
return None
|
||||
@@ -3125,21 +3169,6 @@ class HermesCLI:
|
||||
MAX_ASST_LEN = 200 # truncate assistant text
|
||||
MAX_ASST_LINES = 3 # max lines of assistant text
|
||||
|
||||
def _strip_reasoning(text: str) -> str:
|
||||
"""Remove <REASONING_SCRATCHPAD>...</REASONING_SCRATCHPAD> blocks
|
||||
from displayed text (reasoning model internal thoughts)."""
|
||||
import re
|
||||
cleaned = re.sub(
|
||||
r"<REASONING_SCRATCHPAD>.*?</REASONING_SCRATCHPAD>\s*",
|
||||
"", text, flags=re.DOTALL,
|
||||
)
|
||||
# Also strip unclosed reasoning tags at the end
|
||||
cleaned = re.sub(
|
||||
r"<REASONING_SCRATCHPAD>.*$",
|
||||
"", cleaned, flags=re.DOTALL,
|
||||
)
|
||||
return cleaned.strip()
|
||||
|
||||
# Collect displayable entries (skip system, tool-result messages)
|
||||
entries = [] # list of (role, display_text)
|
||||
_last_asst_idx = None # index of last assistant entry
|
||||
@@ -3171,7 +3200,7 @@ class HermesCLI:
|
||||
|
||||
elif role == "assistant":
|
||||
text = "" if content is None else str(content)
|
||||
text = _strip_reasoning(text)
|
||||
text = _strip_reasoning_tags(text)
|
||||
parts = []
|
||||
full_parts = [] # un-truncated version
|
||||
if text:
|
||||
@@ -3510,6 +3539,26 @@ class HermesCLI:
|
||||
killed = process_registry.kill_all()
|
||||
print(f" ✅ Stopped {killed} process(es).")
|
||||
|
||||
def _handle_agents_command(self):
|
||||
"""Handle /agents — show background processes and agent status."""
|
||||
from tools.process_registry import format_uptime_short, process_registry
|
||||
|
||||
processes = process_registry.list_sessions()
|
||||
running = [p for p in processes if p.get("status") == "running"]
|
||||
finished = [p for p in processes if p.get("status") != "running"]
|
||||
|
||||
_cprint(f" Running processes: {len(running)}")
|
||||
for p in running:
|
||||
cmd = p.get("command", "")[:80]
|
||||
up = format_uptime_short(p.get("uptime_seconds", 0))
|
||||
_cprint(f" {p.get('session_id', '?')} · {up} · {cmd}")
|
||||
|
||||
if finished:
|
||||
_cprint(f" Recently finished: {len(finished)}")
|
||||
|
||||
agent_running = getattr(self, "_agent_running", False)
|
||||
_cprint(f" Agent: {'running' if agent_running else 'idle'}")
|
||||
|
||||
def _handle_paste_command(self):
|
||||
"""Handle /paste — explicitly check clipboard for an image.
|
||||
|
||||
@@ -3535,6 +3584,61 @@ class HermesCLI:
|
||||
else:
|
||||
_cprint(f" {_DIM}(._.) No image found in clipboard{_RST}")
|
||||
|
||||
def _write_osc52_clipboard(self, text: str) -> None:
|
||||
"""Copy *text* to terminal clipboard via OSC 52."""
|
||||
payload = base64.b64encode(text.encode("utf-8")).decode("ascii")
|
||||
seq = f"\x1b]52;c;{payload}\x07"
|
||||
out = getattr(self, "_app", None)
|
||||
output = getattr(out, "output", None) if out else None
|
||||
if output and hasattr(output, "write_raw"):
|
||||
output.write_raw(seq)
|
||||
output.flush()
|
||||
return
|
||||
if output and hasattr(output, "write"):
|
||||
output.write(seq)
|
||||
output.flush()
|
||||
return
|
||||
sys.stdout.write(seq)
|
||||
sys.stdout.flush()
|
||||
|
||||
def _handle_copy_command(self, cmd_original: str) -> None:
|
||||
"""Handle /copy [number] — copy assistant output to clipboard."""
|
||||
parts = cmd_original.split(maxsplit=1)
|
||||
arg = parts[1].strip() if len(parts) > 1 else ""
|
||||
|
||||
assistant = [m for m in self.conversation_history if m.get("role") == "assistant"]
|
||||
if not assistant:
|
||||
_cprint(" Nothing to copy yet.")
|
||||
return
|
||||
|
||||
if arg:
|
||||
try:
|
||||
idx = int(arg) - 1
|
||||
except ValueError:
|
||||
_cprint(" Usage: /copy [number]")
|
||||
return
|
||||
if idx < 0 or idx >= len(assistant):
|
||||
_cprint(f" Invalid response number. Use 1-{len(assistant)}.")
|
||||
return
|
||||
else:
|
||||
idx = len(assistant) - 1
|
||||
while idx >= 0 and not _assistant_copy_text(assistant[idx].get("content")):
|
||||
idx -= 1
|
||||
if idx < 0:
|
||||
_cprint(" Nothing to copy in assistant responses yet.")
|
||||
return
|
||||
|
||||
text = _assistant_copy_text(assistant[idx].get("content"))
|
||||
if not text:
|
||||
_cprint(" Nothing to copy in that assistant response.")
|
||||
return
|
||||
|
||||
try:
|
||||
self._write_osc52_clipboard(text)
|
||||
_cprint(f" Copied assistant response #{idx + 1} to clipboard")
|
||||
except Exception as e:
|
||||
_cprint(f" Clipboard copy failed: {e}")
|
||||
|
||||
def _handle_image_command(self, cmd_original: str):
|
||||
"""Handle /image <path> — attach a local image file for the next prompt."""
|
||||
raw_args = (cmd_original.split(None, 1)[1].strip() if " " in cmd_original else "")
|
||||
@@ -3671,7 +3775,7 @@ class HermesCLI:
|
||||
skin = get_active_skin()
|
||||
separator_color = skin.get_color("banner_dim", "#B8860B")
|
||||
accent_color = skin.get_color("ui_accent", "#FFBF00")
|
||||
label_color = skin.get_color("ui_label", "#4dd0e1")
|
||||
label_color = skin.get_color("ui_label", "#DAA520")
|
||||
except Exception:
|
||||
separator_color, accent_color, label_color = "#B8860B", "#FFBF00", "cyan"
|
||||
toolsets_info = ""
|
||||
@@ -5553,6 +5657,8 @@ class HermesCLI:
|
||||
self._show_usage()
|
||||
elif canonical == "insights":
|
||||
self._show_insights(cmd_original)
|
||||
elif canonical == "copy":
|
||||
self._handle_copy_command(cmd_original)
|
||||
elif canonical == "debug":
|
||||
self._handle_debug_command()
|
||||
elif canonical == "paste":
|
||||
@@ -5596,6 +5702,8 @@ class HermesCLI:
|
||||
self._handle_snapshot_command(cmd_original)
|
||||
elif canonical == "stop":
|
||||
self._handle_stop_command()
|
||||
elif canonical == "agents":
|
||||
self._handle_agents_command()
|
||||
elif canonical == "background":
|
||||
self._handle_background_command(cmd_original)
|
||||
elif canonical == "btw":
|
||||
@@ -5612,6 +5720,30 @@ class HermesCLI:
|
||||
_cprint(f" Queued for the next turn: {payload[:80]}{'...' if len(payload) > 80 else ''}")
|
||||
else:
|
||||
_cprint(f" Queued: {payload[:80]}{'...' if len(payload) > 80 else ''}")
|
||||
elif canonical == "steer":
|
||||
# Inject a message after the next tool call without interrupting.
|
||||
# If the agent is actively running, push the text into the agent's
|
||||
# pending_steer slot — the drain hook in _execute_tool_calls_*
|
||||
# will append it to the next tool result's content. If no agent
|
||||
# is running, fall back to queue semantics (same as /queue).
|
||||
parts = cmd_original.split(None, 1)
|
||||
payload = parts[1].strip() if len(parts) > 1 else ""
|
||||
if not payload:
|
||||
_cprint(" Usage: /steer <prompt>")
|
||||
elif self._agent_running and self.agent is not None and hasattr(self.agent, "steer"):
|
||||
try:
|
||||
accepted = self.agent.steer(payload)
|
||||
except Exception as exc:
|
||||
_cprint(f" Steer failed: {exc}")
|
||||
else:
|
||||
if accepted:
|
||||
_cprint(f" ⏩ Steer queued — arrives after the next tool call: {payload[:80]}{'...' if len(payload) > 80 else ''}")
|
||||
else:
|
||||
_cprint(" Steer rejected (empty payload).")
|
||||
else:
|
||||
# No active run — treat as a normal next-turn message.
|
||||
self._pending_input.put(payload)
|
||||
_cprint(f" No agent running; queued as next turn: {payload[:80]}{'...' if len(payload) > 80 else ''}")
|
||||
elif canonical == "skin":
|
||||
self._handle_skin_command(cmd_original)
|
||||
elif canonical == "voice":
|
||||
@@ -6909,8 +7041,7 @@ class HermesCLI:
|
||||
)
|
||||
raise RuntimeError(
|
||||
"Voice mode requires sounddevice and numpy.\n"
|
||||
"Install with: pip install sounddevice numpy\n"
|
||||
"Or: pip install hermes-agent[voice]"
|
||||
f"Install with: {sys.executable} -m pip install sounddevice numpy"
|
||||
)
|
||||
if not reqs.get("stt_available", reqs.get("stt_key_set")):
|
||||
raise RuntimeError(
|
||||
@@ -7186,8 +7317,7 @@ class HermesCLI:
|
||||
_cprint(f" {_DIM}Then install/update the Termux:API Android app for microphone capture{_RST}")
|
||||
_cprint(f" {_BOLD}Option 2: pkg install python-numpy portaudio && python -m pip install sounddevice{_RST}")
|
||||
else:
|
||||
_cprint(f"\n {_BOLD}Install: pip install {' '.join(reqs['missing_packages'])}{_RST}")
|
||||
_cprint(f" {_DIM}Or: pip install hermes-agent[voice]{_RST}")
|
||||
_cprint(f"\n {_BOLD}Install: {sys.executable} -m pip install {' '.join(reqs['missing_packages'])}{_RST}")
|
||||
return
|
||||
|
||||
with self._voice_lock:
|
||||
@@ -8138,7 +8268,15 @@ class HermesCLI:
|
||||
else:
|
||||
print(f"\n⚡ Sending after interrupt: '{preview}'")
|
||||
self._pending_input.put(combined)
|
||||
|
||||
|
||||
# If a /steer was left over (agent finished before another tool
|
||||
# batch could absorb it), deliver it as the next user turn.
|
||||
_leftover_steer = result.get("pending_steer") if result else None
|
||||
if _leftover_steer and hasattr(self, '_pending_input'):
|
||||
preview = _leftover_steer[:60] + ("..." if len(_leftover_steer) > 60 else "")
|
||||
print(f"\n⏩ Delivering leftover /steer as next turn: '{preview}'")
|
||||
self._pending_input.put(_leftover_steer)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
@@ -9959,8 +10097,36 @@ class HermesCLI:
|
||||
|
||||
# Register signal handlers for graceful shutdown on SSH disconnect / SIGTERM
|
||||
def _signal_handler(signum, frame):
|
||||
"""Handle SIGHUP/SIGTERM by triggering graceful cleanup."""
|
||||
"""Handle SIGHUP/SIGTERM by triggering graceful cleanup.
|
||||
|
||||
Calls ``self.agent.interrupt()`` first so the agent daemon
|
||||
thread's poll loop sees the per-thread interrupt and kills the
|
||||
tool's subprocess group via ``_kill_process`` (os.killpg).
|
||||
Without this, the main thread dies from KeyboardInterrupt and
|
||||
the daemon thread is killed with it — before it can run one
|
||||
more poll iteration to clean up the subprocess, which was
|
||||
spawned with ``os.setsid`` and therefore survives as an orphan
|
||||
with PPID=1.
|
||||
|
||||
Grace window (``HERMES_SIGTERM_GRACE``, default 1.5 s) gives
|
||||
the daemon time to: detect the interrupt (next 200 ms poll) →
|
||||
call _kill_process (SIGTERM + 1 s wait + SIGKILL if needed) →
|
||||
return from _wait_for_process. ``time.sleep`` releases the
|
||||
GIL so the daemon actually runs during the window.
|
||||
"""
|
||||
logger.debug("Received signal %s, triggering graceful shutdown", signum)
|
||||
try:
|
||||
if getattr(self, "agent", None) and getattr(self, "_agent_running", False):
|
||||
self.agent.interrupt(f"received signal {signum}")
|
||||
import time as _t
|
||||
try:
|
||||
_grace = float(os.getenv("HERMES_SIGTERM_GRACE", "1.5"))
|
||||
except (TypeError, ValueError):
|
||||
_grace = 1.5
|
||||
if _grace > 0:
|
||||
_t.sleep(_grace)
|
||||
except Exception:
|
||||
pass # never block signal handling
|
||||
raise KeyboardInterrupt()
|
||||
|
||||
try:
|
||||
@@ -10263,6 +10429,45 @@ def main(
|
||||
|
||||
# Register cleanup for single-query mode (interactive mode registers in run())
|
||||
atexit.register(_run_cleanup)
|
||||
|
||||
# Also install signal handlers in single-query / `-q` mode. Interactive
|
||||
# mode registers its own inside HermesCLI.run(), but `-q` runs
|
||||
# cli.agent.run_conversation() below and AIAgent spawns worker threads
|
||||
# for tools — so when SIGTERM arrives on the main thread, raising
|
||||
# KeyboardInterrupt only unwinds the main thread, not the worker
|
||||
# running _wait_for_process. Python then exits, the child subprocess
|
||||
# (spawned with os.setsid, its own process group) is reparented to
|
||||
# init and keeps running as an orphan.
|
||||
#
|
||||
# Fix: route SIGTERM/SIGHUP through agent.interrupt() which sets the
|
||||
# per-thread interrupt flag the worker's poll loop checks every 200 ms.
|
||||
# Give the worker a grace window to call _kill_process (SIGTERM to the
|
||||
# process group, then SIGKILL after 1 s), then raise KeyboardInterrupt
|
||||
# so main unwinds normally. HERMES_SIGTERM_GRACE overrides the 1.5 s
|
||||
# default for debugging.
|
||||
def _signal_handler_q(signum, frame):
|
||||
logger.debug("Received signal %s in single-query mode", signum)
|
||||
try:
|
||||
_agent = getattr(cli, "agent", None)
|
||||
if _agent is not None:
|
||||
_agent.interrupt(f"received signal {signum}")
|
||||
import time as _t
|
||||
try:
|
||||
_grace = float(os.getenv("HERMES_SIGTERM_GRACE", "1.5"))
|
||||
except (TypeError, ValueError):
|
||||
_grace = 1.5
|
||||
if _grace > 0:
|
||||
_t.sleep(_grace)
|
||||
except Exception:
|
||||
pass # never block signal handling
|
||||
raise KeyboardInterrupt()
|
||||
try:
|
||||
import signal as _signal
|
||||
_signal.signal(_signal.SIGTERM, _signal_handler_q)
|
||||
if hasattr(_signal, "SIGHUP"):
|
||||
_signal.signal(_signal.SIGHUP, _signal_handler_q)
|
||||
except Exception:
|
||||
pass # signal handler may fail in restricted environments
|
||||
|
||||
# Handle single query mode
|
||||
if query or image:
|
||||
|
||||
+15
-2
@@ -65,7 +65,15 @@ _HOME_TARGET_ENV_VARS = {
|
||||
"wecom": "WECOM_HOME_CHANNEL",
|
||||
"weixin": "WEIXIN_HOME_CHANNEL",
|
||||
"bluebubbles": "BLUEBUBBLES_HOME_CHANNEL",
|
||||
"qqbot": "QQ_HOME_CHANNEL",
|
||||
"qqbot": "QQBOT_HOME_CHANNEL",
|
||||
}
|
||||
|
||||
# Legacy env var names kept for back-compat. Each entry is the current
|
||||
# primary env var → the previous name. _get_home_target_chat_id falls
|
||||
# back to the legacy name if the primary is unset, so users who set the
|
||||
# old name before the rename keep working until they migrate.
|
||||
_LEGACY_HOME_TARGET_ENV_VARS = {
|
||||
"QQBOT_HOME_CHANNEL": "QQ_HOME_CHANNEL",
|
||||
}
|
||||
|
||||
from cron.jobs import get_due_jobs, mark_job_run, save_job_output, advance_next_run
|
||||
@@ -100,7 +108,12 @@ def _get_home_target_chat_id(platform_name: str) -> str:
|
||||
env_var = _HOME_TARGET_ENV_VARS.get(platform_name.lower())
|
||||
if not env_var:
|
||||
return ""
|
||||
return os.getenv(env_var, "")
|
||||
value = os.getenv(env_var, "")
|
||||
if not value:
|
||||
legacy = _LEGACY_HOME_TARGET_ENV_VARS.get(env_var)
|
||||
if legacy:
|
||||
value = os.getenv(legacy, "")
|
||||
return value
|
||||
|
||||
|
||||
def _resolve_single_delivery_target(job: dict, deliver_value: str) -> Optional[dict]:
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
# Ink Gateway TUI Migration — Post-mortem
|
||||
|
||||
Planned: 2026-04-01 · Delivered: 2026-04 · Status: shipped, classic (prompt_toolkit) CLI still present
|
||||
|
||||
## What Shipped
|
||||
|
||||
Three layers, same repo, Python runtime unchanged.
|
||||
|
||||
```
|
||||
ui-tui (Node/TS) ──stdio JSON-RPC──▶ tui_gateway (Py) ──▶ AIAgent (run_agent.py)
|
||||
```
|
||||
|
||||
### Backend — `tui_gateway/`
|
||||
|
||||
```
|
||||
tui_gateway/
|
||||
├── entry.py # subprocess entrypoint, stdio read/write loop
|
||||
├── server.py # everything: sessions dict, @method handlers, _emit
|
||||
├── render.py # stream renderer, diff rendering, message rendering
|
||||
├── slash_worker.py # subprocess that runs hermes_cli slash commands
|
||||
└── __init__.py
|
||||
```
|
||||
|
||||
`server.py` owns the full runtime-control surface: session store (`_sessions: dict[str, dict]`), method registry (`@method("…")` decorator), event emitter (`_emit`), agent lifecycle (`_make_agent`, `_init_session`, `_wire_callbacks`), approval/sudo/clarify round-trips, and JSON-RPC dispatch.
|
||||
|
||||
Protocol methods (`@method(...)` in `server.py`):
|
||||
|
||||
- session: `session.{create, resume, list, close, interrupt, usage, history, compress, branch, title, save, undo}`
|
||||
- prompt: `prompt.{submit, background, btw}`
|
||||
- tools: `tools.{list, show, configure}`
|
||||
- slash: `slash.exec`, `command.{dispatch, resolve}`, `commands.catalog`, `complete.{path, slash}`
|
||||
- approvals: `approval.respond`, `sudo.respond`, `clarify.respond`, `secret.respond`
|
||||
- config/state: `config.{get, set, show}`, `model.options`, `reload.mcp`
|
||||
- ops: `shell.exec`, `cli.exec`, `terminal.resize`, `input.detect_drop`, `clipboard.paste`, `paste.collapse`, `image.attach`, `process.stop`
|
||||
- misc: `agents.list`, `skills.manage`, `plugins.list`, `cron.manage`, `insights.get`, `rollback.{list, diff, restore}`, `browser.manage`
|
||||
|
||||
Protocol events (`_emit(…)` → handled in `ui-tui/src/app/createGatewayEventHandler.ts`):
|
||||
|
||||
- lifecycle: `gateway.{ready, stderr}`, `session.info`, `skin.changed`
|
||||
- stream: `message.{start, delta, complete}`, `thinking.delta`, `reasoning.{delta, available}`, `status.update`
|
||||
- tools: `tool.{start, progress, complete, generating}`, `subagent.{start, thinking, tool, progress, complete}`
|
||||
- interactive: `approval.request`, `sudo.request`, `clarify.request`, `secret.request`
|
||||
- async: `background.complete`, `btw.complete`, `error`
|
||||
|
||||
### Frontend — `ui-tui/src/`
|
||||
|
||||
```
|
||||
src/
|
||||
├── entry.tsx # node bootstrap: bootBanner → spawn python → dynamic-import Ink → render(<App/>)
|
||||
├── app.tsx # <GatewayProvider> wraps <AppLayout>
|
||||
├── bootBanner.ts # raw-ANSI banner to stdout in ~2ms, pre-React
|
||||
├── gatewayClient.ts # JSON-RPC client over child_process stdio
|
||||
├── gatewayTypes.ts # typed RPC responses + GatewayEvent union
|
||||
├── theme.ts # DEFAULT_THEME + fromSkin
|
||||
│
|
||||
├── app/ # hooks + stores — the orchestration layer
|
||||
│ ├── uiStore.ts # nanostore: sid, info, busy, usage, theme, status…
|
||||
│ ├── turnStore.ts # nanostore: per-turn activity / reasoning / tools
|
||||
│ ├── turnController.ts # imperative singleton for stream-time operations
|
||||
│ ├── overlayStore.ts # nanostore: modal/overlay state
|
||||
│ ├── useMainApp.ts # top-level composition hook
|
||||
│ ├── useSessionLifecycle.ts # session.create/resume/close/reset
|
||||
│ ├── useSubmission.ts # shell/slash/prompt dispatch + interpolation
|
||||
│ ├── useConfigSync.ts # config.get + mtime poll
|
||||
│ ├── useComposerState.ts # input buffer, paste snippets, editor mode
|
||||
│ ├── useInputHandlers.ts # key bindings
|
||||
│ ├── createGatewayEventHandler.ts # event-stream dispatcher
|
||||
│ ├── createSlashHandler.ts # slash command router (registry + python fallback)
|
||||
│ └── slash/commands/ # core.ts, ops.ts, session.ts — TS-owned slash commands
|
||||
│
|
||||
├── components/ # AppLayout, AppChrome, AppOverlays, MessageLine, Thinking, Markdown, pickers, prompts, Banner, SessionPanel
|
||||
├── config/ # env, limits, timing constants
|
||||
├── content/ # charms, faces, fortunes, hotkeys, placeholders, verbs
|
||||
├── domain/ # details, messages, paths, roles, slash, usage, viewport
|
||||
├── protocol/ # interpolation, paste regex
|
||||
├── hooks/ # useCompletion, useInputHistory, useQueue, useVirtualHistory
|
||||
└── lib/ # history, messages, osc52, rpc, text
|
||||
```
|
||||
|
||||
### CLI entry points — `hermes_cli/main.py`
|
||||
|
||||
- `hermes --tui` → `node dist/entry.js` (auto-builds when `.ts`/`.tsx` newer than `dist/entry.js`)
|
||||
- `hermes --tui --dev` → `tsx src/entry.tsx` (skip build)
|
||||
- `HERMES_TUI_DIR=…` → external prebuilt dist (nix, distro packaging)
|
||||
|
||||
## Diverged From Original Plan
|
||||
|
||||
| Plan | Reality | Why |
|
||||
|---|---|---|
|
||||
| `tui_gateway/{controller,session_state,events,protocol}.py` | all collapsed into `server.py` | no second consumer ever emerged, keeping one file cheaper than four |
|
||||
| `ui-tui/src/main.tsx` | split into `entry.tsx` (bootstrap) + `app.tsx` (shell) | boot banner + early python spawn wanted a pre-React moment |
|
||||
| `ui-tui/src/state/store.ts` | three nanostores (`uiStore`, `turnStore`, `overlayStore`) | separate lifetimes: ui persists, turn resets per reply, overlay is modal |
|
||||
| `approval.requested` / `sudo.requested` / `clarify.requested` | `*.request` (no `-ed`) | cosmetic |
|
||||
| `session.cancel` | dropped | `session.interrupt` covers it |
|
||||
| `HERMES_EXPERIMENTAL_TUI=1`, `display.experimental_tui: true`, `/tui on/off/status` | none shipped | `--tui` went from opt-in to first-class without an experimental phase |
|
||||
|
||||
## Post-migration Additions (not in original plan)
|
||||
|
||||
- **Async `session.create`** — returns sid in ~1ms, agent builds on a background thread, `session.info` broadcasts when ready; `_wait_agent()` gates every agent-touching handler via `_sess`
|
||||
- **`bootBanner`** — raw-ANSI logo painted to stdout at T≈2ms, before Ink loads; `<AlternateScreen>` wipes it seamlessly when React mounts
|
||||
- **Selection uniform bg** — `theme.color.selectionBg` wired via `useSelection().setSelectionBgColor`; replaces SGR-inverse per-cell swap that fragmented over amber/gold fg
|
||||
- **Slash command registry** — TS-owned commands in `app/slash/commands/{core,ops,session}.ts`, everything else falls through to `slash.exec` (python worker)
|
||||
- **Turn store + controller split** — imperative singleton (`turnController`) holds refs/timers, nanostore (`turnStore`) holds render-visible state
|
||||
|
||||
## What's Still Open
|
||||
|
||||
- **Classic CLI not deleted.** `cli.py` still has ~80 `prompt_toolkit` references; classic REPL is still the default when `--tui` is absent. The original plan's "Cut 4 · prompt_toolkit removal later" hasn't happened.
|
||||
- **No config-file opt-in.** `HERMES_EXPERIMENTAL_TUI` and `display.experimental_tui` were never built; only the CLI flag exists. Fine for now — if we want "default to TUI", a single line in `main.py` flips it.
|
||||
@@ -6,6 +6,11 @@
|
||||
# All fields are optional — missing values inherit from the default skin.
|
||||
# Activate with: /skin <name> or display.skin: <name> in config.yaml
|
||||
#
|
||||
# Keys are marked:
|
||||
# (both) — applies to both the classic CLI and the TUI
|
||||
# (classic) — classic CLI only (see hermes --tui in user-guide/tui.md)
|
||||
# (tui) — TUI only
|
||||
#
|
||||
# See hermes_cli/skin_engine.py for the full schema reference.
|
||||
# ============================================================================
|
||||
|
||||
@@ -14,43 +19,47 @@ name: example
|
||||
description: An example custom skin — copy and modify this template
|
||||
|
||||
# ── Colors ──────────────────────────────────────────────────────────────────
|
||||
# Hex color values for Rich markup. These control the CLI's visual palette.
|
||||
# Hex color values. These control the visual palette.
|
||||
colors:
|
||||
# Banner panel (the startup welcome box)
|
||||
# Banner panel (the startup welcome box) — (both)
|
||||
banner_border: "#CD7F32" # Panel border
|
||||
banner_title: "#FFD700" # Panel title text
|
||||
banner_accent: "#FFBF00" # Section headers (Available Tools, Skills, etc.)
|
||||
banner_dim: "#B8860B" # Dim/muted text (separators, model info)
|
||||
banner_text: "#FFF8DC" # Body text (tool names, skill names)
|
||||
|
||||
# UI elements
|
||||
ui_accent: "#FFBF00" # General accent color
|
||||
# UI elements — (both)
|
||||
ui_accent: "#FFBF00" # General accent (falls back to banner_accent)
|
||||
ui_label: "#4dd0e1" # Labels
|
||||
ui_ok: "#4caf50" # Success indicators
|
||||
ui_error: "#ef5350" # Error indicators
|
||||
ui_warn: "#ffa726" # Warning indicators
|
||||
|
||||
# Input area
|
||||
prompt: "#FFF8DC" # Prompt text color
|
||||
input_rule: "#CD7F32" # Horizontal rule around input
|
||||
prompt: "#FFF8DC" # Prompt text / `❯` glyph color (both)
|
||||
input_rule: "#CD7F32" # Horizontal rule above input (classic)
|
||||
|
||||
# Response box
|
||||
response_border: "#FFD700" # Response box border (ANSI color)
|
||||
# Response box — (classic)
|
||||
response_border: "#FFD700" # Response box border
|
||||
|
||||
# Session display
|
||||
session_label: "#DAA520" # Session label
|
||||
session_border: "#8B8682" # Session ID dim color
|
||||
# Session display — (both)
|
||||
session_label: "#DAA520" # "Session: " label
|
||||
session_border: "#8B8682" # Session ID text
|
||||
|
||||
# TUI surfaces
|
||||
status_bar_bg: "#1a1a2e" # Status / usage bar background
|
||||
voice_status_bg: "#1a1a2e" # Voice-mode badge background
|
||||
completion_menu_bg: "#1a1a2e" # Completion list background
|
||||
completion_menu_current_bg: "#333355" # Active completion row background
|
||||
completion_menu_meta_bg: "#1a1a2e" # Completion meta column background
|
||||
completion_menu_meta_current_bg: "#333355" # Active completion meta background
|
||||
# TUI / CLI surfaces — (classic: status bar, voice badge, completion meta)
|
||||
status_bar_bg: "#1a1a2e" # Status / usage bar background (classic)
|
||||
voice_status_bg: "#1a1a2e" # Voice-mode badge background (classic)
|
||||
completion_menu_bg: "#1a1a2e" # Completion list background (both)
|
||||
completion_menu_current_bg: "#333355" # Active completion row background (both)
|
||||
completion_menu_meta_bg: "#1a1a2e" # Completion meta column bg (classic)
|
||||
completion_menu_meta_current_bg: "#333355" # Active meta bg (classic)
|
||||
|
||||
# Drag-to-select background — (tui)
|
||||
selection_bg: "#3a3a55" # Uniform selection highlight in the TUI
|
||||
|
||||
# ── Spinner ─────────────────────────────────────────────────────────────────
|
||||
# Customize the animated spinner shown during API calls and tool execution.
|
||||
# (classic) — the TUI uses its own animated indicators; spinner config here
|
||||
# is only read by the classic prompt_toolkit CLI.
|
||||
spinner:
|
||||
# Faces shown while waiting for the API response
|
||||
waiting_faces:
|
||||
@@ -78,17 +87,17 @@ spinner:
|
||||
# - ["⟪▲", "▲⟫"]
|
||||
|
||||
# ── Branding ────────────────────────────────────────────────────────────────
|
||||
# Text strings used throughout the CLI interface.
|
||||
# Text strings used throughout the interface.
|
||||
branding:
|
||||
agent_name: "Hermes Agent" # Banner title, about display
|
||||
welcome: "Welcome! Type your message or /help for commands."
|
||||
goodbye: "Goodbye! ⚕" # Exit message
|
||||
response_label: " ⚕ Hermes " # Response box header label
|
||||
prompt_symbol: "❯ " # Input prompt symbol
|
||||
help_header: "(^_^)? Available Commands" # /help header text
|
||||
agent_name: "Hermes Agent" # (both) Banner title, about display
|
||||
welcome: "Welcome! Type your message or /help for commands." # (both)
|
||||
goodbye: "Goodbye! ⚕" # (both) Exit message
|
||||
response_label: " ⚕ Hermes " # (classic) Response box header label
|
||||
prompt_symbol: "❯ " # (both) Input prompt glyph
|
||||
help_header: "(^_^)? Available Commands" # (both) /help overlay title
|
||||
|
||||
# ── Tool Output ─────────────────────────────────────────────────────────────
|
||||
# Character used as the prefix for tool output lines.
|
||||
# Character used as the prefix for tool output lines. (both)
|
||||
# Default is "┊" (thin dotted vertical line). Some alternatives:
|
||||
# "╎" (light triple dash vertical)
|
||||
# "▏" (left one-eighth block)
|
||||
|
||||
Generated
+21
@@ -36,6 +36,26 @@
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"npm-lockfile-fix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1775903712,
|
||||
"narHash": "sha256-2GV79U6iVH4gKAPWYrxUReB0S41ty/Y3dBLquU8AlaA=",
|
||||
"owner": "jeslie0",
|
||||
"repo": "npm-lockfile-fix",
|
||||
"rev": "c6093acb0c0548e0f9b8b3d82918823721930fe8",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "jeslie0",
|
||||
"repo": "npm-lockfile-fix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"pyproject-build-systems": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
@@ -124,6 +144,7 @@
|
||||
"inputs": {
|
||||
"flake-parts": "flake-parts",
|
||||
"nixpkgs": "nixpkgs",
|
||||
"npm-lockfile-fix": "npm-lockfile-fix",
|
||||
"pyproject-build-systems": "pyproject-build-systems",
|
||||
"pyproject-nix": "pyproject-nix_2",
|
||||
"uv2nix": "uv2nix_2"
|
||||
|
||||
@@ -19,11 +19,20 @@
|
||||
url = "github:pyproject-nix/build-system-pkgs";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
npm-lockfile-fix = {
|
||||
url = "github:jeslie0/npm-lockfile-fix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
};
|
||||
|
||||
outputs = inputs:
|
||||
outputs =
|
||||
inputs:
|
||||
inputs.flake-parts.lib.mkFlake { inherit inputs; } {
|
||||
systems = [ "x86_64-linux" "aarch64-linux" "aarch64-darwin" ];
|
||||
systems = [
|
||||
"x86_64-linux"
|
||||
"aarch64-linux"
|
||||
"aarch64-darwin"
|
||||
];
|
||||
|
||||
imports = [
|
||||
./nix/packages.nix
|
||||
|
||||
@@ -100,7 +100,7 @@ def build_channel_directory(adapters: Dict[Any, Any]) -> Dict[str, Any]:
|
||||
|
||||
|
||||
def _build_discord(adapter) -> List[Dict[str, str]]:
|
||||
"""Enumerate all text channels the Discord bot can see."""
|
||||
"""Enumerate all text channels and forum channels the Discord bot can see."""
|
||||
channels = []
|
||||
client = getattr(adapter, "_client", None)
|
||||
if not client:
|
||||
@@ -119,6 +119,15 @@ def _build_discord(adapter) -> List[Dict[str, str]]:
|
||||
"guild": guild.name,
|
||||
"type": "channel",
|
||||
})
|
||||
# Forum channels (type 15) — creating a message auto-spawns a thread post.
|
||||
forums = getattr(guild, "forum_channels", None) or []
|
||||
for ch in forums:
|
||||
channels.append({
|
||||
"id": str(ch.id),
|
||||
"name": ch.name,
|
||||
"guild": guild.name,
|
||||
"type": "forum",
|
||||
})
|
||||
# Also include DM-capable users we've interacted with is not
|
||||
# feasible via guild enumeration; those come from sessions.
|
||||
|
||||
@@ -191,6 +200,15 @@ def load_directory() -> Dict[str, Any]:
|
||||
return {"updated_at": None, "platforms": {}}
|
||||
|
||||
|
||||
def lookup_channel_type(platform_name: str, chat_id: str) -> Optional[str]:
|
||||
"""Return the channel ``type`` string (e.g. ``"channel"``, ``"forum"``) for *chat_id*, or *None* if unknown."""
|
||||
directory = load_directory()
|
||||
for ch in directory.get("platforms", {}).get(platform_name, []):
|
||||
if ch.get("id") == chat_id:
|
||||
return ch.get("type")
|
||||
return None
|
||||
|
||||
|
||||
def resolve_channel_name(platform_name: str, name: str) -> Optional[str]:
|
||||
"""
|
||||
Resolve a human-friendly channel name to a numeric ID.
|
||||
|
||||
+30
-2
@@ -258,6 +258,13 @@ class GatewayConfig:
|
||||
# Streaming configuration
|
||||
streaming: StreamingConfig = field(default_factory=StreamingConfig)
|
||||
|
||||
# Session store pruning: drop SessionEntry records older than this many
|
||||
# days from the in-memory dict and sessions.json. Keeps the store from
|
||||
# growing unbounded in gateways serving many chats/threads/users over
|
||||
# months. Pruning is invisible to users — if they resume, they get a
|
||||
# fresh session exactly as if the reset policy had fired. 0 = disabled.
|
||||
session_store_max_age_days: int = 90
|
||||
|
||||
def get_connected_platforms(self) -> List[Platform]:
|
||||
"""Return list of platforms that are enabled and configured."""
|
||||
connected = []
|
||||
@@ -365,6 +372,7 @@ class GatewayConfig:
|
||||
"thread_sessions_per_user": self.thread_sessions_per_user,
|
||||
"unauthorized_dm_behavior": self.unauthorized_dm_behavior,
|
||||
"streaming": self.streaming.to_dict(),
|
||||
"session_store_max_age_days": self.session_store_max_age_days,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -412,6 +420,13 @@ class GatewayConfig:
|
||||
"pair",
|
||||
)
|
||||
|
||||
try:
|
||||
session_store_max_age_days = int(data.get("session_store_max_age_days", 90))
|
||||
if session_store_max_age_days < 0:
|
||||
session_store_max_age_days = 0
|
||||
except (TypeError, ValueError):
|
||||
session_store_max_age_days = 90
|
||||
|
||||
return cls(
|
||||
platforms=platforms,
|
||||
default_reset_policy=default_policy,
|
||||
@@ -426,6 +441,7 @@ class GatewayConfig:
|
||||
thread_sessions_per_user=_coerce_bool(thread_sessions_per_user, False),
|
||||
unauthorized_dm_behavior=unauthorized_dm_behavior,
|
||||
streaming=StreamingConfig.from_dict(data.get("streaming", {})),
|
||||
session_store_max_age_days=session_store_max_age_days,
|
||||
)
|
||||
|
||||
def get_unauthorized_dm_behavior(self, platform: Optional[Platform] = None) -> str:
|
||||
@@ -1213,12 +1229,24 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
qq_group_allowed = os.getenv("QQ_GROUP_ALLOWED_USERS", "").strip()
|
||||
if qq_group_allowed:
|
||||
extra["group_allow_from"] = qq_group_allowed
|
||||
qq_home = os.getenv("QQ_HOME_CHANNEL", "").strip()
|
||||
qq_home = os.getenv("QQBOT_HOME_CHANNEL", "").strip()
|
||||
qq_home_name_env = "QQBOT_HOME_CHANNEL_NAME"
|
||||
if not qq_home:
|
||||
# Back-compat: accept the pre-rename name and log a one-time warning.
|
||||
legacy_home = os.getenv("QQ_HOME_CHANNEL", "").strip()
|
||||
if legacy_home:
|
||||
qq_home = legacy_home
|
||||
qq_home_name_env = "QQ_HOME_CHANNEL_NAME"
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(
|
||||
"QQ_HOME_CHANNEL is deprecated; rename to QQBOT_HOME_CHANNEL "
|
||||
"in your .env for consistency with the platform key."
|
||||
)
|
||||
if qq_home:
|
||||
config.platforms[Platform.QQBOT].home_channel = HomeChannel(
|
||||
platform=Platform.QQBOT,
|
||||
chat_id=qq_home,
|
||||
name=os.getenv("QQ_HOME_CHANNEL_NAME", "Home"),
|
||||
name=os.getenv("QQBOT_HOME_CHANNEL_NAME") or os.getenv(qq_home_name_env, "Home"),
|
||||
)
|
||||
|
||||
# Session settings
|
||||
|
||||
@@ -669,6 +669,15 @@ class MessageEvent:
|
||||
# Original platform data
|
||||
raw_message: Any = None
|
||||
message_id: Optional[str] = None
|
||||
|
||||
# Platform-specific update identifier. For Telegram this is the
|
||||
# ``update_id`` from the PTB Update wrapper; other platforms currently
|
||||
# ignore it. Used by ``/restart`` to record the triggering update so the
|
||||
# new gateway can advance the Telegram offset past it and avoid processing
|
||||
# the same ``/restart`` twice if PTB's graceful-shutdown ACK times out
|
||||
# ("Error while calling `get_updates` one more time to mark all fetched
|
||||
# updates" in gateway.log).
|
||||
platform_update_id: Optional[int] = None
|
||||
|
||||
# Media attachments
|
||||
# media_urls: local file paths (for vision tool access)
|
||||
@@ -1045,16 +1054,40 @@ class BasePlatformAdapter(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
# Default: the adapter treats ``finalize=True`` on edit_message as a
|
||||
# no-op and is happy to have the stream consumer skip redundant final
|
||||
# edits. Subclasses that *require* an explicit finalize call to close
|
||||
# out the message lifecycle (e.g. rich card / AI assistant surfaces
|
||||
# such as DingTalk AI Cards) override this to True (class attribute or
|
||||
# property) so the stream consumer knows not to short-circuit.
|
||||
REQUIRES_EDIT_FINALIZE: bool = False
|
||||
|
||||
async def edit_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
content: str,
|
||||
*,
|
||||
finalize: bool = False,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Edit a previously sent message. Optional — platforms that don't
|
||||
support editing return success=False and callers fall back to
|
||||
sending a new message.
|
||||
|
||||
``finalize`` signals that this is the last edit in a streaming
|
||||
sequence. Most platforms (Telegram, Slack, Discord, Matrix,
|
||||
etc.) treat it as a no-op because their edit APIs have no notion
|
||||
of message lifecycle state — an edit is an edit. Platforms that
|
||||
render streaming updates with a distinct "in progress" state and
|
||||
require explicit closure (e.g. rich card / AI assistant surfaces
|
||||
such as DingTalk AI Cards) use it to finalize the message and
|
||||
transition the UI out of the streaming indicator — those should
|
||||
also set ``REQUIRES_EDIT_FINALIZE = True`` so callers route a
|
||||
final edit through even when content is unchanged. Callers
|
||||
should set ``finalize=True`` on the final edit of a streamed
|
||||
response (typically when ``got_done`` fires in the stream
|
||||
consumer) and leave it ``False`` on intermediate edits.
|
||||
"""
|
||||
return SendResult(success=False, error="Not supported")
|
||||
|
||||
@@ -1579,7 +1612,9 @@ class BasePlatformAdapter(ABC):
|
||||
# session lifecycle and its cleanup races with the running task
|
||||
# (see PR #4926).
|
||||
cmd = event.get_command()
|
||||
if cmd in ("approve", "deny", "status", "stop", "new", "reset", "background", "restart", "queue", "q"):
|
||||
from hermes_cli.commands import should_bypass_active_session
|
||||
|
||||
if should_bypass_active_session(cmd):
|
||||
logger.debug(
|
||||
"[%s] Command '/%s' bypassing active-session guard for %s",
|
||||
self.name, cmd, session_key,
|
||||
|
||||
+912
-71
File diff suppressed because it is too large
Load Diff
@@ -857,6 +857,9 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
|
||||
When metadata contains a thread_id, the message is sent to that
|
||||
thread instead of the parent channel identified by chat_id.
|
||||
|
||||
Forum channels (type 15) reject direct messages — a thread post is
|
||||
created automatically.
|
||||
"""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
@@ -882,6 +885,10 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
# Forum channels reject channel.send() — create a thread post instead.
|
||||
if self._is_forum_parent(channel):
|
||||
return await self._send_to_forum(channel, content)
|
||||
|
||||
# Format and split message if needed
|
||||
formatted = self.format_message(content)
|
||||
chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH)
|
||||
@@ -945,6 +952,120 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
logger.error("[%s] Failed to send Discord message: %s", self.name, e, exc_info=True)
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def _send_to_forum(self, forum_channel: Any, content: str) -> SendResult:
|
||||
"""Create a thread post in a forum channel with the message as starter content.
|
||||
|
||||
Forum channels (type 15) don't support direct messages. Instead we
|
||||
POST to /channels/{forum_id}/threads with a thread name derived from
|
||||
the first line of the message. Any follow-up chunk failures are
|
||||
reported in ``raw_response['warnings']`` so the caller can surface
|
||||
partial-send issues.
|
||||
"""
|
||||
from tools.send_message_tool import _derive_forum_thread_name
|
||||
|
||||
formatted = self.format_message(content)
|
||||
chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH)
|
||||
|
||||
thread_name = _derive_forum_thread_name(content)
|
||||
|
||||
starter_content = chunks[0] if chunks else thread_name
|
||||
|
||||
try:
|
||||
thread = await forum_channel.create_thread(
|
||||
name=thread_name,
|
||||
content=starter_content,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("[%s] Failed to create forum thread in %s: %s", self.name, forum_channel.id, e)
|
||||
return SendResult(success=False, error=f"Forum thread creation failed: {e}")
|
||||
|
||||
thread_channel = thread if hasattr(thread, "send") else getattr(thread, "thread", None)
|
||||
thread_id = str(getattr(thread_channel, "id", getattr(thread, "id", "")))
|
||||
starter_msg = getattr(thread, "message", None)
|
||||
message_id = str(getattr(starter_msg, "id", thread_id)) if starter_msg else thread_id
|
||||
|
||||
# Send remaining chunks into the newly created thread. Track any
|
||||
# per-chunk failures so the caller sees partial-send outcomes.
|
||||
message_ids = [message_id]
|
||||
warnings: list[str] = []
|
||||
for chunk in chunks[1:]:
|
||||
try:
|
||||
msg = await thread_channel.send(content=chunk)
|
||||
message_ids.append(str(msg.id))
|
||||
except Exception as e:
|
||||
warning = f"Failed to send follow-up chunk to forum thread {thread_id}: {e}"
|
||||
logger.warning("[%s] %s", self.name, warning)
|
||||
warnings.append(warning)
|
||||
|
||||
raw_response: Dict[str, Any] = {"message_ids": message_ids, "thread_id": thread_id}
|
||||
if warnings:
|
||||
raw_response["warnings"] = warnings
|
||||
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=message_ids[0],
|
||||
raw_response=raw_response,
|
||||
)
|
||||
|
||||
async def _forum_post_file(
|
||||
self,
|
||||
forum_channel: Any,
|
||||
*,
|
||||
thread_name: Optional[str] = None,
|
||||
content: str = "",
|
||||
file: Any = None,
|
||||
files: Optional[list] = None,
|
||||
) -> SendResult:
|
||||
"""Create a forum thread whose starter message carries file attachments.
|
||||
|
||||
Used by the send_voice / send_image_file / send_document paths when
|
||||
the target channel is a forum (type 15). ``create_thread`` on a
|
||||
ForumChannel accepts the same file/files/content kwargs as
|
||||
``channel.send``, creating the thread and starter message atomically.
|
||||
"""
|
||||
from tools.send_message_tool import _derive_forum_thread_name
|
||||
|
||||
if not thread_name:
|
||||
# Prefer the text content, fall back to the first attached
|
||||
# filename, fall back to the generic default.
|
||||
hint = content or ""
|
||||
if not hint.strip():
|
||||
if file is not None:
|
||||
hint = getattr(file, "filename", "") or ""
|
||||
elif files:
|
||||
hint = getattr(files[0], "filename", "") or ""
|
||||
thread_name = _derive_forum_thread_name(hint) if hint.strip() else "New Post"
|
||||
|
||||
kwargs: Dict[str, Any] = {"name": thread_name}
|
||||
if content:
|
||||
kwargs["content"] = content
|
||||
if file is not None:
|
||||
kwargs["file"] = file
|
||||
if files:
|
||||
kwargs["files"] = files
|
||||
|
||||
try:
|
||||
thread = await forum_channel.create_thread(**kwargs)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"[%s] Failed to create forum thread with file in %s: %s",
|
||||
self.name,
|
||||
getattr(forum_channel, "id", "?"),
|
||||
e,
|
||||
)
|
||||
return SendResult(success=False, error=f"Forum thread creation failed: {e}")
|
||||
|
||||
thread_channel = thread if hasattr(thread, "send") else getattr(thread, "thread", None)
|
||||
thread_id = str(getattr(thread_channel, "id", getattr(thread, "id", "")))
|
||||
starter_msg = getattr(thread, "message", None)
|
||||
message_id = str(getattr(starter_msg, "id", thread_id)) if starter_msg else thread_id
|
||||
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=message_id,
|
||||
raw_response={"thread_id": thread_id},
|
||||
)
|
||||
|
||||
async def edit_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
@@ -975,7 +1096,11 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
caption: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send a local file as a Discord attachment."""
|
||||
"""Send a local file as a Discord attachment.
|
||||
|
||||
Forum channels (type 15) get a new thread whose starter message
|
||||
carries the file — they reject direct POST /messages.
|
||||
"""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
@@ -988,6 +1113,12 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
filename = file_name or os.path.basename(file_path)
|
||||
with open(file_path, "rb") as fh:
|
||||
file = discord.File(fh, filename=filename)
|
||||
if self._is_forum_parent(channel):
|
||||
return await self._forum_post_file(
|
||||
channel,
|
||||
content=(caption or "").strip(),
|
||||
file=file,
|
||||
)
|
||||
msg = await channel.send(content=caption if caption else None, file=file)
|
||||
return SendResult(success=True, message_id=str(msg.id))
|
||||
|
||||
@@ -1036,6 +1167,18 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
with open(audio_path, "rb") as f:
|
||||
file_data = f.read()
|
||||
|
||||
# Forum channels (type 15) reject direct POST /messages — the
|
||||
# native voice flag path also targets /messages so it would fail
|
||||
# too. Create a thread post with the audio as the starter
|
||||
# attachment instead.
|
||||
if self._is_forum_parent(channel):
|
||||
forum_file = discord.File(io.BytesIO(file_data), filename=filename)
|
||||
return await self._forum_post_file(
|
||||
channel,
|
||||
content=(caption or "").strip(),
|
||||
file=forum_file,
|
||||
)
|
||||
|
||||
# Try sending as a native voice message via raw API (flags=8192).
|
||||
try:
|
||||
import base64
|
||||
@@ -1488,6 +1631,13 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
import io
|
||||
file = discord.File(io.BytesIO(image_data), filename=f"image.{ext}")
|
||||
|
||||
if self._is_forum_parent(channel):
|
||||
return await self._forum_post_file(
|
||||
channel,
|
||||
content=(caption or "").strip(),
|
||||
file=file,
|
||||
)
|
||||
|
||||
msg = await channel.send(
|
||||
content=caption if caption else None,
|
||||
file=file,
|
||||
@@ -1550,6 +1700,13 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
import io
|
||||
file = discord.File(io.BytesIO(animation_data), filename="animation.gif")
|
||||
|
||||
if self._is_forum_parent(channel):
|
||||
return await self._forum_post_file(
|
||||
channel,
|
||||
content=(caption or "").strip(),
|
||||
file=file,
|
||||
)
|
||||
|
||||
msg = await channel.send(
|
||||
content=caption if caption else None,
|
||||
file=file,
|
||||
@@ -1837,6 +1994,11 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
async def slash_stop(interaction: discord.Interaction):
|
||||
await self._run_simple_slash(interaction, "/stop", "Stop requested~")
|
||||
|
||||
@tree.command(name="steer", description="Inject a message after the next tool call (no interrupt)")
|
||||
@discord.app_commands.describe(prompt="Text to inject into the agent's next tool result")
|
||||
async def slash_steer(interaction: discord.Interaction, prompt: str):
|
||||
await self._run_simple_slash(interaction, f"/steer {prompt}".strip())
|
||||
|
||||
@tree.command(name="compress", description="Compress conversation context")
|
||||
async def slash_compress(interaction: discord.Interaction):
|
||||
await self._run_simple_slash(interaction, "/compress")
|
||||
|
||||
@@ -1228,6 +1228,10 @@ class FeishuAdapter(BasePlatformAdapter):
|
||||
.register_p2_im_chat_member_bot_deleted_v1(self._on_bot_removed_from_chat)
|
||||
.register_p2_im_chat_access_event_bot_p2p_chat_entered_v1(self._on_p2p_chat_entered)
|
||||
.register_p2_im_message_recalled_v1(self._on_message_recalled)
|
||||
.register_p2_customized_event(
|
||||
"drive.notice.comment_add_v1",
|
||||
self._on_drive_comment_event,
|
||||
)
|
||||
.build()
|
||||
)
|
||||
|
||||
@@ -1965,6 +1969,25 @@ class FeishuAdapter(BasePlatformAdapter):
|
||||
def _on_message_recalled(self, data: Any) -> None:
|
||||
logger.debug("[Feishu] Message recalled by user")
|
||||
|
||||
def _on_drive_comment_event(self, data: Any) -> None:
|
||||
"""Handle drive document comment notification (drive.notice.comment_add_v1).
|
||||
|
||||
Delegates to :mod:`gateway.platforms.feishu_comment` for parsing,
|
||||
logging, and reaction. Scheduling follows the same
|
||||
``run_coroutine_threadsafe`` pattern used by ``_on_message_event``.
|
||||
"""
|
||||
from gateway.platforms.feishu_comment import handle_drive_comment_event
|
||||
|
||||
loop = self._loop
|
||||
if not self._loop_accepts_callbacks(loop):
|
||||
logger.warning("[Feishu] Dropping drive comment event before adapter loop is ready")
|
||||
return
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
handle_drive_comment_event(self._client, data, self_open_id=self._bot_open_id),
|
||||
loop,
|
||||
)
|
||||
future.add_done_callback(self._log_background_failure)
|
||||
|
||||
def _on_reaction_event(self, event_type: str, data: Any) -> None:
|
||||
"""Route user reactions on bot messages as synthetic text events."""
|
||||
event = getattr(data, "event", None)
|
||||
@@ -2590,6 +2613,8 @@ class FeishuAdapter(BasePlatformAdapter):
|
||||
self._on_reaction_event(event_type, data)
|
||||
elif event_type == "card.action.trigger":
|
||||
self._on_card_action_trigger(data)
|
||||
elif event_type == "drive.notice.comment_add_v1":
|
||||
self._on_drive_comment_event(data)
|
||||
else:
|
||||
logger.debug("[Feishu] Ignoring webhook event type: %s", event_type or "unknown")
|
||||
return web.json_response({"code": 0, "msg": "ok"})
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,429 @@
|
||||
"""
|
||||
Feishu document comment access-control rules.
|
||||
|
||||
3-tier rule resolution: exact doc > wildcard "*" > top-level > code defaults.
|
||||
Each field (enabled/policy/allow_from) falls back independently.
|
||||
Config: ~/.hermes/feishu_comment_rules.json (mtime-cached, hot-reload).
|
||||
Pairing store: ~/.hermes/feishu_comment_pairing.json.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Paths
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Uses the canonical ``get_hermes_home()`` helper (HERMES_HOME-aware and
|
||||
# profile-safe). Resolved at import time; this module is lazy-imported by
|
||||
# the Feishu comment event handler, which runs long after profile overrides
|
||||
# have been applied, so freezing paths here is safe.
|
||||
|
||||
RULES_FILE = get_hermes_home() / "feishu_comment_rules.json"
|
||||
PAIRING_FILE = get_hermes_home() / "feishu_comment_pairing.json"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_VALID_POLICIES = ("allowlist", "pairing")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CommentDocumentRule:
|
||||
"""Per-document rule. ``None`` means 'inherit from lower tier'."""
|
||||
enabled: Optional[bool] = None
|
||||
policy: Optional[str] = None
|
||||
allow_from: Optional[frozenset] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CommentsConfig:
|
||||
"""Top-level comment access config."""
|
||||
enabled: bool = True
|
||||
policy: str = "pairing"
|
||||
allow_from: frozenset = field(default_factory=frozenset)
|
||||
documents: Dict[str, CommentDocumentRule] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ResolvedCommentRule:
|
||||
"""Fully resolved rule after field-by-field fallback."""
|
||||
enabled: bool
|
||||
policy: str
|
||||
allow_from: frozenset
|
||||
match_source: str # e.g. "exact:docx:xxx" | "wildcard" | "top" | "default"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mtime-cached file loading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _MtimeCache:
|
||||
"""Generic mtime-based file cache. ``stat()`` per access, re-read only on change."""
|
||||
|
||||
def __init__(self, path: Path):
|
||||
self._path = path
|
||||
self._mtime: float = 0.0
|
||||
self._data: Optional[dict] = None
|
||||
|
||||
def load(self) -> dict:
|
||||
try:
|
||||
st = self._path.stat()
|
||||
mtime = st.st_mtime
|
||||
except FileNotFoundError:
|
||||
self._mtime = 0.0
|
||||
self._data = {}
|
||||
return {}
|
||||
|
||||
if mtime == self._mtime and self._data is not None:
|
||||
return self._data
|
||||
|
||||
try:
|
||||
with open(self._path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if not isinstance(data, dict):
|
||||
data = {}
|
||||
except (json.JSONDecodeError, OSError):
|
||||
logger.warning("[Feishu-Rules] Failed to read %s, using empty config", self._path)
|
||||
data = {}
|
||||
|
||||
self._mtime = mtime
|
||||
self._data = data
|
||||
return data
|
||||
|
||||
|
||||
_rules_cache = _MtimeCache(RULES_FILE)
|
||||
_pairing_cache = _MtimeCache(PAIRING_FILE)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _parse_frozenset(raw: Any) -> Optional[frozenset]:
|
||||
"""Parse a list of strings into a frozenset; return None if key absent."""
|
||||
if raw is None:
|
||||
return None
|
||||
if isinstance(raw, (list, tuple)):
|
||||
return frozenset(str(u).strip() for u in raw if str(u).strip())
|
||||
return None
|
||||
|
||||
|
||||
def _parse_document_rule(raw: dict) -> CommentDocumentRule:
|
||||
enabled = raw.get("enabled")
|
||||
if enabled is not None:
|
||||
enabled = bool(enabled)
|
||||
policy = raw.get("policy")
|
||||
if policy is not None:
|
||||
policy = str(policy).strip().lower()
|
||||
if policy not in _VALID_POLICIES:
|
||||
policy = None
|
||||
allow_from = _parse_frozenset(raw.get("allow_from"))
|
||||
return CommentDocumentRule(enabled=enabled, policy=policy, allow_from=allow_from)
|
||||
|
||||
|
||||
def load_config() -> CommentsConfig:
|
||||
"""Load comment rules from disk (mtime-cached)."""
|
||||
raw = _rules_cache.load()
|
||||
if not raw:
|
||||
return CommentsConfig()
|
||||
|
||||
documents: Dict[str, CommentDocumentRule] = {}
|
||||
raw_docs = raw.get("documents", {})
|
||||
if isinstance(raw_docs, dict):
|
||||
for key, rule_raw in raw_docs.items():
|
||||
if isinstance(rule_raw, dict):
|
||||
documents[str(key)] = _parse_document_rule(rule_raw)
|
||||
|
||||
policy = str(raw.get("policy", "pairing")).strip().lower()
|
||||
if policy not in _VALID_POLICIES:
|
||||
policy = "pairing"
|
||||
|
||||
return CommentsConfig(
|
||||
enabled=raw.get("enabled", True),
|
||||
policy=policy,
|
||||
allow_from=_parse_frozenset(raw.get("allow_from")) or frozenset(),
|
||||
documents=documents,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Rule resolution (§8.4 field-by-field fallback)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def has_wiki_keys(cfg: CommentsConfig) -> bool:
|
||||
"""Check if any document rule key starts with 'wiki:'."""
|
||||
return any(k.startswith("wiki:") for k in cfg.documents)
|
||||
|
||||
|
||||
def resolve_rule(
|
||||
cfg: CommentsConfig,
|
||||
file_type: str,
|
||||
file_token: str,
|
||||
wiki_token: str = "",
|
||||
) -> ResolvedCommentRule:
|
||||
"""Resolve effective rule: exact doc → wiki key → wildcard → top-level → defaults."""
|
||||
exact_key = f"{file_type}:{file_token}"
|
||||
|
||||
exact = cfg.documents.get(exact_key)
|
||||
exact_src = f"exact:{exact_key}"
|
||||
if exact is None and wiki_token:
|
||||
wiki_key = f"wiki:{wiki_token}"
|
||||
exact = cfg.documents.get(wiki_key)
|
||||
exact_src = f"exact:{wiki_key}"
|
||||
|
||||
wildcard = cfg.documents.get("*")
|
||||
|
||||
layers = []
|
||||
if exact is not None:
|
||||
layers.append((exact, exact_src))
|
||||
if wildcard is not None:
|
||||
layers.append((wildcard, "wildcard"))
|
||||
|
||||
def _pick(field_name: str):
|
||||
for layer, source in layers:
|
||||
val = getattr(layer, field_name)
|
||||
if val is not None:
|
||||
return val, source
|
||||
return getattr(cfg, field_name), "top"
|
||||
|
||||
enabled, en_src = _pick("enabled")
|
||||
policy, pol_src = _pick("policy")
|
||||
allow_from, _ = _pick("allow_from")
|
||||
|
||||
# match_source = highest-priority tier that contributed any field
|
||||
priority_order = {"exact": 0, "wildcard": 1, "top": 2}
|
||||
best_src = min(
|
||||
[en_src, pol_src],
|
||||
key=lambda s: priority_order.get(s.split(":")[0], 3),
|
||||
)
|
||||
|
||||
return ResolvedCommentRule(
|
||||
enabled=enabled,
|
||||
policy=policy,
|
||||
allow_from=allow_from,
|
||||
match_source=best_src,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pairing store
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _load_pairing_approved() -> set:
|
||||
"""Return set of approved user open_ids (mtime-cached)."""
|
||||
data = _pairing_cache.load()
|
||||
approved = data.get("approved", {})
|
||||
if isinstance(approved, dict):
|
||||
return set(approved.keys())
|
||||
if isinstance(approved, list):
|
||||
return set(str(u) for u in approved if u)
|
||||
return set()
|
||||
|
||||
|
||||
def _save_pairing(data: dict) -> None:
|
||||
PAIRING_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = PAIRING_FILE.with_suffix(".tmp")
|
||||
with open(tmp, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
tmp.replace(PAIRING_FILE)
|
||||
# Invalidate cache so next load picks up change
|
||||
_pairing_cache._mtime = 0.0
|
||||
_pairing_cache._data = None
|
||||
|
||||
|
||||
def pairing_add(user_open_id: str) -> bool:
|
||||
"""Add a user to the pairing-approved list. Returns True if newly added."""
|
||||
data = _pairing_cache.load()
|
||||
approved = data.get("approved", {})
|
||||
if not isinstance(approved, dict):
|
||||
approved = {}
|
||||
if user_open_id in approved:
|
||||
return False
|
||||
approved[user_open_id] = {"approved_at": time.time()}
|
||||
data["approved"] = approved
|
||||
_save_pairing(data)
|
||||
return True
|
||||
|
||||
|
||||
def pairing_remove(user_open_id: str) -> bool:
|
||||
"""Remove a user from the pairing-approved list. Returns True if removed."""
|
||||
data = _pairing_cache.load()
|
||||
approved = data.get("approved", {})
|
||||
if not isinstance(approved, dict):
|
||||
return False
|
||||
if user_open_id not in approved:
|
||||
return False
|
||||
del approved[user_open_id]
|
||||
data["approved"] = approved
|
||||
_save_pairing(data)
|
||||
return True
|
||||
|
||||
|
||||
def pairing_list() -> Dict[str, Any]:
|
||||
"""Return the approved dict {user_open_id: {approved_at: ...}}."""
|
||||
data = _pairing_cache.load()
|
||||
approved = data.get("approved", {})
|
||||
return dict(approved) if isinstance(approved, dict) else {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Access check (public API for feishu_comment.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def is_user_allowed(rule: ResolvedCommentRule, user_open_id: str) -> bool:
|
||||
"""Check if user passes the resolved rule's policy gate."""
|
||||
if user_open_id in rule.allow_from:
|
||||
return True
|
||||
if rule.policy == "pairing":
|
||||
return user_open_id in _load_pairing_approved()
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _print_status() -> None:
|
||||
cfg = load_config()
|
||||
print(f"Rules file: {RULES_FILE}")
|
||||
print(f" exists: {RULES_FILE.exists()}")
|
||||
print(f"Pairing file: {PAIRING_FILE}")
|
||||
print(f" exists: {PAIRING_FILE.exists()}")
|
||||
print()
|
||||
print(f"Top-level:")
|
||||
print(f" enabled: {cfg.enabled}")
|
||||
print(f" policy: {cfg.policy}")
|
||||
print(f" allow_from: {sorted(cfg.allow_from) if cfg.allow_from else '[]'}")
|
||||
print()
|
||||
if cfg.documents:
|
||||
print(f"Document rules ({len(cfg.documents)}):")
|
||||
for key, rule in sorted(cfg.documents.items()):
|
||||
parts = []
|
||||
if rule.enabled is not None:
|
||||
parts.append(f"enabled={rule.enabled}")
|
||||
if rule.policy is not None:
|
||||
parts.append(f"policy={rule.policy}")
|
||||
if rule.allow_from is not None:
|
||||
parts.append(f"allow_from={sorted(rule.allow_from)}")
|
||||
print(f" [{key}] {', '.join(parts) if parts else '(empty — inherits all)'}")
|
||||
else:
|
||||
print("Document rules: (none)")
|
||||
print()
|
||||
approved = pairing_list()
|
||||
print(f"Pairing approved ({len(approved)}):")
|
||||
for uid, meta in sorted(approved.items()):
|
||||
ts = meta.get("approved_at", 0)
|
||||
print(f" {uid} (approved_at={ts})")
|
||||
|
||||
|
||||
def _do_check(doc_key: str, user_open_id: str) -> None:
|
||||
cfg = load_config()
|
||||
parts = doc_key.split(":", 1)
|
||||
if len(parts) != 2:
|
||||
print(f"Error: doc_key must be 'fileType:fileToken', got '{doc_key}'")
|
||||
return
|
||||
file_type, file_token = parts
|
||||
rule = resolve_rule(cfg, file_type, file_token)
|
||||
allowed = is_user_allowed(rule, user_open_id)
|
||||
print(f"Document: {doc_key}")
|
||||
print(f"User: {user_open_id}")
|
||||
print(f"Resolved rule:")
|
||||
print(f" enabled: {rule.enabled}")
|
||||
print(f" policy: {rule.policy}")
|
||||
print(f" allow_from: {sorted(rule.allow_from) if rule.allow_from else '[]'}")
|
||||
print(f" match_source: {rule.match_source}")
|
||||
print(f"Result: {'ALLOWED' if allowed else 'DENIED'}")
|
||||
|
||||
|
||||
def _main() -> int:
|
||||
import sys
|
||||
|
||||
try:
|
||||
from hermes_cli.env_loader import load_hermes_dotenv
|
||||
load_hermes_dotenv()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
usage = (
|
||||
"Usage: python -m gateway.platforms.feishu_comment_rules <command> [args]\n"
|
||||
"\n"
|
||||
"Commands:\n"
|
||||
" status Show rules config and pairing state\n"
|
||||
" check <fileType:token> <user> Simulate access check\n"
|
||||
" pairing add <user_open_id> Add user to pairing-approved list\n"
|
||||
" pairing remove <user_open_id> Remove user from pairing-approved list\n"
|
||||
" pairing list List pairing-approved users\n"
|
||||
"\n"
|
||||
f"Rules config file: {RULES_FILE}\n"
|
||||
" Edit this JSON file directly to configure policies and document rules.\n"
|
||||
" Changes take effect on the next comment event (no restart needed).\n"
|
||||
)
|
||||
|
||||
args = sys.argv[1:]
|
||||
if not args:
|
||||
print(usage)
|
||||
return 1
|
||||
|
||||
cmd = args[0]
|
||||
|
||||
if cmd == "status":
|
||||
_print_status()
|
||||
|
||||
elif cmd == "check":
|
||||
if len(args) < 3:
|
||||
print("Usage: check <fileType:fileToken> <user_open_id>")
|
||||
return 1
|
||||
_do_check(args[1], args[2])
|
||||
|
||||
elif cmd == "pairing":
|
||||
if len(args) < 2:
|
||||
print("Usage: pairing <add|remove|list> [args]")
|
||||
return 1
|
||||
sub = args[1]
|
||||
if sub == "add":
|
||||
if len(args) < 3:
|
||||
print("Usage: pairing add <user_open_id>")
|
||||
return 1
|
||||
if pairing_add(args[2]):
|
||||
print(f"Added: {args[2]}")
|
||||
else:
|
||||
print(f"Already approved: {args[2]}")
|
||||
elif sub == "remove":
|
||||
if len(args) < 3:
|
||||
print("Usage: pairing remove <user_open_id>")
|
||||
return 1
|
||||
if pairing_remove(args[2]):
|
||||
print(f"Removed: {args[2]}")
|
||||
else:
|
||||
print(f"Not in approved list: {args[2]}")
|
||||
elif sub == "list":
|
||||
approved = pairing_list()
|
||||
if not approved:
|
||||
print("(no approved users)")
|
||||
for uid, meta in sorted(approved.items()):
|
||||
print(f" {uid} approved_at={meta.get('approved_at', '?')}")
|
||||
else:
|
||||
print(f"Unknown pairing subcommand: {sub}")
|
||||
return 1
|
||||
else:
|
||||
print(f"Unknown command: {cmd}\n")
|
||||
print(usage)
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(_main())
|
||||
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
QQBot platform package.
|
||||
|
||||
Re-exports the main adapter symbols from ``adapter.py`` (the original
|
||||
``qqbot.py``) so that **all existing import paths remain unchanged**::
|
||||
|
||||
from gateway.platforms.qqbot import QQAdapter # works
|
||||
from gateway.platforms.qqbot import check_qq_requirements # works
|
||||
|
||||
New modules:
|
||||
- ``constants`` — shared constants (API URLs, timeouts, message types)
|
||||
- ``utils`` — User-Agent builder, config helpers
|
||||
- ``crypto`` — AES-256-GCM key generation and decryption
|
||||
- ``onboard`` — QR-code scan-to-configure flow
|
||||
"""
|
||||
|
||||
# -- Adapter (original qqbot.py) ------------------------------------------
|
||||
from .adapter import ( # noqa: F401
|
||||
QQAdapter,
|
||||
QQCloseError,
|
||||
check_qq_requirements,
|
||||
_coerce_list,
|
||||
_ssrf_redirect_guard,
|
||||
)
|
||||
|
||||
# -- Onboard (QR-code scan-to-configure) -----------------------------------
|
||||
from .onboard import ( # noqa: F401
|
||||
BindStatus,
|
||||
create_bind_task,
|
||||
poll_bind_result,
|
||||
build_connect_url,
|
||||
)
|
||||
from .crypto import decrypt_secret, generate_bind_key # noqa: F401
|
||||
|
||||
# -- Utils -----------------------------------------------------------------
|
||||
from .utils import build_user_agent, get_api_headers, coerce_list # noqa: F401
|
||||
|
||||
__all__ = [
|
||||
# adapter
|
||||
"QQAdapter",
|
||||
"QQCloseError",
|
||||
"check_qq_requirements",
|
||||
"_coerce_list",
|
||||
"_ssrf_redirect_guard",
|
||||
# onboard
|
||||
"BindStatus",
|
||||
"create_bind_task",
|
||||
"poll_bind_result",
|
||||
"build_connect_url",
|
||||
# crypto
|
||||
"decrypt_secret",
|
||||
"generate_bind_key",
|
||||
# utils
|
||||
"build_user_agent",
|
||||
"get_api_headers",
|
||||
"coerce_list",
|
||||
]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,74 @@
|
||||
"""QQBot package-level constants shared across adapter, onboard, and other modules."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# QQBot adapter version — bump on functional changes to the adapter package.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
QQBOT_VERSION = "1.1.0"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# API endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# The portal domain is configurable via QQ_API_HOST for corporate proxies
|
||||
# or test environments. Default: q.qq.com (production).
|
||||
PORTAL_HOST = os.getenv("QQ_PORTAL_HOST", "q.qq.com")
|
||||
|
||||
API_BASE = "https://api.sgroup.qq.com"
|
||||
TOKEN_URL = "https://bots.qq.com/app/getAppAccessToken"
|
||||
GATEWAY_URL_PATH = "/gateway"
|
||||
|
||||
# QR-code onboard endpoints (on the portal host)
|
||||
ONBOARD_CREATE_PATH = "/lite/create_bind_task"
|
||||
ONBOARD_POLL_PATH = "/lite/poll_bind_result"
|
||||
QR_URL_TEMPLATE = (
|
||||
"https://q.qq.com/qqbot/openclaw/connect.html"
|
||||
"?task_id={task_id}&_wv=2&source=hermes"
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Timeouts & retry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DEFAULT_API_TIMEOUT = 30.0
|
||||
FILE_UPLOAD_TIMEOUT = 120.0
|
||||
CONNECT_TIMEOUT_SECONDS = 20.0
|
||||
|
||||
RECONNECT_BACKOFF = [2, 5, 10, 30, 60]
|
||||
MAX_RECONNECT_ATTEMPTS = 100
|
||||
RATE_LIMIT_DELAY = 60 # seconds
|
||||
QUICK_DISCONNECT_THRESHOLD = 5.0 # seconds
|
||||
MAX_QUICK_DISCONNECT_COUNT = 3
|
||||
|
||||
ONBOARD_POLL_INTERVAL = 2.0 # seconds between poll_bind_result calls
|
||||
ONBOARD_API_TIMEOUT = 10.0
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message limits
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
MAX_MESSAGE_LENGTH = 4000
|
||||
DEDUP_WINDOW_SECONDS = 300
|
||||
DEDUP_MAX_SIZE = 1000
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# QQ Bot message types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
MSG_TYPE_TEXT = 0
|
||||
MSG_TYPE_MARKDOWN = 2
|
||||
MSG_TYPE_MEDIA = 7
|
||||
MSG_TYPE_INPUT_NOTIFY = 6
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# QQ Bot file media types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
MEDIA_TYPE_IMAGE = 1
|
||||
MEDIA_TYPE_VIDEO = 2
|
||||
MEDIA_TYPE_VOICE = 3
|
||||
MEDIA_TYPE_FILE = 4
|
||||
@@ -0,0 +1,45 @@
|
||||
"""AES-256-GCM utilities for QQBot scan-to-configure credential decryption."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import os
|
||||
|
||||
|
||||
def generate_bind_key() -> str:
|
||||
"""Generate a 256-bit random AES key and return it as base64.
|
||||
|
||||
The key is passed to ``create_bind_task`` so the server can encrypt
|
||||
the bot's *client_secret* before returning it. Only this CLI holds
|
||||
the key, ensuring the secret never travels in plaintext.
|
||||
"""
|
||||
return base64.b64encode(os.urandom(32)).decode()
|
||||
|
||||
|
||||
def decrypt_secret(encrypted_base64: str, key_base64: str) -> str:
|
||||
"""Decrypt a base64-encoded AES-256-GCM ciphertext.
|
||||
|
||||
Ciphertext layout (after base64-decoding)::
|
||||
|
||||
IV (12 bytes) ‖ ciphertext (N bytes) ‖ AuthTag (16 bytes)
|
||||
|
||||
Args:
|
||||
encrypted_base64: The ``bot_encrypt_secret`` value from
|
||||
``poll_bind_result``.
|
||||
key_base64: The base64 AES key generated by
|
||||
:func:`generate_bind_key`.
|
||||
|
||||
Returns:
|
||||
The decrypted *client_secret* as a UTF-8 string.
|
||||
"""
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
|
||||
key = base64.b64decode(key_base64)
|
||||
raw = base64.b64decode(encrypted_base64)
|
||||
|
||||
iv = raw[:12]
|
||||
ciphertext_with_tag = raw[12:] # AESGCM expects ciphertext + tag concatenated
|
||||
|
||||
aesgcm = AESGCM(key)
|
||||
plaintext = aesgcm.decrypt(iv, ciphertext_with_tag, None)
|
||||
return plaintext.decode("utf-8")
|
||||
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
QQBot scan-to-configure (QR code onboard) module.
|
||||
|
||||
Calls the ``q.qq.com`` ``create_bind_task`` / ``poll_bind_result`` APIs to
|
||||
generate a QR-code URL and poll for scan completion. On success the caller
|
||||
receives the bot's *app_id*, *client_secret* (decrypted locally), and the
|
||||
scanner's *user_openid* — enough to fully configure the QQBot gateway.
|
||||
|
||||
Reference: https://bot.q.qq.com/wiki/develop/api-v2/
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from enum import IntEnum
|
||||
from typing import Tuple
|
||||
from urllib.parse import quote
|
||||
|
||||
from .constants import (
|
||||
ONBOARD_API_TIMEOUT,
|
||||
ONBOARD_CREATE_PATH,
|
||||
ONBOARD_POLL_PATH,
|
||||
PORTAL_HOST,
|
||||
QR_URL_TEMPLATE,
|
||||
)
|
||||
from .crypto import generate_bind_key
|
||||
from .utils import get_api_headers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bind status
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class BindStatus(IntEnum):
|
||||
"""Status codes returned by ``poll_bind_result``."""
|
||||
|
||||
NONE = 0
|
||||
PENDING = 1
|
||||
COMPLETED = 2
|
||||
EXPIRED = 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def create_bind_task(
|
||||
timeout: float = ONBOARD_API_TIMEOUT,
|
||||
) -> Tuple[str, str]:
|
||||
"""Create a bind task and return *(task_id, aes_key_base64)*.
|
||||
|
||||
The AES key is generated locally and sent to the server so it can
|
||||
encrypt the bot credentials before returning them.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the API returns a non-zero ``retcode``.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
url = f"https://{PORTAL_HOST}{ONBOARD_CREATE_PATH}"
|
||||
key = generate_bind_key()
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, follow_redirects=True) as client:
|
||||
resp = await client.post(url, json={"key": key}, headers=get_api_headers())
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
if data.get("retcode") != 0:
|
||||
raise RuntimeError(data.get("msg", "create_bind_task failed"))
|
||||
|
||||
task_id = data.get("data", {}).get("task_id")
|
||||
if not task_id:
|
||||
raise RuntimeError("create_bind_task: missing task_id in response")
|
||||
|
||||
logger.debug("create_bind_task ok: task_id=%s", task_id)
|
||||
return task_id, key
|
||||
|
||||
|
||||
async def poll_bind_result(
|
||||
task_id: str,
|
||||
timeout: float = ONBOARD_API_TIMEOUT,
|
||||
) -> Tuple[BindStatus, str, str, str]:
|
||||
"""Poll the bind result for *task_id*.
|
||||
|
||||
Returns:
|
||||
A 4-tuple of ``(status, bot_appid, bot_encrypt_secret, user_openid)``.
|
||||
|
||||
* ``bot_encrypt_secret`` is AES-256-GCM encrypted — decrypt it with
|
||||
:func:`~gateway.platforms.qqbot.crypto.decrypt_secret` using the
|
||||
key from :func:`create_bind_task`.
|
||||
* ``user_openid`` is the OpenID of the person who scanned the code
|
||||
(available when ``status == COMPLETED``).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the API returns a non-zero ``retcode``.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
url = f"https://{PORTAL_HOST}{ONBOARD_POLL_PATH}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, follow_redirects=True) as client:
|
||||
resp = await client.post(url, json={"task_id": task_id}, headers=get_api_headers())
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
if data.get("retcode") != 0:
|
||||
raise RuntimeError(data.get("msg", "poll_bind_result failed"))
|
||||
|
||||
d = data.get("data", {})
|
||||
return (
|
||||
BindStatus(d.get("status", 0)),
|
||||
str(d.get("bot_appid", "")),
|
||||
d.get("bot_encrypt_secret", ""),
|
||||
d.get("user_openid", ""),
|
||||
)
|
||||
|
||||
|
||||
def build_connect_url(task_id: str) -> str:
|
||||
"""Build the QR-code target URL for a given *task_id*."""
|
||||
return QR_URL_TEMPLATE.format(task_id=quote(task_id))
|
||||
@@ -0,0 +1,71 @@
|
||||
"""QQBot shared utilities — User-Agent, HTTP helpers, config coercion."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import platform
|
||||
import sys
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from .constants import QQBOT_VERSION
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User-Agent
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _get_hermes_version() -> str:
|
||||
"""Return the hermes-agent package version, or 'dev' if unavailable."""
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
return version("hermes-agent")
|
||||
except Exception:
|
||||
return "dev"
|
||||
|
||||
|
||||
def build_user_agent() -> str:
|
||||
"""Build a descriptive User-Agent string.
|
||||
|
||||
Format::
|
||||
|
||||
QQBotAdapter/<qqbot_version> (Python/<py_version>; <os>; Hermes/<hermes_version>)
|
||||
|
||||
Example::
|
||||
|
||||
QQBotAdapter/1.0.0 (Python/3.11.15; darwin; Hermes/0.9.0)
|
||||
"""
|
||||
py_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
|
||||
os_name = platform.system().lower()
|
||||
hermes_version = _get_hermes_version()
|
||||
return f"QQBotAdapter/{QQBOT_VERSION} (Python/{py_version}; {os_name}; Hermes/{hermes_version})"
|
||||
|
||||
|
||||
def get_api_headers() -> Dict[str, str]:
|
||||
"""Return standard HTTP headers for QQBot API requests.
|
||||
|
||||
Includes ``Content-Type``, ``Accept``, and a dynamic ``User-Agent``.
|
||||
``q.qq.com`` requires ``Accept: application/json`` — without it,
|
||||
the server returns a JavaScript anti-bot challenge page.
|
||||
"""
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"User-Agent": build_user_agent(),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def coerce_list(value: Any) -> List[str]:
|
||||
"""Coerce config values into a trimmed string list.
|
||||
|
||||
Accepts comma-separated strings, lists, tuples, sets, or single values.
|
||||
"""
|
||||
if value is None:
|
||||
return []
|
||||
if isinstance(value, str):
|
||||
return [item.strip() for item in value.split(",") if item.strip()]
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
return [str(item).strip() for item in value if str(item).strip()]
|
||||
return [str(value).strip()] if str(value).strip() else []
|
||||
@@ -160,6 +160,14 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
self._sse_task: Optional[asyncio.Task] = None
|
||||
self._health_monitor_task: Optional[asyncio.Task] = None
|
||||
self._typing_tasks: Dict[str, asyncio.Task] = {}
|
||||
# Per-chat typing-indicator backoff. When signal-cli reports
|
||||
# NETWORK_FAILURE (recipient offline / unroutable), base.py's
|
||||
# _keep_typing refresh loop would otherwise hammer sendTyping every
|
||||
# ~2s indefinitely, producing WARNING-level log spam and pointless
|
||||
# RPC traffic. We track consecutive failures per chat and skip the
|
||||
# RPC during a cooldown window instead.
|
||||
self._typing_failures: Dict[str, int] = {}
|
||||
self._typing_skip_until: Dict[str, float] = {}
|
||||
self._running = False
|
||||
self._last_sse_activity = 0.0
|
||||
self._sse_response: Optional[httpx.Response] = None
|
||||
@@ -548,8 +556,22 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
# JSON-RPC Communication
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _rpc(self, method: str, params: dict, rpc_id: str = None) -> Any:
|
||||
"""Send a JSON-RPC 2.0 request to signal-cli daemon."""
|
||||
async def _rpc(
|
||||
self,
|
||||
method: str,
|
||||
params: dict,
|
||||
rpc_id: str = None,
|
||||
*,
|
||||
log_failures: bool = True,
|
||||
) -> Any:
|
||||
"""Send a JSON-RPC 2.0 request to signal-cli daemon.
|
||||
|
||||
When ``log_failures=False``, error and exception paths log at DEBUG
|
||||
instead of WARNING — used by the typing-indicator path to silence
|
||||
repeated NETWORK_FAILURE spam for unreachable recipients while
|
||||
still preserving visibility for the first occurrence and for
|
||||
unrelated RPCs.
|
||||
"""
|
||||
if not self.client:
|
||||
logger.warning("Signal: RPC called but client not connected")
|
||||
return None
|
||||
@@ -574,13 +596,19 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
data = resp.json()
|
||||
|
||||
if "error" in data:
|
||||
logger.warning("Signal RPC error (%s): %s", method, data["error"])
|
||||
if log_failures:
|
||||
logger.warning("Signal RPC error (%s): %s", method, data["error"])
|
||||
else:
|
||||
logger.debug("Signal RPC error (%s): %s", method, data["error"])
|
||||
return None
|
||||
|
||||
return data.get("result")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Signal RPC %s failed: %s", method, e)
|
||||
if log_failures:
|
||||
logger.warning("Signal RPC %s failed: %s", method, e)
|
||||
else:
|
||||
logger.debug("Signal RPC %s failed: %s", method, e)
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@@ -627,7 +655,28 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
self._recent_sent_timestamps.pop()
|
||||
|
||||
async def send_typing(self, chat_id: str, metadata=None) -> None:
|
||||
"""Send a typing indicator."""
|
||||
"""Send a typing indicator.
|
||||
|
||||
base.py's ``_keep_typing`` refresh loop calls this every ~2s while
|
||||
the agent is processing. If signal-cli returns NETWORK_FAILURE for
|
||||
this recipient (offline, unroutable, group membership lost, etc.)
|
||||
the unmitigated behaviour is: a WARNING log every 2 seconds for as
|
||||
long as the agent keeps running. Instead we:
|
||||
|
||||
- silence the WARNING after the first consecutive failure (subsequent
|
||||
attempts log at DEBUG) so transport issues are still visible once
|
||||
but don't flood the log,
|
||||
- skip the RPC entirely during an exponential cooldown window once
|
||||
three consecutive failures have happened, so we stop hammering
|
||||
signal-cli with requests it can't deliver.
|
||||
|
||||
A successful sendTyping clears the counters.
|
||||
"""
|
||||
now = time.monotonic()
|
||||
skip_until = self._typing_skip_until.get(chat_id, 0.0)
|
||||
if now < skip_until:
|
||||
return
|
||||
|
||||
params: Dict[str, Any] = {
|
||||
"account": self.account,
|
||||
}
|
||||
@@ -637,7 +686,26 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
else:
|
||||
params["recipient"] = [chat_id]
|
||||
|
||||
await self._rpc("sendTyping", params, rpc_id="typing")
|
||||
fails = self._typing_failures.get(chat_id, 0)
|
||||
result = await self._rpc(
|
||||
"sendTyping",
|
||||
params,
|
||||
rpc_id="typing",
|
||||
log_failures=(fails == 0),
|
||||
)
|
||||
|
||||
if result is None:
|
||||
fails += 1
|
||||
self._typing_failures[chat_id] = fails
|
||||
# After 3 consecutive failures, back off exponentially (16s,
|
||||
# 32s, 60s cap) to stop spamming signal-cli for a recipient
|
||||
# that clearly isn't reachable right now.
|
||||
if fails >= 3:
|
||||
backoff = min(60.0, 16.0 * (2 ** (fails - 3)))
|
||||
self._typing_skip_until[chat_id] = now + backoff
|
||||
else:
|
||||
self._typing_failures.pop(chat_id, None)
|
||||
self._typing_skip_until.pop(chat_id, None)
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
@@ -789,6 +857,10 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
# Reset per-chat typing backoff state so the next agent turn starts
|
||||
# fresh rather than inheriting a cooldown from a prior conversation.
|
||||
self._typing_failures.pop(chat_id, None)
|
||||
self._typing_skip_until.pop(chat_id, None)
|
||||
|
||||
async def stop_typing(self, chat_id: str) -> None:
|
||||
"""Public interface for stopping typing — called by base adapter's
|
||||
|
||||
@@ -118,6 +118,84 @@ def _strip_mdv2(text: str) -> str:
|
||||
return cleaned
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Markdown table → code block conversion
|
||||
# ---------------------------------------------------------------------------
|
||||
# Telegram's MarkdownV2 has no table syntax — '|' is just an escaped literal,
|
||||
# so pipe tables render as noisy backslash-pipe text with no alignment.
|
||||
# Wrapping the table in a fenced code block makes Telegram render it as
|
||||
# monospace preformatted text with columns intact.
|
||||
|
||||
# Matches a GFM table delimiter row: optional outer pipes, cells containing
|
||||
# only dashes (with optional leading/trailing colons for alignment) separated
|
||||
# by '|'. Requires at least one internal '|' so lone '---' horizontal rules
|
||||
# are NOT matched.
|
||||
_TABLE_SEPARATOR_RE = re.compile(
|
||||
r'^\s*\|?\s*:?-+:?\s*(?:\|\s*:?-+:?\s*){1,}\|?\s*$'
|
||||
)
|
||||
|
||||
|
||||
def _is_table_row(line: str) -> bool:
|
||||
"""Return True if *line* could plausibly be a table data row."""
|
||||
stripped = line.strip()
|
||||
return bool(stripped) and '|' in stripped
|
||||
|
||||
|
||||
def _wrap_markdown_tables(text: str) -> str:
|
||||
"""Wrap GFM-style pipe tables in ``` fences so Telegram renders them.
|
||||
|
||||
Detected by a row containing '|' immediately followed by a delimiter
|
||||
row matching :data:`_TABLE_SEPARATOR_RE`. Subsequent pipe-containing
|
||||
non-blank lines are consumed as the table body and included in the
|
||||
wrapped block. Tables inside existing fenced code blocks are left
|
||||
alone.
|
||||
"""
|
||||
if '|' not in text or '-' not in text:
|
||||
return text
|
||||
|
||||
lines = text.split('\n')
|
||||
out: list[str] = []
|
||||
in_fence = False
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
stripped = line.lstrip()
|
||||
|
||||
# Track existing fenced code blocks — never touch content inside.
|
||||
if stripped.startswith('```'):
|
||||
in_fence = not in_fence
|
||||
out.append(line)
|
||||
i += 1
|
||||
continue
|
||||
if in_fence:
|
||||
out.append(line)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Look for a header row (contains '|') immediately followed by a
|
||||
# delimiter row.
|
||||
if (
|
||||
'|' in line
|
||||
and i + 1 < len(lines)
|
||||
and _TABLE_SEPARATOR_RE.match(lines[i + 1])
|
||||
):
|
||||
table_block = [line, lines[i + 1]]
|
||||
j = i + 2
|
||||
while j < len(lines) and _is_table_row(lines[j]):
|
||||
table_block.append(lines[j])
|
||||
j += 1
|
||||
out.append('```')
|
||||
out.extend(table_block)
|
||||
out.append('```')
|
||||
i = j
|
||||
continue
|
||||
|
||||
out.append(line)
|
||||
i += 1
|
||||
|
||||
return '\n'.join(out)
|
||||
|
||||
|
||||
class TelegramAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
Telegram bot adapter.
|
||||
@@ -1916,6 +1994,12 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
|
||||
text = content
|
||||
|
||||
# 0) Pre-wrap GFM-style pipe tables in ``` fences. Telegram can't
|
||||
# render tables natively, but fenced code blocks render as
|
||||
# monospace preformatted text with columns intact. The wrapped
|
||||
# tables then flow through step (1) below as protected regions.
|
||||
text = _wrap_markdown_tables(text)
|
||||
|
||||
# 1) Protect fenced code blocks (``` ... ```)
|
||||
# Per MarkdownV2 spec, \ and ` inside pre/code must be escaped.
|
||||
def _protect_fenced(m):
|
||||
@@ -2242,7 +2326,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
if not self._should_process_message(update.message):
|
||||
return
|
||||
|
||||
event = self._build_message_event(update.message, MessageType.TEXT)
|
||||
event = self._build_message_event(update.message, MessageType.TEXT, update_id=update.update_id)
|
||||
event.text = self._clean_bot_trigger_text(event.text)
|
||||
self._enqueue_text_event(event)
|
||||
|
||||
@@ -2253,7 +2337,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
if not self._should_process_message(update.message, is_command=True):
|
||||
return
|
||||
|
||||
event = self._build_message_event(update.message, MessageType.COMMAND)
|
||||
event = self._build_message_event(update.message, MessageType.COMMAND, update_id=update.update_id)
|
||||
await self.handle_message(event)
|
||||
|
||||
async def _handle_location_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
@@ -2289,7 +2373,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
parts.append(f"Map: https://www.google.com/maps/search/?api=1&query={lat},{lon}")
|
||||
parts.append("Ask what they'd like to find nearby (restaurants, cafes, etc.) and any preferences.")
|
||||
|
||||
event = self._build_message_event(msg, MessageType.LOCATION)
|
||||
event = self._build_message_event(msg, MessageType.LOCATION, update_id=update.update_id)
|
||||
event.text = "\n".join(parts)
|
||||
await self.handle_message(event)
|
||||
|
||||
@@ -2440,7 +2524,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
else:
|
||||
msg_type = MessageType.DOCUMENT
|
||||
|
||||
event = self._build_message_event(msg, msg_type)
|
||||
event = self._build_message_event(msg, msg_type, update_id=update.update_id)
|
||||
|
||||
# Add caption as text
|
||||
if msg.caption:
|
||||
@@ -2779,8 +2863,19 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
self.name, cache_key, thread_id,
|
||||
)
|
||||
|
||||
def _build_message_event(self, message: Message, msg_type: MessageType) -> MessageEvent:
|
||||
"""Build a MessageEvent from a Telegram message."""
|
||||
def _build_message_event(
|
||||
self,
|
||||
message: Message,
|
||||
msg_type: MessageType,
|
||||
update_id: Optional[int] = None,
|
||||
) -> MessageEvent:
|
||||
"""Build a MessageEvent from a Telegram message.
|
||||
|
||||
``update_id`` is the ``Update.update_id`` from PTB; passing it through
|
||||
lets ``/restart`` record the triggering offset so the new gateway
|
||||
process can advance past it (prevents ``/restart`` being re-delivered
|
||||
when PTB's graceful-shutdown ACK fails).
|
||||
"""
|
||||
chat = message.chat
|
||||
user = message.from_user
|
||||
|
||||
@@ -2859,6 +2954,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
source=source,
|
||||
raw_message=message,
|
||||
message_id=str(message.message_id),
|
||||
platform_update_id=update_id,
|
||||
reply_to_message_id=reply_to_id,
|
||||
reply_to_text=reply_to_text,
|
||||
auto_skill=topic_skill,
|
||||
|
||||
+41
-10
@@ -180,6 +180,8 @@ class WeComAdapter(BasePlatformAdapter):
|
||||
self._text_batch_split_delay_seconds = float(os.getenv("HERMES_WECOM_TEXT_BATCH_SPLIT_DELAY_SECONDS", "2.0"))
|
||||
self._pending_text_batches: Dict[str, MessageEvent] = {}
|
||||
self._pending_text_batch_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._device_id = uuid.uuid4().hex
|
||||
self._last_chat_req_ids: Dict[str, str] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Connection lifecycle
|
||||
@@ -277,7 +279,11 @@ class WeComAdapter(BasePlatformAdapter):
|
||||
{
|
||||
"cmd": APP_CMD_SUBSCRIBE,
|
||||
"headers": {"req_id": req_id},
|
||||
"body": {"bot_id": self._bot_id, "secret": self._secret},
|
||||
"body": {
|
||||
"bot_id": self._bot_id,
|
||||
"secret": self._secret,
|
||||
"device_id": self._device_id,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -496,6 +502,11 @@ class WeComAdapter(BasePlatformAdapter):
|
||||
logger.debug("[%s] DM sender %s blocked by policy", self.name, sender_id)
|
||||
return
|
||||
|
||||
# Cache the inbound req_id after policy checks so proactive sends to
|
||||
# this chat can fall back to APP_CMD_RESPONSE (required for groups —
|
||||
# WeCom AI Bots cannot initiate APP_CMD_SEND in group chats).
|
||||
self._remember_chat_req_id(chat_id, self._payload_req_id(payload))
|
||||
|
||||
text, reply_text = self._extract_text(body)
|
||||
media_urls, media_types = await self._extract_media(body)
|
||||
message_type = self._derive_message_type(body, text, media_types)
|
||||
@@ -847,6 +858,23 @@ class WeComAdapter(BasePlatformAdapter):
|
||||
while len(self._reply_req_ids) > DEDUP_MAX_SIZE:
|
||||
self._reply_req_ids.pop(next(iter(self._reply_req_ids)))
|
||||
|
||||
def _remember_chat_req_id(self, chat_id: str, req_id: str) -> None:
|
||||
"""Cache the most recent inbound req_id per chat.
|
||||
|
||||
Used as a fallback reply target when we need to send into a group
|
||||
without an explicit ``reply_to`` — WeCom AI Bots are blocked from
|
||||
APP_CMD_SEND in groups and must use APP_CMD_RESPONSE bound to some
|
||||
prior req_id. Bounded like _reply_req_ids so long-running gateways
|
||||
don't leak memory across many chats.
|
||||
"""
|
||||
normalized_chat_id = str(chat_id or "").strip()
|
||||
normalized_req_id = str(req_id or "").strip()
|
||||
if not normalized_chat_id or not normalized_req_id:
|
||||
return
|
||||
self._last_chat_req_ids[normalized_chat_id] = normalized_req_id
|
||||
while len(self._last_chat_req_ids) > DEDUP_MAX_SIZE:
|
||||
self._last_chat_req_ids.pop(next(iter(self._last_chat_req_ids)))
|
||||
|
||||
def _reply_req_id_for_message(self, reply_to: Optional[str]) -> Optional[str]:
|
||||
normalized = str(reply_to or "").strip()
|
||||
if not normalized or normalized.startswith("quote:"):
|
||||
@@ -1163,19 +1191,15 @@ class WeComAdapter(BasePlatformAdapter):
|
||||
self._raise_for_wecom_error(response, "send media message")
|
||||
return response
|
||||
|
||||
async def _send_reply_stream(self, reply_req_id: str, content: str) -> Dict[str, Any]:
|
||||
async def _send_reply_markdown(self, reply_req_id: str, content: str) -> Dict[str, Any]:
|
||||
response = await self._send_reply_request(
|
||||
reply_req_id,
|
||||
{
|
||||
"msgtype": "stream",
|
||||
"stream": {
|
||||
"id": self._new_req_id("stream"),
|
||||
"finish": True,
|
||||
"content": content[:self.MAX_MESSAGE_LENGTH],
|
||||
},
|
||||
"msgtype": "markdown",
|
||||
"markdown": {"content": content[:self.MAX_MESSAGE_LENGTH]},
|
||||
},
|
||||
)
|
||||
self._raise_for_wecom_error(response, "send reply stream")
|
||||
self._raise_for_wecom_error(response, "send reply markdown")
|
||||
return response
|
||||
|
||||
async def _send_reply_media_message(
|
||||
@@ -1235,6 +1259,9 @@ class WeComAdapter(BasePlatformAdapter):
|
||||
return SendResult(success=False, error=prepared["reject_reason"])
|
||||
|
||||
reply_req_id = self._reply_req_id_for_message(reply_to)
|
||||
if not reply_req_id and chat_id in self._last_chat_req_ids:
|
||||
reply_req_id = self._last_chat_req_ids[chat_id]
|
||||
|
||||
try:
|
||||
upload_result = await self._upload_media_bytes(
|
||||
prepared["data"],
|
||||
@@ -1302,8 +1329,12 @@ class WeComAdapter(BasePlatformAdapter):
|
||||
|
||||
try:
|
||||
reply_req_id = self._reply_req_id_for_message(reply_to)
|
||||
|
||||
if not reply_req_id and chat_id in self._last_chat_req_ids:
|
||||
reply_req_id = self._last_chat_req_ids[chat_id]
|
||||
|
||||
if reply_req_id:
|
||||
response = await self._send_reply_stream(reply_req_id, content)
|
||||
response = await self._send_reply_markdown(reply_req_id, content)
|
||||
else:
|
||||
response = await self._send_request(
|
||||
APP_CMD_SEND,
|
||||
|
||||
+394
-25
@@ -2178,6 +2178,30 @@ class GatewayRunner:
|
||||
)
|
||||
except Exception as _e:
|
||||
logger.debug("Idle agent sweep failed: %s", _e)
|
||||
|
||||
# Periodically prune stale SessionStore entries. The
|
||||
# in-memory dict (and sessions.json) would otherwise grow
|
||||
# unbounded in gateways serving many rotating chats /
|
||||
# threads / users over long time windows. Pruning is
|
||||
# invisible to users — a resumed session just gets a
|
||||
# fresh session_id, exactly as if the reset policy fired.
|
||||
_last_prune_ts = getattr(self, "_last_session_store_prune_ts", 0.0)
|
||||
_prune_interval = 3600.0 # once per hour
|
||||
if time.time() - _last_prune_ts > _prune_interval:
|
||||
try:
|
||||
_max_age = int(
|
||||
getattr(self.config, "session_store_max_age_days", 0) or 0
|
||||
)
|
||||
if _max_age > 0:
|
||||
_pruned = self.session_store.prune_old_entries(_max_age)
|
||||
if _pruned:
|
||||
logger.info(
|
||||
"SessionStore prune: dropped %d stale entries",
|
||||
_pruned,
|
||||
)
|
||||
except Exception as _e:
|
||||
logger.debug("SessionStore prune failed: %s", _e)
|
||||
self._last_session_store_prune_ts = time.time()
|
||||
except Exception as e:
|
||||
logger.debug("Session expiry watcher error: %s", e)
|
||||
# Sleep in small increments so we can stop quickly
|
||||
@@ -2384,6 +2408,7 @@ class GatewayRunner:
|
||||
|
||||
self.adapters.clear()
|
||||
self._running_agents.clear()
|
||||
self._running_agents_ts.clear()
|
||||
self._pending_messages.clear()
|
||||
self._pending_approvals.clear()
|
||||
if hasattr(self, '_busy_ack_ts'):
|
||||
@@ -2408,6 +2433,20 @@ class GatewayRunner:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Close SQLite session DBs so the WAL write lock is released.
|
||||
# Without this, --replace and similar restart flows leave the
|
||||
# old gateway's connection holding the WAL lock until Python
|
||||
# actually exits — causing 'database is locked' errors when
|
||||
# the new gateway tries to open the same file.
|
||||
for _db_holder in (self, getattr(self, "session_store", None)):
|
||||
_db = getattr(_db_holder, "_db", None) if _db_holder else None
|
||||
if _db is None or not hasattr(_db, "close"):
|
||||
continue
|
||||
try:
|
||||
_db.close()
|
||||
except Exception as _e:
|
||||
logger.debug("SessionDB close error: %s", _e)
|
||||
|
||||
from gateway.status import remove_pid_file
|
||||
remove_pid_file()
|
||||
|
||||
@@ -2906,16 +2945,17 @@ class GatewayRunner:
|
||||
_quick_key[:30], _stale_age, _stale_idle,
|
||||
_raw_stale_timeout, _stale_detail,
|
||||
)
|
||||
del self._running_agents[_quick_key]
|
||||
self._running_agents_ts.pop(_quick_key, None)
|
||||
self._busy_ack_ts.pop(_quick_key, None)
|
||||
self._release_running_agent_state(_quick_key)
|
||||
|
||||
if _quick_key in self._running_agents:
|
||||
if event.get_command() == "status":
|
||||
return await self._handle_status_command(event)
|
||||
|
||||
# Resolve the command once for all early-intercept checks below.
|
||||
from hermes_cli.commands import resolve_command as _resolve_cmd_inner
|
||||
from hermes_cli.commands import (
|
||||
resolve_command as _resolve_cmd_inner,
|
||||
should_bypass_active_session as _should_bypass_active_inner,
|
||||
)
|
||||
_evt_cmd = event.get_command()
|
||||
_cmd_def_inner = _resolve_cmd_inner(_evt_cmd) if _evt_cmd else None
|
||||
|
||||
@@ -2936,8 +2976,7 @@ class GatewayRunner:
|
||||
if adapter and hasattr(adapter, 'get_pending_message'):
|
||||
adapter.get_pending_message(_quick_key) # consume and discard
|
||||
self._pending_messages.pop(_quick_key, None)
|
||||
if _quick_key in self._running_agents:
|
||||
del self._running_agents[_quick_key]
|
||||
self._release_running_agent_state(_quick_key)
|
||||
logger.info("STOP for session %s — agent interrupted, session lock released", _quick_key[:20])
|
||||
return "⚡ Stopped. You can continue this session."
|
||||
|
||||
@@ -2959,8 +2998,7 @@ class GatewayRunner:
|
||||
self._pending_messages.pop(_quick_key, None)
|
||||
# Clean up the running agent entry so the reset handler
|
||||
# doesn't think an agent is still active.
|
||||
if _quick_key in self._running_agents:
|
||||
del self._running_agents[_quick_key]
|
||||
self._release_running_agent_state(_quick_key)
|
||||
return await self._handle_reset_command(event)
|
||||
|
||||
# /queue <prompt> — queue without interrupting
|
||||
@@ -2981,6 +3019,54 @@ class GatewayRunner:
|
||||
adapter._pending_messages[_quick_key] = queued_event
|
||||
return "Queued for the next turn."
|
||||
|
||||
# /steer <prompt> — inject mid-run after the next tool call.
|
||||
# Unlike /queue (turn boundary), /steer lands BETWEEN tool-call
|
||||
# iterations inside the same agent run, by appending to the
|
||||
# last tool result's content. No interrupt, no new user turn,
|
||||
# no role-alternation violation.
|
||||
if _cmd_def_inner and _cmd_def_inner.name == "steer":
|
||||
steer_text = event.get_command_args().strip()
|
||||
if not steer_text:
|
||||
return "Usage: /steer <prompt>"
|
||||
running_agent = self._running_agents.get(_quick_key)
|
||||
if running_agent is _AGENT_PENDING_SENTINEL:
|
||||
# Agent hasn't started yet — queue as turn-boundary fallback.
|
||||
adapter = self.adapters.get(source.platform)
|
||||
if adapter:
|
||||
from gateway.platforms.base import MessageEvent as _ME, MessageType as _MT
|
||||
queued_event = _ME(
|
||||
text=steer_text,
|
||||
message_type=_MT.TEXT,
|
||||
source=event.source,
|
||||
message_id=event.message_id,
|
||||
channel_prompt=event.channel_prompt,
|
||||
)
|
||||
adapter._pending_messages[_quick_key] = queued_event
|
||||
return "Agent still starting — /steer queued for the next turn."
|
||||
if running_agent and hasattr(running_agent, "steer"):
|
||||
try:
|
||||
accepted = running_agent.steer(steer_text)
|
||||
except Exception as exc:
|
||||
logger.warning("Steer failed for session %s: %s", _quick_key[:20], exc)
|
||||
return f"⚠️ Steer failed: {exc}"
|
||||
if accepted:
|
||||
preview = steer_text[:60] + ("..." if len(steer_text) > 60 else "")
|
||||
return f"⏩ Steer queued — arrives after the next tool call: '{preview}'"
|
||||
return "Steer rejected (empty payload)."
|
||||
# Running agent is missing or lacks steer() — fall back to queue.
|
||||
adapter = self.adapters.get(source.platform)
|
||||
if adapter:
|
||||
from gateway.platforms.base import MessageEvent as _ME, MessageType as _MT
|
||||
queued_event = _ME(
|
||||
text=steer_text,
|
||||
message_type=_MT.TEXT,
|
||||
source=event.source,
|
||||
message_id=event.message_id,
|
||||
channel_prompt=event.channel_prompt,
|
||||
)
|
||||
adapter._pending_messages[_quick_key] = queued_event
|
||||
return "No active agent — /steer queued for the next turn."
|
||||
|
||||
# /model must not be used while the agent is running.
|
||||
if _cmd_def_inner and _cmd_def_inner.name == "model":
|
||||
return "Agent is running — wait or /stop first, then switch models."
|
||||
@@ -2994,11 +3080,29 @@ class GatewayRunner:
|
||||
return await self._handle_approve_command(event)
|
||||
return await self._handle_deny_command(event)
|
||||
|
||||
# /agents (/tasks alias) should be query-only and never interrupt.
|
||||
if _cmd_def_inner and _cmd_def_inner.name == "agents":
|
||||
return await self._handle_agents_command(event)
|
||||
|
||||
# /background must bypass the running-agent guard — it starts a
|
||||
# parallel task and must never interrupt the active conversation.
|
||||
if _cmd_def_inner and _cmd_def_inner.name == "background":
|
||||
return await self._handle_background_command(event)
|
||||
|
||||
# Gateway-handled info/control commands must never fall through to
|
||||
# the interrupt path. If they are queued as pending text, the
|
||||
# slash-command safety net discards them before the user sees any
|
||||
# response.
|
||||
if _cmd_def_inner and _should_bypass_active_inner(_cmd_def_inner.name):
|
||||
if _cmd_def_inner.name == "help":
|
||||
return await self._handle_help_command(event)
|
||||
if _cmd_def_inner.name == "commands":
|
||||
return await self._handle_commands_command(event)
|
||||
if _cmd_def_inner.name == "profile":
|
||||
return await self._handle_profile_command(event)
|
||||
if _cmd_def_inner.name == "update":
|
||||
return await self._handle_update_command(event)
|
||||
|
||||
if event.message_type == MessageType.PHOTO:
|
||||
logger.debug("PRIORITY photo follow-up for session %s — queueing without interrupt", _quick_key[:20])
|
||||
adapter = self.adapters.get(source.platform)
|
||||
@@ -3037,8 +3141,7 @@ class GatewayRunner:
|
||||
# Agent is being set up but not ready yet.
|
||||
if event.get_command() == "stop":
|
||||
# Force-clean the sentinel so the session is unlocked.
|
||||
if _quick_key in self._running_agents:
|
||||
del self._running_agents[_quick_key]
|
||||
self._release_running_agent_state(_quick_key)
|
||||
logger.info("HARD STOP (pending) for session %s — sentinel cleared", _quick_key[:20])
|
||||
return "⚡ Force-stopped. The agent was still starting — session unlocked."
|
||||
# Queue the message so it will be picked up after the
|
||||
@@ -3102,6 +3205,9 @@ class GatewayRunner:
|
||||
if canonical == "status":
|
||||
return await self._handle_status_command(event)
|
||||
|
||||
if canonical == "agents":
|
||||
return await self._handle_agents_command(event)
|
||||
|
||||
if canonical == "restart":
|
||||
return await self._handle_restart_command(event)
|
||||
|
||||
@@ -3202,6 +3308,21 @@ class GatewayRunner:
|
||||
if canonical == "btw":
|
||||
return await self._handle_btw_command(event)
|
||||
|
||||
if canonical == "steer":
|
||||
# No active agent — /steer has no tool call to inject into.
|
||||
# Strip the prefix so downstream treats it as a normal user
|
||||
# message. If the payload is empty, surface the usage hint.
|
||||
steer_payload = event.get_command_args().strip()
|
||||
if not steer_payload:
|
||||
return "Usage: /steer <prompt> (no agent is running; sending as a normal message)"
|
||||
try:
|
||||
event.text = steer_payload
|
||||
except Exception:
|
||||
pass
|
||||
# Do NOT return — fall through to _handle_message_with_agent
|
||||
# at the end of this function so the rewritten text is sent
|
||||
# to the agent as a regular user turn.
|
||||
|
||||
if canonical == "voice":
|
||||
return await self._handle_voice_command(event)
|
||||
|
||||
@@ -3354,8 +3475,13 @@ class GatewayRunner:
|
||||
# (exception, command fallthrough, etc.) the sentinel must
|
||||
# not linger or the session would be permanently locked out.
|
||||
if self._running_agents.get(_quick_key) is _AGENT_PENDING_SENTINEL:
|
||||
del self._running_agents[_quick_key]
|
||||
self._running_agents_ts.pop(_quick_key, None)
|
||||
self._release_running_agent_state(_quick_key)
|
||||
else:
|
||||
# Agent path already cleaned _running_agents; make sure
|
||||
# the paired metadata dicts are gone too.
|
||||
self._running_agents_ts.pop(_quick_key, None)
|
||||
if hasattr(self, "_busy_ack_ts"):
|
||||
self._busy_ack_ts.pop(_quick_key, None)
|
||||
|
||||
async def _prepare_inbound_message_text(
|
||||
self,
|
||||
@@ -4552,6 +4678,96 @@ class GatewayRunner:
|
||||
])
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def _handle_agents_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /agents command - list active agents and running tasks."""
|
||||
from tools.process_registry import format_uptime_short, process_registry
|
||||
|
||||
now = time.time()
|
||||
current_session_key = self._session_key_for_source(event.source)
|
||||
|
||||
running_agents: dict = getattr(self, "_running_agents", {}) or {}
|
||||
running_started: dict = getattr(self, "_running_agents_ts", {}) or {}
|
||||
|
||||
agent_rows: list[dict] = []
|
||||
for session_key, agent in running_agents.items():
|
||||
started = float(running_started.get(session_key, now))
|
||||
elapsed = max(0, int(now - started))
|
||||
is_pending = agent is _AGENT_PENDING_SENTINEL
|
||||
agent_rows.append(
|
||||
{
|
||||
"session_key": session_key,
|
||||
"elapsed": elapsed,
|
||||
"state": "starting" if is_pending else "running",
|
||||
"session_id": "" if is_pending else str(getattr(agent, "session_id", "") or ""),
|
||||
"model": "" if is_pending else str(getattr(agent, "model", "") or ""),
|
||||
}
|
||||
)
|
||||
|
||||
agent_rows.sort(key=lambda row: row["elapsed"], reverse=True)
|
||||
|
||||
running_processes: list[dict] = []
|
||||
try:
|
||||
running_processes = [
|
||||
p for p in process_registry.list_sessions()
|
||||
if p.get("status") == "running"
|
||||
]
|
||||
except Exception:
|
||||
running_processes = []
|
||||
|
||||
background_tasks = [
|
||||
t for t in (getattr(self, "_background_tasks", set()) or set())
|
||||
if hasattr(t, "done") and not t.done()
|
||||
]
|
||||
|
||||
lines = [
|
||||
"🤖 **Active Agents & Tasks**",
|
||||
"",
|
||||
f"**Active agents:** {len(agent_rows)}",
|
||||
]
|
||||
|
||||
if agent_rows:
|
||||
for idx, row in enumerate(agent_rows[:12], 1):
|
||||
current = " · this chat" if row["session_key"] == current_session_key else ""
|
||||
sid = f" · `{row['session_id']}`" if row["session_id"] else ""
|
||||
model = f" · `{row['model']}`" if row["model"] else ""
|
||||
lines.append(
|
||||
f"{idx}. `{row['session_key']}` · {row['state']} · "
|
||||
f"{format_uptime_short(row['elapsed'])}{sid}{model}{current}"
|
||||
)
|
||||
if len(agent_rows) > 12:
|
||||
lines.append(f"... and {len(agent_rows) - 12} more")
|
||||
|
||||
lines.extend(
|
||||
[
|
||||
"",
|
||||
f"**Running background processes:** {len(running_processes)}",
|
||||
]
|
||||
)
|
||||
if running_processes:
|
||||
for proc in running_processes[:12]:
|
||||
cmd = " ".join(str(proc.get("command", "")).split())
|
||||
if len(cmd) > 90:
|
||||
cmd = cmd[:87] + "..."
|
||||
lines.append(
|
||||
f"- `{proc.get('session_id', '?')}` · "
|
||||
f"{format_uptime_short(int(proc.get('uptime_seconds', 0)))} · `{cmd}`"
|
||||
)
|
||||
if len(running_processes) > 12:
|
||||
lines.append(f"... and {len(running_processes) - 12} more")
|
||||
|
||||
lines.extend(
|
||||
[
|
||||
"",
|
||||
f"**Gateway async jobs:** {len(background_tasks)}",
|
||||
]
|
||||
)
|
||||
|
||||
if not agent_rows and not running_processes and not background_tasks:
|
||||
lines.append("")
|
||||
lines.append("No active agents or running tasks.")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def _handle_stop_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /stop command - interrupt a running agent.
|
||||
@@ -4571,22 +4787,40 @@ class GatewayRunner:
|
||||
agent = self._running_agents.get(session_key)
|
||||
if agent is _AGENT_PENDING_SENTINEL:
|
||||
# Force-clean the sentinel so the session is unlocked.
|
||||
if session_key in self._running_agents:
|
||||
del self._running_agents[session_key]
|
||||
self._release_running_agent_state(session_key)
|
||||
logger.info("STOP (pending) for session %s — sentinel cleared", session_key[:20])
|
||||
return "⚡ Stopped. The agent hadn't started yet — you can continue this session."
|
||||
if agent:
|
||||
agent.interrupt("Stop requested")
|
||||
# Force-clean the session lock so a truly hung agent doesn't
|
||||
# keep it locked forever.
|
||||
if session_key in self._running_agents:
|
||||
del self._running_agents[session_key]
|
||||
self._release_running_agent_state(session_key)
|
||||
return "⚡ Stopped. You can continue this session."
|
||||
else:
|
||||
return "No active task to stop."
|
||||
|
||||
async def _handle_restart_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /restart command - drain active work, then restart the gateway."""
|
||||
# Defensive idempotency check: if the previous gateway process
|
||||
# recorded this same /restart (same platform + update_id) and the new
|
||||
# process is seeing it *again*, this is a re-delivery caused by PTB's
|
||||
# graceful-shutdown `get_updates` ACK failing on the way out ("Error
|
||||
# while calling `get_updates` one more time to mark all fetched
|
||||
# updates. Suppressing error to ensure graceful shutdown. When
|
||||
# polling for updates is restarted, updates may be received twice."
|
||||
# in gateway.log). Ignoring the stale redelivery prevents a
|
||||
# self-perpetuating restart loop where every fresh gateway
|
||||
# re-processes the same /restart command and immediately restarts
|
||||
# again.
|
||||
if self._is_stale_restart_redelivery(event):
|
||||
logger.info(
|
||||
"Ignoring redelivered /restart (platform=%s, update_id=%s) — "
|
||||
"already processed by a previous gateway instance.",
|
||||
event.source.platform.value if event.source and event.source.platform else "?",
|
||||
event.platform_update_id,
|
||||
)
|
||||
return ""
|
||||
|
||||
if self._restart_requested or self._draining:
|
||||
count = self._running_agent_count()
|
||||
if count:
|
||||
@@ -4609,6 +4843,26 @@ class GatewayRunner:
|
||||
except Exception as e:
|
||||
logger.debug("Failed to write restart notify file: %s", e)
|
||||
|
||||
# Record the triggering platform + update_id in a dedicated dedup
|
||||
# marker. Unlike .restart_notify.json (which gets unlinked once the
|
||||
# new gateway sends the "gateway restarted" notification), this
|
||||
# marker persists so the new gateway can still detect a delayed
|
||||
# /restart redelivery from Telegram. Overwritten on every /restart.
|
||||
try:
|
||||
import json as _json
|
||||
import time as _time
|
||||
dedup_data = {
|
||||
"platform": event.source.platform.value if event.source.platform else None,
|
||||
"requested_at": _time.time(),
|
||||
}
|
||||
if event.platform_update_id is not None:
|
||||
dedup_data["update_id"] = event.platform_update_id
|
||||
(_hermes_home / ".restart_last_processed.json").write_text(
|
||||
_json.dumps(dedup_data)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to write restart dedup marker: %s", e)
|
||||
|
||||
active_agents = self._running_agent_count()
|
||||
# When running under a service manager (systemd/launchd), use the
|
||||
# service restart path: exit with code 75 so the service manager
|
||||
@@ -4624,6 +4878,58 @@ class GatewayRunner:
|
||||
return f"⏳ Draining {active_agents} active agent(s) before restart..."
|
||||
return "♻ Restarting gateway. If you aren't notified within 60 seconds, restart from the console with `hermes gateway restart`."
|
||||
|
||||
def _is_stale_restart_redelivery(self, event: MessageEvent) -> bool:
|
||||
"""Return True if this /restart is a Telegram re-delivery we already handled.
|
||||
|
||||
The previous gateway wrote ``.restart_last_processed.json`` with the
|
||||
triggering platform + update_id when it processed the /restart. If
|
||||
we now see a /restart on the same platform with an update_id <= that
|
||||
recorded value AND the marker is recent (< 5 minutes), it's a
|
||||
redelivery and should be ignored.
|
||||
|
||||
Only applies to Telegram today (the only platform that exposes a
|
||||
numeric cross-session update ordering); other platforms return False.
|
||||
"""
|
||||
if event is None or event.source is None:
|
||||
return False
|
||||
if event.platform_update_id is None:
|
||||
return False
|
||||
if event.source.platform is None:
|
||||
return False
|
||||
# Only Telegram populates platform_update_id currently; be explicit
|
||||
# so future platforms aren't accidentally gated by this check.
|
||||
try:
|
||||
platform_value = event.source.platform.value
|
||||
except Exception:
|
||||
return False
|
||||
if platform_value != "telegram":
|
||||
return False
|
||||
|
||||
try:
|
||||
import json as _json
|
||||
import time as _time
|
||||
marker_path = _hermes_home / ".restart_last_processed.json"
|
||||
if not marker_path.exists():
|
||||
return False
|
||||
data = _json.loads(marker_path.read_text())
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
if data.get("platform") != platform_value:
|
||||
return False
|
||||
recorded_uid = data.get("update_id")
|
||||
if not isinstance(recorded_uid, int):
|
||||
return False
|
||||
# Staleness guard: ignore markers older than 5 minutes. A legitimately
|
||||
# old marker (e.g. crash recovery where notify never fired) should not
|
||||
# swallow a fresh /restart from the user.
|
||||
requested_at = data.get("requested_at")
|
||||
if isinstance(requested_at, (int, float)):
|
||||
if _time.time() - requested_at > 300:
|
||||
return False
|
||||
return event.platform_update_id <= recorded_uid
|
||||
|
||||
|
||||
async def _handle_help_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /help command - list available commands."""
|
||||
from hermes_cli.commands import gateway_help_lines
|
||||
@@ -5369,8 +5675,7 @@ class GatewayRunner:
|
||||
if "pynacl" in err_lower or "nacl" in err_lower or "davey" in err_lower:
|
||||
return (
|
||||
"Voice dependencies are missing (PyNaCl / davey). "
|
||||
"Install or reinstall Hermes with the messaging extra, e.g. "
|
||||
"`pip install hermes-agent[messaging]`."
|
||||
f"Install with: `{sys.executable} -m pip install PyNaCl`"
|
||||
)
|
||||
return f"Failed to join voice channel: {e}"
|
||||
|
||||
@@ -6496,8 +6801,7 @@ class GatewayRunner:
|
||||
logger.debug("Memory flush on resume failed: %s", e)
|
||||
|
||||
# Clear any running agent for this session key
|
||||
if session_key in self._running_agents:
|
||||
del self._running_agents[session_key]
|
||||
self._release_running_agent_state(session_key)
|
||||
|
||||
# Switch the session entry to point at the old session
|
||||
new_entry = self.session_store.switch_session(session_key, target_id)
|
||||
@@ -7913,6 +8217,30 @@ class GatewayRunner:
|
||||
override = self._session_model_overrides.get(session_key)
|
||||
return override is not None and override.get("model") == agent_model
|
||||
|
||||
def _release_running_agent_state(self, session_key: str) -> None:
|
||||
"""Pop ALL per-running-agent state entries for ``session_key``.
|
||||
|
||||
Replaces ad-hoc ``del self._running_agents[key]`` calls scattered
|
||||
across the gateway. Those sites had drifted: some popped only
|
||||
``_running_agents``; some also ``_running_agents_ts``; only one
|
||||
path also cleared ``_busy_ack_ts``. Each missed entry was a
|
||||
small, persistent leak — a (str_key → float) tuple per session
|
||||
per gateway lifetime.
|
||||
|
||||
Use this at every site that ends a running turn, regardless of
|
||||
cause (normal completion, /stop, /reset, /resume, sentinel
|
||||
cleanup, stale-eviction). Per-session state that PERSISTS
|
||||
across turns (``_session_model_overrides``, ``_voice_mode``,
|
||||
``_pending_approvals``, ``_update_prompt_pending``) is NOT
|
||||
touched here — those have their own lifecycles.
|
||||
"""
|
||||
if not session_key:
|
||||
return
|
||||
self._running_agents.pop(session_key, None)
|
||||
self._running_agents_ts.pop(session_key, None)
|
||||
if hasattr(self, "_busy_ack_ts"):
|
||||
self._busy_ack_ts.pop(session_key, None)
|
||||
|
||||
def _evict_cached_agent(self, session_key: str) -> None:
|
||||
"""Remove a cached agent for a session (called on /new, /model, etc)."""
|
||||
_lock = getattr(self, "_agent_cache_lock", None)
|
||||
@@ -9748,10 +10076,8 @@ class GatewayRunner:
|
||||
|
||||
# Clean up tracking
|
||||
tracking_task.cancel()
|
||||
if session_key and session_key in self._running_agents:
|
||||
del self._running_agents[session_key]
|
||||
if session_key:
|
||||
self._running_agents_ts.pop(session_key, None)
|
||||
self._release_running_agent_state(session_key)
|
||||
if self._draining:
|
||||
self._update_runtime_status("draining")
|
||||
|
||||
@@ -9880,6 +10206,16 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool =
|
||||
"Replacing existing gateway instance (PID %d) with --replace.",
|
||||
existing_pid,
|
||||
)
|
||||
# Record a takeover marker so the target's shutdown handler
|
||||
# recognises its SIGTERM as a planned takeover and exits 0
|
||||
# (rather than exit 1, which would trigger systemd's
|
||||
# Restart=on-failure and start a flap loop against us).
|
||||
# Best-effort — proceed even if the write fails.
|
||||
try:
|
||||
from gateway.status import write_takeover_marker
|
||||
write_takeover_marker(existing_pid)
|
||||
except Exception as e:
|
||||
logger.debug("Could not write takeover marker: %s", e)
|
||||
try:
|
||||
terminate_pid(existing_pid, force=False)
|
||||
except ProcessLookupError:
|
||||
@@ -9889,6 +10225,13 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool =
|
||||
"Permission denied killing PID %d. Cannot replace.",
|
||||
existing_pid,
|
||||
)
|
||||
# Marker is scoped to a specific target; clean it up on
|
||||
# give-up so it doesn't grief an unrelated future shutdown.
|
||||
try:
|
||||
from gateway.status import clear_takeover_marker
|
||||
clear_takeover_marker()
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
# Wait up to 10 seconds for the old process to exit
|
||||
for _ in range(20):
|
||||
@@ -9909,6 +10252,13 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool =
|
||||
except (ProcessLookupError, PermissionError, OSError):
|
||||
pass
|
||||
remove_pid_file()
|
||||
# Clean up any takeover marker the old process didn't consume
|
||||
# (e.g. SIGKILL'd before its shutdown handler could read it).
|
||||
try:
|
||||
from gateway.status import clear_takeover_marker
|
||||
clear_takeover_marker()
|
||||
except Exception:
|
||||
pass
|
||||
# Also release all scoped locks left by the old process.
|
||||
# Stopped (Ctrl+Z) processes don't release locks on exit,
|
||||
# leaving stale lock files that block the new gateway from starting.
|
||||
@@ -9976,8 +10326,27 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool =
|
||||
# Set up signal handlers
|
||||
def shutdown_signal_handler():
|
||||
nonlocal _signal_initiated_shutdown
|
||||
_signal_initiated_shutdown = True
|
||||
logger.info("Received SIGTERM/SIGINT — initiating shutdown")
|
||||
# Planned --replace takeover check: when a sibling gateway is
|
||||
# taking over via --replace, it wrote a marker naming this PID
|
||||
# before sending SIGTERM. If present, treat the signal as a
|
||||
# planned shutdown and exit 0 so systemd's Restart=on-failure
|
||||
# doesn't revive us (which would flap-fight the replacer when
|
||||
# both services are enabled, e.g. hermes.service + hermes-
|
||||
# gateway.service from pre-rename installs).
|
||||
planned_takeover = False
|
||||
try:
|
||||
from gateway.status import consume_takeover_marker_for_self
|
||||
planned_takeover = consume_takeover_marker_for_self()
|
||||
except Exception as e:
|
||||
logger.debug("Takeover marker check failed: %s", e)
|
||||
|
||||
if planned_takeover:
|
||||
logger.info(
|
||||
"Received SIGTERM as a planned --replace takeover — exiting cleanly"
|
||||
)
|
||||
else:
|
||||
_signal_initiated_shutdown = True
|
||||
logger.info("Received SIGTERM/SIGINT — initiating shutdown")
|
||||
# Diagnostic: log all hermes-related processes so we can identify
|
||||
# what triggered the signal (hermes update, hermes gateway restart,
|
||||
# a stale detached subprocess, etc.).
|
||||
|
||||
@@ -802,6 +802,57 @@ class SessionStore:
|
||||
return True
|
||||
return False
|
||||
|
||||
def prune_old_entries(self, max_age_days: int) -> int:
|
||||
"""Drop SessionEntry records older than max_age_days.
|
||||
|
||||
Pruning is based on ``updated_at`` (last activity), not ``created_at``.
|
||||
A session that's been active within the window is kept regardless of
|
||||
how old it is. Entries marked ``suspended`` are kept — the user
|
||||
explicitly paused them for later resume. Entries held by an active
|
||||
process (via has_active_processes_fn) are also kept so long-running
|
||||
background work isn't orphaned.
|
||||
|
||||
Pruning is functionally identical to a natural reset-policy expiry:
|
||||
the transcript in SQLite stays, but the session_key → session_id
|
||||
mapping is dropped and the user starts a fresh session on return.
|
||||
|
||||
``max_age_days <= 0`` disables pruning; returns 0 immediately.
|
||||
Returns the number of entries removed.
|
||||
"""
|
||||
if max_age_days is None or max_age_days <= 0:
|
||||
return 0
|
||||
from datetime import timedelta
|
||||
|
||||
cutoff = _now() - timedelta(days=max_age_days)
|
||||
removed_keys: list[str] = []
|
||||
|
||||
with self._lock:
|
||||
self._ensure_loaded_locked()
|
||||
for key, entry in list(self._entries.items()):
|
||||
if entry.suspended:
|
||||
continue
|
||||
# Never prune sessions with an active background process
|
||||
# attached — the user may still be waiting on output.
|
||||
if self._has_active_processes_fn is not None:
|
||||
try:
|
||||
if self._has_active_processes_fn(entry.session_id):
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
if entry.updated_at < cutoff:
|
||||
removed_keys.append(key)
|
||||
for key in removed_keys:
|
||||
self._entries.pop(key, None)
|
||||
if removed_keys:
|
||||
self._save()
|
||||
|
||||
if removed_keys:
|
||||
logger.info(
|
||||
"SessionStore pruned %d entries older than %d days",
|
||||
len(removed_keys), max_age_days,
|
||||
)
|
||||
return len(removed_keys)
|
||||
|
||||
def suspend_recently_active(self, max_age_seconds: int = 120) -> int:
|
||||
"""Mark recently-active sessions as suspended.
|
||||
|
||||
|
||||
+159
-11
@@ -188,8 +188,8 @@ def _write_json_file(path: Path, payload: dict[str, Any]) -> None:
|
||||
path.write_text(json.dumps(payload))
|
||||
|
||||
|
||||
def _read_pid_record() -> Optional[dict]:
|
||||
pid_path = _get_pid_path()
|
||||
def _read_pid_record(pid_path: Optional[Path] = None) -> Optional[dict]:
|
||||
pid_path = pid_path or _get_pid_path()
|
||||
if not pid_path.exists():
|
||||
return None
|
||||
|
||||
@@ -212,6 +212,18 @@ def _read_pid_record() -> Optional[dict]:
|
||||
return None
|
||||
|
||||
|
||||
def _cleanup_invalid_pid_path(pid_path: Path, *, cleanup_stale: bool) -> None:
|
||||
if not cleanup_stale:
|
||||
return
|
||||
try:
|
||||
if pid_path == _get_pid_path():
|
||||
remove_pid_file()
|
||||
else:
|
||||
pid_path.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def write_pid_file() -> None:
|
||||
"""Write the current process PID and metadata to the gateway PID file."""
|
||||
_write_json_file(_get_pid_path(), _build_pid_record())
|
||||
@@ -413,43 +425,179 @@ def release_all_scoped_locks() -> int:
|
||||
return removed
|
||||
|
||||
|
||||
def get_running_pid() -> Optional[int]:
|
||||
# ── --replace takeover marker ─────────────────────────────────────────
|
||||
#
|
||||
# When a new gateway starts with ``--replace``, it SIGTERMs the existing
|
||||
# gateway so it can take over the bot token. PR #5646 made SIGTERM exit
|
||||
# the gateway with code 1 so ``Restart=on-failure`` can revive it after
|
||||
# unexpected kills — but that also means a --replace takeover target
|
||||
# exits 1, which tricks systemd into reviving it 30 seconds later,
|
||||
# starting a flap loop against the replacer when both services are
|
||||
# enabled in the user's systemd (e.g. ``hermes.service`` + ``hermes-
|
||||
# gateway.service``).
|
||||
#
|
||||
# The takeover marker breaks the loop: the replacer writes a short-lived
|
||||
# file naming the target PID + start_time BEFORE sending SIGTERM.
|
||||
# The target's shutdown handler reads the marker and, if it names
|
||||
# this process, treats the SIGTERM as a planned takeover and exits 0.
|
||||
# The marker is unlinked after the target has consumed it, so a stale
|
||||
# marker left by a crashed replacer can grief at most one future
|
||||
# shutdown on the same PID — and only within _TAKEOVER_MARKER_TTL_S.
|
||||
|
||||
_TAKEOVER_MARKER_FILENAME = ".gateway-takeover.json"
|
||||
_TAKEOVER_MARKER_TTL_S = 60 # Marker older than this is treated as stale
|
||||
|
||||
|
||||
def _get_takeover_marker_path() -> Path:
|
||||
"""Return the path to the --replace takeover marker file."""
|
||||
home = get_hermes_home()
|
||||
return home / _TAKEOVER_MARKER_FILENAME
|
||||
|
||||
|
||||
def write_takeover_marker(target_pid: int) -> bool:
|
||||
"""Record that ``target_pid`` is being replaced by the current process.
|
||||
|
||||
Captures the target's ``start_time`` so that PID reuse after the
|
||||
target exits cannot later match the marker. Also records the
|
||||
replacer's PID and a UTC timestamp for TTL-based staleness checks.
|
||||
|
||||
Returns True on successful write, False on any failure. The caller
|
||||
should proceed with the SIGTERM even if the write fails (the marker
|
||||
is a best-effort signal, not a correctness requirement).
|
||||
"""
|
||||
try:
|
||||
target_start_time = _get_process_start_time(target_pid)
|
||||
record = {
|
||||
"target_pid": target_pid,
|
||||
"target_start_time": target_start_time,
|
||||
"replacer_pid": os.getpid(),
|
||||
"written_at": _utc_now_iso(),
|
||||
}
|
||||
_write_json_file(_get_takeover_marker_path(), record)
|
||||
return True
|
||||
except (OSError, PermissionError):
|
||||
return False
|
||||
|
||||
|
||||
def consume_takeover_marker_for_self() -> bool:
|
||||
"""Check & unlink the takeover marker if it names the current process.
|
||||
|
||||
Returns True only when a valid (non-stale) marker names this PID +
|
||||
start_time. A returning True indicates the current SIGTERM is a
|
||||
planned --replace takeover; the caller should exit 0 instead of
|
||||
signalling ``_signal_initiated_shutdown``.
|
||||
|
||||
Always unlinks the marker on match (and on detected staleness) so
|
||||
subsequent unrelated signals don't re-trigger.
|
||||
"""
|
||||
path = _get_takeover_marker_path()
|
||||
record = _read_json_file(path)
|
||||
if not record:
|
||||
return False
|
||||
|
||||
# Any malformed or stale marker → drop it and return False
|
||||
try:
|
||||
target_pid = int(record["target_pid"])
|
||||
target_start_time = record.get("target_start_time")
|
||||
written_at = record.get("written_at") or ""
|
||||
except (KeyError, TypeError, ValueError):
|
||||
try:
|
||||
path.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
return False
|
||||
|
||||
# TTL guard: a stale marker older than _TAKEOVER_MARKER_TTL_S is ignored.
|
||||
stale = False
|
||||
try:
|
||||
written_dt = datetime.fromisoformat(written_at)
|
||||
age = (datetime.now(timezone.utc) - written_dt).total_seconds()
|
||||
if age > _TAKEOVER_MARKER_TTL_S:
|
||||
stale = True
|
||||
except (TypeError, ValueError):
|
||||
stale = True # Unparseable timestamp — treat as stale
|
||||
|
||||
if stale:
|
||||
try:
|
||||
path.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
return False
|
||||
|
||||
# Does the marker name THIS process?
|
||||
our_pid = os.getpid()
|
||||
our_start_time = _get_process_start_time(our_pid)
|
||||
matches = (
|
||||
target_pid == our_pid
|
||||
and target_start_time is not None
|
||||
and our_start_time is not None
|
||||
and target_start_time == our_start_time
|
||||
)
|
||||
|
||||
# Consume the marker whether it matched or not — a marker that doesn't
|
||||
# match our identity is stale-for-us anyway.
|
||||
try:
|
||||
path.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def clear_takeover_marker() -> None:
|
||||
"""Remove the takeover marker unconditionally. Safe to call repeatedly."""
|
||||
try:
|
||||
_get_takeover_marker_path().unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def get_running_pid(
|
||||
pid_path: Optional[Path] = None,
|
||||
*,
|
||||
cleanup_stale: bool = True,
|
||||
) -> Optional[int]:
|
||||
"""Return the PID of a running gateway instance, or ``None``.
|
||||
|
||||
Checks the PID file and verifies the process is actually alive.
|
||||
Cleans up stale PID files automatically.
|
||||
"""
|
||||
record = _read_pid_record()
|
||||
resolved_pid_path = pid_path or _get_pid_path()
|
||||
record = _read_pid_record(resolved_pid_path)
|
||||
if not record:
|
||||
remove_pid_file()
|
||||
_cleanup_invalid_pid_path(resolved_pid_path, cleanup_stale=cleanup_stale)
|
||||
return None
|
||||
|
||||
try:
|
||||
pid = int(record["pid"])
|
||||
except (KeyError, TypeError, ValueError):
|
||||
remove_pid_file()
|
||||
_cleanup_invalid_pid_path(resolved_pid_path, cleanup_stale=cleanup_stale)
|
||||
return None
|
||||
|
||||
try:
|
||||
os.kill(pid, 0) # signal 0 = existence check, no actual signal sent
|
||||
except (ProcessLookupError, PermissionError):
|
||||
remove_pid_file()
|
||||
_cleanup_invalid_pid_path(resolved_pid_path, cleanup_stale=cleanup_stale)
|
||||
return None
|
||||
|
||||
recorded_start = record.get("start_time")
|
||||
current_start = _get_process_start_time(pid)
|
||||
if recorded_start is not None and current_start is not None and current_start != recorded_start:
|
||||
remove_pid_file()
|
||||
_cleanup_invalid_pid_path(resolved_pid_path, cleanup_stale=cleanup_stale)
|
||||
return None
|
||||
|
||||
if not _looks_like_gateway_process(pid):
|
||||
if not _record_looks_like_gateway(record):
|
||||
remove_pid_file()
|
||||
_cleanup_invalid_pid_path(resolved_pid_path, cleanup_stale=cleanup_stale)
|
||||
return None
|
||||
|
||||
return pid
|
||||
|
||||
|
||||
def is_gateway_running() -> bool:
|
||||
def is_gateway_running(
|
||||
pid_path: Optional[Path] = None,
|
||||
*,
|
||||
cleanup_stale: bool = True,
|
||||
) -> bool:
|
||||
"""Check if the gateway daemon is currently running."""
|
||||
return get_running_pid() is not None
|
||||
return get_running_pid(pid_path, cleanup_stale=cleanup_stale) is not None
|
||||
|
||||
@@ -100,6 +100,14 @@ class GatewayStreamConsumer:
|
||||
self._flood_strikes = 0 # Consecutive flood-control edit failures
|
||||
self._current_edit_interval = self.cfg.edit_interval # Adaptive backoff
|
||||
self._final_response_sent = False
|
||||
# Cache adapter lifecycle capability: only platforms that need an
|
||||
# explicit finalize call (e.g. DingTalk AI Cards) force us to make
|
||||
# a redundant final edit. Everyone else keeps the fast path.
|
||||
# Use ``is True`` (not ``bool(...)``) so MagicMock attribute access
|
||||
# in tests doesn't incorrectly enable this path.
|
||||
self._adapter_requires_finalize: bool = (
|
||||
getattr(adapter, "REQUIRES_EDIT_FINALIZE", False) is True
|
||||
)
|
||||
|
||||
# Think-block filter state (mirrors CLI's _stream_delta tag suppression)
|
||||
self._in_think_block = False
|
||||
@@ -361,7 +369,16 @@ class GatewayStreamConsumer:
|
||||
if not got_done and not got_segment_break and commentary_text is None:
|
||||
display_text += self.cfg.cursor
|
||||
|
||||
current_update_visible = await self._send_or_edit(display_text)
|
||||
# Segment break: finalize the current message so platforms
|
||||
# that need explicit closure (e.g. DingTalk AI Cards) don't
|
||||
# leave the previous segment stuck in a loading state when
|
||||
# the next segment (tool progress, next chunk) creates a
|
||||
# new message below it. got_done has its own finalize
|
||||
# path below so we don't finalize here for it.
|
||||
current_update_visible = await self._send_or_edit(
|
||||
display_text,
|
||||
finalize=got_segment_break,
|
||||
)
|
||||
self._last_edit_time = time.monotonic()
|
||||
|
||||
if got_done:
|
||||
@@ -372,10 +389,22 @@ class GatewayStreamConsumer:
|
||||
if self._accumulated:
|
||||
if self._fallback_final_send:
|
||||
await self._send_fallback_final(self._accumulated)
|
||||
elif current_update_visible:
|
||||
elif (
|
||||
current_update_visible
|
||||
and not self._adapter_requires_finalize
|
||||
):
|
||||
# Mid-stream edit above already delivered the
|
||||
# final accumulated content. Skip the redundant
|
||||
# final edit — but only for adapters that don't
|
||||
# need an explicit finalize signal.
|
||||
self._final_response_sent = True
|
||||
elif self._message_id:
|
||||
self._final_response_sent = await self._send_or_edit(self._accumulated)
|
||||
# Either the mid-stream edit didn't run (no
|
||||
# visible update this tick) OR the adapter needs
|
||||
# explicit finalize=True to close the stream.
|
||||
self._final_response_sent = await self._send_or_edit(
|
||||
self._accumulated, finalize=True,
|
||||
)
|
||||
elif not self._already_sent:
|
||||
self._final_response_sent = await self._send_or_edit(self._accumulated)
|
||||
return
|
||||
@@ -633,12 +662,15 @@ class GatewayStreamConsumer:
|
||||
logger.error("Commentary send error: %s", e)
|
||||
return False
|
||||
|
||||
async def _send_or_edit(self, text: str) -> bool:
|
||||
async def _send_or_edit(self, text: str, *, finalize: bool = False) -> bool:
|
||||
"""Send or edit the streaming message.
|
||||
|
||||
Returns True if the text was successfully delivered (sent or edited),
|
||||
False otherwise. Callers like the overflow split loop use this to
|
||||
decide whether to advance past the delivered chunk.
|
||||
|
||||
``finalize`` is True when this is the last edit in a streaming
|
||||
sequence.
|
||||
"""
|
||||
# Strip MEDIA: directives so they don't appear as visible text.
|
||||
# Media files are delivered as native attachments after the stream
|
||||
@@ -672,14 +704,22 @@ class GatewayStreamConsumer:
|
||||
try:
|
||||
if self._message_id is not None:
|
||||
if self._edit_supported:
|
||||
# Skip if text is identical to what we last sent
|
||||
if text == self._last_sent_text:
|
||||
# Skip if text is identical to what we last sent.
|
||||
# Exception: adapters that require an explicit finalize
|
||||
# call (REQUIRES_EDIT_FINALIZE) must still receive the
|
||||
# finalize=True edit even when content is unchanged, so
|
||||
# their streaming UI can transition out of the in-
|
||||
# progress state. Everyone else short-circuits.
|
||||
if text == self._last_sent_text and not (
|
||||
finalize and self._adapter_requires_finalize
|
||||
):
|
||||
return True
|
||||
# Edit existing message
|
||||
result = await self.adapter.edit_message(
|
||||
chat_id=self.chat_id,
|
||||
message_id=self._message_id,
|
||||
content=text,
|
||||
finalize=finalize,
|
||||
)
|
||||
if result.success:
|
||||
self._already_sent = True
|
||||
|
||||
@@ -233,6 +233,14 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
|
||||
api_key_env_vars=("XAI_API_KEY",),
|
||||
base_url_env_var="XAI_BASE_URL",
|
||||
),
|
||||
"nvidia": ProviderConfig(
|
||||
id="nvidia",
|
||||
name="NVIDIA NIM",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://integrate.api.nvidia.com/v1",
|
||||
api_key_env_vars=("NVIDIA_API_KEY",),
|
||||
base_url_env_var="NVIDIA_BASE_URL",
|
||||
),
|
||||
"ai-gateway": ProviderConfig(
|
||||
id="ai-gateway",
|
||||
name="Vercel AI Gateway",
|
||||
@@ -2151,6 +2159,62 @@ def refresh_nous_oauth_from_state(
|
||||
)
|
||||
|
||||
|
||||
NOUS_DEVICE_CODE_SOURCE = "device_code"
|
||||
|
||||
|
||||
def persist_nous_credentials(
|
||||
creds: Dict[str, Any],
|
||||
*,
|
||||
label: Optional[str] = None,
|
||||
):
|
||||
"""Persist minted Nous OAuth credentials as the singleton provider state
|
||||
and ensure the credential pool is in sync.
|
||||
|
||||
Nous credentials are read at runtime from two independent locations:
|
||||
|
||||
- ``providers.nous``: singleton state read by
|
||||
``resolve_nous_runtime_credentials()`` during 401 recovery and by
|
||||
``_seed_from_singletons()`` during pool load.
|
||||
- ``credential_pool.nous``: used by the runtime ``pool.select()`` path.
|
||||
|
||||
Historically ``hermes auth add nous`` wrote a ``manual:device_code`` pool
|
||||
entry only, skipping ``providers.nous``. When the 24h agent_key TTL
|
||||
expired, the recovery path read the empty singleton state and raised
|
||||
``AuthError`` silently (``logger.debug`` at INFO level).
|
||||
|
||||
This helper writes ``providers.nous`` then calls ``load_pool("nous")`` so
|
||||
``_seed_from_singletons`` materialises the canonical ``device_code`` pool
|
||||
entry from the singleton. Re-running login upserts the same entry in
|
||||
place; the pool never accumulates duplicate device_code rows.
|
||||
|
||||
``label`` is an optional user-chosen display name (from
|
||||
``hermes auth add nous --label <name>``). It gets embedded in the
|
||||
singleton state so that ``_seed_from_singletons`` uses it as the pool
|
||||
entry's label on every subsequent ``load_pool("nous")`` instead of the
|
||||
auto-derived token fingerprint. When ``None``, the auto-derived label
|
||||
via ``label_from_token`` is used (unchanged default behaviour).
|
||||
|
||||
Returns the upserted :class:`PooledCredential` entry (or ``None`` if
|
||||
seeding somehow produced no match — shouldn't happen).
|
||||
"""
|
||||
from agent.credential_pool import load_pool
|
||||
|
||||
state = dict(creds)
|
||||
if label and str(label).strip():
|
||||
state["label"] = str(label).strip()
|
||||
|
||||
with _auth_store_lock():
|
||||
auth_store = _load_auth_store()
|
||||
_save_provider_state(auth_store, "nous", state)
|
||||
_save_auth_store(auth_store)
|
||||
|
||||
pool = load_pool("nous")
|
||||
return next(
|
||||
(e for e in pool.entries() if e.source == NOUS_DEVICE_CODE_SOURCE),
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def resolve_nous_runtime_credentials(
|
||||
*,
|
||||
min_key_ttl_seconds: int = DEFAULT_AGENT_KEY_MIN_TTL_SECONDS,
|
||||
|
||||
@@ -217,19 +217,15 @@ def auth_add_command(args) -> None:
|
||||
ca_bundle=getattr(args, "ca_bundle", None),
|
||||
min_key_ttl_seconds=max(60, int(getattr(args, "min_key_ttl_seconds", 5 * 60))),
|
||||
)
|
||||
label = (getattr(args, "label", None) or "").strip() or label_from_token(
|
||||
creds.get("access_token", ""),
|
||||
_oauth_default_label(provider, len(pool.entries()) + 1),
|
||||
# Honor `--label <name>` so nous matches other providers' UX. The
|
||||
# helper embeds this into providers.nous so that label_from_token
|
||||
# doesn't overwrite it on every subsequent load_pool("nous").
|
||||
custom_label = (getattr(args, "label", None) or "").strip() or None
|
||||
entry = auth_mod.persist_nous_credentials(creds, label=custom_label)
|
||||
shown_label = entry.label if entry is not None else label_from_token(
|
||||
creds.get("access_token", ""), _oauth_default_label(provider, 1),
|
||||
)
|
||||
entry = PooledCredential.from_dict(provider, {
|
||||
**creds,
|
||||
"label": label,
|
||||
"auth_type": AUTH_TYPE_OAUTH,
|
||||
"source": f"{SOURCE_MANUAL}:device_code",
|
||||
"base_url": creds.get("inference_base_url"),
|
||||
})
|
||||
pool.add_entry(entry)
|
||||
print(f'Added {provider} OAuth credential #{len(pool.entries())}: "{entry.label}"')
|
||||
print(f'Saved {provider} OAuth device-code credentials: "{shown_label}"')
|
||||
return
|
||||
|
||||
if provider == "openai-codex":
|
||||
|
||||
+119
-70
@@ -7,8 +7,8 @@ CLI tools that ship with the platform (or are commonly installed).
|
||||
|
||||
Platform support:
|
||||
macOS — osascript (always available), pngpaste (if installed)
|
||||
Windows — PowerShell via .NET System.Windows.Forms.Clipboard
|
||||
WSL2 — powershell.exe via .NET System.Windows.Forms.Clipboard
|
||||
Windows — PowerShell via WinForms, Get-Clipboard, file-drop fallback
|
||||
WSL2 — powershell.exe via WinForms, Get-Clipboard, file-drop fallback
|
||||
Linux — wl-paste (Wayland), xclip (X11)
|
||||
"""
|
||||
|
||||
@@ -46,10 +46,11 @@ def has_clipboard_image() -> bool:
|
||||
return _macos_has_image()
|
||||
if sys.platform == "win32":
|
||||
return _windows_has_image()
|
||||
if _is_wsl():
|
||||
return _wsl_has_image()
|
||||
if os.environ.get("WAYLAND_DISPLAY"):
|
||||
return _wayland_has_image()
|
||||
# Match _linux_save fallthrough order: WSL → Wayland → X11
|
||||
if _is_wsl() and _wsl_has_image():
|
||||
return True
|
||||
if os.environ.get("WAYLAND_DISPLAY") and _wayland_has_image():
|
||||
return True
|
||||
return _xclip_has_image()
|
||||
|
||||
|
||||
@@ -135,6 +136,114 @@ _PS_EXTRACT_IMAGE = (
|
||||
"[System.Convert]::ToBase64String($ms.ToArray())"
|
||||
)
|
||||
|
||||
_PS_CHECK_IMAGE_GET_CLIPBOARD = (
|
||||
"try { "
|
||||
"$img = Get-Clipboard -Format Image -ErrorAction Stop;"
|
||||
"if ($null -ne $img) { 'True' } else { 'False' }"
|
||||
"} catch { 'False' }"
|
||||
)
|
||||
|
||||
_PS_EXTRACT_IMAGE_GET_CLIPBOARD = (
|
||||
"try { "
|
||||
"Add-Type -AssemblyName System.Drawing;"
|
||||
"Add-Type -AssemblyName PresentationCore;"
|
||||
"Add-Type -AssemblyName WindowsBase;"
|
||||
"$img = Get-Clipboard -Format Image -ErrorAction Stop;"
|
||||
"if ($null -eq $img) { exit 1 }"
|
||||
"$ms = New-Object System.IO.MemoryStream;"
|
||||
"if ($img -is [System.Drawing.Image]) {"
|
||||
"$img.Save($ms, [System.Drawing.Imaging.ImageFormat]::Png)"
|
||||
"} elseif ($img -is [System.Windows.Media.Imaging.BitmapSource]) {"
|
||||
"$enc = New-Object System.Windows.Media.Imaging.PngBitmapEncoder;"
|
||||
"$enc.Frames.Add([System.Windows.Media.Imaging.BitmapFrame]::Create($img));"
|
||||
"$enc.Save($ms)"
|
||||
"} else { exit 2 }"
|
||||
"[System.Convert]::ToBase64String($ms.ToArray())"
|
||||
"} catch { exit 1 }"
|
||||
)
|
||||
|
||||
_FILEDROP_IMAGE_EXTS = "'.png','.jpg','.jpeg','.gif','.webp','.bmp','.tiff','.tif'"
|
||||
|
||||
_PS_CHECK_FILEDROP_IMAGE = (
|
||||
"try { "
|
||||
"$files = Get-Clipboard -Format FileDropList -ErrorAction Stop;"
|
||||
f"$exts = @({_FILEDROP_IMAGE_EXTS});"
|
||||
"$hit = $files | Where-Object { $exts -contains ([System.IO.Path]::GetExtension($_).ToLowerInvariant()) } | Select-Object -First 1;"
|
||||
"if ($null -ne $hit) { 'True' } else { 'False' }"
|
||||
"} catch { 'False' }"
|
||||
)
|
||||
|
||||
_PS_EXTRACT_FILEDROP_IMAGE = (
|
||||
"try { "
|
||||
"$files = Get-Clipboard -Format FileDropList -ErrorAction Stop;"
|
||||
f"$exts = @({_FILEDROP_IMAGE_EXTS});"
|
||||
"$hit = $files | Where-Object { $exts -contains ([System.IO.Path]::GetExtension($_).ToLowerInvariant()) } | Select-Object -First 1;"
|
||||
"if ($null -eq $hit) { exit 1 }"
|
||||
"[System.Convert]::ToBase64String([System.IO.File]::ReadAllBytes($hit))"
|
||||
"} catch { exit 1 }"
|
||||
)
|
||||
|
||||
_POWERSHELL_HAS_IMAGE_SCRIPTS = (
|
||||
_PS_CHECK_IMAGE,
|
||||
_PS_CHECK_IMAGE_GET_CLIPBOARD,
|
||||
_PS_CHECK_FILEDROP_IMAGE,
|
||||
)
|
||||
|
||||
_POWERSHELL_EXTRACT_IMAGE_SCRIPTS = (
|
||||
_PS_EXTRACT_IMAGE,
|
||||
_PS_EXTRACT_IMAGE_GET_CLIPBOARD,
|
||||
_PS_EXTRACT_FILEDROP_IMAGE,
|
||||
)
|
||||
|
||||
|
||||
def _run_powershell(exe: str, script: str, timeout: int) -> subprocess.CompletedProcess:
|
||||
return subprocess.run(
|
||||
[exe, "-NoProfile", "-NonInteractive", "-Command", script],
|
||||
capture_output=True, text=True, timeout=timeout,
|
||||
)
|
||||
|
||||
|
||||
def _write_base64_image(dest: Path, b64_data: str) -> bool:
|
||||
image_bytes = base64.b64decode(b64_data, validate=True)
|
||||
dest.write_bytes(image_bytes)
|
||||
return dest.exists() and dest.stat().st_size > 0
|
||||
|
||||
|
||||
def _powershell_has_image(exe: str, *, timeout: int, label: str) -> bool:
|
||||
for script in _POWERSHELL_HAS_IMAGE_SCRIPTS:
|
||||
try:
|
||||
r = _run_powershell(exe, script, timeout=timeout)
|
||||
if r.returncode == 0 and "True" in r.stdout:
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
logger.debug("%s not found — clipboard unavailable", exe)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.debug("%s clipboard image check failed: %s", label, e)
|
||||
return False
|
||||
|
||||
|
||||
def _powershell_save_image(exe: str, dest: Path, *, timeout: int, label: str) -> bool:
|
||||
for script in _POWERSHELL_EXTRACT_IMAGE_SCRIPTS:
|
||||
try:
|
||||
r = _run_powershell(exe, script, timeout=timeout)
|
||||
if r.returncode != 0:
|
||||
continue
|
||||
|
||||
b64_data = r.stdout.strip()
|
||||
if not b64_data:
|
||||
continue
|
||||
|
||||
if _write_base64_image(dest, b64_data):
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
logger.debug("%s not found — clipboard unavailable", exe)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.debug("%s clipboard image extraction failed: %s", label, e)
|
||||
dest.unlink(missing_ok=True)
|
||||
return False
|
||||
|
||||
|
||||
# ── Native Windows ────────────────────────────────────────────────────────
|
||||
|
||||
@@ -175,15 +284,7 @@ def _windows_has_image() -> bool:
|
||||
ps = _get_ps_exe()
|
||||
if ps is None:
|
||||
return False
|
||||
try:
|
||||
r = subprocess.run(
|
||||
[ps, "-NoProfile", "-NonInteractive", "-Command", _PS_CHECK_IMAGE],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
)
|
||||
return r.returncode == 0 and "True" in r.stdout
|
||||
except Exception as e:
|
||||
logger.debug("Windows clipboard image check failed: %s", e)
|
||||
return False
|
||||
return _powershell_has_image(ps, timeout=5, label="Windows")
|
||||
|
||||
|
||||
def _windows_save(dest: Path) -> bool:
|
||||
@@ -192,26 +293,7 @@ def _windows_save(dest: Path) -> bool:
|
||||
if ps is None:
|
||||
logger.debug("No PowerShell found — Windows clipboard image paste unavailable")
|
||||
return False
|
||||
try:
|
||||
r = subprocess.run(
|
||||
[ps, "-NoProfile", "-NonInteractive", "-Command", _PS_EXTRACT_IMAGE],
|
||||
capture_output=True, text=True, timeout=15,
|
||||
)
|
||||
if r.returncode != 0:
|
||||
return False
|
||||
|
||||
b64_data = r.stdout.strip()
|
||||
if not b64_data:
|
||||
return False
|
||||
|
||||
png_bytes = base64.b64decode(b64_data)
|
||||
dest.write_bytes(png_bytes)
|
||||
return dest.exists() and dest.stat().st_size > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.debug("Windows clipboard image extraction failed: %s", e)
|
||||
dest.unlink(missing_ok=True)
|
||||
return False
|
||||
return _powershell_save_image(ps, dest, timeout=15, label="Windows")
|
||||
|
||||
|
||||
# ── Linux ────────────────────────────────────────────────────────────────
|
||||
@@ -235,45 +317,12 @@ def _linux_save(dest: Path) -> bool:
|
||||
|
||||
def _wsl_has_image() -> bool:
|
||||
"""Check if Windows clipboard has an image (via powershell.exe)."""
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["powershell.exe", "-NoProfile", "-NonInteractive", "-Command",
|
||||
_PS_CHECK_IMAGE],
|
||||
capture_output=True, text=True, timeout=8,
|
||||
)
|
||||
return r.returncode == 0 and "True" in r.stdout
|
||||
except FileNotFoundError:
|
||||
logger.debug("powershell.exe not found — WSL clipboard unavailable")
|
||||
except Exception as e:
|
||||
logger.debug("WSL clipboard check failed: %s", e)
|
||||
return False
|
||||
return _powershell_has_image("powershell.exe", timeout=8, label="WSL")
|
||||
|
||||
|
||||
def _wsl_save(dest: Path) -> bool:
|
||||
"""Extract clipboard image via powershell.exe → base64 → decode to PNG."""
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["powershell.exe", "-NoProfile", "-NonInteractive", "-Command",
|
||||
_PS_EXTRACT_IMAGE],
|
||||
capture_output=True, text=True, timeout=15,
|
||||
)
|
||||
if r.returncode != 0:
|
||||
return False
|
||||
|
||||
b64_data = r.stdout.strip()
|
||||
if not b64_data:
|
||||
return False
|
||||
|
||||
png_bytes = base64.b64decode(b64_data)
|
||||
dest.write_bytes(png_bytes)
|
||||
return dest.exists() and dest.stat().st_size > 0
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.debug("powershell.exe not found — WSL clipboard unavailable")
|
||||
except Exception as e:
|
||||
logger.debug("WSL clipboard extraction failed: %s", e)
|
||||
dest.unlink(missing_ok=True)
|
||||
return False
|
||||
return _powershell_save_image("powershell.exe", dest, timeout=15, label="WSL")
|
||||
|
||||
|
||||
# ── Wayland (wl-paste) ──────────────────────────────────────────────────
|
||||
|
||||
+95
-7
@@ -87,8 +87,12 @@ COMMAND_REGISTRY: list[CommandDef] = [
|
||||
aliases=("bg",), args_hint="<prompt>"),
|
||||
CommandDef("btw", "Ephemeral side question using session context (no tools, not persisted)", "Session",
|
||||
args_hint="<question>"),
|
||||
CommandDef("agents", "Show active agents and running tasks", "Session",
|
||||
aliases=("tasks",)),
|
||||
CommandDef("queue", "Queue a prompt for the next turn (doesn't interrupt)", "Session",
|
||||
aliases=("q",), args_hint="<prompt>"),
|
||||
CommandDef("steer", "Inject a message after the next tool call without interrupting", "Session",
|
||||
args_hint="<prompt>"),
|
||||
CommandDef("status", "Show session info", "Session"),
|
||||
CommandDef("profile", "Show active profile name and home directory", "Info"),
|
||||
CommandDef("sethome", "Set this chat as the home channel", "Session",
|
||||
@@ -99,7 +103,7 @@ COMMAND_REGISTRY: list[CommandDef] = [
|
||||
# Configuration
|
||||
CommandDef("config", "Show current configuration", "Configuration",
|
||||
cli_only=True),
|
||||
CommandDef("model", "Switch model for this session", "Configuration", args_hint="[model] [--global]"),
|
||||
CommandDef("model", "Switch model for this session", "Configuration", args_hint="[model] [--provider name] [--global]"),
|
||||
CommandDef("provider", "Show available providers and current provider",
|
||||
"Configuration"),
|
||||
CommandDef("gquota", "Show Google Gemini Code Assist quota usage", "Info"),
|
||||
@@ -120,7 +124,7 @@ COMMAND_REGISTRY: list[CommandDef] = [
|
||||
args_hint="[normal|fast|status]",
|
||||
subcommands=("normal", "fast", "status", "on", "off")),
|
||||
CommandDef("skin", "Show or change the display skin/theme", "Configuration",
|
||||
cli_only=True, args_hint="[name]"),
|
||||
args_hint="[name]"),
|
||||
CommandDef("voice", "Toggle voice mode", "Configuration",
|
||||
args_hint="[on|off|tts|status]", subcommands=("on", "off", "tts", "status")),
|
||||
|
||||
@@ -155,7 +159,9 @@ COMMAND_REGISTRY: list[CommandDef] = [
|
||||
args_hint="[days]"),
|
||||
CommandDef("platforms", "Show gateway/messaging platform status", "Info",
|
||||
cli_only=True, aliases=("gateway",)),
|
||||
CommandDef("paste", "Check clipboard for an image and attach it", "Info",
|
||||
CommandDef("copy", "Copy the last assistant response to clipboard", "Info",
|
||||
cli_only=True, args_hint="[number]"),
|
||||
CommandDef("paste", "Attach clipboard image from your clipboard", "Info",
|
||||
cli_only=True),
|
||||
CommandDef("image", "Attach a local image file for your next prompt", "Info",
|
||||
cli_only=True, args_hint="<path>"),
|
||||
@@ -254,6 +260,36 @@ GATEWAY_KNOWN_COMMANDS: frozenset[str] = frozenset(
|
||||
)
|
||||
|
||||
|
||||
# Commands that must never be queued behind an active gateway session.
|
||||
# These are explicit control/info commands handled by the gateway itself;
|
||||
# if they get queued as pending text, the safety net in gateway.run will
|
||||
# discard them before they ever reach the user.
|
||||
ACTIVE_SESSION_BYPASS_COMMANDS: frozenset[str] = frozenset(
|
||||
{
|
||||
"agents",
|
||||
"approve",
|
||||
"background",
|
||||
"commands",
|
||||
"deny",
|
||||
"help",
|
||||
"new",
|
||||
"profile",
|
||||
"queue",
|
||||
"restart",
|
||||
"status",
|
||||
"steer",
|
||||
"stop",
|
||||
"update",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def should_bypass_active_session(command_name: str | None) -> bool:
|
||||
"""Return True when a slash command must bypass active-session queuing."""
|
||||
cmd = resolve_command(command_name) if command_name else None
|
||||
return bool(cmd and cmd.name in ACTIVE_SESSION_BYPASS_COMMANDS)
|
||||
|
||||
|
||||
def _resolve_config_gates() -> set[str]:
|
||||
"""Return canonical names of commands whose ``gateway_config_gate`` is truthy.
|
||||
|
||||
@@ -1044,6 +1080,51 @@ class SlashCommandCompleter(Completer):
|
||||
display_meta=f"{fp} {meta}" if meta else fp,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _skin_completions(sub_text: str, sub_lower: str):
|
||||
"""Yield completions for /skin from available skins."""
|
||||
try:
|
||||
from hermes_cli.skin_engine import list_skins
|
||||
for s in list_skins():
|
||||
name = s["name"]
|
||||
if name.startswith(sub_lower) and name != sub_lower:
|
||||
yield Completion(
|
||||
name,
|
||||
start_position=-len(sub_text),
|
||||
display=name,
|
||||
display_meta=s.get("description", "") or s.get("source", ""),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _personality_completions(sub_text: str, sub_lower: str):
|
||||
"""Yield completions for /personality from configured personalities."""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
personalities = load_config().get("agent", {}).get("personalities", {})
|
||||
if "none".startswith(sub_lower) and "none" != sub_lower:
|
||||
yield Completion(
|
||||
"none",
|
||||
start_position=-len(sub_text),
|
||||
display="none",
|
||||
display_meta="clear personality overlay",
|
||||
)
|
||||
for name, prompt in personalities.items():
|
||||
if name.startswith(sub_lower) and name != sub_lower:
|
||||
if isinstance(prompt, dict):
|
||||
meta = prompt.get("description") or prompt.get("system_prompt", "")[:50]
|
||||
else:
|
||||
meta = str(prompt)[:50]
|
||||
yield Completion(
|
||||
name,
|
||||
start_position=-len(sub_text),
|
||||
display=name,
|
||||
display_meta=meta,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _model_completions(self, sub_text: str, sub_lower: str):
|
||||
"""Yield completions for /model from config aliases + built-in aliases."""
|
||||
seen = set()
|
||||
@@ -1098,10 +1179,17 @@ class SlashCommandCompleter(Completer):
|
||||
sub_text = parts[1] if len(parts) > 1 else ""
|
||||
sub_lower = sub_text.lower()
|
||||
|
||||
# Dynamic model alias completions for /model
|
||||
if " " not in sub_text and base_cmd == "/model":
|
||||
yield from self._model_completions(sub_text, sub_lower)
|
||||
return
|
||||
# Dynamic completions for commands with runtime lists
|
||||
if " " not in sub_text:
|
||||
if base_cmd == "/model":
|
||||
yield from self._model_completions(sub_text, sub_lower)
|
||||
return
|
||||
if base_cmd == "/skin":
|
||||
yield from self._skin_completions(sub_text, sub_lower)
|
||||
return
|
||||
if base_cmd == "/personality":
|
||||
yield from self._personality_completions(sub_text, sub_lower)
|
||||
return
|
||||
|
||||
# Static subcommand completions
|
||||
if " " not in sub_text and base_cmd in SUBCOMMANDS and self._command_allowed(base_cmd):
|
||||
|
||||
+139
-8
@@ -12,6 +12,7 @@ This module provides:
|
||||
- hermes config wizard - Re-run setup wizard
|
||||
"""
|
||||
|
||||
import copy
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
@@ -26,6 +27,7 @@ from typing import Dict, Any, Optional, List, Tuple
|
||||
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
_ENV_VAR_NAME_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
||||
_LAST_EXPANDED_CONFIG_BY_PATH: Dict[str, Any] = {}
|
||||
# Env var names written to .env that aren't in OPTIONAL_ENV_VARS
|
||||
# (managed by setup/provider flows directly).
|
||||
_EXTRA_ENV_KEYS = frozenset({
|
||||
@@ -44,7 +46,8 @@ _EXTRA_ENV_KEYS = frozenset({
|
||||
"WEIXIN_HOME_CHANNEL", "WEIXIN_HOME_CHANNEL_NAME", "WEIXIN_DM_POLICY", "WEIXIN_GROUP_POLICY",
|
||||
"WEIXIN_ALLOWED_USERS", "WEIXIN_GROUP_ALLOWED_USERS", "WEIXIN_ALLOW_ALL_USERS",
|
||||
"BLUEBUBBLES_SERVER_URL", "BLUEBUBBLES_PASSWORD",
|
||||
"QQ_APP_ID", "QQ_CLIENT_SECRET", "QQ_HOME_CHANNEL", "QQ_HOME_CHANNEL_NAME",
|
||||
"QQ_APP_ID", "QQ_CLIENT_SECRET", "QQBOT_HOME_CHANNEL", "QQBOT_HOME_CHANNEL_NAME",
|
||||
"QQ_HOME_CHANNEL", "QQ_HOME_CHANNEL_NAME", # legacy aliases (pre-rename, still read for back-compat)
|
||||
"QQ_ALLOWED_USERS", "QQ_GROUP_ALLOWED_USERS", "QQ_ALLOW_ALL_USERS", "QQ_MARKDOWN_SUPPORT",
|
||||
"QQ_STT_API_KEY", "QQ_STT_BASE_URL", "QQ_STT_MODEL",
|
||||
"TERMINAL_ENV", "TERMINAL_SSH_KEY", "TERMINAL_SSH_PORT",
|
||||
@@ -417,6 +420,7 @@ DEFAULT_CONFIG = {
|
||||
"command_timeout": 30, # Timeout for browser commands in seconds (screenshot, navigate, etc.)
|
||||
"record_sessions": False, # Auto-record browser sessions as WebM videos
|
||||
"allow_private_urls": False, # Allow navigating to private/internal IPs (localhost, 192.168.x.x, etc.)
|
||||
"cdp_url": "", # Optional persistent CDP endpoint for attaching to an existing Chromium/Chrome
|
||||
"camofox": {
|
||||
# When true, Hermes sends a stable profile-scoped userId to Camofox
|
||||
# so the server maps it to a persistent Firefox profile automatically.
|
||||
@@ -537,6 +541,13 @@ DEFAULT_CONFIG = {
|
||||
"api_key": "",
|
||||
"timeout": 30,
|
||||
},
|
||||
"title_generation": {
|
||||
"provider": "auto",
|
||||
"model": "",
|
||||
"base_url": "",
|
||||
"api_key": "",
|
||||
"timeout": 30,
|
||||
},
|
||||
},
|
||||
|
||||
"display": {
|
||||
@@ -760,6 +771,20 @@ DEFAULT_CONFIG = {
|
||||
"wrap_response": True,
|
||||
},
|
||||
|
||||
# execute_code settings — controls the tool used for programmatic tool calls.
|
||||
"code_execution": {
|
||||
# Execution mode:
|
||||
# project (default) — scripts run in the session's working directory
|
||||
# with the active virtualenv/conda env's python, so project deps
|
||||
# (pandas, torch, project packages) and relative paths resolve.
|
||||
# strict — scripts run in an isolated temp directory with
|
||||
# hermes-agent's own python (sys.executable). Maximum isolation
|
||||
# and reproducibility; project deps and relative paths won't work.
|
||||
# Env scrubbing (strips *_API_KEY, *_TOKEN, *_SECRET, ...) and the
|
||||
# tool whitelist apply identically in both modes.
|
||||
"mode": "project",
|
||||
},
|
||||
|
||||
# Logging — controls file logging to ~/.hermes/logs/.
|
||||
# agent.log captures INFO+ (all agent activity); errors.log captures WARNING+.
|
||||
"logging": {
|
||||
@@ -777,7 +802,7 @@ DEFAULT_CONFIG = {
|
||||
},
|
||||
|
||||
# Config schema version - bump this when adding new required fields
|
||||
"_config_version": 18,
|
||||
"_config_version": 19,
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
@@ -861,6 +886,22 @@ OPTIONAL_ENV_VARS = {
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"NVIDIA_API_KEY": {
|
||||
"description": "NVIDIA NIM API key (build.nvidia.com or local NIM endpoint)",
|
||||
"prompt": "NVIDIA NIM API key",
|
||||
"url": "https://build.nvidia.com/",
|
||||
"password": True,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"NVIDIA_BASE_URL": {
|
||||
"description": "NVIDIA NIM base URL override (e.g. http://localhost:8000/v1 for local NIM)",
|
||||
"prompt": "NVIDIA NIM base URL (leave empty for default)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"GLM_API_KEY": {
|
||||
"description": "Z.AI / GLM API key (also recognized as ZAI_API_KEY / Z_AI_API_KEY)",
|
||||
"prompt": "Z.AI / GLM API key",
|
||||
@@ -1518,12 +1559,12 @@ OPTIONAL_ENV_VARS = {
|
||||
"prompt": "Allow All QQ Users",
|
||||
"category": "messaging",
|
||||
},
|
||||
"QQ_HOME_CHANNEL": {
|
||||
"QQBOT_HOME_CHANNEL": {
|
||||
"description": "Default QQ channel/group for cron delivery and notifications",
|
||||
"prompt": "QQ Home Channel",
|
||||
"category": "messaging",
|
||||
},
|
||||
"QQ_HOME_CHANNEL_NAME": {
|
||||
"QQBOT_HOME_CHANNEL_NAME": {
|
||||
"description": "Display name for the QQ home channel",
|
||||
"prompt": "QQ Home Channel Name",
|
||||
"category": "messaging",
|
||||
@@ -2610,6 +2651,85 @@ def _expand_env_vars(obj):
|
||||
return obj
|
||||
|
||||
|
||||
def _items_by_unique_name(items):
|
||||
"""Return a name-indexed dict only when all items have unique string names."""
|
||||
if not isinstance(items, list):
|
||||
return None
|
||||
indexed = {}
|
||||
for item in items:
|
||||
if not isinstance(item, dict) or not isinstance(item.get("name"), str):
|
||||
return None
|
||||
name = item["name"]
|
||||
if name in indexed:
|
||||
return None
|
||||
indexed[name] = item
|
||||
return indexed
|
||||
|
||||
|
||||
def _preserve_env_ref_templates(current, raw, loaded_expanded=None):
|
||||
"""Restore raw ``${VAR}`` templates when a value is otherwise unchanged.
|
||||
|
||||
``load_config()`` expands env refs for runtime use. When a caller later
|
||||
persists that config after modifying some unrelated setting, keep the
|
||||
original on-disk template instead of writing the expanded plaintext
|
||||
secret back to ``config.yaml``.
|
||||
|
||||
Prefer preserving the raw template when ``current`` still matches either
|
||||
the value previously returned by ``load_config()`` for this config path or
|
||||
the current environment expansion of ``raw``. This handles env-var
|
||||
rotation between load and save while still treating mixed literal/template
|
||||
string edits as caller-owned once their rendered value diverges.
|
||||
"""
|
||||
if isinstance(current, str) and isinstance(raw, str) and re.search(r"\${[^}]+}", raw):
|
||||
if current == raw:
|
||||
return raw
|
||||
if isinstance(loaded_expanded, str) and current == loaded_expanded:
|
||||
return raw
|
||||
if _expand_env_vars(raw) == current:
|
||||
return raw
|
||||
return current
|
||||
|
||||
if isinstance(current, dict) and isinstance(raw, dict):
|
||||
return {
|
||||
key: _preserve_env_ref_templates(
|
||||
value,
|
||||
raw.get(key),
|
||||
loaded_expanded.get(key) if isinstance(loaded_expanded, dict) else None,
|
||||
)
|
||||
for key, value in current.items()
|
||||
}
|
||||
|
||||
if isinstance(current, list) and isinstance(raw, list):
|
||||
# Prefer matching named config objects (e.g. custom_providers) by name
|
||||
# so harmless reordering doesn't drop the original template. If names
|
||||
# are duplicated, fall back to positional matching instead of silently
|
||||
# shadowing one entry.
|
||||
current_by_name = _items_by_unique_name(current)
|
||||
raw_by_name = _items_by_unique_name(raw)
|
||||
loaded_by_name = _items_by_unique_name(loaded_expanded)
|
||||
if current_by_name is not None and raw_by_name is not None:
|
||||
return [
|
||||
_preserve_env_ref_templates(
|
||||
item,
|
||||
raw_by_name.get(item.get("name")),
|
||||
loaded_by_name.get(item.get("name")) if loaded_by_name is not None else None,
|
||||
)
|
||||
for item in current
|
||||
]
|
||||
return [
|
||||
_preserve_env_ref_templates(
|
||||
item,
|
||||
raw[index] if index < len(raw) else None,
|
||||
loaded_expanded[index]
|
||||
if isinstance(loaded_expanded, list) and index < len(loaded_expanded)
|
||||
else None,
|
||||
)
|
||||
for index, item in enumerate(current)
|
||||
]
|
||||
|
||||
return current
|
||||
|
||||
|
||||
def _normalize_root_model_keys(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Move stale root-level provider/base_url into model section.
|
||||
|
||||
@@ -2677,7 +2797,6 @@ def read_raw_config() -> Dict[str, Any]:
|
||||
|
||||
def load_config() -> Dict[str, Any]:
|
||||
"""Load configuration from ~/.hermes/config.yaml."""
|
||||
import copy
|
||||
ensure_hermes_home()
|
||||
config_path = get_config_path()
|
||||
|
||||
@@ -2698,8 +2817,11 @@ def load_config() -> Dict[str, Any]:
|
||||
config = _deep_merge(config, user_config)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load config: {e}")
|
||||
|
||||
return _expand_env_vars(_normalize_root_model_keys(_normalize_max_turns_config(config)))
|
||||
|
||||
normalized = _normalize_root_model_keys(_normalize_max_turns_config(config))
|
||||
expanded = _expand_env_vars(normalized)
|
||||
_LAST_EXPANDED_CONFIG_BY_PATH[str(config_path)] = copy.deepcopy(expanded)
|
||||
return expanded
|
||||
|
||||
|
||||
_SECURITY_COMMENT = """
|
||||
@@ -2808,7 +2930,15 @@ def save_config(config: Dict[str, Any]):
|
||||
|
||||
ensure_hermes_home()
|
||||
config_path = get_config_path()
|
||||
normalized = _normalize_root_model_keys(_normalize_max_turns_config(config))
|
||||
current_normalized = _normalize_root_model_keys(_normalize_max_turns_config(config))
|
||||
normalized = current_normalized
|
||||
raw_existing = _normalize_root_model_keys(_normalize_max_turns_config(read_raw_config()))
|
||||
if raw_existing:
|
||||
normalized = _preserve_env_ref_templates(
|
||||
normalized,
|
||||
raw_existing,
|
||||
_LAST_EXPANDED_CONFIG_BY_PATH.get(str(config_path)),
|
||||
)
|
||||
|
||||
# Build optional commented-out sections for features that are off by
|
||||
# default or only relevant when explicitly configured.
|
||||
@@ -2826,6 +2956,7 @@ def save_config(config: Dict[str, Any]):
|
||||
extra_content="".join(parts) if parts else None,
|
||||
)
|
||||
_secure_file(config_path)
|
||||
_LAST_EXPANDED_CONFIG_BY_PATH[str(config_path)] = copy.deepcopy(current_normalized)
|
||||
|
||||
|
||||
def load_env() -> Dict[str, str]:
|
||||
|
||||
+137
-29
@@ -6,7 +6,10 @@ Currently supports:
|
||||
"""
|
||||
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import urllib.error
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
@@ -31,6 +34,119 @@ _MAX_LOG_BYTES = 512_000
|
||||
_AUTO_DELETE_SECONDS = 21600
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pending-deletion tracking (replaces the old fork-and-sleep subprocess).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _pending_file() -> Path:
|
||||
"""Path to ``~/.hermes/pastes/pending.json``.
|
||||
|
||||
Each entry: ``{"url": "...", "expire_at": <unix_ts>}``. Scheduled
|
||||
DELETEs used to be handled by spawning a detached Python process per
|
||||
paste that slept for 6 hours; those accumulated forever if the user
|
||||
ran ``hermes debug share`` repeatedly. We now persist the schedule
|
||||
to disk and sweep expired entries on the next debug invocation.
|
||||
"""
|
||||
return get_hermes_home() / "pastes" / "pending.json"
|
||||
|
||||
|
||||
def _load_pending() -> list[dict]:
|
||||
path = _pending_file()
|
||||
if not path.exists():
|
||||
return []
|
||||
try:
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
if isinstance(data, list):
|
||||
# Filter to well-formed entries only
|
||||
return [
|
||||
e for e in data
|
||||
if isinstance(e, dict) and "url" in e and "expire_at" in e
|
||||
]
|
||||
except (OSError, ValueError, json.JSONDecodeError):
|
||||
pass
|
||||
return []
|
||||
|
||||
|
||||
def _save_pending(entries: list[dict]) -> None:
|
||||
path = _pending_file()
|
||||
try:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = path.with_suffix(".json.tmp")
|
||||
tmp.write_text(json.dumps(entries, indent=2), encoding="utf-8")
|
||||
os.replace(tmp, path)
|
||||
except OSError:
|
||||
# Non-fatal — worst case the user has to run ``hermes debug delete``
|
||||
# manually.
|
||||
pass
|
||||
|
||||
|
||||
def _record_pending(urls: list[str], delay_seconds: int = _AUTO_DELETE_SECONDS) -> None:
|
||||
"""Record *urls* for deletion at ``now + delay_seconds``.
|
||||
|
||||
Only paste.rs URLs are recorded (dpaste.com auto-expires). Entries
|
||||
are merged into any existing pending.json.
|
||||
"""
|
||||
paste_rs_urls = [u for u in urls if _extract_paste_id(u)]
|
||||
if not paste_rs_urls:
|
||||
return
|
||||
|
||||
entries = _load_pending()
|
||||
# Dedupe by URL: keep the later expire_at if same URL appears twice
|
||||
by_url: dict[str, float] = {e["url"]: float(e["expire_at"]) for e in entries}
|
||||
expire_at = time.time() + delay_seconds
|
||||
for u in paste_rs_urls:
|
||||
by_url[u] = max(expire_at, by_url.get(u, 0.0))
|
||||
merged = [{"url": u, "expire_at": ts} for u, ts in by_url.items()]
|
||||
_save_pending(merged)
|
||||
|
||||
|
||||
def _sweep_expired_pastes(now: Optional[float] = None) -> tuple[int, int]:
|
||||
"""Synchronously DELETE any pending pastes whose ``expire_at`` has passed.
|
||||
|
||||
Returns ``(deleted, remaining)``. Best-effort: failed deletes stay in
|
||||
the pending file and will be retried on the next sweep. Silent —
|
||||
intended to be called from every ``hermes debug`` invocation with
|
||||
minimal noise.
|
||||
"""
|
||||
entries = _load_pending()
|
||||
if not entries:
|
||||
return (0, 0)
|
||||
|
||||
current = time.time() if now is None else now
|
||||
deleted = 0
|
||||
remaining: list[dict] = []
|
||||
|
||||
for entry in entries:
|
||||
try:
|
||||
expire_at = float(entry.get("expire_at", 0))
|
||||
except (TypeError, ValueError):
|
||||
continue # drop malformed entries
|
||||
if expire_at > current:
|
||||
remaining.append(entry)
|
||||
continue
|
||||
|
||||
url = entry.get("url", "")
|
||||
try:
|
||||
if delete_paste(url):
|
||||
deleted += 1
|
||||
continue
|
||||
except Exception:
|
||||
# Network hiccup, 404 (already gone), etc. — drop the entry
|
||||
# after a grace period; don't retry forever.
|
||||
pass
|
||||
|
||||
# Retain failed deletes for up to 24h past expiration, then give up.
|
||||
if expire_at + 86400 > current:
|
||||
remaining.append(entry)
|
||||
else:
|
||||
deleted += 1 # count as reaped (paste.rs will GC eventually)
|
||||
|
||||
if deleted:
|
||||
_save_pending(remaining)
|
||||
|
||||
return (deleted, len(remaining))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Privacy / delete helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -90,37 +206,19 @@ def delete_paste(url: str) -> bool:
|
||||
|
||||
|
||||
def _schedule_auto_delete(urls: list[str], delay_seconds: int = _AUTO_DELETE_SECONDS):
|
||||
"""Spawn a detached process to delete paste.rs pastes after *delay_seconds*.
|
||||
"""Record *urls* for deletion ``delay_seconds`` from now.
|
||||
|
||||
The child process is fully detached (``start_new_session=True``) so it
|
||||
survives the parent exiting (important for CLI mode). Only paste.rs
|
||||
URLs are attempted — dpaste.com pastes auto-expire on their own.
|
||||
Previously this spawned a detached Python subprocess per call that slept
|
||||
for 6 hours and then issued DELETE requests. Those subprocesses leaked —
|
||||
every ``hermes debug share`` invocation added ~20 MB of resident Python
|
||||
interpreters that never exited until the sleep completed.
|
||||
|
||||
The replacement is stateless: we append to ``~/.hermes/pastes/pending.json``
|
||||
and rely on opportunistic sweeps (``_sweep_expired_pastes``) called from
|
||||
every ``hermes debug`` invocation. If the user never runs ``hermes debug``
|
||||
again, paste.rs's own retention policy handles cleanup.
|
||||
"""
|
||||
import subprocess
|
||||
|
||||
paste_rs_urls = [u for u in urls if _extract_paste_id(u)]
|
||||
if not paste_rs_urls:
|
||||
return
|
||||
|
||||
# Build a tiny inline Python script. No imports beyond stdlib.
|
||||
url_list = ", ".join(f'"{u}"' for u in paste_rs_urls)
|
||||
script = (
|
||||
"import time, urllib.request; "
|
||||
f"time.sleep({delay_seconds}); "
|
||||
f"[urllib.request.urlopen(urllib.request.Request(u, method='DELETE', "
|
||||
f"headers={{'User-Agent': 'hermes-agent/auto-delete'}}), timeout=15) "
|
||||
f"for u in [{url_list}]]"
|
||||
)
|
||||
|
||||
try:
|
||||
subprocess.Popen(
|
||||
[sys.executable, "-c", script],
|
||||
start_new_session=True,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
except Exception:
|
||||
pass # Best-effort; manual delete still available.
|
||||
_record_pending(urls, delay_seconds=delay_seconds)
|
||||
|
||||
|
||||
def _delete_hint(url: str) -> str:
|
||||
@@ -455,6 +553,16 @@ def run_debug_delete(args):
|
||||
|
||||
def run_debug(args):
|
||||
"""Route debug subcommands."""
|
||||
# Opportunistic sweep of expired pastes on every ``hermes debug`` call.
|
||||
# Replaces the old per-paste sleeping subprocess that used to leak as
|
||||
# one orphaned Python interpreter per scheduled deletion. Silent and
|
||||
# best-effort — any failure is swallowed so ``hermes debug`` stays
|
||||
# reliable even when offline.
|
||||
try:
|
||||
_sweep_expired_pastes()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
subcmd = getattr(args, "debug_command", None)
|
||||
if subcmd == "share":
|
||||
run_debug_share(args)
|
||||
|
||||
@@ -825,6 +825,7 @@ def run_doctor(args):
|
||||
("Arcee AI", ("ARCEEAI_API_KEY",), "https://api.arcee.ai/api/v1/models", "ARCEE_BASE_URL", True),
|
||||
("DeepSeek", ("DEEPSEEK_API_KEY",), "https://api.deepseek.com/v1/models", "DEEPSEEK_BASE_URL", True),
|
||||
("Hugging Face", ("HF_TOKEN",), "https://router.huggingface.co/v1/models", "HF_BASE_URL", True),
|
||||
("NVIDIA NIM", ("NVIDIA_API_KEY",), "https://integrate.api.nvidia.com/v1/models", "NVIDIA_BASE_URL", True),
|
||||
("Alibaba/DashScope", ("DASHSCOPE_API_KEY",), "https://dashscope-intl.aliyuncs.com/compatible-mode/v1/models", "DASHSCOPE_BASE_URL", True),
|
||||
# MiniMax: the /anthropic endpoint doesn't support /models, but the /v1 endpoint does.
|
||||
("MiniMax", ("MINIMAX_API_KEY",), "https://api.minimax.io/v1/models", "MINIMAX_BASE_URL", True),
|
||||
@@ -894,8 +895,8 @@ def run_doctor(args):
|
||||
_model_count = len(_br_resp.get("modelSummaries", []))
|
||||
print(f"\r {color('✓', Colors.GREEN)} {_label} {color(f'({_auth_var}, {_region}, {_model_count} models)', Colors.DIM)} ")
|
||||
except ImportError:
|
||||
print(f"\r {color('⚠', Colors.YELLOW)} {_label} {color('(boto3 not installed — pip install hermes-agent[bedrock])', Colors.DIM)} ")
|
||||
issues.append("Install boto3 for Bedrock: pip install hermes-agent[bedrock]")
|
||||
print(f"\r {color('⚠', Colors.YELLOW)} {_label} {color(f'(boto3 not installed — {sys.executable} -m pip install boto3)', Colors.DIM)} ")
|
||||
issues.append(f"Install boto3 for Bedrock: {sys.executable} -m pip install boto3")
|
||||
except Exception as _e:
|
||||
_err_name = type(_e).__name__
|
||||
print(f"\r {color('⚠', Colors.YELLOW)} {_label} {color(f'({_err_name}: {_e})', Colors.DIM)} ")
|
||||
|
||||
+15
-35
@@ -43,41 +43,20 @@ def _redact(value: str) -> str:
|
||||
|
||||
def _gateway_status() -> str:
|
||||
"""Return a short gateway status string."""
|
||||
if sys.platform.startswith("linux"):
|
||||
from hermes_constants import is_container
|
||||
if is_container():
|
||||
try:
|
||||
from hermes_cli.gateway import find_gateway_pids
|
||||
pids = find_gateway_pids()
|
||||
if pids:
|
||||
return f"running (docker, pid {pids[0]})"
|
||||
return "stopped (docker)"
|
||||
except Exception:
|
||||
return "stopped (docker)"
|
||||
try:
|
||||
from hermes_cli.gateway import get_service_name
|
||||
svc = get_service_name()
|
||||
except Exception:
|
||||
svc = "hermes-gateway"
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["systemctl", "--user", "is-active", svc],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
)
|
||||
return "running (systemd)" if r.stdout.strip() == "active" else "stopped"
|
||||
except Exception:
|
||||
return "unknown"
|
||||
elif sys.platform == "darwin":
|
||||
try:
|
||||
from hermes_cli.gateway import get_launchd_label
|
||||
r = subprocess.run(
|
||||
["launchctl", "list", get_launchd_label()],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
)
|
||||
return "loaded (launchd)" if r.returncode == 0 else "not loaded"
|
||||
except Exception:
|
||||
return "unknown"
|
||||
return "N/A"
|
||||
try:
|
||||
from hermes_cli.gateway import get_gateway_runtime_snapshot
|
||||
|
||||
snapshot = get_gateway_runtime_snapshot()
|
||||
if snapshot.running:
|
||||
mode = snapshot.manager
|
||||
if snapshot.has_process_service_mismatch:
|
||||
mode = "manual"
|
||||
return f"running ({mode}, pid {snapshot.gateway_pids[0]})"
|
||||
if snapshot.service_installed and not snapshot.service_running:
|
||||
return f"stopped ({snapshot.manager})"
|
||||
return f"stopped ({snapshot.manager})"
|
||||
except Exception:
|
||||
return "unknown" if sys.platform.startswith(("linux", "darwin")) else "N/A"
|
||||
|
||||
|
||||
def _count_skills(hermes_home: Path) -> int:
|
||||
@@ -296,6 +275,7 @@ def run_dump(args):
|
||||
("DEEPSEEK_API_KEY", "deepseek"),
|
||||
("DASHSCOPE_API_KEY", "dashscope"),
|
||||
("HF_TOKEN", "huggingface"),
|
||||
("NVIDIA_API_KEY", "nvidia"),
|
||||
("AI_GATEWAY_API_KEY", "ai_gateway"),
|
||||
("OPENCODE_ZEN_API_KEY", "opencode_zen"),
|
||||
("OPENCODE_GO_API_KEY", "opencode_go"),
|
||||
|
||||
+634
-32
@@ -10,6 +10,7 @@ import shutil
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
||||
@@ -41,6 +42,23 @@ from hermes_cli.colors import Colors, color
|
||||
# Process Management (for manual gateway runs)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GatewayRuntimeSnapshot:
|
||||
manager: str
|
||||
service_installed: bool = False
|
||||
service_running: bool = False
|
||||
gateway_pids: tuple[int, ...] = ()
|
||||
service_scope: str | None = None
|
||||
|
||||
@property
|
||||
def running(self) -> bool:
|
||||
return self.service_running or bool(self.gateway_pids)
|
||||
|
||||
@property
|
||||
def has_process_service_mismatch(self) -> bool:
|
||||
return self.service_installed and self.running and not self.service_running
|
||||
|
||||
def _get_service_pids() -> set:
|
||||
"""Return PIDs currently managed by systemd or launchd gateway services.
|
||||
|
||||
@@ -157,20 +175,22 @@ def _request_gateway_self_restart(pid: int) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def find_gateway_pids(exclude_pids: set | None = None, all_profiles: bool = False) -> list:
|
||||
"""Find PIDs of running gateway processes.
|
||||
def _append_unique_pid(pids: list[int], pid: int | None, exclude_pids: set[int]) -> None:
|
||||
if pid is None or pid <= 0:
|
||||
return
|
||||
if pid == os.getpid() or pid in exclude_pids or pid in pids:
|
||||
return
|
||||
pids.append(pid)
|
||||
|
||||
Args:
|
||||
exclude_pids: PIDs to exclude from the result (e.g. service-managed
|
||||
PIDs that should not be killed during a stale-process sweep).
|
||||
all_profiles: When ``True``, return gateway PIDs across **all**
|
||||
profiles (the pre-7923 global behaviour). ``hermes update``
|
||||
needs this because a code update affects every profile.
|
||||
When ``False`` (default), only PIDs belonging to the current
|
||||
Hermes profile are returned.
|
||||
|
||||
def _scan_gateway_pids(exclude_pids: set[int], all_profiles: bool = False) -> list[int]:
|
||||
"""Best-effort process-table scan for gateway PIDs.
|
||||
|
||||
This supplements the profile-scoped PID file so status views can still spot
|
||||
a live gateway when the PID file is stale/missing, and ``--all`` sweeps can
|
||||
discover gateways outside the current profile.
|
||||
"""
|
||||
_exclude = exclude_pids or set()
|
||||
pids = [pid for pid in _get_service_pids() if pid not in _exclude]
|
||||
pids: list[int] = []
|
||||
patterns = [
|
||||
"hermes_cli.main gateway",
|
||||
"hermes_cli.main --profile",
|
||||
@@ -203,20 +223,24 @@ def find_gateway_pids(exclude_pids: set | None = None, all_profiles: bool = Fals
|
||||
if is_windows():
|
||||
result = subprocess.run(
|
||||
["wmic", "process", "get", "ProcessId,CommandLine", "/FORMAT:LIST"],
|
||||
capture_output=True, text=True, timeout=10
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
return []
|
||||
current_cmd = ""
|
||||
for line in result.stdout.split('\n'):
|
||||
for line in result.stdout.split("\n"):
|
||||
line = line.strip()
|
||||
if line.startswith("CommandLine="):
|
||||
current_cmd = line[len("CommandLine="):]
|
||||
elif line.startswith("ProcessId="):
|
||||
pid_str = line[len("ProcessId="):]
|
||||
if any(p in current_cmd for p in patterns) and (all_profiles or _matches_current_profile(current_cmd)):
|
||||
if any(p in current_cmd for p in patterns) and (
|
||||
all_profiles or _matches_current_profile(current_cmd)
|
||||
):
|
||||
try:
|
||||
pid = int(pid_str)
|
||||
if pid != os.getpid() and pid not in pids and pid not in _exclude:
|
||||
pids.append(pid)
|
||||
_append_unique_pid(pids, int(pid_str), exclude_pids)
|
||||
except ValueError:
|
||||
pass
|
||||
current_cmd = ""
|
||||
@@ -227,9 +251,11 @@ def find_gateway_pids(exclude_pids: set | None = None, all_profiles: bool = Fals
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
for line in result.stdout.split('\n'):
|
||||
if result.returncode != 0:
|
||||
return []
|
||||
for line in result.stdout.split("\n"):
|
||||
stripped = line.strip()
|
||||
if not stripped or 'grep' in stripped:
|
||||
if not stripped or "grep" in stripped:
|
||||
continue
|
||||
|
||||
pid = None
|
||||
@@ -251,16 +277,137 @@ def find_gateway_pids(exclude_pids: set | None = None, all_profiles: bool = Fals
|
||||
|
||||
if pid is None:
|
||||
continue
|
||||
if pid == os.getpid() or pid in pids or pid in _exclude:
|
||||
continue
|
||||
if any(pattern in command for pattern in patterns) and (all_profiles or _matches_current_profile(command)):
|
||||
pids.append(pid)
|
||||
if any(pattern in command for pattern in patterns) and (
|
||||
all_profiles or _matches_current_profile(command)
|
||||
):
|
||||
_append_unique_pid(pids, pid, exclude_pids)
|
||||
except (OSError, subprocess.TimeoutExpired):
|
||||
pass
|
||||
return []
|
||||
|
||||
return pids
|
||||
|
||||
|
||||
def find_gateway_pids(exclude_pids: set | None = None, all_profiles: bool = False) -> list:
|
||||
"""Find PIDs of running gateway processes.
|
||||
|
||||
Args:
|
||||
exclude_pids: PIDs to exclude from the result (e.g. service-managed
|
||||
PIDs that should not be killed during a stale-process sweep).
|
||||
all_profiles: When ``True``, return gateway PIDs across **all**
|
||||
profiles (the pre-7923 global behaviour). ``hermes update``
|
||||
needs this because a code update affects every profile.
|
||||
When ``False`` (default), only PIDs belonging to the current
|
||||
Hermes profile are returned.
|
||||
"""
|
||||
_exclude = set(exclude_pids or set())
|
||||
pids: list[int] = []
|
||||
if not all_profiles:
|
||||
try:
|
||||
from gateway.status import get_running_pid
|
||||
|
||||
_append_unique_pid(pids, get_running_pid(), _exclude)
|
||||
except Exception:
|
||||
pass
|
||||
for pid in _get_service_pids():
|
||||
_append_unique_pid(pids, pid, _exclude)
|
||||
for pid in _scan_gateway_pids(_exclude, all_profiles=all_profiles):
|
||||
_append_unique_pid(pids, pid, _exclude)
|
||||
return pids
|
||||
|
||||
|
||||
def _probe_systemd_service_running(system: bool = False) -> tuple[bool, bool]:
|
||||
selected_system = _select_systemd_scope(system)
|
||||
unit_exists = get_systemd_unit_path(system=selected_system).exists()
|
||||
if not unit_exists:
|
||||
return selected_system, False
|
||||
try:
|
||||
result = _run_systemctl(
|
||||
["is-active", get_service_name()],
|
||||
system=selected_system,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
except (RuntimeError, subprocess.TimeoutExpired):
|
||||
return selected_system, False
|
||||
return selected_system, result.stdout.strip() == "active"
|
||||
|
||||
|
||||
def _probe_launchd_service_running() -> bool:
|
||||
if not get_launchd_plist_path().exists():
|
||||
return False
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["launchctl", "list", get_launchd_label()],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
return False
|
||||
return result.returncode == 0
|
||||
|
||||
|
||||
def get_gateway_runtime_snapshot(system: bool = False) -> GatewayRuntimeSnapshot:
|
||||
"""Return a unified view of gateway liveness for the current profile."""
|
||||
gateway_pids = tuple(find_gateway_pids())
|
||||
if is_termux():
|
||||
return GatewayRuntimeSnapshot(
|
||||
manager="Termux / manual process",
|
||||
gateway_pids=gateway_pids,
|
||||
)
|
||||
|
||||
from hermes_constants import is_container
|
||||
|
||||
if is_linux() and is_container():
|
||||
return GatewayRuntimeSnapshot(
|
||||
manager="docker (foreground)",
|
||||
gateway_pids=gateway_pids,
|
||||
)
|
||||
|
||||
if supports_systemd_services():
|
||||
selected_system, service_running = _probe_systemd_service_running(system=system)
|
||||
scope_label = _service_scope_label(selected_system)
|
||||
return GatewayRuntimeSnapshot(
|
||||
manager=f"systemd ({scope_label})",
|
||||
service_installed=get_systemd_unit_path(system=selected_system).exists(),
|
||||
service_running=service_running,
|
||||
gateway_pids=gateway_pids,
|
||||
service_scope=scope_label,
|
||||
)
|
||||
|
||||
if is_macos():
|
||||
return GatewayRuntimeSnapshot(
|
||||
manager="launchd",
|
||||
service_installed=get_launchd_plist_path().exists(),
|
||||
service_running=_probe_launchd_service_running(),
|
||||
gateway_pids=gateway_pids,
|
||||
service_scope="launchd",
|
||||
)
|
||||
|
||||
return GatewayRuntimeSnapshot(
|
||||
manager="manual process",
|
||||
gateway_pids=gateway_pids,
|
||||
)
|
||||
|
||||
|
||||
def _format_gateway_pids(pids: tuple[int, ...] | list[int], *, limit: int | None = 3) -> str:
|
||||
rendered = [str(pid) for pid in pids[:limit] if pid > 0] if limit is not None else [str(pid) for pid in pids if pid > 0]
|
||||
if limit is not None and len(pids) > limit:
|
||||
rendered.append("...")
|
||||
return ", ".join(rendered)
|
||||
|
||||
|
||||
def _print_gateway_process_mismatch(snapshot: GatewayRuntimeSnapshot) -> None:
|
||||
if not snapshot.has_process_service_mismatch:
|
||||
return
|
||||
print()
|
||||
print("⚠ Gateway process is running for this profile, but the service is not active")
|
||||
print(f" PID(s): {_format_gateway_pids(snapshot.gateway_pids, limit=None)}")
|
||||
print(" This is usually a manual foreground/tmux/nohup run, so `hermes gateway`")
|
||||
print(" can refuse to start another copy until this process stops.")
|
||||
|
||||
|
||||
def kill_gateway_processes(force: bool = False, exclude_pids: set | None = None,
|
||||
all_profiles: bool = False) -> int:
|
||||
"""Kill any running gateway processes. Returns count killed.
|
||||
@@ -340,25 +487,44 @@ def _wsl_systemd_operational() -> bool:
|
||||
WSL2 with ``systemd=true`` in wsl.conf has working systemd.
|
||||
WSL2 without it (or WSL1) does not — systemctl commands fail.
|
||||
"""
|
||||
return _systemd_operational(system=True)
|
||||
|
||||
|
||||
def _systemd_operational(system: bool = False) -> bool:
|
||||
"""Return True when the requested systemd scope is usable."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["systemctl", "is-system-running"],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
result = _run_systemctl(
|
||||
["is-system-running"],
|
||||
system=system,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
# "running", "degraded", "starting" all mean systemd is PID 1
|
||||
status = result.stdout.strip().lower()
|
||||
return status in ("running", "degraded", "starting", "initializing")
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired, OSError):
|
||||
except (RuntimeError, subprocess.TimeoutExpired, OSError):
|
||||
return False
|
||||
|
||||
|
||||
def _container_systemd_operational() -> bool:
|
||||
"""Return True when a container exposes working user or system systemd."""
|
||||
if _systemd_operational(system=False):
|
||||
return True
|
||||
if _systemd_operational(system=True):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def supports_systemd_services() -> bool:
|
||||
if not is_linux() or is_termux() or is_container():
|
||||
if not is_linux() or is_termux():
|
||||
return False
|
||||
if shutil.which("systemctl") is None:
|
||||
return False
|
||||
if is_wsl():
|
||||
return _wsl_systemd_operational()
|
||||
if is_container():
|
||||
return _container_systemd_operational()
|
||||
return True
|
||||
|
||||
|
||||
@@ -521,6 +687,195 @@ def has_conflicting_systemd_units() -> bool:
|
||||
return len(get_installed_systemd_scopes()) > 1
|
||||
|
||||
|
||||
# Legacy service names from older Hermes installs that predate the
|
||||
# hermes-gateway rename. Kept as an explicit allowlist (NOT a glob) so
|
||||
# profile units (hermes-gateway-*.service) and unrelated third-party
|
||||
# "hermes" units are never matched.
|
||||
_LEGACY_SERVICE_NAMES: tuple[str, ...] = ("hermes.service",)
|
||||
|
||||
# ExecStart content markers that identify a unit as running our gateway.
|
||||
# A legacy unit is only flagged when its file contains one of these.
|
||||
_LEGACY_UNIT_EXECSTART_MARKERS: tuple[str, ...] = (
|
||||
"hermes_cli.main gateway",
|
||||
"hermes_cli/main.py gateway",
|
||||
"gateway/run.py",
|
||||
" hermes gateway ",
|
||||
"/hermes gateway ",
|
||||
)
|
||||
|
||||
|
||||
def _legacy_unit_search_paths() -> list[tuple[bool, Path]]:
|
||||
"""Return ``[(is_system, base_dir), ...]`` — directories to scan for legacy units.
|
||||
|
||||
Factored out so tests can monkeypatch the search roots without touching
|
||||
real filesystem paths.
|
||||
"""
|
||||
return [
|
||||
(False, Path.home() / ".config" / "systemd" / "user"),
|
||||
(True, Path("/etc/systemd/system")),
|
||||
]
|
||||
|
||||
|
||||
def _find_legacy_hermes_units() -> list[tuple[str, Path, bool]]:
|
||||
"""Return ``[(unit_name, unit_path, is_system)]`` for legacy Hermes gateway units.
|
||||
|
||||
Detects unit files installed by older Hermes versions that used a
|
||||
different service name (e.g. ``hermes.service`` before the rename to
|
||||
``hermes-gateway.service``). When both a legacy unit and the current
|
||||
``hermes-gateway.service`` are active, they fight over the same bot
|
||||
token — the PR #5646 signal-recovery change turns this into a 30-second
|
||||
SIGTERM flap loop.
|
||||
|
||||
Safety guards:
|
||||
|
||||
* Explicit allowlist of legacy names (no globbing). Profile units such
|
||||
as ``hermes-gateway-coder.service`` and unrelated third-party
|
||||
``hermes-*`` services are never matched.
|
||||
* ExecStart content check — only flag units that invoke our gateway
|
||||
entrypoint. A user-created ``hermes.service`` running an unrelated
|
||||
binary is left untouched.
|
||||
* Results are returned purely for caller inspection; this function
|
||||
never mutates or removes anything.
|
||||
"""
|
||||
results: list[tuple[str, Path, bool]] = []
|
||||
for is_system, base in _legacy_unit_search_paths():
|
||||
for name in _LEGACY_SERVICE_NAMES:
|
||||
unit_path = base / name
|
||||
try:
|
||||
if not unit_path.exists():
|
||||
continue
|
||||
text = unit_path.read_text(encoding="utf-8", errors="ignore")
|
||||
except (OSError, PermissionError):
|
||||
continue
|
||||
if not any(marker in text for marker in _LEGACY_UNIT_EXECSTART_MARKERS):
|
||||
# Not our gateway — leave alone
|
||||
continue
|
||||
results.append((name, unit_path, is_system))
|
||||
return results
|
||||
|
||||
|
||||
def has_legacy_hermes_units() -> bool:
|
||||
"""Return True when any legacy Hermes gateway unit files exist."""
|
||||
return bool(_find_legacy_hermes_units())
|
||||
|
||||
|
||||
def print_legacy_unit_warning() -> None:
|
||||
"""Warn about legacy Hermes gateway unit files if any are installed.
|
||||
|
||||
Idempotent: prints nothing when no legacy units are detected. Safe to
|
||||
call from any status/install/setup path.
|
||||
"""
|
||||
legacy = _find_legacy_hermes_units()
|
||||
if not legacy:
|
||||
return
|
||||
print_warning("Legacy Hermes gateway unit(s) detected from an older install:")
|
||||
for name, path, is_system in legacy:
|
||||
scope = "system" if is_system else "user"
|
||||
print_info(f" {path} ({scope} scope)")
|
||||
print_info(" These run alongside the current hermes-gateway service and")
|
||||
print_info(" cause SIGTERM flap loops — both try to use the same bot token.")
|
||||
print_info(" Remove them with:")
|
||||
print_info(" hermes gateway migrate-legacy")
|
||||
|
||||
|
||||
def remove_legacy_hermes_units(
|
||||
interactive: bool = True,
|
||||
dry_run: bool = False,
|
||||
) -> tuple[int, list[Path]]:
|
||||
"""Stop, disable, and remove legacy Hermes gateway unit files.
|
||||
|
||||
Iterates over whatever ``_find_legacy_hermes_units()`` returns — which is
|
||||
an explicit allowlist of legacy names (not a glob). Profile units and
|
||||
unrelated third-party services are never touched.
|
||||
|
||||
Args:
|
||||
interactive: When True, prompt before removing. When False, remove
|
||||
without asking (used when another prompt has already confirmed,
|
||||
e.g. from the install flow).
|
||||
dry_run: When True, list what would be removed and return.
|
||||
|
||||
Returns:
|
||||
``(removed_count, remaining_paths)`` — remaining includes units we
|
||||
couldn't remove (typically system-scope when not running as root).
|
||||
"""
|
||||
legacy = _find_legacy_hermes_units()
|
||||
if not legacy:
|
||||
print("No legacy Hermes gateway units found.")
|
||||
return 0, []
|
||||
|
||||
user_units = [(n, p) for n, p, is_sys in legacy if not is_sys]
|
||||
system_units = [(n, p) for n, p, is_sys in legacy if is_sys]
|
||||
|
||||
print()
|
||||
print("Legacy Hermes gateway unit(s) found:")
|
||||
for name, path, is_system in legacy:
|
||||
scope = "system" if is_system else "user"
|
||||
print(f" {path} ({scope} scope)")
|
||||
print()
|
||||
|
||||
if dry_run:
|
||||
print("(dry-run — nothing removed)")
|
||||
return 0, [p for _, p, _ in legacy]
|
||||
|
||||
if interactive and not prompt_yes_no("Remove these legacy units?", True):
|
||||
print("Skipped. Run again with: hermes gateway migrate-legacy")
|
||||
return 0, [p for _, p, _ in legacy]
|
||||
|
||||
removed = 0
|
||||
remaining: list[Path] = []
|
||||
|
||||
# User-scope removal
|
||||
for name, path in user_units:
|
||||
try:
|
||||
_run_systemctl(["stop", name], system=False, check=False, timeout=90)
|
||||
_run_systemctl(["disable", name], system=False, check=False, timeout=30)
|
||||
path.unlink(missing_ok=True)
|
||||
print(f" ✓ Removed {path}")
|
||||
removed += 1
|
||||
except (OSError, RuntimeError) as e:
|
||||
print(f" ⚠ Could not remove {path}: {e}")
|
||||
remaining.append(path)
|
||||
|
||||
if user_units:
|
||||
try:
|
||||
_run_systemctl(["daemon-reload"], system=False, check=False, timeout=30)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
# System-scope removal (needs root)
|
||||
if system_units:
|
||||
if os.geteuid() != 0:
|
||||
print()
|
||||
print_warning("System-scope legacy units require root to remove.")
|
||||
print_info(" Re-run with: sudo hermes gateway migrate-legacy")
|
||||
for _, path in system_units:
|
||||
remaining.append(path)
|
||||
else:
|
||||
for name, path in system_units:
|
||||
try:
|
||||
_run_systemctl(["stop", name], system=True, check=False, timeout=90)
|
||||
_run_systemctl(["disable", name], system=True, check=False, timeout=30)
|
||||
path.unlink(missing_ok=True)
|
||||
print(f" ✓ Removed {path}")
|
||||
removed += 1
|
||||
except (OSError, RuntimeError) as e:
|
||||
print(f" ⚠ Could not remove {path}: {e}")
|
||||
remaining.append(path)
|
||||
|
||||
try:
|
||||
_run_systemctl(["daemon-reload"], system=True, check=False, timeout=30)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
print()
|
||||
if remaining:
|
||||
print_warning(f"{len(remaining)} legacy unit(s) still present — see messages above.")
|
||||
else:
|
||||
print_success(f"Removed {removed} legacy unit(s).")
|
||||
|
||||
return removed, remaining
|
||||
|
||||
|
||||
def print_systemd_scope_conflict_warning() -> None:
|
||||
scopes = get_installed_systemd_scopes()
|
||||
if len(scopes) < 2:
|
||||
@@ -1054,6 +1409,19 @@ def systemd_install(force: bool = False, system: bool = False, run_as_user: str
|
||||
if system:
|
||||
_require_root_for_system_service("install")
|
||||
|
||||
# Offer to remove legacy units (hermes.service from pre-rename installs)
|
||||
# before installing the new hermes-gateway.service. If both remain, they
|
||||
# flap-fight for the Telegram bot token on every gateway startup.
|
||||
# Only removes units matching _LEGACY_SERVICE_NAMES + our ExecStart
|
||||
# signature — profile units are never touched.
|
||||
if has_legacy_hermes_units():
|
||||
print()
|
||||
print_legacy_unit_warning()
|
||||
print()
|
||||
if prompt_yes_no("Remove the legacy unit(s) before installing?", True):
|
||||
remove_legacy_hermes_units(interactive=False)
|
||||
print()
|
||||
|
||||
unit_path = get_systemd_unit_path(system=system)
|
||||
scope_flag = " --system" if system else ""
|
||||
|
||||
@@ -1092,6 +1460,7 @@ def systemd_install(force: bool = False, system: bool = False, run_as_user: str
|
||||
_ensure_linger_enabled()
|
||||
|
||||
print_systemd_scope_conflict_warning()
|
||||
print_legacy_unit_warning()
|
||||
|
||||
|
||||
def systemd_uninstall(system: bool = False):
|
||||
@@ -1215,6 +1584,10 @@ def systemd_status(deep: bool = False, system: bool = False):
|
||||
print_systemd_scope_conflict_warning()
|
||||
print()
|
||||
|
||||
if has_legacy_hermes_units():
|
||||
print_legacy_unit_warning()
|
||||
print()
|
||||
|
||||
if not systemd_unit_is_current(system=system):
|
||||
print("⚠ Installed gateway service definition is outdated")
|
||||
print(f" Run: {'sudo ' if system else ''}hermes gateway restart{scope_flag} # auto-refreshes the unit")
|
||||
@@ -1998,7 +2371,7 @@ _PLATFORMS = [
|
||||
{"name": "QQ_ALLOWED_USERS", "prompt": "Allowed user OpenIDs (comma-separated, leave empty for open access)", "password": False,
|
||||
"is_allowlist": True,
|
||||
"help": "Optional — restrict DM access to specific user OpenIDs."},
|
||||
{"name": "QQ_HOME_CHANNEL", "prompt": "Home channel (user/group OpenID for cron delivery, or empty)", "password": False,
|
||||
{"name": "QQBOT_HOME_CHANNEL", "prompt": "Home channel (user/group OpenID for cron delivery, or empty)", "password": False,
|
||||
"help": "OpenID to deliver cron results and notifications to."},
|
||||
],
|
||||
},
|
||||
@@ -2625,6 +2998,215 @@ def _setup_feishu():
|
||||
print_info(f" Bot: {bot_name}")
|
||||
|
||||
|
||||
def _setup_qqbot():
|
||||
"""Interactive setup for QQ Bot — scan-to-configure or manual credentials."""
|
||||
print()
|
||||
print(color(" ─── 🐧 QQ Bot Setup ───", Colors.CYAN))
|
||||
|
||||
existing_app_id = get_env_value("QQ_APP_ID")
|
||||
existing_secret = get_env_value("QQ_CLIENT_SECRET")
|
||||
if existing_app_id and existing_secret:
|
||||
print()
|
||||
print_success("QQ Bot is already configured.")
|
||||
if not prompt_yes_no(" Reconfigure QQ Bot?", False):
|
||||
return
|
||||
|
||||
# ── Choose setup method ──
|
||||
print()
|
||||
method_choices = [
|
||||
"Scan QR code to add bot automatically (recommended)",
|
||||
"Enter existing App ID and App Secret manually",
|
||||
]
|
||||
method_idx = prompt_choice(" How would you like to set up QQ Bot?", method_choices, 0)
|
||||
|
||||
credentials = None
|
||||
used_qr = False
|
||||
|
||||
if method_idx == 0:
|
||||
# ── QR scan-to-configure ──
|
||||
try:
|
||||
credentials = _qqbot_qr_flow()
|
||||
except KeyboardInterrupt:
|
||||
print()
|
||||
print_warning(" QQ Bot setup cancelled.")
|
||||
return
|
||||
if credentials:
|
||||
used_qr = True
|
||||
if not credentials:
|
||||
print_info(" QR setup did not complete. Continuing with manual input.")
|
||||
|
||||
# ── Manual credential input ──
|
||||
if not credentials:
|
||||
print()
|
||||
print_info(" Go to https://q.qq.com to register a QQ Bot application.")
|
||||
print_info(" Note your App ID and App Secret from the application page.")
|
||||
print()
|
||||
app_id = prompt(" App ID", password=False)
|
||||
if not app_id:
|
||||
print_warning(" Skipped — QQ Bot won't work without an App ID.")
|
||||
return
|
||||
app_secret = prompt(" App Secret", password=True)
|
||||
if not app_secret:
|
||||
print_warning(" Skipped — QQ Bot won't work without an App Secret.")
|
||||
return
|
||||
credentials = {"app_id": app_id.strip(), "client_secret": app_secret.strip(), "user_openid": ""}
|
||||
|
||||
# ── Save core credentials ──
|
||||
save_env_value("QQ_APP_ID", credentials["app_id"])
|
||||
save_env_value("QQ_CLIENT_SECRET", credentials["client_secret"])
|
||||
|
||||
user_openid = credentials.get("user_openid", "")
|
||||
|
||||
# ── DM security policy ──
|
||||
print()
|
||||
access_choices = [
|
||||
"Use DM pairing approval (recommended)",
|
||||
"Allow all direct messages",
|
||||
"Only allow listed user OpenIDs",
|
||||
]
|
||||
access_idx = prompt_choice(" How should direct messages be authorized?", access_choices, 0)
|
||||
if access_idx == 0:
|
||||
save_env_value("QQ_ALLOW_ALL_USERS", "false")
|
||||
if user_openid:
|
||||
print()
|
||||
if prompt_yes_no(f" Add yourself ({user_openid}) to the allow list?", True):
|
||||
save_env_value("QQ_ALLOWED_USERS", user_openid)
|
||||
print_success(f" Allow list set to {user_openid}")
|
||||
else:
|
||||
save_env_value("QQ_ALLOWED_USERS", "")
|
||||
else:
|
||||
save_env_value("QQ_ALLOWED_USERS", "")
|
||||
print_success(" DM pairing enabled.")
|
||||
print_info(" Unknown users can request access; approve with `hermes pairing approve`.")
|
||||
elif access_idx == 1:
|
||||
save_env_value("QQ_ALLOW_ALL_USERS", "true")
|
||||
save_env_value("QQ_ALLOWED_USERS", "")
|
||||
print_warning(" Open DM access enabled for QQ Bot.")
|
||||
else:
|
||||
default_allow = user_openid or ""
|
||||
allowlist = prompt(" Allowed user OpenIDs (comma-separated)", default_allow, password=False).replace(" ", "")
|
||||
save_env_value("QQ_ALLOW_ALL_USERS", "false")
|
||||
save_env_value("QQ_ALLOWED_USERS", allowlist)
|
||||
print_success(" Allowlist saved.")
|
||||
|
||||
# ── Home channel ──
|
||||
if user_openid:
|
||||
print()
|
||||
if prompt_yes_no(f" Use your QQ user ID ({user_openid}) as the home channel?", True):
|
||||
save_env_value("QQBOT_HOME_CHANNEL", user_openid)
|
||||
print_success(f" Home channel set to {user_openid}")
|
||||
else:
|
||||
print()
|
||||
home_channel = prompt(" Home channel OpenID (for cron/notifications, or empty)", password=False)
|
||||
if home_channel:
|
||||
save_env_value("QQBOT_HOME_CHANNEL", home_channel.strip())
|
||||
print_success(f" Home channel set to {home_channel.strip()}")
|
||||
|
||||
print()
|
||||
print_success("🐧 QQ Bot configured!")
|
||||
print_info(f" App ID: {credentials['app_id']}")
|
||||
|
||||
|
||||
def _qqbot_render_qr(url: str) -> bool:
|
||||
"""Try to render a QR code in the terminal. Returns True if successful."""
|
||||
try:
|
||||
import qrcode as _qr
|
||||
qr = _qr.QRCode(border=1,error_correction=_qr.constants.ERROR_CORRECT_L)
|
||||
qr.add_data(url)
|
||||
qr.make(fit=True)
|
||||
qr.print_ascii(invert=True)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _qqbot_qr_flow():
|
||||
"""Run the QR-code scan-to-configure flow.
|
||||
|
||||
Returns a dict with app_id, client_secret, user_openid on success,
|
||||
or None on failure/cancel.
|
||||
"""
|
||||
try:
|
||||
from gateway.platforms.qqbot import (
|
||||
create_bind_task, poll_bind_result, build_connect_url,
|
||||
decrypt_secret, BindStatus,
|
||||
)
|
||||
from gateway.platforms.qqbot.constants import ONBOARD_POLL_INTERVAL
|
||||
except Exception as exc:
|
||||
print_error(f" QQBot onboard import failed: {exc}")
|
||||
return None
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
MAX_REFRESHES = 3
|
||||
refresh_count = 0
|
||||
|
||||
while refresh_count <= MAX_REFRESHES:
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
# ── Create bind task ──
|
||||
try:
|
||||
task_id, aes_key = loop.run_until_complete(create_bind_task())
|
||||
except Exception as e:
|
||||
print_warning(f" Failed to create bind task: {e}")
|
||||
loop.close()
|
||||
return None
|
||||
|
||||
url = build_connect_url(task_id)
|
||||
|
||||
# ── Display QR code + URL ──
|
||||
print()
|
||||
if _qqbot_render_qr(url):
|
||||
print(f" Scan the QR code above, or open this URL directly:\n {url}")
|
||||
else:
|
||||
print(f" Open this URL in QQ on your phone:\n {url}")
|
||||
print_info(" Tip: pip install qrcode to show a scannable QR code here")
|
||||
|
||||
# ── Poll loop (silent — keep QR visible at bottom) ──
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
status, app_id, encrypted_secret, user_openid = loop.run_until_complete(
|
||||
poll_bind_result(task_id)
|
||||
)
|
||||
except Exception:
|
||||
time.sleep(ONBOARD_POLL_INTERVAL)
|
||||
continue
|
||||
|
||||
if status == BindStatus.COMPLETED:
|
||||
client_secret = decrypt_secret(encrypted_secret, aes_key)
|
||||
print()
|
||||
print_success(f" QR scan complete! (App ID: {app_id})")
|
||||
if user_openid:
|
||||
print_info(f" Scanner's OpenID: {user_openid}")
|
||||
return {
|
||||
"app_id": app_id,
|
||||
"client_secret": client_secret,
|
||||
"user_openid": user_openid,
|
||||
}
|
||||
|
||||
if status == BindStatus.EXPIRED:
|
||||
refresh_count += 1
|
||||
if refresh_count > MAX_REFRESHES:
|
||||
print()
|
||||
print_warning(f" QR code expired {MAX_REFRESHES} times — giving up.")
|
||||
return None
|
||||
print()
|
||||
print_warning(f" QR code expired, refreshing... ({refresh_count}/{MAX_REFRESHES})")
|
||||
loop.close()
|
||||
break # outer while creates a new task
|
||||
|
||||
time.sleep(ONBOARD_POLL_INTERVAL)
|
||||
except KeyboardInterrupt:
|
||||
loop.close()
|
||||
raise
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _setup_signal():
|
||||
"""Interactive setup for Signal messenger."""
|
||||
import shutil
|
||||
@@ -2762,6 +3344,10 @@ def gateway_setup():
|
||||
print_systemd_scope_conflict_warning()
|
||||
print()
|
||||
|
||||
if supports_systemd_services() and has_legacy_hermes_units():
|
||||
print_legacy_unit_warning()
|
||||
print()
|
||||
|
||||
if service_installed and service_running:
|
||||
print_success("Gateway service is installed and running.")
|
||||
elif service_installed:
|
||||
@@ -2806,6 +3392,8 @@ def gateway_setup():
|
||||
_setup_dingtalk()
|
||||
elif platform["key"] == "feishu":
|
||||
_setup_feishu()
|
||||
elif platform["key"] == "qqbot":
|
||||
_setup_qqbot()
|
||||
else:
|
||||
_setup_standard_platform(platform)
|
||||
|
||||
@@ -3165,15 +3753,18 @@ def gateway_command(args):
|
||||
elif subcmd == "status":
|
||||
deep = getattr(args, 'deep', False)
|
||||
system = getattr(args, 'system', False)
|
||||
snapshot = get_gateway_runtime_snapshot(system=system)
|
||||
|
||||
# Check for service first
|
||||
if supports_systemd_services() and (get_systemd_unit_path(system=False).exists() or get_systemd_unit_path(system=True).exists()):
|
||||
systemd_status(deep, system=system)
|
||||
_print_gateway_process_mismatch(snapshot)
|
||||
elif is_macos() and get_launchd_plist_path().exists():
|
||||
launchd_status(deep)
|
||||
_print_gateway_process_mismatch(snapshot)
|
||||
else:
|
||||
# Check for manually running processes
|
||||
pids = find_gateway_pids()
|
||||
pids = list(snapshot.gateway_pids)
|
||||
if pids:
|
||||
print(f"✓ Gateway is running (PID: {', '.join(map(str, pids))})")
|
||||
print(" (Running manually, not as a system service)")
|
||||
@@ -3214,3 +3805,14 @@ def gateway_command(args):
|
||||
else:
|
||||
print(" hermes gateway install # Install as user service")
|
||||
print(" sudo hermes gateway install --system # Install as boot-time system service")
|
||||
|
||||
elif subcmd == "migrate-legacy":
|
||||
# Stop, disable, and remove legacy Hermes gateway unit files from
|
||||
# pre-rename installs (e.g. hermes.service). Profile units and
|
||||
# unrelated third-party services are never touched.
|
||||
dry_run = getattr(args, 'dry_run', False)
|
||||
yes = getattr(args, 'yes', False)
|
||||
if not supports_systemd_services() and not is_macos():
|
||||
print("Legacy unit migration only applies to systemd-based Linux hosts.")
|
||||
return
|
||||
remove_legacy_hermes_units(interactive=not yes, dry_run=dry_run)
|
||||
|
||||
+2585
-687
File diff suppressed because it is too large
Load Diff
@@ -692,12 +692,12 @@ def switch_model(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
validation = {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"accepted": False,
|
||||
"persist": False,
|
||||
"recognized": False,
|
||||
"message": None,
|
||||
"message": f"Could not validate `{new_model}`: {e}",
|
||||
}
|
||||
|
||||
if not validation.get("accepted"):
|
||||
|
||||
+45
-29
@@ -26,7 +26,8 @@ COPILOT_REASONING_EFFORTS_O_SERIES = ["low", "medium", "high"]
|
||||
# Fallback OpenRouter snapshot used when the live catalog is unavailable.
|
||||
# (model_id, display description shown in menus)
|
||||
OPENROUTER_MODELS: list[tuple[str, str]] = [
|
||||
("anthropic/claude-opus-4.7", "recommended"),
|
||||
("moonshotai/kimi-k2.5", "recommended"),
|
||||
("anthropic/claude-opus-4.7", ""),
|
||||
("anthropic/claude-opus-4.6", ""),
|
||||
("anthropic/claude-sonnet-4.6", ""),
|
||||
("qwen/qwen3.6-plus", ""),
|
||||
@@ -49,7 +50,6 @@ OPENROUTER_MODELS: list[tuple[str, str]] = [
|
||||
("z-ai/glm-5.1", ""),
|
||||
("z-ai/glm-5v-turbo", ""),
|
||||
("z-ai/glm-5-turbo", ""),
|
||||
("moonshotai/kimi-k2.5", ""),
|
||||
("x-ai/grok-4.20", ""),
|
||||
("nvidia/nemotron-3-super-120b-a12b", ""),
|
||||
("nvidia/nemotron-3-super-120b-a12b:free", "free"),
|
||||
@@ -75,6 +75,7 @@ def _codex_curated_models() -> list[str]:
|
||||
|
||||
_PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"nous": [
|
||||
"moonshotai/kimi-k2.5",
|
||||
"xiaomi/mimo-v2-pro",
|
||||
"anthropic/claude-opus-4.7",
|
||||
"anthropic/claude-opus-4.6",
|
||||
@@ -96,7 +97,6 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"z-ai/glm-5.1",
|
||||
"z-ai/glm-5v-turbo",
|
||||
"z-ai/glm-5-turbo",
|
||||
"moonshotai/kimi-k2.5",
|
||||
"x-ai/grok-4.20-beta",
|
||||
"nvidia/nemotron-3-super-120b-a12b",
|
||||
"nvidia/nemotron-3-super-120b-a12b:free",
|
||||
@@ -135,7 +135,6 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"gemini-2.5-flash-lite",
|
||||
# Gemma open models (also served via AI Studio)
|
||||
"gemma-4-31b-it",
|
||||
"gemma-4-26b-it",
|
||||
],
|
||||
"google-gemini-cli": [
|
||||
"gemini-2.5-pro",
|
||||
@@ -155,9 +154,23 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"grok-4.20-reasoning",
|
||||
"grok-4-1-fast-reasoning",
|
||||
],
|
||||
"nvidia": [
|
||||
# NVIDIA flagship reasoning models
|
||||
"nvidia/nemotron-3-super-120b-a12b",
|
||||
"nvidia/nemotron-3-nano-30b-a3b",
|
||||
"nvidia/llama-3.3-nemotron-super-49b-v1.5",
|
||||
# Third-party agentic models hosted on build.nvidia.com
|
||||
# (map to OpenRouter defaults — users get familiar picks on NIM)
|
||||
"qwen/qwen3.5-397b-a17b",
|
||||
"deepseek-ai/deepseek-v3.2",
|
||||
"moonshotai/kimi-k2.5",
|
||||
"minimaxai/minimax-m2.5",
|
||||
"z-ai/glm5",
|
||||
"openai/gpt-oss-120b",
|
||||
],
|
||||
"kimi-coding": [
|
||||
"kimi-for-coding",
|
||||
"kimi-k2.5",
|
||||
"kimi-for-coding",
|
||||
"kimi-k2-thinking",
|
||||
"kimi-k2-thinking-turbo",
|
||||
"kimi-k2-turbo-preview",
|
||||
@@ -212,6 +225,7 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"trinity-mini",
|
||||
],
|
||||
"opencode-zen": [
|
||||
"kimi-k2.5",
|
||||
"gpt-5.4-pro",
|
||||
"gpt-5.4",
|
||||
"gpt-5.3-codex",
|
||||
@@ -243,16 +257,15 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"glm-5",
|
||||
"glm-4.7",
|
||||
"glm-4.6",
|
||||
"kimi-k2.5",
|
||||
"kimi-k2-thinking",
|
||||
"kimi-k2",
|
||||
"qwen3-coder",
|
||||
"big-pickle",
|
||||
],
|
||||
"opencode-go": [
|
||||
"kimi-k2.5",
|
||||
"glm-5.1",
|
||||
"glm-5",
|
||||
"kimi-k2.5",
|
||||
"mimo-v2-pro",
|
||||
"mimo-v2-omni",
|
||||
"minimax-m2.7",
|
||||
@@ -285,21 +298,21 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
# to https://dashscope-intl.aliyuncs.com/compatible-mode/v1 (OpenAI-compat)
|
||||
# or https://dashscope-intl.aliyuncs.com/apps/anthropic (Anthropic-compat).
|
||||
"alibaba": [
|
||||
"kimi-k2.5",
|
||||
"qwen3.5-plus",
|
||||
"qwen3-coder-plus",
|
||||
"qwen3-coder-next",
|
||||
# Third-party models available on coding-intl
|
||||
"glm-5",
|
||||
"glm-4.7",
|
||||
"kimi-k2.5",
|
||||
"MiniMax-M2.5",
|
||||
],
|
||||
# Curated HF model list — only agentic models that map to OpenRouter defaults.
|
||||
"huggingface": [
|
||||
"moonshotai/Kimi-K2.5",
|
||||
"Qwen/Qwen3.5-397B-A17B",
|
||||
"Qwen/Qwen3.5-35B-A3B",
|
||||
"deepseek-ai/DeepSeek-V3.2",
|
||||
"moonshotai/Kimi-K2.5",
|
||||
"MiniMaxAI/MiniMax-M2.5",
|
||||
"zai-org/GLM-5",
|
||||
"XiaomiMiMo/MiMo-V2-Flash",
|
||||
@@ -536,6 +549,7 @@ CANONICAL_PROVIDERS: list[ProviderEntry] = [
|
||||
ProviderEntry("anthropic", "Anthropic", "Anthropic (Claude models — API key or Claude Code)"),
|
||||
ProviderEntry("openai-codex", "OpenAI Codex", "OpenAI Codex"),
|
||||
ProviderEntry("xiaomi", "Xiaomi MiMo", "Xiaomi MiMo (MiMo-V2 models — pro, omni, flash)"),
|
||||
ProviderEntry("nvidia", "NVIDIA NIM", "NVIDIA NIM (Nemotron models — build.nvidia.com or local NIM)"),
|
||||
ProviderEntry("qwen-oauth", "Qwen OAuth (Portal)", "Qwen OAuth (reuses local Qwen CLI login)"),
|
||||
ProviderEntry("copilot", "GitHub Copilot", "GitHub Copilot (uses GITHUB_TOKEN or gh auth token)"),
|
||||
ProviderEntry("copilot-acp", "GitHub Copilot ACP", "GitHub Copilot ACP (spawns `copilot --acp --stdio`)"),
|
||||
@@ -618,6 +632,10 @@ _PROVIDER_ALIASES = {
|
||||
"grok": "xai",
|
||||
"x-ai": "xai",
|
||||
"x.ai": "xai",
|
||||
"nim": "nvidia",
|
||||
"nvidia-nim": "nvidia",
|
||||
"build-nvidia": "nvidia",
|
||||
"nemotron": "nvidia",
|
||||
"ollama": "custom", # bare "ollama" = local; use "ollama-cloud" for cloud
|
||||
"ollama_cloud": "ollama-cloud",
|
||||
}
|
||||
@@ -2032,8 +2050,8 @@ def validate_requested_model(
|
||||
)
|
||||
|
||||
return {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"accepted": False,
|
||||
"persist": False,
|
||||
"recognized": False,
|
||||
"message": message,
|
||||
}
|
||||
@@ -2046,8 +2064,8 @@ def validate_requested_model(
|
||||
message += f"\n If this server expects `/v1`, try base URL: `{probe.get('suggested_base_url')}`"
|
||||
|
||||
return {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"accepted": False,
|
||||
"persist": False,
|
||||
"recognized": False,
|
||||
"message": message,
|
||||
}
|
||||
@@ -2081,12 +2099,11 @@ def validate_requested_model(
|
||||
if suggestions:
|
||||
suggestion_text = "\n Similar models: " + ", ".join(f"`{s}`" for s in suggestions)
|
||||
return {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"accepted": False,
|
||||
"persist": False,
|
||||
"recognized": False,
|
||||
"message": (
|
||||
f"Note: `{requested}` was not found in the OpenAI Codex model listing. "
|
||||
f"It may still work if your account has access to it."
|
||||
f"Model `{requested}` was not found in the OpenAI Codex model listing."
|
||||
f"{suggestion_text}"
|
||||
),
|
||||
}
|
||||
@@ -2125,16 +2142,15 @@ def validate_requested_model(
|
||||
if suggestions:
|
||||
suggestion_text = "\n Similar models: " + ", ".join(f"`{s}`" for s in suggestions)
|
||||
|
||||
return {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"recognized": False,
|
||||
"message": (
|
||||
f"Note: `{requested}` was not found in this provider's model listing. "
|
||||
f"It may still work if your plan supports it."
|
||||
f"{suggestion_text}"
|
||||
),
|
||||
}
|
||||
return {
|
||||
"accepted": False,
|
||||
"persist": False,
|
||||
"recognized": False,
|
||||
"message": (
|
||||
f"Model `{requested}` was not found in this provider's model listing."
|
||||
f"{suggestion_text}"
|
||||
),
|
||||
}
|
||||
|
||||
# api_models is None — couldn't reach API. Accept and persist,
|
||||
# but warn so typos don't silently break things.
|
||||
@@ -2176,8 +2192,8 @@ def validate_requested_model(
|
||||
|
||||
provider_label = _PROVIDER_LABELS.get(normalized, normalized)
|
||||
return {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"accepted": False,
|
||||
"persist": False,
|
||||
"recognized": False,
|
||||
"message": (
|
||||
f"Could not reach the {provider_label} API to validate `{requested}`. "
|
||||
|
||||
+3
-12
@@ -300,19 +300,10 @@ def _read_config_model(profile_dir: Path) -> tuple:
|
||||
|
||||
def _check_gateway_running(profile_dir: Path) -> bool:
|
||||
"""Check if a gateway is running for a given profile directory."""
|
||||
pid_file = profile_dir / "gateway.pid"
|
||||
if not pid_file.exists():
|
||||
return False
|
||||
try:
|
||||
raw = pid_file.read_text().strip()
|
||||
if not raw:
|
||||
return False
|
||||
data = json.loads(raw) if raw.startswith("{") else {"pid": int(raw)}
|
||||
pid = int(data["pid"])
|
||||
os.kill(pid, 0) # existence check
|
||||
return True
|
||||
except (json.JSONDecodeError, KeyError, ValueError, TypeError,
|
||||
ProcessLookupError, PermissionError, OSError):
|
||||
from gateway.status import get_running_pid
|
||||
return get_running_pid(profile_dir / "gateway.pid", cleanup_stale=False) is not None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@@ -137,6 +137,11 @@ HERMES_OVERLAYS: Dict[str, HermesOverlay] = {
|
||||
base_url_override="https://api.x.ai/v1",
|
||||
base_url_env_var="XAI_BASE_URL",
|
||||
),
|
||||
"nvidia": HermesOverlay(
|
||||
transport="openai_chat",
|
||||
base_url_override="https://integrate.api.nvidia.com/v1",
|
||||
base_url_env_var="NVIDIA_BASE_URL",
|
||||
),
|
||||
"xiaomi": HermesOverlay(
|
||||
transport="openai_chat",
|
||||
base_url_env_var="XIAOMI_BASE_URL",
|
||||
@@ -191,6 +196,12 @@ ALIASES: Dict[str, str] = {
|
||||
"x.ai": "xai",
|
||||
"grok": "xai",
|
||||
|
||||
# nvidia
|
||||
"nim": "nvidia",
|
||||
"nvidia-nim": "nvidia",
|
||||
"build-nvidia": "nvidia",
|
||||
"nemotron": "nvidia",
|
||||
|
||||
# kimi-for-coding (models.dev ID)
|
||||
"kimi": "kimi-for-coding",
|
||||
"kimi-coding": "kimi-for-coding",
|
||||
|
||||
+13
-54
@@ -91,7 +91,7 @@ _DEFAULT_PROVIDER_MODELS = {
|
||||
"gemini": [
|
||||
"gemini-3.1-pro-preview", "gemini-3-flash-preview", "gemini-3.1-flash-lite-preview",
|
||||
"gemini-2.5-pro", "gemini-2.5-flash", "gemini-2.5-flash-lite",
|
||||
"gemma-4-31b-it", "gemma-4-26b-it",
|
||||
"gemma-4-31b-it",
|
||||
],
|
||||
"zai": ["glm-5.1", "glm-5", "glm-4.7", "glm-4.5", "glm-4.5-flash"],
|
||||
"kimi-coding": ["kimi-k2.5", "kimi-k2-thinking", "kimi-k2-turbo-preview"],
|
||||
@@ -2005,52 +2005,6 @@ def _setup_wecom_callback():
|
||||
_gw_setup()
|
||||
|
||||
|
||||
def _setup_qqbot():
|
||||
"""Configure QQ Bot gateway."""
|
||||
print_header("QQ Bot")
|
||||
existing = get_env_value("QQ_APP_ID")
|
||||
if existing:
|
||||
print_info("QQ Bot: already configured")
|
||||
if not prompt_yes_no("Reconfigure QQ Bot?", False):
|
||||
return
|
||||
|
||||
print_info("Connects Hermes to QQ via the Official QQ Bot API (v2).")
|
||||
print_info(" Requires a QQ Bot application at q.qq.com")
|
||||
print_info(" Reference: https://bot.q.qq.com/wiki/develop/api-v2/")
|
||||
print()
|
||||
|
||||
app_id = prompt("QQ Bot App ID")
|
||||
if not app_id:
|
||||
print_warning("App ID is required — skipping QQ Bot setup")
|
||||
return
|
||||
save_env_value("QQ_APP_ID", app_id.strip())
|
||||
|
||||
client_secret = prompt("QQ Bot App Secret", password=True)
|
||||
if not client_secret:
|
||||
print_warning("App Secret is required — skipping QQ Bot setup")
|
||||
return
|
||||
save_env_value("QQ_CLIENT_SECRET", client_secret)
|
||||
print_success("QQ Bot credentials saved")
|
||||
|
||||
print()
|
||||
print_info("🔒 Security: Restrict who can DM your bot")
|
||||
print_info(" Use QQ user OpenIDs (found in event payloads)")
|
||||
print()
|
||||
allowed_users = prompt("Allowed user OpenIDs (comma-separated, leave empty for open access)")
|
||||
if allowed_users:
|
||||
save_env_value("QQ_ALLOWED_USERS", allowed_users.replace(" ", ""))
|
||||
print_success("QQ Bot allowlist configured")
|
||||
else:
|
||||
print_info("⚠️ No allowlist set — anyone can DM the bot!")
|
||||
|
||||
print()
|
||||
print_info("📬 Home Channel: OpenID for cron job delivery and notifications.")
|
||||
home_channel = prompt("Home channel OpenID (leave empty to set later)")
|
||||
if home_channel:
|
||||
save_env_value("QQ_HOME_CHANNEL", home_channel)
|
||||
|
||||
print()
|
||||
print_success("QQ Bot configured!")
|
||||
|
||||
|
||||
def _setup_bluebubbles():
|
||||
@@ -2119,12 +2073,9 @@ def _setup_bluebubbles():
|
||||
|
||||
|
||||
def _setup_qqbot():
|
||||
"""Configure QQ Bot (Official API v2) via standard platform setup."""
|
||||
from hermes_cli.gateway import _PLATFORMS
|
||||
qq_platform = next((p for p in _PLATFORMS if p["key"] == "qqbot"), None)
|
||||
if qq_platform:
|
||||
from hermes_cli.gateway import _setup_standard_platform
|
||||
_setup_standard_platform(qq_platform)
|
||||
"""Configure QQ Bot (Official API v2) via gateway setup."""
|
||||
from hermes_cli.gateway import _setup_qqbot as _gateway_setup_qqbot
|
||||
_gateway_setup_qqbot()
|
||||
|
||||
|
||||
def _setup_webhooks():
|
||||
@@ -2264,7 +2215,9 @@ def setup_gateway(config: dict):
|
||||
missing_home.append("Slack")
|
||||
if get_env_value("BLUEBUBBLES_SERVER_URL") and not get_env_value("BLUEBUBBLES_HOME_CHANNEL"):
|
||||
missing_home.append("BlueBubbles")
|
||||
if get_env_value("QQ_APP_ID") and not get_env_value("QQ_HOME_CHANNEL"):
|
||||
if get_env_value("QQ_APP_ID") and not (
|
||||
get_env_value("QQBOT_HOME_CHANNEL") or get_env_value("QQ_HOME_CHANNEL")
|
||||
):
|
||||
missing_home.append("QQBot")
|
||||
|
||||
if missing_home:
|
||||
@@ -2289,8 +2242,10 @@ def setup_gateway(config: dict):
|
||||
_is_service_running,
|
||||
supports_systemd_services,
|
||||
has_conflicting_systemd_units,
|
||||
has_legacy_hermes_units,
|
||||
install_linux_gateway_from_setup,
|
||||
print_systemd_scope_conflict_warning,
|
||||
print_legacy_unit_warning,
|
||||
systemd_start,
|
||||
systemd_restart,
|
||||
launchd_install,
|
||||
@@ -2308,6 +2263,10 @@ def setup_gateway(config: dict):
|
||||
print_systemd_scope_conflict_warning()
|
||||
print()
|
||||
|
||||
if supports_systemd and has_legacy_hermes_units():
|
||||
print_legacy_unit_warning()
|
||||
print()
|
||||
|
||||
if service_running:
|
||||
if prompt_yes_no(" Restart the gateway to pick up changes?", True):
|
||||
try:
|
||||
|
||||
@@ -515,6 +515,90 @@ def do_inspect(identifier: str, console: Optional[Console] = None) -> None:
|
||||
c.print()
|
||||
|
||||
|
||||
def browse_skills(page: int = 1, page_size: int = 20, source: str = "all") -> dict:
|
||||
"""Paginated hub browse for programmatic callers (e.g. TUI gateway).
|
||||
|
||||
Returns ``{"items": [...], "page": int, "total_pages": int, "total": int}``.
|
||||
"""
|
||||
from tools.skills_hub import GitHubAuth, create_source_router
|
||||
|
||||
page_size = max(1, min(page_size, 100))
|
||||
_TRUST_RANK = {"builtin": 3, "trusted": 2, "community": 1}
|
||||
_PER_SOURCE_LIMIT = {"official": 100, "skills-sh": 100, "well-known": 25, "github": 100, "clawhub": 50,
|
||||
"claude-marketplace": 50, "lobehub": 50}
|
||||
auth = GitHubAuth()
|
||||
sources = create_source_router(auth)
|
||||
all_results: list = []
|
||||
for src in sources:
|
||||
sid = src.source_id()
|
||||
if source != "all" and sid != source and sid != "official":
|
||||
continue
|
||||
try:
|
||||
limit = _PER_SOURCE_LIMIT.get(sid, 50)
|
||||
all_results.extend(src.search("", limit=limit))
|
||||
except Exception:
|
||||
continue
|
||||
if not all_results:
|
||||
return {"items": [], "page": 1, "total_pages": 1, "total": 0}
|
||||
seen: dict = {}
|
||||
for r in all_results:
|
||||
rank = _TRUST_RANK.get(r.trust_level, 0)
|
||||
if r.name not in seen or rank > _TRUST_RANK.get(seen[r.name].trust_level, 0):
|
||||
seen[r.name] = r
|
||||
deduped = list(seen.values())
|
||||
deduped.sort(key=lambda r: (-_TRUST_RANK.get(r.trust_level, 0), r.source != "official", r.name.lower()))
|
||||
total = len(deduped)
|
||||
total_pages = max(1, (total + page_size - 1) // page_size)
|
||||
page = max(1, min(page, total_pages))
|
||||
start = (page - 1) * page_size
|
||||
page_items = deduped[start : min(start + page_size, total)]
|
||||
return {
|
||||
"items": [{"name": r.name, "description": r.description, "source": r.source,
|
||||
"trust": r.trust_level} for r in page_items],
|
||||
"page": page,
|
||||
"total_pages": total_pages,
|
||||
"total": total,
|
||||
}
|
||||
|
||||
|
||||
def inspect_skill(identifier: str) -> Optional[dict]:
|
||||
"""Skill metadata (+ SKILL.md preview) for programmatic callers."""
|
||||
from tools.skills_hub import GitHubAuth, create_source_router
|
||||
|
||||
class _Q:
|
||||
def print(self, *a, **k):
|
||||
pass
|
||||
|
||||
c = _Q()
|
||||
auth = GitHubAuth()
|
||||
sources = create_source_router(auth)
|
||||
ident = identifier
|
||||
if "/" not in ident:
|
||||
ident = _resolve_short_name(ident, sources, c)
|
||||
if not ident:
|
||||
return None
|
||||
meta, bundle, _ = _resolve_source_meta_and_bundle(ident, sources)
|
||||
if not meta:
|
||||
return None
|
||||
out: dict = {
|
||||
"name": meta.name,
|
||||
"description": meta.description,
|
||||
"source": meta.source,
|
||||
"identifier": meta.identifier,
|
||||
"tags": list(meta.tags) if meta.tags else [],
|
||||
}
|
||||
if bundle and "SKILL.md" in bundle.files:
|
||||
content = bundle.files["SKILL.md"]
|
||||
if isinstance(content, bytes):
|
||||
content = content.decode("utf-8", errors="replace")
|
||||
lines = content.split("\n")
|
||||
preview = "\n".join(lines[:50])
|
||||
if len(lines) > 50:
|
||||
preview += f"\n\n... ({len(lines) - 50} more lines)"
|
||||
out["skill_md_preview"] = preview
|
||||
return out
|
||||
|
||||
|
||||
def do_list(source_filter: str = "all", console: Optional[Console] = None) -> None:
|
||||
"""List installed skills, distinguishing hub, builtin, and local skills."""
|
||||
from tools.skills_hub import HubLockFile, ensure_hub_dirs
|
||||
|
||||
@@ -23,7 +23,7 @@ All fields are optional. Missing values inherit from the ``default`` skin.
|
||||
banner_dim: "#B8860B" # Dim/muted text (separators, labels)
|
||||
banner_text: "#FFF8DC" # Body text (tool names, skill names)
|
||||
ui_accent: "#FFBF00" # General UI accent
|
||||
ui_label: "#4dd0e1" # UI labels
|
||||
ui_label: "#DAA520" # UI labels (warm gold; teal clashed w/ default banner gold)
|
||||
ui_ok: "#4caf50" # Success indicators
|
||||
ui_error: "#ef5350" # Error indicators
|
||||
ui_warn: "#ffa726" # Warning indicators
|
||||
@@ -163,7 +163,7 @@ _BUILTIN_SKINS: Dict[str, Dict[str, Any]] = {
|
||||
"banner_dim": "#B8860B",
|
||||
"banner_text": "#FFF8DC",
|
||||
"ui_accent": "#FFBF00",
|
||||
"ui_label": "#4dd0e1",
|
||||
"ui_label": "#DAA520",
|
||||
"ui_ok": "#4caf50",
|
||||
"ui_error": "#ef5350",
|
||||
"ui_warn": "#ffa726",
|
||||
|
||||
+30
-64
@@ -317,7 +317,7 @@ def show_status(args):
|
||||
"WeCom Callback": ("WECOM_CALLBACK_CORP_ID", None),
|
||||
"Weixin": ("WEIXIN_ACCOUNT_ID", "WEIXIN_HOME_CHANNEL"),
|
||||
"BlueBubbles": ("BLUEBUBBLES_SERVER_URL", "BLUEBUBBLES_HOME_CHANNEL"),
|
||||
"QQBot": ("QQ_APP_ID", "QQ_HOME_CHANNEL"),
|
||||
"QQBot": ("QQ_APP_ID", "QQBOT_HOME_CHANNEL"),
|
||||
}
|
||||
|
||||
for name, (token_var, home_var) in platforms.items():
|
||||
@@ -327,6 +327,9 @@ def show_status(args):
|
||||
home_channel = ""
|
||||
if home_var:
|
||||
home_channel = os.getenv(home_var, "")
|
||||
# Back-compat: QQBot home channel was renamed from QQ_HOME_CHANNEL to QQBOT_HOME_CHANNEL
|
||||
if not home_channel and home_var == "QQBOT_HOME_CHANNEL":
|
||||
home_channel = os.getenv("QQ_HOME_CHANNEL", "")
|
||||
|
||||
status = "configured" if has_token else "not configured"
|
||||
if home_channel:
|
||||
@@ -339,73 +342,36 @@ def show_status(args):
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Gateway Service", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
if _is_termux():
|
||||
try:
|
||||
from hermes_cli.gateway import find_gateway_pids
|
||||
gateway_pids = find_gateway_pids()
|
||||
except Exception:
|
||||
gateway_pids = []
|
||||
is_running = bool(gateway_pids)
|
||||
|
||||
try:
|
||||
from hermes_cli.gateway import get_gateway_runtime_snapshot, _format_gateway_pids
|
||||
|
||||
snapshot = get_gateway_runtime_snapshot()
|
||||
is_running = snapshot.running
|
||||
print(f" Status: {check_mark(is_running)} {'running' if is_running else 'stopped'}")
|
||||
print(" Manager: Termux / manual process")
|
||||
if gateway_pids:
|
||||
rendered = ", ".join(str(pid) for pid in gateway_pids[:3])
|
||||
if len(gateway_pids) > 3:
|
||||
rendered += ", ..."
|
||||
print(f" PID(s): {rendered}")
|
||||
else:
|
||||
print(f" Manager: {snapshot.manager}")
|
||||
if snapshot.gateway_pids:
|
||||
print(f" PID(s): {_format_gateway_pids(snapshot.gateway_pids)}")
|
||||
if snapshot.has_process_service_mismatch:
|
||||
print(" Service: installed but not managing the current running gateway")
|
||||
elif _is_termux() and not snapshot.gateway_pids:
|
||||
print(" Start with: hermes gateway")
|
||||
print(" Note: Android may stop background jobs when Termux is suspended")
|
||||
|
||||
elif sys.platform.startswith('linux'):
|
||||
from hermes_constants import is_container
|
||||
if is_container():
|
||||
# Docker/Podman: no systemd — check for running gateway processes
|
||||
try:
|
||||
from hermes_cli.gateway import find_gateway_pids
|
||||
gateway_pids = find_gateway_pids()
|
||||
is_active = len(gateway_pids) > 0
|
||||
except Exception:
|
||||
is_active = False
|
||||
print(f" Status: {check_mark(is_active)} {'running' if is_active else 'stopped'}")
|
||||
print(" Manager: docker (foreground)")
|
||||
elif snapshot.service_installed and not snapshot.service_running:
|
||||
print(" Service: installed but stopped")
|
||||
except Exception:
|
||||
if _is_termux():
|
||||
print(f" Status: {color('unknown', Colors.DIM)}")
|
||||
print(" Manager: Termux / manual process")
|
||||
elif sys.platform.startswith('linux'):
|
||||
print(f" Status: {color('unknown', Colors.DIM)}")
|
||||
print(" Manager: systemd/manual")
|
||||
elif sys.platform == 'darwin':
|
||||
print(f" Status: {color('unknown', Colors.DIM)}")
|
||||
print(" Manager: launchd")
|
||||
else:
|
||||
try:
|
||||
from hermes_cli.gateway import get_service_name
|
||||
_gw_svc = get_service_name()
|
||||
except Exception:
|
||||
_gw_svc = "hermes-gateway"
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["systemctl", "--user", "is-active", _gw_svc],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
is_active = result.stdout.strip() == "active"
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||
is_active = False
|
||||
print(f" Status: {check_mark(is_active)} {'running' if is_active else 'stopped'}")
|
||||
print(" Manager: systemd (user)")
|
||||
|
||||
elif sys.platform == 'darwin':
|
||||
from hermes_cli.gateway import get_launchd_label
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["launchctl", "list", get_launchd_label()],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
is_loaded = result.returncode == 0
|
||||
except subprocess.TimeoutExpired:
|
||||
is_loaded = False
|
||||
print(f" Status: {check_mark(is_loaded)} {'loaded' if is_loaded else 'not loaded'}")
|
||||
print(" Manager: launchd")
|
||||
else:
|
||||
print(f" Status: {color('N/A', Colors.DIM)}")
|
||||
print(" Manager: (not supported on this platform)")
|
||||
print(f" Status: {color('N/A', Colors.DIM)}")
|
||||
print(" Manager: (not supported on this platform)")
|
||||
|
||||
# =========================================================================
|
||||
# Cron Jobs
|
||||
|
||||
@@ -56,10 +56,10 @@ try:
|
||||
except ImportError:
|
||||
raise SystemExit(
|
||||
"Web UI requires fastapi and uvicorn.\n"
|
||||
"Run 'hermes web' to auto-install, or: pip install hermes-agent[web]"
|
||||
f"Install with: {sys.executable} -m pip install 'fastapi' 'uvicorn[standard]'"
|
||||
)
|
||||
|
||||
WEB_DIST = Path(__file__).parent / "web_dist"
|
||||
WEB_DIST = Path(os.environ["HERMES_WEB_DIST"]) if "HERMES_WEB_DIST" in os.environ else Path(__file__).parent / "web_dist"
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI(title="Hermes Agent", version=__version__)
|
||||
@@ -1444,38 +1444,8 @@ def _nous_poller(session_id: str) -> None:
|
||||
auth_state, min_key_ttl_seconds=300, timeout_seconds=15.0,
|
||||
force_refresh=False, force_mint=True,
|
||||
)
|
||||
# Save into credential pool same as auth_commands.py does
|
||||
from agent.credential_pool import (
|
||||
PooledCredential,
|
||||
load_pool,
|
||||
AUTH_TYPE_OAUTH,
|
||||
SOURCE_MANUAL,
|
||||
)
|
||||
pool = load_pool("nous")
|
||||
entry = PooledCredential.from_dict("nous", {
|
||||
**full_state,
|
||||
"label": "dashboard device_code",
|
||||
"auth_type": AUTH_TYPE_OAUTH,
|
||||
"source": f"{SOURCE_MANUAL}:dashboard_device_code",
|
||||
"base_url": full_state.get("inference_base_url"),
|
||||
})
|
||||
pool.add_entry(entry)
|
||||
# Also persist to auth store so get_nous_auth_status() sees it
|
||||
# (matches what _login_nous in auth.py does for the CLI flow).
|
||||
try:
|
||||
from hermes_cli.auth import (
|
||||
_load_auth_store, _save_provider_state, _save_auth_store,
|
||||
_auth_store_lock,
|
||||
)
|
||||
with _auth_store_lock():
|
||||
auth_store = _load_auth_store()
|
||||
_save_provider_state(auth_store, "nous", full_state)
|
||||
_save_auth_store(auth_store)
|
||||
except Exception as store_exc:
|
||||
_log.warning(
|
||||
"oauth/device: credential pool saved but auth store write failed "
|
||||
"(session=%s): %s", session_id, store_exc,
|
||||
)
|
||||
from hermes_cli.auth import persist_nous_credentials
|
||||
persist_nous_credentials(full_state)
|
||||
with _oauth_sessions_lock:
|
||||
sess["status"] = "approved"
|
||||
_log.info("oauth/device: nous login completed (session=%s)", session_id)
|
||||
|
||||
+2
-1
@@ -14,7 +14,8 @@ def get_hermes_home() -> Path:
|
||||
Reads HERMES_HOME env var, falls back to ~/.hermes.
|
||||
This is the single source of truth — all other copies should import this.
|
||||
"""
|
||||
return Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
val = os.environ.get("HERMES_HOME", "").strip()
|
||||
return Path(val) if val else Path.home() / ".hermes"
|
||||
|
||||
|
||||
def get_default_hermes_root() -> Path:
|
||||
|
||||
+57
-2
@@ -987,6 +987,22 @@ class SessionDB:
|
||||
|
||||
return sanitized.strip()
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _contains_cjk(text: str) -> bool:
|
||||
"""Check if text contains CJK (Chinese, Japanese, Korean) characters."""
|
||||
for ch in text:
|
||||
cp = ord(ch)
|
||||
if (0x4E00 <= cp <= 0x9FFF or # CJK Unified Ideographs
|
||||
0x3400 <= cp <= 0x4DBF or # CJK Extension A
|
||||
0x20000 <= cp <= 0x2A6DF or # CJK Extension B
|
||||
0x3000 <= cp <= 0x303F or # CJK Symbols
|
||||
0x3040 <= cp <= 0x309F or # Hiragana
|
||||
0x30A0 <= cp <= 0x30FF or # Katakana
|
||||
0xAC00 <= cp <= 0xD7AF): # Hangul Syllables
|
||||
return True
|
||||
return False
|
||||
|
||||
def search_messages(
|
||||
self,
|
||||
query: str,
|
||||
@@ -1062,8 +1078,47 @@ class SessionDB:
|
||||
cursor = self._conn.execute(sql, params)
|
||||
except sqlite3.OperationalError:
|
||||
# FTS5 query syntax error despite sanitization — return empty
|
||||
return []
|
||||
matches = [dict(row) for row in cursor.fetchall()]
|
||||
# unless query contains CJK (fall back to LIKE below)
|
||||
if not self._contains_cjk(query):
|
||||
return []
|
||||
matches = []
|
||||
else:
|
||||
matches = [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
# LIKE fallback for CJK queries: FTS5 default tokenizer splits CJK
|
||||
# characters individually, causing multi-character queries to fail.
|
||||
if not matches and self._contains_cjk(query):
|
||||
raw_query = query.strip('"').strip()
|
||||
like_where = ["m.content LIKE ?"]
|
||||
like_params: list = [f"%{raw_query}%"]
|
||||
if source_filter is not None:
|
||||
like_where.append(f"s.source IN ({','.join('?' for _ in source_filter)})")
|
||||
like_params.extend(source_filter)
|
||||
if exclude_sources is not None:
|
||||
like_where.append(f"s.source NOT IN ({','.join('?' for _ in exclude_sources)})")
|
||||
like_params.extend(exclude_sources)
|
||||
if role_filter:
|
||||
like_where.append(f"m.role IN ({','.join('?' for _ in role_filter)})")
|
||||
like_params.extend(role_filter)
|
||||
like_sql = f"""
|
||||
SELECT m.id, m.session_id, m.role,
|
||||
substr(m.content,
|
||||
max(1, instr(m.content, ?) - 40),
|
||||
120) AS snippet,
|
||||
m.content, m.timestamp, m.tool_name,
|
||||
s.source, s.model, s.started_at AS session_started
|
||||
FROM messages m
|
||||
JOIN sessions s ON s.id = m.session_id
|
||||
WHERE {' AND '.join(like_where)}
|
||||
ORDER BY m.timestamp DESC
|
||||
LIMIT ? OFFSET ?
|
||||
"""
|
||||
like_params.extend([limit, offset])
|
||||
# instr() parameter goes first in the bound list
|
||||
like_params = [raw_query] + like_params
|
||||
with self._lock:
|
||||
like_cursor = self._conn.execute(like_sql, like_params)
|
||||
matches = [dict(row) for row in like_cursor.fetchall()]
|
||||
|
||||
# Add surrounding context (1 message before + after each match).
|
||||
# Done outside the lock so we don't hold it across N sequential queries.
|
||||
|
||||
+2
-2
@@ -433,7 +433,7 @@ def create_mcp_server(event_bridge: Optional[EventBridge] = None) -> "FastMCP":
|
||||
if not _MCP_SERVER_AVAILABLE:
|
||||
raise ImportError(
|
||||
"MCP server requires the 'mcp' package. "
|
||||
"Install with: pip install 'hermes-agent[mcp]'"
|
||||
f"Install with: {sys.executable} -m pip install 'mcp'"
|
||||
)
|
||||
|
||||
mcp = FastMCP(
|
||||
@@ -838,7 +838,7 @@ def run_mcp_server(verbose: bool = False) -> None:
|
||||
if not _MCP_SERVER_AVAILABLE:
|
||||
print(
|
||||
"Error: MCP server requires the 'mcp' package.\n"
|
||||
"Install with: pip install 'hermes-agent[mcp]'",
|
||||
f"Install with: {sys.executable} -m pip install 'mcp'",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
+20
-6
@@ -43,6 +43,15 @@ from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def _effective_temperature_for_model(model: str) -> Optional[float]:
|
||||
"""Return a fixed temperature for models with strict sampling contracts."""
|
||||
try:
|
||||
from agent.auxiliary_client import _fixed_temperature_for_model
|
||||
except Exception:
|
||||
return None
|
||||
return _fixed_temperature_for_model(model)
|
||||
|
||||
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@@ -442,12 +451,17 @@ Complete the user's task step by step."""
|
||||
|
||||
# Make API call
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=api_messages,
|
||||
tools=self.tools,
|
||||
timeout=300.0
|
||||
)
|
||||
api_kwargs = {
|
||||
"model": self.model,
|
||||
"messages": api_messages,
|
||||
"tools": self.tools,
|
||||
"timeout": 300.0,
|
||||
}
|
||||
fixed_temperature = _effective_temperature_for_model(self.model)
|
||||
if fixed_temperature is not None:
|
||||
api_kwargs["temperature"] = fixed_temperature
|
||||
|
||||
response = self.client.chat.completions.create(**api_kwargs)
|
||||
except Exception as e:
|
||||
self.logger.error(f"API call failed: {e}")
|
||||
break
|
||||
|
||||
+2
-2
@@ -274,9 +274,9 @@ def get_tool_definitions(
|
||||
# execute_code" even when the API key isn't configured or the toolset is
|
||||
# disabled (#560-discord).
|
||||
if "execute_code" in available_tool_names:
|
||||
from tools.code_execution_tool import SANDBOX_ALLOWED_TOOLS, build_execute_code_schema
|
||||
from tools.code_execution_tool import SANDBOX_ALLOWED_TOOLS, build_execute_code_schema, _get_execution_mode
|
||||
sandbox_enabled = SANDBOX_ALLOWED_TOOLS & available_tool_names
|
||||
dynamic_schema = build_execute_code_schema(sandbox_enabled)
|
||||
dynamic_schema = build_execute_code_schema(sandbox_enabled, mode=_get_execution_mode())
|
||||
for i, td in enumerate(filtered_tools):
|
||||
if td.get("function", {}).get("name") == "execute_code":
|
||||
filtered_tools[i] = {"type": "function", "function": dynamic_schema}
|
||||
|
||||
@@ -103,6 +103,51 @@ json.dump(sorted(leaf_paths(DEFAULT_CONFIG)), sys.stdout, indent=2)
|
||||
echo "ok" > $out/result
|
||||
'';
|
||||
|
||||
# Verify bundled TUI is present and compiled
|
||||
bundled-tui = pkgs.runCommand "hermes-bundled-tui" { } ''
|
||||
set -e
|
||||
echo "=== Checking bundled TUI ==="
|
||||
test -d ${hermes-agent}/ui-tui || (echo "FAIL: ui-tui directory missing"; exit 1)
|
||||
echo "PASS: ui-tui directory exists"
|
||||
|
||||
test -f ${hermes-agent}/ui-tui/dist/entry.js || (echo "FAIL: compiled entry.js missing"; exit 1)
|
||||
echo "PASS: compiled entry.js present"
|
||||
|
||||
test -d ${hermes-agent}/ui-tui/node_modules || (echo "FAIL: node_modules missing"; exit 1)
|
||||
echo "PASS: node_modules present"
|
||||
|
||||
grep -q "HERMES_TUI_DIR" ${hermes-agent}/bin/hermes || \
|
||||
(echo "FAIL: HERMES_TUI_DIR not in wrapper"; exit 1)
|
||||
echo "PASS: HERMES_TUI_DIR set in wrapper"
|
||||
|
||||
echo "=== All bundled TUI checks passed ==="
|
||||
mkdir -p $out
|
||||
echo "ok" > $out/result
|
||||
'';
|
||||
|
||||
# Verify HERMES_NODE is set in wrapper and points to Node 20+
|
||||
# (string-width uses the /v regex flag which requires Node 20+)
|
||||
hermes-node = pkgs.runCommand "hermes-node-version" { } ''
|
||||
set -e
|
||||
echo "=== Checking HERMES_NODE in wrapper ==="
|
||||
grep -q "HERMES_NODE" ${hermes-agent}/bin/hermes || \
|
||||
(echo "FAIL: HERMES_NODE not set in wrapper"; exit 1)
|
||||
echo "PASS: HERMES_NODE present in wrapper"
|
||||
|
||||
HERMES_NODE=$(sed -n "s/^export HERMES_NODE='\(.*\)'/\1/p" ${hermes-agent}/bin/hermes)
|
||||
test -x "$HERMES_NODE" || (echo "FAIL: HERMES_NODE=$HERMES_NODE not executable"; exit 1)
|
||||
echo "PASS: HERMES_NODE executable at $HERMES_NODE"
|
||||
|
||||
NODE_MAJOR=$("$HERMES_NODE" --version | sed 's/^v//' | cut -d. -f1)
|
||||
test "$NODE_MAJOR" -ge 20 || \
|
||||
(echo "FAIL: Node v$NODE_MAJOR < 20, TUI needs /v regex flag support"; exit 1)
|
||||
echo "PASS: Node v$NODE_MAJOR >= 20"
|
||||
|
||||
echo "=== All HERMES_NODE checks passed ==="
|
||||
mkdir -p $out
|
||||
echo "ok" > $out/result
|
||||
'';
|
||||
|
||||
# Verify HERMES_MANAGED guard works on all mutation commands
|
||||
managed-guard = pkgs.runCommand "hermes-managed-guard" { } ''
|
||||
set -e
|
||||
|
||||
+15
-38
@@ -1,49 +1,26 @@
|
||||
# nix/devShell.nix — Fast dev shell with stamp-file optimization
|
||||
# nix/devShell.nix — Dev shell that delegates setup to each package
|
||||
#
|
||||
# Each package in inputsFrom exposes passthru.devShellHook — a bash snippet
|
||||
# with stamp-checked setup logic. This file collects and runs them all.
|
||||
{ inputs, ... }: {
|
||||
perSystem = { pkgs, ... }:
|
||||
perSystem = { pkgs, system, ... }:
|
||||
let
|
||||
python = pkgs.python311;
|
||||
hermes-agent = inputs.self.packages.${system}.default;
|
||||
hermes-tui = inputs.self.packages.${system}.tui;
|
||||
packages = [ hermes-agent hermes-tui ];
|
||||
in {
|
||||
devShells.default = pkgs.mkShell {
|
||||
inputsFrom = packages;
|
||||
packages = with pkgs; [
|
||||
python uv nodejs_20 ripgrep git openssh ffmpeg
|
||||
python311 uv nodejs_22 ripgrep git openssh ffmpeg
|
||||
];
|
||||
|
||||
shellHook = ''
|
||||
shellHook = let
|
||||
hooks = map (p: p.passthru.devShellHook or "") packages;
|
||||
combined = pkgs.lib.concatStringsSep "\n" (builtins.filter (h: h != "") hooks);
|
||||
in ''
|
||||
echo "Hermes Agent dev shell"
|
||||
|
||||
# Composite stamp: changes when nix python or uv change
|
||||
STAMP_VALUE="${python}:${pkgs.uv}"
|
||||
STAMP_FILE=".venv/.nix-stamp"
|
||||
|
||||
# Create venv if missing
|
||||
if [ ! -d .venv ]; then
|
||||
echo "Creating Python 3.11 venv..."
|
||||
uv venv .venv --python ${python}/bin/python3
|
||||
fi
|
||||
|
||||
source .venv/bin/activate
|
||||
|
||||
# Only install if stamp is stale or missing
|
||||
if [ ! -f "$STAMP_FILE" ] || [ "$(cat "$STAMP_FILE")" != "$STAMP_VALUE" ]; then
|
||||
echo "Installing Python dependencies..."
|
||||
uv pip install -e ".[all]"
|
||||
if [ -d mini-swe-agent ]; then
|
||||
uv pip install -e ./mini-swe-agent 2>/dev/null || true
|
||||
fi
|
||||
if [ -d tinker-atropos ]; then
|
||||
uv pip install -e ./tinker-atropos 2>/dev/null || true
|
||||
fi
|
||||
|
||||
# Install npm deps
|
||||
if [ -f package.json ] && [ ! -d node_modules ]; then
|
||||
echo "Installing npm dependencies..."
|
||||
npm install
|
||||
fi
|
||||
|
||||
echo "$STAMP_VALUE" > "$STAMP_FILE"
|
||||
fi
|
||||
|
||||
${combined}
|
||||
echo "Ready. Run 'hermes' to start."
|
||||
'';
|
||||
};
|
||||
|
||||
+11
-3
@@ -121,11 +121,19 @@
|
||||
# ── Provision apt packages (first boot only, cached in writable layer) ──
|
||||
# sudo: agent self-modification
|
||||
# nodejs/npm: writable node so npm i -g works (nix store copies are read-only)
|
||||
# curl: needed for uv installer
|
||||
# Node 22 via NodeSource — Ubuntu 24.04 ships Node 18 which is EOL.
|
||||
# curl: needed for uv installer + NodeSource setup
|
||||
if [ ! -f /var/lib/hermes-tools-provisioned ] && command -v apt-get >/dev/null 2>&1; then
|
||||
echo "First boot: provisioning agent tools..."
|
||||
apt-get update -qq
|
||||
apt-get install -y -qq sudo nodejs npm curl
|
||||
apt-get install -y -qq sudo curl ca-certificates gnupg
|
||||
mkdir -p /etc/apt/keyrings
|
||||
curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key \
|
||||
| gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg
|
||||
echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_22.x nodistro main" \
|
||||
> /etc/apt/sources.list.d/nodesource.list
|
||||
apt-get update -qq
|
||||
apt-get install -y -qq nodejs
|
||||
touch /var/lib/hermes-tools-provisioned
|
||||
fi
|
||||
|
||||
@@ -171,7 +179,7 @@
|
||||
# Package and entrypoint use stable symlinks (current-package, current-entrypoint)
|
||||
# so they can update without recreation. Env vars go through $HERMES_HOME/.env.
|
||||
containerIdentity = builtins.hashString "sha256" (builtins.toJSON {
|
||||
schema = 3; # bump when identity inputs change
|
||||
schema = 4; # bump when identity inputs change (4: Node 18→22 via NodeSource)
|
||||
image = cfg.container.image;
|
||||
extraVolumes = cfg.container.extraVolumes;
|
||||
extraOptions = cfg.container.extraOptions;
|
||||
|
||||
+91
-29
@@ -1,54 +1,116 @@
|
||||
# nix/packages.nix — Hermes Agent package built with uv2nix
|
||||
{ inputs, ... }: {
|
||||
perSystem = { pkgs, system, ... }:
|
||||
{ inputs, ... }:
|
||||
{
|
||||
perSystem =
|
||||
{ pkgs, inputs', ... }:
|
||||
let
|
||||
hermesVenv = pkgs.callPackage ./python.nix {
|
||||
inherit (inputs) uv2nix pyproject-nix pyproject-build-systems;
|
||||
};
|
||||
|
||||
hermesTui = pkgs.callPackage ./tui.nix {
|
||||
npm-lockfile-fix = inputs'.npm-lockfile-fix.packages.default;
|
||||
};
|
||||
|
||||
# Import bundled skills, excluding runtime caches
|
||||
bundledSkills = pkgs.lib.cleanSourceWith {
|
||||
src = ../skills;
|
||||
filter = path: _type:
|
||||
!(pkgs.lib.hasInfix "/index-cache/" path);
|
||||
filter = path: _type: !(pkgs.lib.hasInfix "/index-cache/" path);
|
||||
};
|
||||
|
||||
hermesWeb = pkgs.callPackage ./web.nix {
|
||||
npm-lockfile-fix = inputs'.npm-lockfile-fix.packages.default;
|
||||
};
|
||||
|
||||
runtimeDeps = with pkgs; [
|
||||
nodejs_20 ripgrep git openssh ffmpeg tirith
|
||||
nodejs_22
|
||||
ripgrep
|
||||
git
|
||||
openssh
|
||||
ffmpeg
|
||||
tirith
|
||||
];
|
||||
|
||||
runtimePath = pkgs.lib.makeBinPath runtimeDeps;
|
||||
in {
|
||||
packages.default = pkgs.stdenv.mkDerivation {
|
||||
pname = "hermes-agent";
|
||||
version = (builtins.fromTOML (builtins.readFile ../pyproject.toml)).project.version;
|
||||
|
||||
dontUnpack = true;
|
||||
dontBuild = true;
|
||||
nativeBuildInputs = [ pkgs.makeWrapper ];
|
||||
# Lockfile hashes for dev shell stamps
|
||||
pyprojectHash = builtins.hashString "sha256" (builtins.readFile ../pyproject.toml);
|
||||
uvLockHash =
|
||||
if builtins.pathExists ../uv.lock then
|
||||
builtins.hashString "sha256" (builtins.readFile ../uv.lock)
|
||||
else
|
||||
"none";
|
||||
in
|
||||
{
|
||||
packages = {
|
||||
default = pkgs.stdenv.mkDerivation {
|
||||
pname = "hermes-agent";
|
||||
version = (fromTOML (builtins.readFile ../pyproject.toml)).project.version;
|
||||
|
||||
installPhase = ''
|
||||
runHook preInstall
|
||||
dontUnpack = true;
|
||||
dontBuild = true;
|
||||
nativeBuildInputs = [ pkgs.makeWrapper ];
|
||||
|
||||
mkdir -p $out/share/hermes-agent $out/bin
|
||||
cp -r ${bundledSkills} $out/share/hermes-agent/skills
|
||||
installPhase = ''
|
||||
runHook preInstall
|
||||
|
||||
${pkgs.lib.concatMapStringsSep "\n" (name: ''
|
||||
makeWrapper ${hermesVenv}/bin/${name} $out/bin/${name} \
|
||||
--suffix PATH : "${runtimePath}" \
|
||||
--set HERMES_BUNDLED_SKILLS $out/share/hermes-agent/skills
|
||||
'') [ "hermes" "hermes-agent" "hermes-acp" ]}
|
||||
mkdir -p $out/share/hermes-agent $out/bin
|
||||
cp -r ${bundledSkills} $out/share/hermes-agent/skills
|
||||
cp -r ${hermesWeb} $out/share/hermes-agent/web_dist
|
||||
|
||||
runHook postInstall
|
||||
'';
|
||||
# copy pre-built TUI (same layout as dev: ui-tui/dist/ + node_modules/)
|
||||
mkdir -p $out/ui-tui
|
||||
cp -r ${hermesTui}/lib/hermes-tui/* $out/ui-tui/
|
||||
|
||||
meta = with pkgs.lib; {
|
||||
description = "AI agent with advanced tool-calling capabilities";
|
||||
homepage = "https://github.com/NousResearch/hermes-agent";
|
||||
mainProgram = "hermes";
|
||||
license = licenses.mit;
|
||||
platforms = platforms.unix;
|
||||
${pkgs.lib.concatMapStringsSep "\n"
|
||||
(name: ''
|
||||
makeWrapper ${hermesVenv}/bin/${name} $out/bin/${name} \
|
||||
--suffix PATH : "${runtimePath}" \
|
||||
--set HERMES_BUNDLED_SKILLS $out/share/hermes-agent/skills \
|
||||
--set HERMES_WEB_DIST $out/share/hermes-agent/web_dist \
|
||||
--set HERMES_TUI_DIR $out/ui-tui \
|
||||
--set HERMES_PYTHON ${hermesVenv}/bin/python3 \
|
||||
--set HERMES_NODE ${pkgs.nodejs_22}/bin/node
|
||||
'')
|
||||
[
|
||||
"hermes"
|
||||
"hermes-agent"
|
||||
"hermes-acp"
|
||||
]
|
||||
}
|
||||
|
||||
runHook postInstall
|
||||
'';
|
||||
|
||||
passthru.devShellHook = ''
|
||||
STAMP=".nix-stamps/hermes-agent"
|
||||
STAMP_VALUE="${pyprojectHash}:${uvLockHash}"
|
||||
if [ ! -f "$STAMP" ] || [ "$(cat "$STAMP")" != "$STAMP_VALUE" ]; then
|
||||
echo "hermes-agent: installing Python dependencies..."
|
||||
uv venv .venv --python ${pkgs.python311}/bin/python3 2>/dev/null || true
|
||||
source .venv/bin/activate
|
||||
uv pip install -e ".[all]"
|
||||
[ -d mini-swe-agent ] && uv pip install -e ./mini-swe-agent 2>/dev/null || true
|
||||
[ -d tinker-atropos ] && uv pip install -e ./tinker-atropos 2>/dev/null || true
|
||||
mkdir -p .nix-stamps
|
||||
echo "$STAMP_VALUE" > "$STAMP"
|
||||
else
|
||||
source .venv/bin/activate
|
||||
export HERMES_PYTHON=${hermesVenv}/bin/python3
|
||||
fi
|
||||
'';
|
||||
|
||||
meta = with pkgs.lib; {
|
||||
description = "AI agent with advanced tool-calling capabilities";
|
||||
homepage = "https://github.com/NousResearch/hermes-agent";
|
||||
mainProgram = "hermes";
|
||||
license = licenses.mit;
|
||||
platforms = platforms.unix;
|
||||
};
|
||||
};
|
||||
|
||||
tui = hermesTui;
|
||||
web = hermesWeb;
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
@@ -35,6 +35,20 @@ let
|
||||
};
|
||||
};
|
||||
|
||||
# Legacy alibabacloud packages ship only sdists with setup.py/setup.cfg
|
||||
# and no pyproject.toml, so setuptools isn't declared as a build dep.
|
||||
buildSystemOverrides = final: prev: builtins.mapAttrs
|
||||
(name: _: prev.${name}.overrideAttrs (old: {
|
||||
nativeBuildInputs = (old.nativeBuildInputs or [ ]) ++ [ final.setuptools ];
|
||||
}))
|
||||
(lib.genAttrs [
|
||||
"alibabacloud-credentials-api"
|
||||
"alibabacloud-endpoint-util"
|
||||
"alibabacloud-gateway-dingtalk"
|
||||
"alibabacloud-gateway-spi"
|
||||
"alibabacloud-tea"
|
||||
] (_: null));
|
||||
|
||||
pythonPackageOverrides = final: _prev:
|
||||
if isAarch64Darwin then {
|
||||
numpy = mkPrebuiltOverride final python311.pkgs.numpy { };
|
||||
@@ -75,6 +89,7 @@ let
|
||||
(lib.composeManyExtensions [
|
||||
pyproject-build-systems.overlays.default
|
||||
overlay
|
||||
buildSystemOverrides
|
||||
pythonPackageOverrides
|
||||
]);
|
||||
in
|
||||
|
||||
+77
@@ -0,0 +1,77 @@
|
||||
# nix/tui.nix — Hermes TUI (Ink/React) compiled with tsc and bundled
|
||||
{ pkgs, npm-lockfile-fix, ... }:
|
||||
let
|
||||
src = ../ui-tui;
|
||||
npmDeps = pkgs.fetchNpmDeps {
|
||||
inherit src;
|
||||
hash = "sha256-mG3vpgGi4ljt4X3XIf3I/5mIcm+rVTUAmx2DQ6YVA90=";
|
||||
};
|
||||
|
||||
packageJson = builtins.fromJSON (builtins.readFile (src + "/package.json"));
|
||||
version = packageJson.version;
|
||||
|
||||
npmLockHash = builtins.hashString "sha256" (builtins.readFile ../ui-tui/package-lock.json);
|
||||
in
|
||||
pkgs.buildNpmPackage {
|
||||
pname = "hermes-tui";
|
||||
inherit src npmDeps version;
|
||||
|
||||
doCheck = false;
|
||||
|
||||
installPhase = ''
|
||||
runHook preInstall
|
||||
|
||||
mkdir -p $out/lib/hermes-tui
|
||||
|
||||
cp -r dist $out/lib/hermes-tui/dist
|
||||
|
||||
# runtime node_modules
|
||||
cp -r node_modules $out/lib/hermes-tui/node_modules
|
||||
|
||||
# @hermes/ink is a file: dependency, we need to copy it in fr
|
||||
rm -f $out/lib/hermes-tui/node_modules/@hermes/ink
|
||||
cp -r packages/hermes-ink $out/lib/hermes-tui/node_modules/@hermes/ink
|
||||
|
||||
# package.json needed for "type": "module" resolution
|
||||
cp package.json $out/lib/hermes-tui/
|
||||
|
||||
runHook postInstall
|
||||
'';
|
||||
|
||||
nativeBuildInputs = [
|
||||
(pkgs.writeShellScriptBin "update_tui_lockfile" ''
|
||||
set -euox pipefail
|
||||
|
||||
# get root of repo
|
||||
REPO_ROOT=$(git rev-parse --show-toplevel)
|
||||
|
||||
# cd into ui-tui and reinstall
|
||||
cd "$REPO_ROOT/ui-tui"
|
||||
rm -rf node_modules/
|
||||
npm cache clean --force
|
||||
CI=true npm install # ci env var to suppress annoying unicode install banner lag
|
||||
${pkgs.lib.getExe npm-lockfile-fix} ./package-lock.json
|
||||
|
||||
NIX_FILE="$REPO_ROOT/nix/tui.nix"
|
||||
# compute the new hash
|
||||
sed -i "s/hash = \"[^\"]*\";/hash = \"\";/" $NIX_FILE
|
||||
NIX_OUTPUT=$(nix build .#tui 2>&1 || true)
|
||||
NEW_HASH=$(echo "$NIX_OUTPUT" | grep 'got:' | awk '{print $2}')
|
||||
echo got new hash $NEW_HASH
|
||||
sed -i "s|hash = \"[^\"]*\";|hash = \"$NEW_HASH\";|" $NIX_FILE
|
||||
nix build .#tui
|
||||
echo "Updated npm hash in $NIX_FILE to $NEW_HASH"
|
||||
'')
|
||||
];
|
||||
|
||||
passthru.devShellHook = ''
|
||||
STAMP=".nix-stamps/hermes-tui"
|
||||
STAMP_VALUE="${npmLockHash}"
|
||||
if [ ! -f "$STAMP" ] || [ "$(cat "$STAMP")" != "$STAMP_VALUE" ]; then
|
||||
echo "hermes-tui: installing npm dependencies..."
|
||||
cd ui-tui && CI=true npm install --silent --no-fund --no-audit 2>/dev/null && cd ..
|
||||
mkdir -p .nix-stamps
|
||||
echo "$STAMP_VALUE" > "$STAMP"
|
||||
fi
|
||||
'';
|
||||
}
|
||||
+63
@@ -0,0 +1,63 @@
|
||||
# nix/web.nix — Hermes Web Dashboard (Vite/React) frontend build
|
||||
{ pkgs, npm-lockfile-fix, ... }:
|
||||
let
|
||||
src = ../web;
|
||||
npmDeps = pkgs.fetchNpmDeps {
|
||||
inherit src;
|
||||
hash = "sha256-Y0pOzdFG8BLjfvCLmsvqYpjxFjAQabXp1i7X9W/cCU4=";
|
||||
};
|
||||
|
||||
npmLockHash = builtins.hashString "sha256" (builtins.readFile ../web/package-lock.json);
|
||||
in
|
||||
pkgs.buildNpmPackage {
|
||||
pname = "hermes-web";
|
||||
version = "0.0.0";
|
||||
inherit src npmDeps;
|
||||
|
||||
doCheck = false;
|
||||
|
||||
buildPhase = ''
|
||||
npx tsc -b
|
||||
npx vite build --outDir dist
|
||||
'';
|
||||
|
||||
installPhase = ''
|
||||
runHook preInstall
|
||||
cp -r dist $out
|
||||
runHook postInstall
|
||||
'';
|
||||
|
||||
nativeBuildInputs = [
|
||||
(pkgs.writeShellScriptBin "update_web_lockfile" ''
|
||||
set -euox pipefail
|
||||
|
||||
REPO_ROOT=$(git rev-parse --show-toplevel)
|
||||
|
||||
cd "$REPO_ROOT/web"
|
||||
rm -rf node_modules/
|
||||
npm cache clean --force
|
||||
CI=true npm install
|
||||
${pkgs.lib.getExe npm-lockfile-fix} ./package-lock.json
|
||||
|
||||
NIX_FILE="$REPO_ROOT/nix/web.nix"
|
||||
sed -i "s/hash = \"[^\"]*\";/hash = \"\";/" $NIX_FILE
|
||||
NIX_OUTPUT=$(nix build .#web 2>&1 || true)
|
||||
NEW_HASH=$(echo "$NIX_OUTPUT" | grep 'got:' | awk '{print $2}')
|
||||
echo got new hash $NEW_HASH
|
||||
sed -i "s|hash = \"[^\"]*\";|hash = \"$NEW_HASH\";|" $NIX_FILE
|
||||
nix build .#web
|
||||
echo "Updated npm hash in $NIX_FILE to $NEW_HASH"
|
||||
'')
|
||||
];
|
||||
|
||||
passthru.devShellHook = ''
|
||||
STAMP=".nix-stamps/hermes-web"
|
||||
STAMP_VALUE="${npmLockHash}"
|
||||
if [ ! -f "$STAMP" ] || [ "$(cat "$STAMP")" != "$STAMP_VALUE" ]; then
|
||||
echo "hermes-web: installing npm dependencies..."
|
||||
cd web && CI=true npm install --silent --no-fund --no-audit 2>/dev/null && cd ..
|
||||
mkdir -p .nix-stamps
|
||||
echo "$STAMP_VALUE" > "$STAMP"
|
||||
fi
|
||||
'';
|
||||
}
|
||||
@@ -7,7 +7,7 @@ license: MIT
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [atropos, rl, environments, training, reinforcement-learning, reward-functions]
|
||||
related_skills: [axolotl, grpo-rl-training, trl-fine-tuning, lm-evaluation-harness]
|
||||
related_skills: [axolotl, fine-tuning-with-trl, lm-evaluation-harness]
|
||||
---
|
||||
|
||||
# Hermes Agent Atropos Environments
|
||||
|
||||
+3
-3
@@ -76,8 +76,8 @@ termux = [
|
||||
"hermes-agent[honcho]",
|
||||
"hermes-agent[acp]",
|
||||
]
|
||||
dingtalk = ["dingtalk-stream>=0.1.0,<1"]
|
||||
feishu = ["lark-oapi>=1.5.3,<2"]
|
||||
dingtalk = ["dingtalk-stream>=0.20,<1", "alibabacloud-dingtalk>=2.0.0", "qrcode>=7.0,<8"]
|
||||
feishu = ["lark-oapi>=1.5.3,<2", "qrcode>=7.0,<8"]
|
||||
web = ["fastapi>=0.104.0,<1", "uvicorn[standard]>=0.24.0,<1"]
|
||||
rl = [
|
||||
"atroposlib @ git+https://github.com/NousResearch/atropos.git@c20c85256e5a45ad31edf8b7276e9c5ee1995a30",
|
||||
@@ -126,7 +126,7 @@ py-modules = ["run_agent", "model_tools", "toolsets", "batch_runner", "trajector
|
||||
hermes_cli = ["web_dist/**/*"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["agent", "tools", "tools.*", "hermes_cli", "gateway", "gateway.*", "cron", "acp_adapter", "plugins", "plugins.*"]
|
||||
include = ["agent", "tools", "tools.*", "hermes_cli", "gateway", "gateway.*", "tui_gateway", "tui_gateway.*", "cron", "acp_adapter", "plugins", "plugins.*"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
|
||||
+437
-25
@@ -353,12 +353,50 @@ def _sanitize_surrogates(text: str) -> str:
|
||||
return text
|
||||
|
||||
|
||||
def _sanitize_structure_surrogates(payload: Any) -> bool:
|
||||
"""Replace surrogate code points in nested dict/list payloads in-place.
|
||||
|
||||
Mirror of ``_sanitize_structure_non_ascii`` but for surrogate recovery.
|
||||
Used to scrub nested structured fields (e.g. ``reasoning_details`` — an
|
||||
array of dicts with ``summary``/``text`` strings) that flat per-field
|
||||
checks don't reach. Returns True if any surrogates were replaced.
|
||||
"""
|
||||
found = False
|
||||
|
||||
def _walk(node):
|
||||
nonlocal found
|
||||
if isinstance(node, dict):
|
||||
for key, value in node.items():
|
||||
if isinstance(value, str):
|
||||
if _SURROGATE_RE.search(value):
|
||||
node[key] = _SURROGATE_RE.sub('\ufffd', value)
|
||||
found = True
|
||||
elif isinstance(value, (dict, list)):
|
||||
_walk(value)
|
||||
elif isinstance(node, list):
|
||||
for idx, value in enumerate(node):
|
||||
if isinstance(value, str):
|
||||
if _SURROGATE_RE.search(value):
|
||||
node[idx] = _SURROGATE_RE.sub('\ufffd', value)
|
||||
found = True
|
||||
elif isinstance(value, (dict, list)):
|
||||
_walk(value)
|
||||
|
||||
_walk(payload)
|
||||
return found
|
||||
|
||||
|
||||
def _sanitize_messages_surrogates(messages: list) -> bool:
|
||||
"""Sanitize surrogate characters from all string content in a messages list.
|
||||
|
||||
Walks message dicts in-place. Returns True if any surrogates were found
|
||||
and replaced, False otherwise. Covers content/text, name, and tool call
|
||||
metadata/arguments so retries don't fail on a non-content field.
|
||||
and replaced, False otherwise. Covers content/text, name, tool call
|
||||
metadata/arguments, AND any additional string or nested structured fields
|
||||
(``reasoning``, ``reasoning_content``, ``reasoning_details``, etc.) so
|
||||
retries don't fail on a non-content field. Byte-level reasoning models
|
||||
(xiaomi/mimo, kimi, glm) can emit lone surrogates in reasoning output
|
||||
that flow through to ``api_messages["reasoning_content"]`` on the next
|
||||
turn and crash json.dumps inside the OpenAI SDK.
|
||||
"""
|
||||
found = False
|
||||
for msg in messages:
|
||||
@@ -398,6 +436,21 @@ def _sanitize_messages_surrogates(messages: list) -> bool:
|
||||
if isinstance(fn_args, str) and _SURROGATE_RE.search(fn_args):
|
||||
fn["arguments"] = _SURROGATE_RE.sub('\ufffd', fn_args)
|
||||
found = True
|
||||
# Walk any additional string / nested fields (reasoning,
|
||||
# reasoning_content, reasoning_details, etc.) — surrogates from
|
||||
# byte-level reasoning models (xiaomi/mimo, kimi, glm) can lurk
|
||||
# in these fields and aren't covered by the per-field checks above.
|
||||
# Matches _sanitize_messages_non_ascii's coverage (PR #10537).
|
||||
for key, value in msg.items():
|
||||
if key in {"content", "name", "tool_calls", "role"}:
|
||||
continue
|
||||
if isinstance(value, str):
|
||||
if _SURROGATE_RE.search(value):
|
||||
msg[key] = _SURROGATE_RE.sub('\ufffd', value)
|
||||
found = True
|
||||
elif isinstance(value, (dict, list)):
|
||||
if _sanitize_structure_surrogates(value):
|
||||
found = True
|
||||
return found
|
||||
|
||||
|
||||
@@ -778,6 +831,26 @@ class AIAgent:
|
||||
self._execution_thread_id: int | None = None # Set at run_conversation() start
|
||||
self._interrupt_thread_signal_pending = False
|
||||
self._client_lock = threading.RLock()
|
||||
|
||||
# /steer mechanism — inject a user note into the next tool result
|
||||
# without interrupting the agent. Unlike interrupt(), steer() does
|
||||
# NOT set _interrupt_requested; it waits for the current tool batch
|
||||
# to finish naturally, then the drain hook appends the text to the
|
||||
# last tool result's content so the model sees it on its next
|
||||
# iteration. Message-role alternation is preserved (we modify an
|
||||
# existing tool message rather than inserting a new user turn).
|
||||
self._pending_steer: Optional[str] = None
|
||||
self._pending_steer_lock = threading.Lock()
|
||||
|
||||
# Concurrent-tool worker thread tracking. `_execute_tool_calls_concurrent`
|
||||
# runs each tool on its own ThreadPoolExecutor worker — those worker
|
||||
# threads have tids distinct from `_execution_thread_id`, so
|
||||
# `_set_interrupt(True, _execution_thread_id)` alone does NOT cause
|
||||
# `is_interrupted()` inside the worker to return True. Track the
|
||||
# workers here so `interrupt()` / `clear_interrupt()` can fan out to
|
||||
# their tids explicitly.
|
||||
self._tool_worker_threads: set[int] = set()
|
||||
self._tool_worker_threads_lock = threading.Lock()
|
||||
|
||||
# Subagent delegation state
|
||||
self._delegate_depth = 0 # 0 = top-level agent, incremented for children
|
||||
@@ -981,6 +1054,16 @@ class AIAgent:
|
||||
}
|
||||
elif "portal.qwen.ai" in effective_base.lower():
|
||||
client_kwargs["default_headers"] = _qwen_portal_headers()
|
||||
elif "generativelanguage.googleapis.com" in effective_base.lower():
|
||||
# Google's OpenAI-compatible endpoint only accepts x-goog-api-key.
|
||||
# The OpenAI SDK auto-injects Authorization: Bearer when api_key= is
|
||||
# set to a real value, causing HTTP 400 "Multiple authentication
|
||||
# credentials received". Pass a placeholder so the SDK does not
|
||||
# emit Bearer, and carry the real key via x-goog-api-key instead.
|
||||
# Fixes: https://github.com/NousResearch/hermes-agent/issues/7893
|
||||
real_key = client_kwargs["api_key"]
|
||||
client_kwargs["api_key"] = "not-used"
|
||||
client_kwargs["default_headers"] = {"x-goog-api-key": real_key}
|
||||
else:
|
||||
# No explicit creds — use the centralized provider router
|
||||
from agent.auxiliary_client import resolve_provider_client
|
||||
@@ -3138,6 +3221,25 @@ class AIAgent:
|
||||
# interrupt signal until startup completes instead of targeting
|
||||
# the caller thread by mistake.
|
||||
self._interrupt_thread_signal_pending = True
|
||||
# Fan out to concurrent-tool worker threads. Those workers run tools
|
||||
# on their own tids (ThreadPoolExecutor workers), so `is_interrupted()`
|
||||
# inside a tool only sees an interrupt when their specific tid is in
|
||||
# the `_interrupted_threads` set. Without this propagation, an
|
||||
# already-running concurrent tool (e.g. a terminal command hung on
|
||||
# network I/O) never notices the interrupt and has to run to its own
|
||||
# timeout. See `_run_tool` for the matching entry/exit bookkeeping.
|
||||
# `getattr` fallback covers test stubs that build AIAgent via
|
||||
# object.__new__ and skip __init__.
|
||||
_tracker = getattr(self, "_tool_worker_threads", None)
|
||||
_tracker_lock = getattr(self, "_tool_worker_threads_lock", None)
|
||||
if _tracker is not None and _tracker_lock is not None:
|
||||
with _tracker_lock:
|
||||
_worker_tids = list(_tracker)
|
||||
for _wtid in _worker_tids:
|
||||
try:
|
||||
_set_interrupt(True, _wtid)
|
||||
except Exception:
|
||||
pass
|
||||
# Propagate interrupt to any running child agents (subagent delegation)
|
||||
with self._active_children_lock:
|
||||
children_copy = list(self._active_children)
|
||||
@@ -3156,6 +3258,146 @@ class AIAgent:
|
||||
self._interrupt_thread_signal_pending = False
|
||||
if self._execution_thread_id is not None:
|
||||
_set_interrupt(False, self._execution_thread_id)
|
||||
# Also clear any concurrent-tool worker thread bits. Tracked
|
||||
# workers normally clear their own bit on exit, but an explicit
|
||||
# clear here guarantees no stale interrupt can survive a turn
|
||||
# boundary and fire on a subsequent, unrelated tool call that
|
||||
# happens to get scheduled onto the same recycled worker tid.
|
||||
# `getattr` fallback covers test stubs that build AIAgent via
|
||||
# object.__new__ and skip __init__.
|
||||
_tracker = getattr(self, "_tool_worker_threads", None)
|
||||
_tracker_lock = getattr(self, "_tool_worker_threads_lock", None)
|
||||
if _tracker is not None and _tracker_lock is not None:
|
||||
with _tracker_lock:
|
||||
_worker_tids = list(_tracker)
|
||||
for _wtid in _worker_tids:
|
||||
try:
|
||||
_set_interrupt(False, _wtid)
|
||||
except Exception:
|
||||
pass
|
||||
# A hard interrupt supersedes any pending /steer — the steer was
|
||||
# meant for the agent's next tool-call iteration, which will no
|
||||
# longer happen. Drop it instead of surprising the user with a
|
||||
# late injection on the post-interrupt turn.
|
||||
_steer_lock = getattr(self, "_pending_steer_lock", None)
|
||||
if _steer_lock is not None:
|
||||
with _steer_lock:
|
||||
self._pending_steer = None
|
||||
|
||||
def steer(self, text: str) -> bool:
|
||||
"""
|
||||
Inject a user message into the next tool result without interrupting.
|
||||
|
||||
Unlike interrupt(), this does NOT stop the current tool call. The
|
||||
text is stashed and the agent loop appends it to the LAST tool
|
||||
result's content once the current tool batch finishes. The model
|
||||
sees the steer as part of the tool output on its next iteration.
|
||||
|
||||
Thread-safe: callable from gateway/CLI/TUI threads. Multiple calls
|
||||
before the drain point concatenate with newlines.
|
||||
|
||||
Args:
|
||||
text: The user text to inject. Empty strings are ignored.
|
||||
|
||||
Returns:
|
||||
True if the steer was accepted, False if the text was empty.
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return False
|
||||
cleaned = text.strip()
|
||||
_lock = getattr(self, "_pending_steer_lock", None)
|
||||
if _lock is None:
|
||||
# Test stubs that built AIAgent via object.__new__ skip __init__.
|
||||
# Fall back to direct attribute set; no concurrent callers expected
|
||||
# in those stubs.
|
||||
existing = getattr(self, "_pending_steer", None)
|
||||
self._pending_steer = (existing + "\n" + cleaned) if existing else cleaned
|
||||
return True
|
||||
with _lock:
|
||||
if self._pending_steer:
|
||||
self._pending_steer = self._pending_steer + "\n" + cleaned
|
||||
else:
|
||||
self._pending_steer = cleaned
|
||||
return True
|
||||
|
||||
def _drain_pending_steer(self) -> Optional[str]:
|
||||
"""Return the pending steer text (if any) and clear the slot.
|
||||
|
||||
Safe to call from the agent execution thread after appending tool
|
||||
results. Returns None when no steer is pending.
|
||||
"""
|
||||
_lock = getattr(self, "_pending_steer_lock", None)
|
||||
if _lock is None:
|
||||
text = getattr(self, "_pending_steer", None)
|
||||
self._pending_steer = None
|
||||
return text
|
||||
with _lock:
|
||||
text = self._pending_steer
|
||||
self._pending_steer = None
|
||||
return text
|
||||
|
||||
def _apply_pending_steer_to_tool_results(self, messages: list, num_tool_msgs: int) -> None:
|
||||
"""Append any pending /steer text to the last tool result in this turn.
|
||||
|
||||
Called at the end of a tool-call batch, before the next API call.
|
||||
The steer is appended to the last ``role:"tool"`` message's content
|
||||
with a clear marker so the model understands it came from the user
|
||||
and NOT from the tool itself. Role alternation is preserved —
|
||||
nothing new is inserted, we only modify existing content.
|
||||
|
||||
Args:
|
||||
messages: The running messages list.
|
||||
num_tool_msgs: Number of tool results appended in this batch;
|
||||
used to locate the tail slice safely.
|
||||
"""
|
||||
if num_tool_msgs <= 0 or not messages:
|
||||
return
|
||||
steer_text = self._drain_pending_steer()
|
||||
if not steer_text:
|
||||
return
|
||||
# Find the last tool-role message in the recent tail. Skipping
|
||||
# non-tool messages defends against future code appending
|
||||
# something else at the boundary.
|
||||
target_idx = None
|
||||
for j in range(len(messages) - 1, max(len(messages) - num_tool_msgs - 1, -1), -1):
|
||||
msg = messages[j]
|
||||
if isinstance(msg, dict) and msg.get("role") == "tool":
|
||||
target_idx = j
|
||||
break
|
||||
if target_idx is None:
|
||||
# No tool result in this batch (e.g. all skipped by interrupt);
|
||||
# put the steer back so the caller's fallback path can deliver
|
||||
# it as a normal next-turn user message.
|
||||
_lock = getattr(self, "_pending_steer_lock", None)
|
||||
if _lock is not None:
|
||||
with _lock:
|
||||
if self._pending_steer:
|
||||
self._pending_steer = self._pending_steer + "\n" + steer_text
|
||||
else:
|
||||
self._pending_steer = steer_text
|
||||
else:
|
||||
existing = getattr(self, "_pending_steer", None)
|
||||
self._pending_steer = (existing + "\n" + steer_text) if existing else steer_text
|
||||
return
|
||||
marker = f"\n\n[USER STEER (injected mid-run, not tool output): {steer_text}]"
|
||||
existing_content = messages[target_idx].get("content", "")
|
||||
if not isinstance(existing_content, str):
|
||||
# Anthropic multimodal content blocks — preserve them and append
|
||||
# a text block at the end.
|
||||
try:
|
||||
blocks = list(existing_content) if existing_content else []
|
||||
blocks.append({"type": "text", "text": marker.lstrip()})
|
||||
messages[target_idx]["content"] = blocks
|
||||
except Exception:
|
||||
# Fall back to string replacement if content shape is unexpected.
|
||||
messages[target_idx]["content"] = f"{existing_content}{marker}"
|
||||
else:
|
||||
messages[target_idx]["content"] = existing_content + marker
|
||||
logger.info(
|
||||
"Delivered /steer to agent after tool batch (%d chars): %s",
|
||||
len(steer_text),
|
||||
steer_text[:120] + ("..." if len(steer_text) > 120 else ""),
|
||||
)
|
||||
|
||||
def _touch_activity(self, desc: str) -> None:
|
||||
"""Update the last-activity timestamp and description (thread-safe)."""
|
||||
@@ -5003,6 +5245,17 @@ class AIAgent:
|
||||
self._client_kwargs["default_headers"] = {"User-Agent": "KimiCLI/1.30.0"}
|
||||
elif "portal.qwen.ai" in normalized:
|
||||
self._client_kwargs["default_headers"] = _qwen_portal_headers()
|
||||
elif "generativelanguage.googleapis.com" in normalized:
|
||||
# Google's endpoint rejects Bearer tokens; use x-goog-api-key instead.
|
||||
# Swap the real key out of api_key and into the header so the OpenAI
|
||||
# SDK does not emit Authorization: Bearer.
|
||||
# Fixes: https://github.com/NousResearch/hermes-agent/issues/7893
|
||||
real_key = self._client_kwargs.get("api_key", "")
|
||||
if real_key and real_key != "not-used":
|
||||
self._client_kwargs["api_key"] = "not-used"
|
||||
self._client_kwargs["default_headers"] = {
|
||||
"x-goog-api-key": real_key or self._client_kwargs.get("api_key", ""),
|
||||
}
|
||||
else:
|
||||
self._client_kwargs.pop("default_headers", None)
|
||||
|
||||
@@ -5459,7 +5712,7 @@ class AIAgent:
|
||||
raise result["error"]
|
||||
return result["response"]
|
||||
|
||||
result = {"response": None, "error": None}
|
||||
result = {"response": None, "error": None, "partial_tool_names": []}
|
||||
request_client_holder = {"client": None}
|
||||
first_delta_fired = {"done": False}
|
||||
deltas_were_sent = {"yes": False} # Track if any deltas were fired (for fallback)
|
||||
@@ -5631,6 +5884,14 @@ class AIAgent:
|
||||
tool_gen_notified.add(idx)
|
||||
_fire_first_delta()
|
||||
self._fire_tool_gen_started(name)
|
||||
# Record the partial tool-call name so the outer
|
||||
# stub-builder can surface a user-visible warning
|
||||
# if streaming dies before this tool's arguments
|
||||
# are fully delivered. Without this, a stall
|
||||
# during tool-call JSON generation lets the stub
|
||||
# at line ~6107 return `tool_calls=None`, silently
|
||||
# discarding the attempted action.
|
||||
result["partial_tool_names"].append(name)
|
||||
|
||||
if chunk.choices[0].finish_reason:
|
||||
finish_reason = chunk.choices[0].finish_reason
|
||||
@@ -5841,6 +6102,7 @@ class AIAgent:
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
self._emit_status("🔄 Reconnected — resuming…")
|
||||
continue
|
||||
self._emit_status(
|
||||
"❌ Connection to provider failed after "
|
||||
@@ -5996,13 +6258,44 @@ class AIAgent:
|
||||
_partial_text = (
|
||||
getattr(self, "_current_streamed_assistant_text", "") or ""
|
||||
).strip() or None
|
||||
logger.warning(
|
||||
"Partial stream delivered before error; returning stub "
|
||||
"response with %s chars of recovered content to prevent "
|
||||
"duplicate messages: %s",
|
||||
len(_partial_text or ""),
|
||||
result["error"],
|
||||
)
|
||||
|
||||
# If the stream died while the model was emitting a tool call,
|
||||
# the stub below will silently set `tool_calls=None` and the
|
||||
# agent loop will treat the turn as complete — the attempted
|
||||
# action is lost with no user-facing signal. Append a
|
||||
# human-visible warning to the stub content so (a) the user
|
||||
# knows something failed, and (b) the next turn's model sees
|
||||
# in conversation history what was attempted and can retry.
|
||||
_partial_names = list(result.get("partial_tool_names") or [])
|
||||
if _partial_names:
|
||||
_name_str = ", ".join(_partial_names[:3])
|
||||
if len(_partial_names) > 3:
|
||||
_name_str += f", +{len(_partial_names) - 3} more"
|
||||
_warn = (
|
||||
f"\n\n⚠ Stream stalled mid tool-call "
|
||||
f"({_name_str}); the action was not executed. "
|
||||
f"Ask me to retry if you want to continue."
|
||||
)
|
||||
_partial_text = (_partial_text or "") + _warn
|
||||
# Also fire as a streaming delta so the user sees it now
|
||||
# instead of only in the persisted transcript.
|
||||
try:
|
||||
self._fire_stream_delta(_warn)
|
||||
except Exception:
|
||||
pass
|
||||
logger.warning(
|
||||
"Partial stream dropped tool call(s) %s after %s chars "
|
||||
"of text; surfaced warning to user: %s",
|
||||
_partial_names, len(_partial_text or ""), result["error"],
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Partial stream delivered before error; returning stub "
|
||||
"response with %s chars of recovered content to prevent "
|
||||
"duplicate messages: %s",
|
||||
len(_partial_text or ""),
|
||||
result["error"],
|
||||
)
|
||||
_stub_msg = SimpleNamespace(
|
||||
role="assistant", content=_partial_text, tool_calls=None,
|
||||
reasoning_content=None,
|
||||
@@ -6744,6 +7037,14 @@ class AIAgent:
|
||||
"messages": sanitized_messages,
|
||||
"timeout": float(os.getenv("HERMES_API_TIMEOUT", 1800.0)),
|
||||
}
|
||||
try:
|
||||
from agent.auxiliary_client import _fixed_temperature_for_model
|
||||
except Exception:
|
||||
_fixed_temperature_for_model = None
|
||||
if _fixed_temperature_for_model is not None:
|
||||
fixed_temperature = _fixed_temperature_for_model(self.model)
|
||||
if fixed_temperature is not None:
|
||||
api_kwargs["temperature"] = fixed_temperature
|
||||
if self._is_qwen_portal():
|
||||
api_kwargs["metadata"] = {
|
||||
"sessionId": self.session_id or "hermes",
|
||||
@@ -6949,7 +7250,7 @@ class AIAgent:
|
||||
# (gateway, batch, quiet) still get reasoning.
|
||||
# Any reasoning that wasn't shown during streaming is caught by the
|
||||
# CLI post-response display fallback (cli.py _reasoning_shown_this_turn).
|
||||
if not self.stream_delta_callback:
|
||||
if not self.stream_delta_callback and not self._stream_callback:
|
||||
try:
|
||||
self.reasoning_callback(reasoning_text)
|
||||
except Exception:
|
||||
@@ -7154,14 +7455,22 @@ class AIAgent:
|
||||
|
||||
# Use auxiliary client for the flush call when available --
|
||||
# it's cheaper and avoids Codex Responses API incompatibility.
|
||||
from agent.auxiliary_client import call_llm as _call_llm
|
||||
from agent.auxiliary_client import (
|
||||
call_llm as _call_llm,
|
||||
_fixed_temperature_for_model,
|
||||
)
|
||||
_aux_available = True
|
||||
# Use the fixed-temperature override (e.g. kimi-for-coding → 0.6) if
|
||||
# the model has a strict contract; otherwise the historical 0.3 default.
|
||||
_flush_temperature = _fixed_temperature_for_model(self.model)
|
||||
if _flush_temperature is None:
|
||||
_flush_temperature = 0.3
|
||||
try:
|
||||
response = _call_llm(
|
||||
task="flush_memories",
|
||||
messages=api_messages,
|
||||
tools=[memory_tool_def],
|
||||
temperature=0.3,
|
||||
temperature=_flush_temperature,
|
||||
max_tokens=5120,
|
||||
# timeout resolved from auxiliary.flush_memories.timeout config
|
||||
)
|
||||
@@ -7173,7 +7482,7 @@ class AIAgent:
|
||||
# No auxiliary client -- use the Codex Responses path directly
|
||||
codex_kwargs = self._build_api_kwargs(api_messages)
|
||||
codex_kwargs["tools"] = self._responses_tools([memory_tool_def])
|
||||
codex_kwargs["temperature"] = 0.3
|
||||
codex_kwargs["temperature"] = _flush_temperature
|
||||
if "max_output_tokens" in codex_kwargs:
|
||||
codex_kwargs["max_output_tokens"] = 5120
|
||||
response = self._run_codex_stream(codex_kwargs)
|
||||
@@ -7192,7 +7501,7 @@ class AIAgent:
|
||||
"model": self.model,
|
||||
"messages": api_messages,
|
||||
"tools": [memory_tool_def],
|
||||
"temperature": 0.3,
|
||||
"temperature": _flush_temperature,
|
||||
**self._max_tokens_param(5120),
|
||||
}
|
||||
from agent.auxiliary_client import _get_task_timeout
|
||||
@@ -7583,6 +7892,22 @@ class AIAgent:
|
||||
|
||||
def _run_tool(index, tool_call, function_name, function_args):
|
||||
"""Worker function executed in a thread."""
|
||||
# Register this worker tid so the agent can fan out an interrupt
|
||||
# to it — see AIAgent.interrupt(). Must happen first thing, and
|
||||
# must be paired with discard + clear in the finally block.
|
||||
_worker_tid = threading.current_thread().ident
|
||||
with self._tool_worker_threads_lock:
|
||||
self._tool_worker_threads.add(_worker_tid)
|
||||
# Race: if the agent was interrupted between fan-out (which
|
||||
# snapshotted an empty/earlier set) and our registration, apply
|
||||
# the interrupt to our own tid now so is_interrupted() inside
|
||||
# the tool returns True on the next poll.
|
||||
if self._interrupt_requested:
|
||||
try:
|
||||
from tools.interrupt import set_interrupt as _sif
|
||||
_sif(True, _worker_tid)
|
||||
except Exception:
|
||||
pass
|
||||
# Set the activity callback on THIS worker thread so
|
||||
# _wait_for_process (terminal commands) can fire heartbeats.
|
||||
# The callback is thread-local; the main thread's callback
|
||||
@@ -7605,6 +7930,16 @@ class AIAgent:
|
||||
else:
|
||||
logger.info("tool %s completed (%.2fs, %d chars)", function_name, duration, len(result))
|
||||
results[index] = (function_name, function_args, result, duration, is_error)
|
||||
# Tear down worker-tid tracking. Clear any interrupt bit we may
|
||||
# have set so the next task scheduled onto this recycled tid
|
||||
# starts with a clean slate.
|
||||
with self._tool_worker_threads_lock:
|
||||
self._tool_worker_threads.discard(_worker_tid)
|
||||
try:
|
||||
from tools.interrupt import set_interrupt as _sif
|
||||
_sif(False, _worker_tid)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Start spinner for CLI mode (skip when TUI handles tool progress)
|
||||
spinner = None
|
||||
@@ -7749,6 +8084,13 @@ class AIAgent:
|
||||
turn_tool_msgs = messages[-num_tools:]
|
||||
enforce_turn_budget(turn_tool_msgs, env=get_active_env(effective_task_id))
|
||||
|
||||
# ── /steer injection ──────────────────────────────────────────────
|
||||
# Append any pending user steer text to the last tool result so the
|
||||
# agent sees it on its next iteration. Runs AFTER budget enforcement
|
||||
# so the steer marker is never truncated. See steer() for details.
|
||||
if num_tools > 0:
|
||||
self._apply_pending_steer_to_tool_results(messages, num_tools)
|
||||
|
||||
def _execute_tool_calls_sequential(self, assistant_message, messages: list, effective_task_id: str, api_call_count: int = 0) -> None:
|
||||
"""Execute tool calls sequentially (original behavior). Used for single calls or interactive tools."""
|
||||
for i, tool_call in enumerate(assistant_message.tool_calls, 1):
|
||||
@@ -8128,6 +8470,12 @@ class AIAgent:
|
||||
if num_tools_seq > 0:
|
||||
enforce_turn_budget(messages[-num_tools_seq:], env=get_active_env(effective_task_id))
|
||||
|
||||
# ── /steer injection ──────────────────────────────────────────────
|
||||
# See _execute_tool_calls_parallel for the rationale. Same hook,
|
||||
# applied to sequential execution as well.
|
||||
if num_tools_seq > 0:
|
||||
self._apply_pending_steer_to_tool_results(messages, num_tools_seq)
|
||||
|
||||
|
||||
|
||||
def _handle_max_iterations(self, messages: list, api_call_count: int) -> str:
|
||||
@@ -8165,6 +8513,15 @@ class AIAgent:
|
||||
api_messages.insert(sys_offset + idx, pfm.copy())
|
||||
|
||||
summary_extra_body = {}
|
||||
try:
|
||||
from agent.auxiliary_client import _fixed_temperature_for_model
|
||||
except Exception:
|
||||
_fixed_temperature_for_model = None
|
||||
_summary_temperature = (
|
||||
_fixed_temperature_for_model(self.model)
|
||||
if _fixed_temperature_for_model is not None
|
||||
else None
|
||||
)
|
||||
_is_nous = "nousresearch" in self._base_url_lower
|
||||
if self._supports_reasoning_extra_body():
|
||||
if self.reasoning_config is not None:
|
||||
@@ -8188,6 +8545,8 @@ class AIAgent:
|
||||
"model": self.model,
|
||||
"messages": api_messages,
|
||||
}
|
||||
if _summary_temperature is not None:
|
||||
summary_kwargs["temperature"] = _summary_temperature
|
||||
if self.max_tokens is not None:
|
||||
summary_kwargs.update(self._max_tokens_param(self.max_tokens))
|
||||
|
||||
@@ -8253,6 +8612,8 @@ class AIAgent:
|
||||
"model": self.model,
|
||||
"messages": api_messages,
|
||||
}
|
||||
if _summary_temperature is not None:
|
||||
summary_kwargs["temperature"] = _summary_temperature
|
||||
if self.max_tokens is not None:
|
||||
summary_kwargs.update(self._max_tokens_param(self.max_tokens))
|
||||
if summary_extra_body:
|
||||
@@ -8688,6 +9049,7 @@ class AIAgent:
|
||||
{
|
||||
"name": tc["function"]["name"],
|
||||
"result": _results_by_id.get(tc.get("id")),
|
||||
"arguments": tc["function"].get("arguments"),
|
||||
}
|
||||
for tc in _m["tool_calls"]
|
||||
if isinstance(tc, dict)
|
||||
@@ -9302,8 +9664,7 @@ class AIAgent:
|
||||
"and had none left for the actual response.\n\n"
|
||||
"To fix this:\n"
|
||||
"→ Lower reasoning effort: `/thinkon low` or `/thinkon minimal`\n"
|
||||
"→ Increase the output token limit: "
|
||||
"set `model.max_tokens` in config.yaml"
|
||||
"→ Or switch to a larger/non-reasoning model with `/model`"
|
||||
)
|
||||
self._cleanup_task_resources(effective_task_id)
|
||||
self._persist_session(messages, conversation_history)
|
||||
@@ -9570,13 +9931,51 @@ class AIAgent:
|
||||
if isinstance(api_error, UnicodeEncodeError) and getattr(self, '_unicode_sanitization_passes', 0) < 2:
|
||||
_err_str = str(api_error).lower()
|
||||
_is_ascii_codec = "'ascii'" in _err_str or "ascii" in _err_str
|
||||
# Detect surrogate errors — utf-8 codec refusing to
|
||||
# encode U+D800..U+DFFF. The error text is:
|
||||
# "'utf-8' codec can't encode characters in position
|
||||
# N-M: surrogates not allowed"
|
||||
_is_surrogate_error = (
|
||||
"surrogate" in _err_str
|
||||
or ("'utf-8'" in _err_str and not _is_ascii_codec)
|
||||
)
|
||||
# Sanitize surrogates from both the canonical `messages`
|
||||
# list AND `api_messages` (the API-copy, which may carry
|
||||
# `reasoning_content`/`reasoning_details` transformed
|
||||
# from `reasoning` — fields the canonical list doesn't
|
||||
# have directly). Also clean `api_kwargs` if built and
|
||||
# `prefill_messages` if present. Mirrors the ASCII
|
||||
# codec recovery below.
|
||||
_surrogates_found = _sanitize_messages_surrogates(messages)
|
||||
if _surrogates_found:
|
||||
if isinstance(api_messages, list):
|
||||
if _sanitize_messages_surrogates(api_messages):
|
||||
_surrogates_found = True
|
||||
if isinstance(api_kwargs, dict):
|
||||
if _sanitize_structure_surrogates(api_kwargs):
|
||||
_surrogates_found = True
|
||||
if isinstance(getattr(self, "prefill_messages", None), list):
|
||||
if _sanitize_messages_surrogates(self.prefill_messages):
|
||||
_surrogates_found = True
|
||||
# Gate the retry on the error type, not on whether we
|
||||
# found anything — _force_ascii_payload / the extended
|
||||
# surrogate walker above cover all known paths, but a
|
||||
# new transformed field could still slip through. If
|
||||
# the error was a surrogate encode failure, always let
|
||||
# the retry run; the proactive sanitizer at line ~8781
|
||||
# runs again on the next iteration. Bounded by
|
||||
# _unicode_sanitization_passes < 2 (outer guard).
|
||||
if _surrogates_found or _is_surrogate_error:
|
||||
self._unicode_sanitization_passes += 1
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ Stripped invalid surrogate characters from messages. Retrying...",
|
||||
force=True,
|
||||
)
|
||||
if _surrogates_found:
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ Stripped invalid surrogate characters from messages. Retrying...",
|
||||
force=True,
|
||||
)
|
||||
else:
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ Surrogate encoding error — retrying after full-payload sanitization...",
|
||||
force=True,
|
||||
)
|
||||
continue
|
||||
if _is_ascii_codec:
|
||||
self._force_ascii_payload = True
|
||||
@@ -10344,9 +10743,9 @@ class AIAgent:
|
||||
pass
|
||||
wait_time = _retry_after if _retry_after else jittered_backoff(retry_count, base_delay=2.0, max_delay=60.0)
|
||||
if is_rate_limited:
|
||||
self._emit_status(f"⏱️ Rate limit reached. Waiting {wait_time}s before retry (attempt {retry_count + 1}/{max_retries})...")
|
||||
self._emit_status(f"⏱️ Rate limited. Waiting {wait_time:.1f}s (attempt {retry_count + 1}/{max_retries})...")
|
||||
else:
|
||||
self._emit_status(f"⏳ Retrying in {wait_time}s (attempt {retry_count}/{max_retries})...")
|
||||
self._emit_status(f"⏳ Retrying in {wait_time:.1f}s (attempt {retry_count}/{max_retries})...")
|
||||
logger.warning(
|
||||
"Retrying API call in %ss (attempt %s/%s) %s error=%s",
|
||||
wait_time,
|
||||
@@ -10762,7 +11161,14 @@ class AIAgent:
|
||||
elif self.quiet_mode:
|
||||
clean = self._strip_think_blocks(turn_content).strip()
|
||||
if clean:
|
||||
self._vprint(f" ┊ 💬 {clean}")
|
||||
relayed = False
|
||||
if (
|
||||
self.tool_progress_callback
|
||||
and getattr(self, "platform", "") == "tui"
|
||||
):
|
||||
relayed = True
|
||||
if not relayed:
|
||||
self._vprint(f" ┊ 💬 {clean}")
|
||||
|
||||
# Pop thinking-only prefill message(s) before appending
|
||||
# (tool-call path — same rationale as the final-response path).
|
||||
@@ -11350,6 +11756,12 @@ class AIAgent:
|
||||
"cost_status": self.session_cost_status,
|
||||
"cost_source": self.session_cost_source,
|
||||
}
|
||||
# If a /steer landed after the final assistant turn (no more tool
|
||||
# batches to drain into), hand it back to the caller so it can be
|
||||
# delivered as the next user turn instead of being silently lost.
|
||||
_leftover_steer = self._drain_pending_steer()
|
||||
if _leftover_steer:
|
||||
result["pending_steer"] = _leftover_steer
|
||||
self._response_was_previewed = False
|
||||
|
||||
# Include interrupt message if one triggered the interrupt
|
||||
|
||||
@@ -721,6 +721,20 @@ function Install-NodeDeps {
|
||||
}
|
||||
}
|
||||
|
||||
# Install TUI dependencies
|
||||
$tuiDir = "$InstallDir\ui-tui"
|
||||
if (Test-Path "$tuiDir\package.json") {
|
||||
Write-Info "Installing TUI dependencies..."
|
||||
Push-Location $tuiDir
|
||||
try {
|
||||
npm install --silent 2>&1 | Out-Null
|
||||
Write-Success "TUI dependencies installed"
|
||||
} catch {
|
||||
Write-Warn "TUI npm install failed (hermes --tui may not work)"
|
||||
}
|
||||
Pop-Location
|
||||
}
|
||||
|
||||
# Install WhatsApp bridge dependencies
|
||||
$bridgeDir = "$InstallDir\scripts\whatsapp-bridge"
|
||||
if (Test-Path "$bridgeDir\package.json") {
|
||||
|
||||
@@ -1194,6 +1194,16 @@ install_node_deps() {
|
||||
log_success "Browser engine setup complete"
|
||||
fi
|
||||
|
||||
# Install TUI dependencies
|
||||
if [ -f "$INSTALL_DIR/ui-tui/package.json" ]; then
|
||||
log_info "Installing TUI dependencies..."
|
||||
cd "$INSTALL_DIR/ui-tui"
|
||||
npm install --silent 2>/dev/null || {
|
||||
log_warn "TUI npm install failed (hermes --tui may not work)"
|
||||
}
|
||||
log_success "TUI dependencies installed"
|
||||
fi
|
||||
|
||||
# Install WhatsApp bridge dependencies
|
||||
if [ -f "$INSTALL_DIR/scripts/whatsapp-bridge/package.json" ]; then
|
||||
log_info "Installing WhatsApp bridge dependencies..."
|
||||
|
||||
@@ -0,0 +1,238 @@
|
||||
#!/usr/bin/env bash
|
||||
# ============================================================================
|
||||
# scripts/lib/node-bootstrap.sh
|
||||
# ----------------------------------------------------------------------------
|
||||
# Sourceable helper: ensure Node.js >= MIN_VERSION is available for the TUI
|
||||
# (React + Ink), browser tools, and the WhatsApp bridge.
|
||||
#
|
||||
# Strategy (first hit wins — respects the user's existing tooling):
|
||||
# 1. modern `node` already on PATH
|
||||
# 2. ~/.hermes/node/ from a prior Hermes-managed install
|
||||
# 3. fnm, proto, nvm (in that order) if the user already uses a version manager
|
||||
# 4. Termux `pkg`, macOS Homebrew
|
||||
# 5. pinned nodejs.org tarball into ~/.hermes/node/ (always works, zero shell rc edits)
|
||||
#
|
||||
# Usage:
|
||||
# source scripts/lib/node-bootstrap.sh
|
||||
# ensure_node # returns 0 on success, non-zero on failure
|
||||
# if [ "$HERMES_NODE_AVAILABLE" = true ]; then ...; fi
|
||||
#
|
||||
# Env inputs (set before sourcing to override defaults):
|
||||
# HERMES_NODE_MIN_VERSION (default: 20) — accepted on PATH
|
||||
# HERMES_NODE_TARGET_MAJOR (default: 22) — installed when we install
|
||||
# HERMES_HOME (default: $HOME/.hermes)
|
||||
# ============================================================================
|
||||
|
||||
HERMES_NODE_MIN_VERSION="${HERMES_NODE_MIN_VERSION:-20}"
|
||||
HERMES_NODE_TARGET_MAJOR="${HERMES_NODE_TARGET_MAJOR:-22}"
|
||||
HERMES_HOME="${HERMES_HOME:-$HOME/.hermes}"
|
||||
HERMES_NODE_AVAILABLE=false
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Logging — prefer the host script's log_* helpers when present
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_nb_log() { declare -F log_info >/dev/null 2>&1 && log_info "$*" || printf '→ %s\n' "$*" >&2; }
|
||||
_nb_ok() { declare -F log_success >/dev/null 2>&1 && log_success "$*" || printf '✓ %s\n' "$*" >&2; }
|
||||
_nb_warn() { declare -F log_warn >/dev/null 2>&1 && log_warn "$*" || printf '⚠ %s\n' "$*" >&2; }
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Platform + version helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_nb_is_termux() {
|
||||
[ -n "${TERMUX_VERSION:-}" ] || [[ "${PREFIX:-}" == *"com.termux/files/usr"* ]]
|
||||
}
|
||||
|
||||
_nb_node_major() {
|
||||
local v
|
||||
v=$(node --version 2>/dev/null | sed 's/^v//' | cut -d. -f1)
|
||||
[[ "$v" =~ ^[0-9]+$ ]] && echo "$v" || echo 0
|
||||
}
|
||||
|
||||
_nb_have_modern_node() {
|
||||
command -v node >/dev/null 2>&1 || return 1
|
||||
[ "$(_nb_node_major)" -ge "$HERMES_NODE_MIN_VERSION" ]
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Version-manager paths — respect what the user already uses
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_nb_try_fnm() {
|
||||
command -v fnm >/dev/null 2>&1 || return 1
|
||||
_nb_log "fnm detected — installing Node $HERMES_NODE_TARGET_MAJOR..."
|
||||
eval "$(fnm env 2>/dev/null)" || true
|
||||
fnm install "$HERMES_NODE_TARGET_MAJOR" >/dev/null 2>&1 || return 1
|
||||
fnm use "$HERMES_NODE_TARGET_MAJOR" >/dev/null 2>&1 || return 1
|
||||
_nb_have_modern_node || return 1
|
||||
_nb_ok "Node $(node --version) activated via fnm"
|
||||
return 0
|
||||
}
|
||||
|
||||
_nb_try_proto() {
|
||||
command -v proto >/dev/null 2>&1 || return 1
|
||||
_nb_log "proto detected — installing Node $HERMES_NODE_TARGET_MAJOR..."
|
||||
proto install node "$HERMES_NODE_TARGET_MAJOR" >/dev/null 2>&1 || return 1
|
||||
_nb_have_modern_node || return 1
|
||||
_nb_ok "Node $(node --version) activated via proto"
|
||||
return 0
|
||||
}
|
||||
|
||||
_nb_try_nvm() {
|
||||
local nvm_sh="${NVM_DIR:-$HOME/.nvm}/nvm.sh"
|
||||
[ -s "$nvm_sh" ] || return 1
|
||||
# shellcheck source=/dev/null
|
||||
\. "$nvm_sh" >/dev/null 2>&1 || return 1
|
||||
_nb_log "nvm detected — installing Node $HERMES_NODE_TARGET_MAJOR..."
|
||||
nvm install "$HERMES_NODE_TARGET_MAJOR" >/dev/null 2>&1 || return 1
|
||||
nvm use "$HERMES_NODE_TARGET_MAJOR" >/dev/null 2>&1 || return 1
|
||||
_nb_have_modern_node || return 1
|
||||
_nb_ok "Node $(node --version) activated via nvm"
|
||||
return 0
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Platform package managers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_nb_try_termux_pkg() {
|
||||
_nb_is_termux || return 1
|
||||
_nb_log "Installing Node.js via pkg..."
|
||||
pkg install -y nodejs >/dev/null 2>&1 || return 1
|
||||
_nb_have_modern_node || return 1
|
||||
_nb_ok "Node $(node --version) installed via pkg"
|
||||
return 0
|
||||
}
|
||||
|
||||
_nb_try_brew() {
|
||||
[ "$(uname -s)" = "Darwin" ] || return 1
|
||||
command -v brew >/dev/null 2>&1 || return 1
|
||||
_nb_log "Installing Node via Homebrew..."
|
||||
brew install "node@${HERMES_NODE_TARGET_MAJOR}" >/dev/null 2>&1 \
|
||||
|| brew install node >/dev/null 2>&1 \
|
||||
|| return 1
|
||||
brew link --overwrite --force "node@${HERMES_NODE_TARGET_MAJOR}" >/dev/null 2>&1 || true
|
||||
_nb_have_modern_node || return 1
|
||||
_nb_ok "Node $(node --version) installed via Homebrew"
|
||||
return 0
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bundled binary fallback — always works, no shell rc edits
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_nb_install_bundled_node() {
|
||||
local arch node_arch os_name node_os
|
||||
arch=$(uname -m)
|
||||
case "$arch" in
|
||||
x86_64) node_arch="x64" ;;
|
||||
aarch64|arm64) node_arch="arm64" ;;
|
||||
armv7l) node_arch="armv7l" ;;
|
||||
*)
|
||||
_nb_warn "Unsupported arch ($arch) — install Node.js manually: https://nodejs.org/"
|
||||
return 1
|
||||
;;
|
||||
esac
|
||||
|
||||
os_name=$(uname -s)
|
||||
case "$os_name" in
|
||||
Linux*) node_os="linux" ;;
|
||||
Darwin*) node_os="darwin" ;;
|
||||
*)
|
||||
_nb_warn "Unsupported OS ($os_name) — install Node.js manually: https://nodejs.org/"
|
||||
return 1
|
||||
;;
|
||||
esac
|
||||
|
||||
local index_url="https://nodejs.org/dist/latest-v${HERMES_NODE_TARGET_MAJOR}.x/"
|
||||
local tarball
|
||||
tarball=$(curl -fsSL "$index_url" \
|
||||
| grep -oE "node-v${HERMES_NODE_TARGET_MAJOR}\.[0-9]+\.[0-9]+-${node_os}-${node_arch}\.tar\.xz" \
|
||||
| head -1)
|
||||
if [ -z "$tarball" ]; then
|
||||
tarball=$(curl -fsSL "$index_url" \
|
||||
| grep -oE "node-v${HERMES_NODE_TARGET_MAJOR}\.[0-9]+\.[0-9]+-${node_os}-${node_arch}\.tar\.gz" \
|
||||
| head -1)
|
||||
fi
|
||||
if [ -z "$tarball" ]; then
|
||||
_nb_warn "Could not resolve Node $HERMES_NODE_TARGET_MAJOR binary for $node_os-$node_arch"
|
||||
return 1
|
||||
fi
|
||||
|
||||
local tmp
|
||||
tmp=$(mktemp -d)
|
||||
_nb_log "Downloading $tarball..."
|
||||
curl -fsSL "${index_url}${tarball}" -o "$tmp/$tarball" || {
|
||||
_nb_warn "Download failed"; rm -rf "$tmp"; return 1
|
||||
}
|
||||
|
||||
_nb_log "Extracting to $HERMES_HOME/node/..."
|
||||
if [[ "$tarball" == *.tar.xz ]]; then
|
||||
tar xf "$tmp/$tarball" -C "$tmp" || { rm -rf "$tmp"; return 1; }
|
||||
else
|
||||
tar xzf "$tmp/$tarball" -C "$tmp" || { rm -rf "$tmp"; return 1; }
|
||||
fi
|
||||
|
||||
local extracted
|
||||
extracted=$(find "$tmp" -maxdepth 1 -type d -name 'node-v*' 2>/dev/null | head -1)
|
||||
if [ ! -d "$extracted" ]; then
|
||||
_nb_warn "Extraction produced no node-v* directory"
|
||||
rm -rf "$tmp"
|
||||
return 1
|
||||
fi
|
||||
|
||||
mkdir -p "$HERMES_HOME"
|
||||
rm -rf "$HERMES_HOME/node"
|
||||
mv "$extracted" "$HERMES_HOME/node"
|
||||
rm -rf "$tmp"
|
||||
|
||||
mkdir -p "$HOME/.local/bin"
|
||||
ln -sf "$HERMES_HOME/node/bin/node" "$HOME/.local/bin/node"
|
||||
ln -sf "$HERMES_HOME/node/bin/npm" "$HOME/.local/bin/npm"
|
||||
ln -sf "$HERMES_HOME/node/bin/npx" "$HOME/.local/bin/npx"
|
||||
export PATH="$HERMES_HOME/node/bin:$PATH"
|
||||
|
||||
_nb_have_modern_node || return 1
|
||||
_nb_ok "Node $(node --version) installed to $HERMES_HOME/node/"
|
||||
return 0
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
ensure_node() {
|
||||
HERMES_NODE_AVAILABLE=false
|
||||
|
||||
if _nb_have_modern_node; then
|
||||
_nb_ok "Node $(node --version) found"
|
||||
HERMES_NODE_AVAILABLE=true
|
||||
return 0
|
||||
fi
|
||||
|
||||
if [ -x "$HERMES_HOME/node/bin/node" ]; then
|
||||
export PATH="$HERMES_HOME/node/bin:$PATH"
|
||||
if _nb_have_modern_node; then
|
||||
_nb_ok "Node $(node --version) found (Hermes-managed)"
|
||||
HERMES_NODE_AVAILABLE=true
|
||||
return 0
|
||||
fi
|
||||
fi
|
||||
|
||||
# Version managers first — respect the user's existing setup.
|
||||
_nb_try_fnm && { HERMES_NODE_AVAILABLE=true; return 0; }
|
||||
_nb_try_proto && { HERMES_NODE_AVAILABLE=true; return 0; }
|
||||
_nb_try_nvm && { HERMES_NODE_AVAILABLE=true; return 0; }
|
||||
|
||||
# Platform package managers.
|
||||
_nb_try_termux_pkg && { HERMES_NODE_AVAILABLE=true; return 0; }
|
||||
_nb_try_brew && { HERMES_NODE_AVAILABLE=true; return 0; }
|
||||
|
||||
# Last resort: pinned nodejs.org tarball.
|
||||
_nb_install_bundled_node && { HERMES_NODE_AVAILABLE=true; return 0; }
|
||||
|
||||
_nb_warn "Node.js install failed — TUI and browser tools will be unavailable."
|
||||
_nb_warn "Install manually: https://nodejs.org/en/download/ (or: \`brew install node\`, \`fnm install $HERMES_NODE_TARGET_MAJOR\`, etc.)"
|
||||
return 1
|
||||
}
|
||||
@@ -44,6 +44,7 @@ AUTHOR_MAP = {
|
||||
"teknium@nousresearch.com": "teknium1",
|
||||
"127238744+teknium1@users.noreply.github.com": "teknium1",
|
||||
# contributors (from noreply pattern)
|
||||
"snreynolds2506@gmail.com": "snreynolds",
|
||||
"35742124+0xbyt4@users.noreply.github.com": "0xbyt4",
|
||||
"82637225+kshitijk4poor@users.noreply.github.com": "kshitijk4poor",
|
||||
"kshitijk4poor@users.noreply.github.com": "kshitijk4poor",
|
||||
@@ -75,6 +76,7 @@ AUTHOR_MAP = {
|
||||
"Asunfly@users.noreply.github.com": "Asunfly",
|
||||
# contributors (manual mapping from git names)
|
||||
"ahmedsherif95@gmail.com": "asheriif",
|
||||
"liujinkun@bytedance.com": "liujinkun2025",
|
||||
"dmayhem93@gmail.com": "dmahan93",
|
||||
"samherring99@gmail.com": "samherring99",
|
||||
"desaiaum08@gmail.com": "Aum08Desai",
|
||||
@@ -95,6 +97,7 @@ AUTHOR_MAP = {
|
||||
"4317663+helix4u@users.noreply.github.com": "helix4u",
|
||||
"331214+counterposition@users.noreply.github.com": "counterposition",
|
||||
"blspear@gmail.com": "BrennerSpear",
|
||||
"akhater@gmail.com": "akhater",
|
||||
"239876380+handsdiff@users.noreply.github.com": "handsdiff",
|
||||
"gpickett00@gmail.com": "gpickett00",
|
||||
"mcosma@gmail.com": "wakamex",
|
||||
@@ -103,6 +106,7 @@ AUTHOR_MAP = {
|
||||
"dangtc94@gmail.com": "dieutx",
|
||||
"jaisehgal11299@gmail.com": "jaisup",
|
||||
"percydikec@gmail.com": "PercyDikec",
|
||||
"noonou7@gmail.com": "HenkDz",
|
||||
"dean.kerr@gmail.com": "deankerr",
|
||||
"socrates1024@gmail.com": "socrates1024",
|
||||
"satelerd@gmail.com": "satelerd",
|
||||
@@ -115,6 +119,7 @@ AUTHOR_MAP = {
|
||||
"vincentcharlebois@gmail.com": "vincentcharlebois",
|
||||
"aryan@synvoid.com": "aryansingh",
|
||||
"johnsonblake1@gmail.com": "blakejohnson",
|
||||
"hcn518@gmail.com": "pedh",
|
||||
"greer.guthrie@gmail.com": "g-guthrie",
|
||||
"kennyx102@gmail.com": "bobashopcashier",
|
||||
"shokatalishaikh95@gmail.com": "areu01or00",
|
||||
@@ -202,6 +207,7 @@ AUTHOR_MAP = {
|
||||
"cola-runner@users.noreply.github.com": "cola-runner",
|
||||
"ygd58@users.noreply.github.com": "ygd58",
|
||||
"vominh1919@users.noreply.github.com": "vominh1919",
|
||||
"iamagenius00@users.noreply.github.com": "iamagenius00",
|
||||
"trevmanthony@gmail.com": "trevthefoolish",
|
||||
"ziliangpeng@users.noreply.github.com": "ziliangpeng",
|
||||
"centripetal-star@users.noreply.github.com": "centripetal-star",
|
||||
@@ -255,6 +261,9 @@ AUTHOR_MAP = {
|
||||
"anthhub@163.com": "anthhub",
|
||||
"shenuu@gmail.com": "shenuu",
|
||||
"xiayh17@gmail.com": "xiayh0107",
|
||||
"asurla@nvidia.com": "anniesurla",
|
||||
"limkuan24@gmail.com": "WideLee",
|
||||
"aviralarora002@gmail.com": "AviArora02-commits",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,430 +0,0 @@
|
||||
---
|
||||
name: gguf-quantization
|
||||
description: GGUF format and llama.cpp quantization for efficient CPU/GPU inference. Use when deploying models on consumer hardware, Apple Silicon, or when needing flexible quantization from 2-8 bit without GPU requirements.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [llama-cpp-python>=0.2.0]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [GGUF, Quantization, llama.cpp, CPU Inference, Apple Silicon, Model Compression, Optimization]
|
||||
|
||||
---
|
||||
|
||||
# GGUF - Quantization Format for llama.cpp
|
||||
|
||||
The GGUF (GPT-Generated Unified Format) is the standard file format for llama.cpp, enabling efficient inference on CPUs, Apple Silicon, and GPUs with flexible quantization options.
|
||||
|
||||
## When to use GGUF
|
||||
|
||||
**Use GGUF when:**
|
||||
- Deploying on consumer hardware (laptops, desktops)
|
||||
- Running on Apple Silicon (M1/M2/M3) with Metal acceleration
|
||||
- Need CPU inference without GPU requirements
|
||||
- Want flexible quantization (Q2_K to Q8_0)
|
||||
- Using local AI tools (LM Studio, Ollama, text-generation-webui)
|
||||
|
||||
**Key advantages:**
|
||||
- **Universal hardware**: CPU, Apple Silicon, NVIDIA, AMD support
|
||||
- **No Python runtime**: Pure C/C++ inference
|
||||
- **Flexible quantization**: 2-8 bit with various methods (K-quants)
|
||||
- **Ecosystem support**: LM Studio, Ollama, koboldcpp, and more
|
||||
- **imatrix**: Importance matrix for better low-bit quality
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **AWQ/GPTQ**: Maximum accuracy with calibration on NVIDIA GPUs
|
||||
- **HQQ**: Fast calibration-free quantization for HuggingFace
|
||||
- **bitsandbytes**: Simple integration with transformers library
|
||||
- **TensorRT-LLM**: Production NVIDIA deployment with maximum speed
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Clone llama.cpp
|
||||
git clone https://github.com/ggml-org/llama.cpp
|
||||
cd llama.cpp
|
||||
|
||||
# Build (CPU)
|
||||
make
|
||||
|
||||
# Build with CUDA (NVIDIA)
|
||||
make GGML_CUDA=1
|
||||
|
||||
# Build with Metal (Apple Silicon)
|
||||
make GGML_METAL=1
|
||||
|
||||
# Install Python bindings (optional)
|
||||
pip install llama-cpp-python
|
||||
```
|
||||
|
||||
### Convert model to GGUF
|
||||
|
||||
```bash
|
||||
# Install requirements
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Convert HuggingFace model to GGUF (FP16)
|
||||
python convert_hf_to_gguf.py ./path/to/model --outfile model-f16.gguf
|
||||
|
||||
# Or specify output type
|
||||
python convert_hf_to_gguf.py ./path/to/model \
|
||||
--outfile model-f16.gguf \
|
||||
--outtype f16
|
||||
```
|
||||
|
||||
### Quantize model
|
||||
|
||||
```bash
|
||||
# Basic quantization to Q4_K_M
|
||||
./llama-quantize model-f16.gguf model-q4_k_m.gguf Q4_K_M
|
||||
|
||||
# Quantize with importance matrix (better quality)
|
||||
./llama-imatrix -m model-f16.gguf -f calibration.txt -o model.imatrix
|
||||
./llama-quantize --imatrix model.imatrix model-f16.gguf model-q4_k_m.gguf Q4_K_M
|
||||
```
|
||||
|
||||
### Run inference
|
||||
|
||||
```bash
|
||||
# CLI inference
|
||||
./llama-cli -m model-q4_k_m.gguf -p "Hello, how are you?"
|
||||
|
||||
# Interactive mode
|
||||
./llama-cli -m model-q4_k_m.gguf --interactive
|
||||
|
||||
# With GPU offload
|
||||
./llama-cli -m model-q4_k_m.gguf -ngl 35 -p "Hello!"
|
||||
```
|
||||
|
||||
## Quantization types
|
||||
|
||||
### K-quant methods (recommended)
|
||||
|
||||
| Type | Bits | Size (7B) | Quality | Use Case |
|
||||
|------|------|-----------|---------|----------|
|
||||
| Q2_K | 2.5 | ~2.8 GB | Low | Extreme compression |
|
||||
| Q3_K_S | 3.0 | ~3.0 GB | Low-Med | Memory constrained |
|
||||
| Q3_K_M | 3.3 | ~3.3 GB | Medium | Balance |
|
||||
| Q4_K_S | 4.0 | ~3.8 GB | Med-High | Good balance |
|
||||
| Q4_K_M | 4.5 | ~4.1 GB | High | **Recommended default** |
|
||||
| Q5_K_S | 5.0 | ~4.6 GB | High | Quality focused |
|
||||
| Q5_K_M | 5.5 | ~4.8 GB | Very High | High quality |
|
||||
| Q6_K | 6.0 | ~5.5 GB | Excellent | Near-original |
|
||||
| Q8_0 | 8.0 | ~7.2 GB | Best | Maximum quality |
|
||||
|
||||
### Legacy methods
|
||||
|
||||
| Type | Description |
|
||||
|------|-------------|
|
||||
| Q4_0 | 4-bit, basic |
|
||||
| Q4_1 | 4-bit with delta |
|
||||
| Q5_0 | 5-bit, basic |
|
||||
| Q5_1 | 5-bit with delta |
|
||||
|
||||
**Recommendation**: Use K-quant methods (Q4_K_M, Q5_K_M) for best quality/size ratio.
|
||||
|
||||
## Conversion workflows
|
||||
|
||||
### Workflow 1: HuggingFace to GGUF
|
||||
|
||||
```bash
|
||||
# 1. Download model
|
||||
huggingface-cli download meta-llama/Llama-3.1-8B --local-dir ./llama-3.1-8b
|
||||
|
||||
# 2. Convert to GGUF (FP16)
|
||||
python convert_hf_to_gguf.py ./llama-3.1-8b \
|
||||
--outfile llama-3.1-8b-f16.gguf \
|
||||
--outtype f16
|
||||
|
||||
# 3. Quantize
|
||||
./llama-quantize llama-3.1-8b-f16.gguf llama-3.1-8b-q4_k_m.gguf Q4_K_M
|
||||
|
||||
# 4. Test
|
||||
./llama-cli -m llama-3.1-8b-q4_k_m.gguf -p "Hello!" -n 50
|
||||
```
|
||||
|
||||
### Workflow 2: With importance matrix (better quality)
|
||||
|
||||
```bash
|
||||
# 1. Convert to GGUF
|
||||
python convert_hf_to_gguf.py ./model --outfile model-f16.gguf
|
||||
|
||||
# 2. Create calibration text (diverse samples)
|
||||
cat > calibration.txt << 'EOF'
|
||||
The quick brown fox jumps over the lazy dog.
|
||||
Machine learning is a subset of artificial intelligence.
|
||||
Python is a popular programming language.
|
||||
# Add more diverse text samples...
|
||||
EOF
|
||||
|
||||
# 3. Generate importance matrix
|
||||
./llama-imatrix -m model-f16.gguf \
|
||||
-f calibration.txt \
|
||||
--chunk 512 \
|
||||
-o model.imatrix \
|
||||
-ngl 35 # GPU layers if available
|
||||
|
||||
# 4. Quantize with imatrix
|
||||
./llama-quantize --imatrix model.imatrix \
|
||||
model-f16.gguf \
|
||||
model-q4_k_m.gguf \
|
||||
Q4_K_M
|
||||
```
|
||||
|
||||
### Workflow 3: Multiple quantizations
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
MODEL="llama-3.1-8b-f16.gguf"
|
||||
IMATRIX="llama-3.1-8b.imatrix"
|
||||
|
||||
# Generate imatrix once
|
||||
./llama-imatrix -m $MODEL -f wiki.txt -o $IMATRIX -ngl 35
|
||||
|
||||
# Create multiple quantizations
|
||||
for QUANT in Q4_K_M Q5_K_M Q6_K Q8_0; do
|
||||
OUTPUT="llama-3.1-8b-${QUANT,,}.gguf"
|
||||
./llama-quantize --imatrix $IMATRIX $MODEL $OUTPUT $QUANT
|
||||
echo "Created: $OUTPUT ($(du -h $OUTPUT | cut -f1))"
|
||||
done
|
||||
```
|
||||
|
||||
## Python usage
|
||||
|
||||
### llama-cpp-python
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
# Load model
|
||||
llm = Llama(
|
||||
model_path="./model-q4_k_m.gguf",
|
||||
n_ctx=4096, # Context window
|
||||
n_gpu_layers=35, # GPU offload (0 for CPU only)
|
||||
n_threads=8 # CPU threads
|
||||
)
|
||||
|
||||
# Generate
|
||||
output = llm(
|
||||
"What is machine learning?",
|
||||
max_tokens=256,
|
||||
temperature=0.7,
|
||||
stop=["</s>", "\n\n"]
|
||||
)
|
||||
print(output["choices"][0]["text"])
|
||||
```
|
||||
|
||||
### Chat completion
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(
|
||||
model_path="./model-q4_k_m.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=35,
|
||||
chat_format="llama-3" # Or "chatml", "mistral", etc.
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is Python?"}
|
||||
]
|
||||
|
||||
response = llm.create_chat_completion(
|
||||
messages=messages,
|
||||
max_tokens=256,
|
||||
temperature=0.7
|
||||
)
|
||||
print(response["choices"][0]["message"]["content"])
|
||||
```
|
||||
|
||||
### Streaming
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(model_path="./model-q4_k_m.gguf", n_gpu_layers=35)
|
||||
|
||||
# Stream tokens
|
||||
for chunk in llm(
|
||||
"Explain quantum computing:",
|
||||
max_tokens=256,
|
||||
stream=True
|
||||
):
|
||||
print(chunk["choices"][0]["text"], end="", flush=True)
|
||||
```
|
||||
|
||||
## Server mode
|
||||
|
||||
### Start OpenAI-compatible server
|
||||
|
||||
```bash
|
||||
# Start server
|
||||
./llama-server -m model-q4_k_m.gguf \
|
||||
--host 0.0.0.0 \
|
||||
--port 8080 \
|
||||
-ngl 35 \
|
||||
-c 4096
|
||||
|
||||
# Or with Python bindings
|
||||
python -m llama_cpp.server \
|
||||
--model model-q4_k_m.gguf \
|
||||
--n_gpu_layers 35 \
|
||||
--host 0.0.0.0 \
|
||||
--port 8080
|
||||
```
|
||||
|
||||
### Use with OpenAI client
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8080/v1",
|
||||
api_key="not-needed"
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="local-model",
|
||||
messages=[{"role": "user", "content": "Hello!"}],
|
||||
max_tokens=256
|
||||
)
|
||||
print(response.choices[0].message.content)
|
||||
```
|
||||
|
||||
## Hardware optimization
|
||||
|
||||
### Apple Silicon (Metal)
|
||||
|
||||
```bash
|
||||
# Build with Metal
|
||||
make clean && make GGML_METAL=1
|
||||
|
||||
# Run with Metal acceleration
|
||||
./llama-cli -m model.gguf -ngl 99 -p "Hello"
|
||||
|
||||
# Python with Metal
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
n_gpu_layers=99, # Offload all layers
|
||||
n_threads=1 # Metal handles parallelism
|
||||
)
|
||||
```
|
||||
|
||||
### NVIDIA CUDA
|
||||
|
||||
```bash
|
||||
# Build with CUDA
|
||||
make clean && make GGML_CUDA=1
|
||||
|
||||
# Run with CUDA
|
||||
./llama-cli -m model.gguf -ngl 35 -p "Hello"
|
||||
|
||||
# Specify GPU
|
||||
CUDA_VISIBLE_DEVICES=0 ./llama-cli -m model.gguf -ngl 35
|
||||
```
|
||||
|
||||
### CPU optimization
|
||||
|
||||
```bash
|
||||
# Build with AVX2/AVX512
|
||||
make clean && make
|
||||
|
||||
# Run with optimal threads
|
||||
./llama-cli -m model.gguf -t 8 -p "Hello"
|
||||
|
||||
# Python CPU config
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
n_gpu_layers=0, # CPU only
|
||||
n_threads=8, # Match physical cores
|
||||
n_batch=512 # Batch size for prompt processing
|
||||
)
|
||||
```
|
||||
|
||||
## Integration with tools
|
||||
|
||||
### Ollama
|
||||
|
||||
```bash
|
||||
# Create Modelfile
|
||||
cat > Modelfile << 'EOF'
|
||||
FROM ./model-q4_k_m.gguf
|
||||
TEMPLATE """{{ .System }}
|
||||
{{ .Prompt }}"""
|
||||
PARAMETER temperature 0.7
|
||||
PARAMETER num_ctx 4096
|
||||
EOF
|
||||
|
||||
# Create Ollama model
|
||||
ollama create mymodel -f Modelfile
|
||||
|
||||
# Run
|
||||
ollama run mymodel "Hello!"
|
||||
```
|
||||
|
||||
### LM Studio
|
||||
|
||||
1. Place GGUF file in `~/.cache/lm-studio/models/`
|
||||
2. Open LM Studio and select the model
|
||||
3. Configure context length and GPU offload
|
||||
4. Start inference
|
||||
|
||||
### text-generation-webui
|
||||
|
||||
```bash
|
||||
# Place in models folder
|
||||
cp model-q4_k_m.gguf text-generation-webui/models/
|
||||
|
||||
# Start with llama.cpp loader
|
||||
python server.py --model model-q4_k_m.gguf --loader llama.cpp --n-gpu-layers 35
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Use K-quants**: Q4_K_M offers best quality/size balance
|
||||
2. **Use imatrix**: Always use importance matrix for Q4 and below
|
||||
3. **GPU offload**: Offload as many layers as VRAM allows
|
||||
4. **Context length**: Start with 4096, increase if needed
|
||||
5. **Thread count**: Match physical CPU cores, not logical
|
||||
6. **Batch size**: Increase n_batch for faster prompt processing
|
||||
|
||||
## Common issues
|
||||
|
||||
**Model loads slowly:**
|
||||
```bash
|
||||
# Use mmap for faster loading
|
||||
./llama-cli -m model.gguf --mmap
|
||||
```
|
||||
|
||||
**Out of memory:**
|
||||
```bash
|
||||
# Reduce GPU layers
|
||||
./llama-cli -m model.gguf -ngl 20 # Reduce from 35
|
||||
|
||||
# Or use smaller quantization
|
||||
./llama-quantize model-f16.gguf model-q3_k_m.gguf Q3_K_M
|
||||
```
|
||||
|
||||
**Poor quality at low bits:**
|
||||
```bash
|
||||
# Always use imatrix for Q4 and below
|
||||
./llama-imatrix -m model-f16.gguf -f calibration.txt -o model.imatrix
|
||||
./llama-quantize --imatrix model.imatrix model-f16.gguf model-q4_k_m.gguf Q4_K_M
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- **[Advanced Usage](references/advanced-usage.md)** - Batching, speculative decoding, custom builds
|
||||
- **[Troubleshooting](references/troubleshooting.md)** - Common issues, debugging, benchmarks
|
||||
|
||||
## Resources
|
||||
|
||||
- **Repository**: https://github.com/ggml-org/llama.cpp
|
||||
- **Python Bindings**: https://github.com/abetlen/llama-cpp-python
|
||||
- **Pre-quantized Models**: https://huggingface.co/TheBloke
|
||||
- **GGUF Converter**: https://huggingface.co/spaces/ggml-org/gguf-my-repo
|
||||
- **License**: MIT
|
||||
@@ -1,138 +1,271 @@
|
||||
---
|
||||
name: llama-cpp
|
||||
description: Runs LLM inference on CPU, Apple Silicon, and consumer GPUs without NVIDIA hardware. Use for edge deployment, M1/M2/M3 Macs, AMD/Intel GPUs, or when CUDA is unavailable. Supports GGUF quantization (1.5-8 bit) for reduced memory and 4-10× speedup vs PyTorch on CPU.
|
||||
version: 1.0.0
|
||||
description: Run LLM inference with llama.cpp on CPU, Apple Silicon, AMD/Intel GPUs, or NVIDIA — plus GGUF model conversion and quantization (2–8 bit with K-quants and imatrix). Covers CLI, Python bindings, OpenAI-compatible server, and Ollama/LM Studio integration. Use for edge deployment, M1/M2/M3/M4 Macs, CUDA-less environments, or flexible local quantization.
|
||||
version: 2.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [llama-cpp-python]
|
||||
dependencies: [llama-cpp-python>=0.2.0]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Inference Serving, Llama.cpp, CPU Inference, Apple Silicon, Edge Deployment, GGUF, Quantization, Non-NVIDIA, AMD GPUs, Intel GPUs, Embedded]
|
||||
|
||||
tags: [llama.cpp, GGUF, Quantization, CPU Inference, Apple Silicon, Edge Deployment, Non-NVIDIA, AMD GPUs, Intel GPUs, Embedded, Model Compression]
|
||||
---
|
||||
|
||||
# llama.cpp
|
||||
# llama.cpp + GGUF
|
||||
|
||||
Pure C/C++ LLM inference with minimal dependencies, optimized for CPUs and non-NVIDIA hardware.
|
||||
Pure C/C++ LLM inference with minimal dependencies, plus the GGUF (GPT-Generated Unified Format) standard used for quantized weights. One toolchain covers conversion, quantization, and serving.
|
||||
|
||||
## When to use llama.cpp
|
||||
## When to use
|
||||
|
||||
**Use llama.cpp when:**
|
||||
- Running on CPU-only machines
|
||||
- Deploying on Apple Silicon (M1/M2/M3/M4)
|
||||
- Using AMD or Intel GPUs (no CUDA)
|
||||
- Edge deployment (Raspberry Pi, embedded systems)
|
||||
- Need simple deployment without Docker/Python
|
||||
**Use llama.cpp + GGUF when:**
|
||||
- Running on CPU-only machines or Apple Silicon (M1/M2/M3/M4) with Metal acceleration
|
||||
- Using AMD (ROCm) or Intel GPUs where CUDA isn't available
|
||||
- Edge deployment (Raspberry Pi, embedded systems, consumer laptops)
|
||||
- Need flexible quantization (2–8 bit with K-quants)
|
||||
- Want local AI tools (LM Studio, Ollama, text-generation-webui, koboldcpp)
|
||||
- Want a single binary deploy without Docker/Python
|
||||
|
||||
**Use TensorRT-LLM instead when:**
|
||||
- Have NVIDIA GPUs (A100/H100)
|
||||
- Need maximum throughput (100K+ tok/s)
|
||||
- Running in datacenter with CUDA
|
||||
**Key advantages:**
|
||||
- Universal hardware: CPU, Apple Silicon, NVIDIA, AMD, Intel
|
||||
- No Python runtime required (pure C/C++)
|
||||
- K-quants + imatrix for better low-bit quality
|
||||
- OpenAI-compatible server built in
|
||||
- Rich ecosystem (Ollama, LM Studio, llama-cpp-python)
|
||||
|
||||
**Use vLLM instead when:**
|
||||
- Have NVIDIA GPUs
|
||||
- Need Python-first API
|
||||
- Want PagedAttention
|
||||
**Use alternatives instead:**
|
||||
- **vLLM** — NVIDIA GPUs, PagedAttention, Python-first, max throughput
|
||||
- **TensorRT-LLM** — Production NVIDIA (A100/H100), maximum speed
|
||||
- **AWQ/GPTQ** — Calibrated quantization for NVIDIA-only deployments
|
||||
- **bitsandbytes** — Simple HuggingFace transformers integration
|
||||
- **HQQ** — Fast calibration-free quantization
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
### Install
|
||||
|
||||
```bash
|
||||
# macOS/Linux
|
||||
# macOS / Linux (simplest)
|
||||
brew install llama.cpp
|
||||
|
||||
# Or build from source
|
||||
git clone https://github.com/ggerganov/llama.cpp
|
||||
git clone https://github.com/ggml-org/llama.cpp
|
||||
cd llama.cpp
|
||||
make
|
||||
make # CPU
|
||||
make GGML_METAL=1 # Apple Silicon
|
||||
make GGML_CUDA=1 # NVIDIA CUDA
|
||||
make LLAMA_HIP=1 # AMD ROCm
|
||||
|
||||
# With Metal (Apple Silicon)
|
||||
make LLAMA_METAL=1
|
||||
|
||||
# With CUDA (NVIDIA)
|
||||
make LLAMA_CUDA=1
|
||||
|
||||
# With ROCm (AMD)
|
||||
make LLAMA_HIP=1
|
||||
# Python bindings (optional)
|
||||
pip install llama-cpp-python
|
||||
# With CUDA: CMAKE_ARGS="-DGGML_CUDA=on" pip install llama-cpp-python --force-reinstall --no-cache-dir
|
||||
# With Metal: CMAKE_ARGS="-DGGML_METAL=on" pip install llama-cpp-python --force-reinstall --no-cache-dir
|
||||
```
|
||||
|
||||
### Download model
|
||||
### Download a pre-quantized GGUF
|
||||
|
||||
```bash
|
||||
# Download from HuggingFace (GGUF format)
|
||||
# TheBloke hosts most popular models pre-quantized
|
||||
huggingface-cli download \
|
||||
TheBloke/Llama-2-7B-Chat-GGUF \
|
||||
llama-2-7b-chat.Q4_K_M.gguf \
|
||||
--local-dir models/
|
||||
```
|
||||
|
||||
# Or convert from HuggingFace
|
||||
python convert_hf_to_gguf.py models/llama-2-7b-chat/
|
||||
### Or convert a HuggingFace model to GGUF
|
||||
|
||||
```bash
|
||||
# 1. Download HF model
|
||||
huggingface-cli download meta-llama/Llama-3.1-8B --local-dir ./llama-3.1-8b
|
||||
|
||||
# 2. Convert to FP16 GGUF
|
||||
python convert_hf_to_gguf.py ./llama-3.1-8b \
|
||||
--outfile llama-3.1-8b-f16.gguf \
|
||||
--outtype f16
|
||||
|
||||
# 3. Quantize to Q4_K_M
|
||||
./llama-quantize llama-3.1-8b-f16.gguf llama-3.1-8b-q4_k_m.gguf Q4_K_M
|
||||
```
|
||||
|
||||
### Run inference
|
||||
|
||||
```bash
|
||||
# Simple chat
|
||||
./llama-cli \
|
||||
-m models/llama-2-7b-chat.Q4_K_M.gguf \
|
||||
-p "Explain quantum computing" \
|
||||
-n 256 # Max tokens
|
||||
# One-shot prompt
|
||||
./llama-cli -m model.Q4_K_M.gguf -p "Explain quantum computing" -n 256
|
||||
|
||||
# Interactive chat
|
||||
./llama-cli \
|
||||
-m models/llama-2-7b-chat.Q4_K_M.gguf \
|
||||
--interactive
|
||||
./llama-cli -m model.Q4_K_M.gguf --interactive
|
||||
|
||||
# With GPU offload
|
||||
./llama-cli -m model.Q4_K_M.gguf -ngl 35 -p "Hello!"
|
||||
```
|
||||
|
||||
### Server mode
|
||||
### Serve an OpenAI-compatible API
|
||||
|
||||
```bash
|
||||
# Start OpenAI-compatible server
|
||||
./llama-server \
|
||||
-m models/llama-2-7b-chat.Q4_K_M.gguf \
|
||||
-m model.Q4_K_M.gguf \
|
||||
--host 0.0.0.0 \
|
||||
--port 8080 \
|
||||
-ngl 32 # Offload 32 layers to GPU
|
||||
-ngl 35 \
|
||||
-c 4096 \
|
||||
--parallel 4 \
|
||||
--cont-batching
|
||||
```
|
||||
|
||||
# Client request
|
||||
```bash
|
||||
curl http://localhost:8080/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "llama-2-7b-chat",
|
||||
"model": "local",
|
||||
"messages": [{"role": "user", "content": "Hello!"}],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 100
|
||||
}'
|
||||
```
|
||||
|
||||
## Quantization formats
|
||||
## Quantization formats (GGUF)
|
||||
|
||||
### GGUF format overview
|
||||
### K-quant methods (recommended)
|
||||
|
||||
| Format | Bits | Size (7B) | Speed | Quality | Use Case |
|
||||
|--------|------|-----------|-------|---------|----------|
|
||||
| **Q4_K_M** | 4.5 | 4.1 GB | Fast | Good | **Recommended default** |
|
||||
| Q4_K_S | 4.3 | 3.9 GB | Faster | Lower | Speed critical |
|
||||
| Q5_K_M | 5.5 | 4.8 GB | Medium | Better | Quality critical |
|
||||
| Q6_K | 6.5 | 5.5 GB | Slower | Best | Maximum quality |
|
||||
| Q8_0 | 8.0 | 7.0 GB | Slow | Excellent | Minimal degradation |
|
||||
| Q2_K | 2.5 | 2.7 GB | Fastest | Poor | Testing only |
|
||||
| Type | Bits | Size (7B) | Quality | Use Case |
|
||||
|------|------|-----------|---------|----------|
|
||||
| Q2_K | 2.5 | ~2.8 GB | Low | Extreme compression (testing only) |
|
||||
| Q3_K_S | 3.0 | ~3.0 GB | Low-Med | Memory constrained |
|
||||
| Q3_K_M | 3.3 | ~3.3 GB | Medium | Fits small devices |
|
||||
| Q4_K_S | 4.0 | ~3.8 GB | Med-High | Speed critical |
|
||||
| **Q4_K_M** | 4.5 | ~4.1 GB | High | **Recommended default** |
|
||||
| Q5_K_S | 5.0 | ~4.6 GB | High | Quality focused |
|
||||
| Q5_K_M | 5.5 | ~4.8 GB | Very High | High quality |
|
||||
| Q6_K | 6.0 | ~5.5 GB | Excellent | Near-original |
|
||||
| Q8_0 | 8.0 | ~7.2 GB | Best | Maximum quality, minimal degradation |
|
||||
|
||||
### Choosing quantization
|
||||
**Variant suffixes** — `_S` (Small, faster, lower quality), `_M` (Medium, balanced), `_L` (Large, better quality).
|
||||
|
||||
**Legacy (Q4_0/Q4_1/Q5_0/Q5_1) exist** but always prefer K-quants for better quality/size ratio.
|
||||
|
||||
**IQ quantization** — ultra-low-bit with importance-aware methods: IQ2_XXS, IQ2_XS, IQ2_S, IQ3_XXS, IQ3_XS, IQ3_S, IQ4_XS. Require `--imatrix`.
|
||||
|
||||
**Task-specific defaults:**
|
||||
- General chat / assistants: Q4_K_M, or Q5_K_M if RAM allows
|
||||
- Code generation: Q5_K_M or Q6_K (higher precision helps)
|
||||
- Technical / medical: Q6_K or Q8_0
|
||||
- Very large (70B, 405B) on consumer hardware: Q3_K_M or Q4_K_S
|
||||
- Raspberry Pi / edge: Q2_K or Q3_K_S
|
||||
|
||||
## Conversion workflows
|
||||
|
||||
### Basic: HF → GGUF → quantized
|
||||
|
||||
```bash
|
||||
# General use (balanced)
|
||||
Q4_K_M # 4-bit, medium quality
|
||||
python convert_hf_to_gguf.py ./model --outfile model-f16.gguf --outtype f16
|
||||
./llama-quantize model-f16.gguf model-q4_k_m.gguf Q4_K_M
|
||||
./llama-cli -m model-q4_k_m.gguf -p "Hello!" -n 50
|
||||
```
|
||||
|
||||
# Maximum speed (more degradation)
|
||||
Q2_K or Q3_K_M
|
||||
### With importance matrix (imatrix) — better low-bit quality
|
||||
|
||||
# Maximum quality (slower)
|
||||
Q6_K or Q8_0
|
||||
`imatrix` gives 10–20% perplexity improvement at Q4, essential at Q3 and below.
|
||||
|
||||
# Very large models (70B, 405B)
|
||||
Q3_K_M or Q4_K_S # Lower bits to fit in memory
|
||||
```bash
|
||||
# 1. Convert to FP16 GGUF
|
||||
python convert_hf_to_gguf.py ./model --outfile model-f16.gguf
|
||||
|
||||
# 2. Prepare calibration data (diverse text, ~100MB is ideal)
|
||||
cat > calibration.txt << 'EOF'
|
||||
The quick brown fox jumps over the lazy dog.
|
||||
Machine learning is a subset of artificial intelligence.
|
||||
# Add more diverse text samples...
|
||||
EOF
|
||||
|
||||
# 3. Generate importance matrix
|
||||
./llama-imatrix -m model-f16.gguf \
|
||||
-f calibration.txt \
|
||||
--chunk 512 \
|
||||
-o model.imatrix \
|
||||
-ngl 35
|
||||
|
||||
# 4. Quantize with imatrix
|
||||
./llama-quantize --imatrix model.imatrix \
|
||||
model-f16.gguf model-q4_k_m.gguf Q4_K_M
|
||||
```
|
||||
|
||||
### Multi-quant batch
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
MODEL="llama-3.1-8b-f16.gguf"
|
||||
IMATRIX="llama-3.1-8b.imatrix"
|
||||
|
||||
./llama-imatrix -m $MODEL -f wiki.txt -o $IMATRIX -ngl 35
|
||||
|
||||
for QUANT in Q4_K_M Q5_K_M Q6_K Q8_0; do
|
||||
OUTPUT="llama-3.1-8b-${QUANT,,}.gguf"
|
||||
./llama-quantize --imatrix $IMATRIX $MODEL $OUTPUT $QUANT
|
||||
echo "Created: $OUTPUT ($(du -h $OUTPUT | cut -f1))"
|
||||
done
|
||||
```
|
||||
|
||||
### Quality testing (perplexity)
|
||||
|
||||
```bash
|
||||
./llama-perplexity -m model.gguf -f wikitext-2-raw/wiki.test.raw -c 512
|
||||
# Baseline FP16: ~5.96 | Q4_K_M: ~6.06 (+1.7%) | Q2_K: ~6.87 (+15.3%)
|
||||
```
|
||||
|
||||
## Python bindings (llama-cpp-python)
|
||||
|
||||
### Basic generation
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(
|
||||
model_path="./model-q4_k_m.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=35, # 0 for CPU only, 99 to offload everything
|
||||
n_threads=8,
|
||||
)
|
||||
|
||||
output = llm(
|
||||
"What is machine learning?",
|
||||
max_tokens=256,
|
||||
temperature=0.7,
|
||||
stop=["</s>", "\n\n"],
|
||||
)
|
||||
print(output["choices"][0]["text"])
|
||||
```
|
||||
|
||||
### Chat completion + streaming
|
||||
|
||||
```python
|
||||
llm = Llama(
|
||||
model_path="./model-q4_k_m.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=35,
|
||||
chat_format="llama-3", # Or "chatml", "mistral", etc.
|
||||
)
|
||||
|
||||
# Non-streaming
|
||||
response = llm.create_chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is Python?"},
|
||||
],
|
||||
max_tokens=256,
|
||||
temperature=0.7,
|
||||
)
|
||||
print(response["choices"][0]["message"]["content"])
|
||||
|
||||
# Streaming
|
||||
for chunk in llm("Explain quantum computing:", max_tokens=256, stream=True):
|
||||
print(chunk["choices"][0]["text"], end="", flush=True)
|
||||
```
|
||||
|
||||
### Embeddings
|
||||
|
||||
```python
|
||||
llm = Llama(model_path="./model-q4_k_m.gguf", embedding=True, n_gpu_layers=35)
|
||||
vec = llm.embed("This is a test sentence.")
|
||||
print(f"Embedding dimension: {len(vec)}")
|
||||
```
|
||||
|
||||
## Hardware acceleration
|
||||
@@ -140,122 +273,166 @@ Q3_K_M or Q4_K_S # Lower bits to fit in memory
|
||||
### Apple Silicon (Metal)
|
||||
|
||||
```bash
|
||||
# Build with Metal
|
||||
make LLAMA_METAL=1
|
||||
|
||||
# Run with GPU acceleration (automatic)
|
||||
./llama-cli -m model.gguf -ngl 999 # Offload all layers
|
||||
|
||||
# Performance: M3 Max 40-60 tokens/sec (Llama 2-7B Q4_K_M)
|
||||
make clean && make GGML_METAL=1
|
||||
./llama-cli -m model.gguf -ngl 99 -p "Hello" # offload all layers
|
||||
```
|
||||
|
||||
### NVIDIA GPUs (CUDA)
|
||||
|
||||
```bash
|
||||
# Build with CUDA
|
||||
make LLAMA_CUDA=1
|
||||
|
||||
# Offload layers to GPU
|
||||
./llama-cli -m model.gguf -ngl 35 # Offload 35/40 layers
|
||||
|
||||
# Hybrid CPU+GPU for large models
|
||||
./llama-cli -m llama-70b.Q4_K_M.gguf -ngl 20 # GPU: 20 layers, CPU: rest
|
||||
```python
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
n_gpu_layers=99, # Offload everything
|
||||
n_threads=1, # Metal handles parallelism
|
||||
)
|
||||
```
|
||||
|
||||
### AMD GPUs (ROCm)
|
||||
Performance: M3 Max ~40–60 tok/s on Llama 2-7B Q4_K_M.
|
||||
|
||||
### NVIDIA (CUDA)
|
||||
|
||||
```bash
|
||||
make clean && make GGML_CUDA=1
|
||||
./llama-cli -m model.gguf -ngl 35 -p "Hello"
|
||||
|
||||
# Hybrid for large models
|
||||
./llama-cli -m llama-70b.Q4_K_M.gguf -ngl 20 # GPU: 20 layers, CPU: rest
|
||||
|
||||
# Multi-GPU split
|
||||
./llama-cli -m large-model.gguf --tensor-split 0.5,0.5 -ngl 60
|
||||
```
|
||||
|
||||
### AMD (ROCm)
|
||||
|
||||
```bash
|
||||
# Build with ROCm
|
||||
make LLAMA_HIP=1
|
||||
|
||||
# Run with AMD GPU
|
||||
./llama-cli -m model.gguf -ngl 999
|
||||
```
|
||||
|
||||
## Common patterns
|
||||
|
||||
### Batch processing
|
||||
### CPU
|
||||
|
||||
```bash
|
||||
# Process multiple prompts from file
|
||||
cat prompts.txt | ./llama-cli \
|
||||
-m model.gguf \
|
||||
--batch-size 512 \
|
||||
-n 100
|
||||
# Match PHYSICAL cores, not logical
|
||||
./llama-cli -m model.gguf -t 8 -p "Hello"
|
||||
|
||||
# BLAS acceleration (2–3× speedup)
|
||||
make LLAMA_OPENBLAS=1
|
||||
```
|
||||
|
||||
### Constrained generation
|
||||
|
||||
```bash
|
||||
# JSON output with grammar
|
||||
./llama-cli \
|
||||
-m model.gguf \
|
||||
-p "Generate a person: " \
|
||||
--grammar-file grammars/json.gbnf
|
||||
|
||||
# Outputs valid JSON only
|
||||
```
|
||||
|
||||
### Context size
|
||||
|
||||
```bash
|
||||
# Increase context (default 512)
|
||||
./llama-cli \
|
||||
-m model.gguf \
|
||||
-c 4096 # 4K context window
|
||||
|
||||
# Very long context (if model supports)
|
||||
./llama-cli -m model.gguf -c 32768 # 32K context
|
||||
```python
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
n_gpu_layers=0,
|
||||
n_threads=8,
|
||||
n_batch=512, # Larger batch = faster prompt processing
|
||||
)
|
||||
```
|
||||
|
||||
## Performance benchmarks
|
||||
|
||||
### CPU performance (Llama 2-7B Q4_K_M)
|
||||
### CPU (Llama 2-7B Q4_K_M)
|
||||
|
||||
| CPU | Threads | Speed | Cost |
|
||||
|-----|---------|-------|------|
|
||||
| Apple M3 Max | 16 | 50 tok/s | $0 (local) |
|
||||
| AMD Ryzen 9 7950X | 32 | 35 tok/s | $0.50/hour |
|
||||
| Intel i9-13900K | 32 | 30 tok/s | $0.40/hour |
|
||||
| AWS c7i.16xlarge | 64 | 40 tok/s | $2.88/hour |
|
||||
| CPU | Threads | Speed |
|
||||
|-----|---------|-------|
|
||||
| Apple M3 Max (Metal) | 16 | 50 tok/s |
|
||||
| AMD Ryzen 9 7950X | 32 | 35 tok/s |
|
||||
| Intel i9-13900K | 32 | 30 tok/s |
|
||||
|
||||
### GPU acceleration (Llama 2-7B Q4_K_M)
|
||||
### GPU offloading on RTX 4090
|
||||
|
||||
| GPU | Speed | vs CPU | Cost |
|
||||
|-----|-------|--------|------|
|
||||
| NVIDIA RTX 4090 | 120 tok/s | 3-4× | $0 (local) |
|
||||
| NVIDIA A10 | 80 tok/s | 2-3× | $1.00/hour |
|
||||
| AMD MI250 | 70 tok/s | 2× | $2.00/hour |
|
||||
| Apple M3 Max (Metal) | 50 tok/s | ~Same | $0 (local) |
|
||||
| Layers GPU | Speed | VRAM |
|
||||
|------------|-------|------|
|
||||
| 0 (CPU only) | 30 tok/s | 0 GB |
|
||||
| 20 (hybrid) | 80 tok/s | 8 GB |
|
||||
| 35 (all) | 120 tok/s | 12 GB |
|
||||
|
||||
## Supported models
|
||||
|
||||
**LLaMA family**:
|
||||
- Llama 2 (7B, 13B, 70B)
|
||||
- Llama 3 (8B, 70B, 405B)
|
||||
- Code Llama
|
||||
- **LLaMA family**: Llama 2 (7B/13B/70B), Llama 3 (8B/70B/405B), Code Llama
|
||||
- **Mistral family**: Mistral 7B, Mixtral 8x7B/8x22B
|
||||
- **Other**: Falcon, BLOOM, GPT-J, Phi-3, Gemma, Qwen, LLaVA (vision), Whisper (audio)
|
||||
|
||||
**Mistral family**:
|
||||
- Mistral 7B
|
||||
- Mixtral 8x7B, 8x22B
|
||||
Find GGUF models: https://huggingface.co/models?library=gguf
|
||||
|
||||
**Other**:
|
||||
- Falcon, BLOOM, GPT-J
|
||||
- Phi-3, Gemma, Qwen
|
||||
- LLaVA (vision), Whisper (audio)
|
||||
## Ecosystem integrations
|
||||
|
||||
**Find models**: https://huggingface.co/models?library=gguf
|
||||
### Ollama
|
||||
|
||||
```bash
|
||||
cat > Modelfile << 'EOF'
|
||||
FROM ./model-q4_k_m.gguf
|
||||
TEMPLATE """{{ .System }}
|
||||
{{ .Prompt }}"""
|
||||
PARAMETER temperature 0.7
|
||||
PARAMETER num_ctx 4096
|
||||
EOF
|
||||
|
||||
ollama create mymodel -f Modelfile
|
||||
ollama run mymodel "Hello!"
|
||||
```
|
||||
|
||||
### LM Studio
|
||||
|
||||
1. Place GGUF file in `~/.cache/lm-studio/models/`
|
||||
2. Open LM Studio and select the model
|
||||
3. Configure context length and GPU offload, start inference
|
||||
|
||||
### text-generation-webui
|
||||
|
||||
```bash
|
||||
cp model-q4_k_m.gguf text-generation-webui/models/
|
||||
python server.py --model model-q4_k_m.gguf --loader llama.cpp --n-gpu-layers 35
|
||||
```
|
||||
|
||||
### OpenAI client → llama-server
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(base_url="http://localhost:8080/v1", api_key="not-needed")
|
||||
response = client.chat.completions.create(
|
||||
model="local-model",
|
||||
messages=[{"role": "user", "content": "Hello!"}],
|
||||
max_tokens=256,
|
||||
)
|
||||
print(response.choices[0].message.content)
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Use K-quants** — Q4_K_M is the recommended default
|
||||
2. **Use imatrix** for Q4 and below (calibration improves quality substantially)
|
||||
3. **Offload as many layers as VRAM allows** — start high, reduce by 5 on OOM
|
||||
4. **Thread count** — match physical cores, not logical
|
||||
5. **Batch size** — increase `n_batch` (e.g. 512) for faster prompt processing
|
||||
6. **Context** — start at 4096, grow only as needed (memory scales with ctx)
|
||||
7. **Flash Attention** — add `--flash-attn` if your build supports it
|
||||
|
||||
## Common issues (quick fixes)
|
||||
|
||||
**Model loads slowly** — use `--mmap` for memory-mapped loading.
|
||||
|
||||
**Out of memory (GPU)** — reduce `-ngl`, use a smaller quant (Q4_K_S / Q3_K_M), or quantize the KV cache:
|
||||
```python
|
||||
Llama(model_path="...", type_k=2, type_v=2, n_gpu_layers=35) # Q4_0 KV cache
|
||||
```
|
||||
|
||||
**Garbage output** — wrong `chat_format`, temperature too high, or model file corrupted. Test with `temperature=0.1` and verify FP16 baseline works.
|
||||
|
||||
**Connection refused (server)** — bind to `--host 0.0.0.0`, check `lsof -i :8080`.
|
||||
|
||||
See `references/troubleshooting.md` for the full playbook.
|
||||
|
||||
## References
|
||||
|
||||
- **[Quantization Guide](references/quantization.md)** - GGUF formats, conversion, quality comparison
|
||||
- **[Server Deployment](references/server.md)** - API endpoints, Docker, monitoring
|
||||
- **[Optimization](references/optimization.md)** - Performance tuning, hybrid CPU+GPU
|
||||
- **[advanced-usage.md](references/advanced-usage.md)** — speculative decoding, batched inference, grammar-constrained generation, LoRA, multi-GPU, custom builds, benchmark scripts
|
||||
- **[quantization.md](references/quantization.md)** — perplexity tables, use-case guide, model size scaling (7B/13B/70B RAM needs), imatrix deep dive
|
||||
- **[server.md](references/server.md)** — OpenAI API endpoints, Docker deployment, NGINX load balancing, monitoring
|
||||
- **[optimization.md](references/optimization.md)** — CPU threading, BLAS, GPU offload heuristics, batch tuning, benchmarks
|
||||
- **[troubleshooting.md](references/troubleshooting.md)** — install/convert/quantize/inference/server issues, Apple Silicon, debugging
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/ggerganov/llama.cpp
|
||||
- **Models**: https://huggingface.co/models?library=gguf
|
||||
- **Discord**: https://discord.gg/llama-cpp
|
||||
|
||||
|
||||
- **GitHub**: https://github.com/ggml-org/llama.cpp
|
||||
- **Python bindings**: https://github.com/abetlen/llama-cpp-python
|
||||
- **Pre-quantized models**: https://huggingface.co/TheBloke
|
||||
- **GGUF converter Space**: https://huggingface.co/spaces/ggml-org/gguf-my-repo
|
||||
- **License**: MIT
|
||||
|
||||
@@ -1,97 +0,0 @@
|
||||
# GRPO/RL Training Skill
|
||||
|
||||
**Expert-level guidance for Group Relative Policy Optimization with TRL**
|
||||
|
||||
## 📁 Skill Structure
|
||||
|
||||
```
|
||||
grpo-rl-training/
|
||||
├── SKILL.md # Main skill documentation (READ THIS FIRST)
|
||||
├── README.md # This file
|
||||
├── templates/
|
||||
│ └── basic_grpo_training.py # Production-ready training template
|
||||
└── examples/
|
||||
└── reward_functions_library.py # 20+ reward function examples
|
||||
```
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
1. **Read SKILL.md** - Comprehensive guide with all concepts and patterns
|
||||
2. **Copy `templates/basic_grpo_training.py`** - Start with working code
|
||||
3. **Browse `examples/reward_functions_library.py`** - Pick reward functions for your task
|
||||
4. **Modify for your use case** - Adapt dataset, rewards, and config
|
||||
|
||||
## 💡 What's Inside
|
||||
|
||||
### SKILL.md (Main Documentation)
|
||||
- Core GRPO concepts and algorithm fundamentals
|
||||
- Complete implementation workflow (dataset → rewards → training → deployment)
|
||||
- 10+ reward function examples with code
|
||||
- Hyperparameter tuning guide
|
||||
- Training insights (loss behavior, metrics, debugging)
|
||||
- Troubleshooting guide
|
||||
- Production best practices
|
||||
|
||||
### Templates
|
||||
- **basic_grpo_training.py**: Minimal, production-ready training script
|
||||
- Uses Qwen 2.5 1.5B Instruct
|
||||
- 3 reward functions (format + correctness)
|
||||
- LoRA for efficient training
|
||||
- Fully documented and ready to run
|
||||
|
||||
### Examples
|
||||
- **reward_functions_library.py**: 20+ battle-tested reward functions
|
||||
- Correctness rewards (exact match, fuzzy match, numeric, code execution)
|
||||
- Format rewards (XML, JSON, strict/soft)
|
||||
- Length rewards (ideal length, min/max)
|
||||
- Style rewards (reasoning quality, citations, repetition penalty)
|
||||
- Combined rewards (multi-objective optimization)
|
||||
- Preset collections for common tasks
|
||||
|
||||
## 📖 Usage for Agents
|
||||
|
||||
When this skill is loaded in your agent's context:
|
||||
|
||||
1. **Always read SKILL.md first** before implementing
|
||||
2. **Start simple** - Use length-based reward to validate setup
|
||||
3. **Build incrementally** - Add one reward function at a time
|
||||
4. **Reference examples** - Copy patterns from reward_functions_library.py
|
||||
5. **Monitor training** - Watch reward metrics (not loss!)
|
||||
|
||||
## 🎯 Common Use Cases
|
||||
|
||||
| Task Type | Recommended Rewards | Template |
|
||||
|-----------|---------------------|----------|
|
||||
| Math reasoning | `MATH_REASONING_REWARDS` preset | basic_grpo_training.py |
|
||||
| Code generation | `CODE_GENERATION_REWARDS` preset | Modify dataset in template |
|
||||
| Summarization | `SUMMARIZATION_REWARDS` preset | Adjust prompts + rewards |
|
||||
| Q&A | `QA_REWARDS` preset | Use fuzzy match + citations |
|
||||
|
||||
## ⚠️ Critical Reminders
|
||||
|
||||
- **Loss goes UP during training** - This is normal (it's KL divergence)
|
||||
- **Use 3-5 reward functions** - Single rewards often fail
|
||||
- **Test rewards before training** - Debug each function independently
|
||||
- **Monitor reward_std** - Should stay > 0.1 (avoid mode collapse)
|
||||
- **Start with num_generations=4-8** - Scale up if GPU allows
|
||||
|
||||
## 🔗 External Resources
|
||||
|
||||
- [TRL Documentation](https://huggingface.co/docs/trl)
|
||||
- [DeepSeek R1 Paper](https://arxiv.org/abs/2501.12948)
|
||||
- [Open R1 Implementation](https://github.com/huggingface/open-r1)
|
||||
- [Unsloth (2-3x faster)](https://docs.unsloth.ai/)
|
||||
|
||||
## 📝 Version
|
||||
|
||||
**v1.0.0** - Initial release (January 2025)
|
||||
|
||||
## 👨💻 Maintained By
|
||||
|
||||
Orchestra Research
|
||||
For questions or improvements, see https://orchestra.com
|
||||
|
||||
---
|
||||
|
||||
**License:** MIT
|
||||
**Last Updated:** January 2025
|
||||
@@ -252,6 +252,8 @@ trl dpo \
|
||||
|
||||
Train with reinforcement learning using minimal memory.
|
||||
|
||||
For in-depth GRPO guidance — reward function design, critical training insights (loss behavior, mode collapse, tuning), and advanced multi-stage patterns — see **[references/grpo-training.md](references/grpo-training.md)**. A production-ready training script is in **[templates/basic_grpo_training.py](templates/basic_grpo_training.py)**.
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
@@ -428,6 +430,8 @@ config = PPOConfig(
|
||||
|
||||
**Online RL methods**: See [references/online-rl.md](references/online-rl.md) for PPO, GRPO, RLOO, and OnlineDPO with detailed configurations.
|
||||
|
||||
**GRPO deep dive**: See [references/grpo-training.md](references/grpo-training.md) for expert-level GRPO patterns — reward function design philosophy, training insights (why loss increases, mode collapse detection), hyperparameter tuning, multi-stage training, and troubleshooting. Production-ready template in [templates/basic_grpo_training.py](templates/basic_grpo_training.py).
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **GPU**: NVIDIA (CUDA required)
|
||||
|
||||
+129
-200
@@ -1,51 +1,36 @@
|
||||
---
|
||||
name: grpo-rl-training
|
||||
description: Expert guidance for GRPO/RL fine-tuning with TRL for reasoning and task-specific model training
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [transformers>=4.47.0, trl>=0.14.0, datasets>=3.2.0, peft>=0.14.0, torch]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Post-Training, Reinforcement Learning, GRPO, TRL, RLHF, Reward Modeling, Reasoning, DPO, PPO, Structured Output]
|
||||
# GRPO (Group Relative Policy Optimization) — Deep Guide
|
||||
|
||||
---
|
||||
Expert-level patterns, critical insights, and production-ready workflows for fine-tuning language models with custom reward functions using TRL's `GRPOTrainer`. This is the deep reference for the GRPO workflow summarized in the main skill.
|
||||
|
||||
# GRPO/RL Training with TRL
|
||||
## When to use GRPO
|
||||
|
||||
Expert-level guidance for implementing Group Relative Policy Optimization (GRPO) using the Transformer Reinforcement Learning (TRL) library. This skill provides battle-tested patterns, critical insights, and production-ready workflows for fine-tuning language models with custom reward functions.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use GRPO training when you need to:
|
||||
- **Enforce specific output formats** (e.g., XML tags, JSON, structured reasoning)
|
||||
Use GRPO when you need to:
|
||||
- **Enforce specific output formats** (XML tags, JSON, structured reasoning)
|
||||
- **Teach verifiable tasks** with objective correctness metrics (math, coding, fact-checking)
|
||||
- **Improve reasoning capabilities** by rewarding chain-of-thought patterns
|
||||
- **Align models to domain-specific behaviors** without labeled preference data
|
||||
- **Optimize for multiple objectives** simultaneously (format + correctness + style)
|
||||
|
||||
**Do NOT use GRPO for:**
|
||||
- Simple supervised fine-tuning tasks (use SFT instead)
|
||||
- Simple supervised fine-tuning tasks → use SFT
|
||||
- Tasks without clear reward signals
|
||||
- When you already have high-quality preference pairs (use DPO/PPO instead)
|
||||
- When you already have high-quality preference pairs → use DPO/PPO
|
||||
|
||||
---
|
||||
## Core concepts
|
||||
|
||||
## Core Concepts
|
||||
### 1. GRPO algorithm fundamentals
|
||||
|
||||
### 1. GRPO Algorithm Fundamentals
|
||||
|
||||
**Key Mechanism:**
|
||||
- Generates **multiple completions** for each prompt (group size: 4-16)
|
||||
**Key mechanism:**
|
||||
- Generates **multiple completions** per prompt (group size: 4–16)
|
||||
- Compares completions within each group using reward functions
|
||||
- Updates policy to favor higher-rewarded responses relative to the group
|
||||
|
||||
**Critical Difference from PPO:**
|
||||
**Critical differences from PPO:**
|
||||
- No separate reward model needed
|
||||
- More sample-efficient (learns from within-group comparisons)
|
||||
- Simpler to implement and debug
|
||||
|
||||
**Mathematical Intuition:**
|
||||
**Mathematical intuition:**
|
||||
```
|
||||
For each prompt p:
|
||||
1. Generate N completions: {c₁, c₂, ..., cₙ}
|
||||
@@ -54,35 +39,32 @@ For each prompt p:
|
||||
relative to low-reward ones in the same group
|
||||
```
|
||||
|
||||
### 2. Reward Function Design Philosophy
|
||||
### 2. Reward function design philosophy
|
||||
|
||||
**Golden Rules:**
|
||||
1. **Compose multiple reward functions** - Each handles one aspect (format, correctness, style)
|
||||
2. **Scale rewards appropriately** - Higher weight = stronger signal
|
||||
3. **Use incremental rewards** - Partial credit for partial compliance
|
||||
4. **Test rewards independently** - Debug each reward function in isolation
|
||||
**Golden rules:**
|
||||
1. **Compose multiple reward functions** — each handles one aspect (format, correctness, style)
|
||||
2. **Scale rewards appropriately** — higher weight = stronger signal
|
||||
3. **Use incremental rewards** — partial credit for partial compliance
|
||||
4. **Test rewards independently** — debug each reward function in isolation
|
||||
|
||||
**Reward Function Types:**
|
||||
**Reward function types:**
|
||||
|
||||
| Type | Use Case | Example Weight |
|
||||
|------|----------|----------------|
|
||||
| **Correctness** | Verifiable tasks (math, code) | 2.0 (highest) |
|
||||
| **Format** | Strict structure enforcement | 0.5-1.0 |
|
||||
| **Length** | Encourage verbosity/conciseness | 0.1-0.5 |
|
||||
| **Style** | Penalize unwanted patterns | -0.5 to 0.5 |
|
||||
| **Format** | Strict structure enforcement | 0.5–1.0 |
|
||||
| **Length** | Encourage verbosity/conciseness | 0.1–0.5 |
|
||||
| **Style** | Penalize unwanted patterns | −0.5 to 0.5 |
|
||||
|
||||
---
|
||||
## Implementation workflow
|
||||
|
||||
## Implementation Workflow
|
||||
### Step 1: Dataset preparation
|
||||
|
||||
### Step 1: Dataset Preparation
|
||||
|
||||
**Critical Requirements:**
|
||||
- Prompts in chat format (list of dicts with 'role' and 'content')
|
||||
**Critical requirements:**
|
||||
- Prompts in chat format (list of dicts with `role` and `content`)
|
||||
- Include system prompts to set expectations
|
||||
- For verifiable tasks, include ground truth answers as additional columns
|
||||
|
||||
**Example Structure:**
|
||||
```python
|
||||
from datasets import load_dataset, Dataset
|
||||
|
||||
@@ -97,8 +79,7 @@ Respond in the following format:
|
||||
"""
|
||||
|
||||
def prepare_dataset(raw_data):
|
||||
"""
|
||||
Transform raw data into GRPO-compatible format.
|
||||
"""Transform raw data into GRPO-compatible format.
|
||||
|
||||
Returns: Dataset with columns:
|
||||
- 'prompt': List[Dict] with role/content (system + user messages)
|
||||
@@ -113,14 +94,14 @@ def prepare_dataset(raw_data):
|
||||
})
|
||||
```
|
||||
|
||||
**Pro Tips:**
|
||||
- Use one-shot or few-shot examples in system prompt for complex formats
|
||||
- Keep prompts concise (max_prompt_length: 256-512 tokens)
|
||||
**Pro tips:**
|
||||
- Use one-shot or few-shot examples in the system prompt for complex formats
|
||||
- Keep prompts concise (max_prompt_length: 256–512 tokens)
|
||||
- Validate data quality before training (garbage in = garbage out)
|
||||
|
||||
### Step 2: Reward Function Implementation
|
||||
### Step 2: Reward function implementation
|
||||
|
||||
**Template Structure:**
|
||||
**Template structure:**
|
||||
```python
|
||||
def reward_function_name(
|
||||
prompts, # List[List[Dict]]: Original prompts
|
||||
@@ -128,24 +109,16 @@ def reward_function_name(
|
||||
answer=None, # Optional: Ground truth from dataset
|
||||
**kwargs # Additional dataset columns
|
||||
) -> list[float]:
|
||||
"""
|
||||
Evaluate completions and return rewards.
|
||||
|
||||
Returns: List of floats (one per completion)
|
||||
"""
|
||||
# Extract completion text
|
||||
"""Evaluate completions and return rewards (one per completion)."""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
|
||||
# Compute rewards
|
||||
rewards = []
|
||||
for response in responses:
|
||||
score = compute_score(response)
|
||||
rewards.append(score)
|
||||
|
||||
return rewards
|
||||
```
|
||||
|
||||
**Example 1: Correctness Reward (Math/Coding)**
|
||||
**Example 1: correctness reward (math/coding)**
|
||||
```python
|
||||
def correctness_reward(prompts, completions, answer, **kwargs):
|
||||
"""Reward correct answers with high score."""
|
||||
@@ -155,7 +128,7 @@ def correctness_reward(prompts, completions, answer, **kwargs):
|
||||
for ans, gt in zip(extracted, answer)]
|
||||
```
|
||||
|
||||
**Example 2: Format Reward (Structured Output)**
|
||||
**Example 2: format reward (structured output)**
|
||||
```python
|
||||
import re
|
||||
|
||||
@@ -167,7 +140,7 @@ def format_reward(completions, **kwargs):
|
||||
for r in responses]
|
||||
```
|
||||
|
||||
**Example 3: Incremental Format Reward (Partial Credit)**
|
||||
**Example 3: incremental format reward (partial credit)**
|
||||
```python
|
||||
def incremental_format_reward(completions, **kwargs):
|
||||
"""Award partial credit for format compliance."""
|
||||
@@ -176,14 +149,10 @@ def incremental_format_reward(completions, **kwargs):
|
||||
|
||||
for r in responses:
|
||||
score = 0.0
|
||||
if '<reasoning>' in r:
|
||||
score += 0.25
|
||||
if '</reasoning>' in r:
|
||||
score += 0.25
|
||||
if '<answer>' in r:
|
||||
score += 0.25
|
||||
if '</answer>' in r:
|
||||
score += 0.25
|
||||
if '<reasoning>' in r: score += 0.25
|
||||
if '</reasoning>' in r: score += 0.25
|
||||
if '<answer>' in r: score += 0.25
|
||||
if '</answer>' in r: score += 0.25
|
||||
# Penalize extra text after closing tag
|
||||
if r.count('</answer>') == 1:
|
||||
extra_text = r.split('</answer>')[-1].strip()
|
||||
@@ -193,12 +162,11 @@ def incremental_format_reward(completions, **kwargs):
|
||||
return rewards
|
||||
```
|
||||
|
||||
**Critical Insight:**
|
||||
Combine 3-5 reward functions for robust training. Order matters less than diversity of signals.
|
||||
**Critical insight:** Combine 3–5 reward functions for robust training. Order matters less than diversity of signals.
|
||||
|
||||
### Step 3: Training Configuration
|
||||
### Step 3: Training configuration
|
||||
|
||||
**Memory-Optimized Config (Small GPU)**
|
||||
**Memory-optimized config (small GPU)**
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
@@ -218,13 +186,13 @@ training_args = GRPOConfig(
|
||||
gradient_accumulation_steps=4, # Effective batch = 4
|
||||
|
||||
# GRPO-specific
|
||||
num_generations=8, # Group size: 8-16 recommended
|
||||
num_generations=8, # Group size: 8–16 recommended
|
||||
max_prompt_length=256,
|
||||
max_completion_length=512,
|
||||
|
||||
# Training duration
|
||||
num_train_epochs=1,
|
||||
max_steps=None, # Or set fixed steps (e.g., 500)
|
||||
max_steps=None,
|
||||
|
||||
# Optimization
|
||||
bf16=True, # Faster on A100/H100
|
||||
@@ -234,11 +202,11 @@ training_args = GRPOConfig(
|
||||
# Logging
|
||||
logging_steps=1,
|
||||
save_steps=100,
|
||||
report_to="wandb", # Or "none" for no logging
|
||||
report_to="wandb",
|
||||
)
|
||||
```
|
||||
|
||||
**High-Performance Config (Large GPU)**
|
||||
**High-performance config (large GPU)**
|
||||
```python
|
||||
training_args = GRPOConfig(
|
||||
output_dir="outputs/grpo-model",
|
||||
@@ -255,31 +223,30 @@ training_args = GRPOConfig(
|
||||
)
|
||||
```
|
||||
|
||||
**Critical Hyperparameters:**
|
||||
**Critical hyperparameters:**
|
||||
|
||||
| Parameter | Impact | Tuning Advice |
|
||||
|-----------|--------|---------------|
|
||||
| `num_generations` | Group size for comparison | Start with 8, increase to 16 if GPU allows |
|
||||
| `num_generations` | Group size for comparison | Start 8, increase to 16 if GPU allows |
|
||||
| `learning_rate` | Convergence speed/stability | 5e-6 (safe), 1e-5 (faster, riskier) |
|
||||
| `max_completion_length` | Output verbosity | Match your task (512 for reasoning, 256 for short answers) |
|
||||
| `max_completion_length` | Output verbosity | Match your task (512 reasoning, 256 short answers) |
|
||||
| `gradient_accumulation_steps` | Effective batch size | Increase if GPU memory limited |
|
||||
|
||||
### Step 4: Model Setup and Training
|
||||
### Step 4: Model setup and training
|
||||
|
||||
**Standard Setup (Transformers)**
|
||||
**Standard setup (Transformers + TRL)**
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import LoraConfig
|
||||
from trl import GRPOTrainer
|
||||
|
||||
# Load model
|
||||
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2", # 2-3x faster
|
||||
device_map="auto"
|
||||
attn_implementation="flash_attention_2", # 2–3× faster
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
@@ -287,17 +254,16 @@ tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# Optional: LoRA for parameter-efficient training
|
||||
peft_config = LoraConfig(
|
||||
r=16, # Rank (higher = more capacity)
|
||||
lora_alpha=32, # Scaling factor (typically 2*r)
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=[
|
||||
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj"
|
||||
"gate_proj", "up_proj", "down_proj",
|
||||
],
|
||||
task_type="CAUSAL_LM",
|
||||
lora_dropout=0.05,
|
||||
)
|
||||
|
||||
# Initialize trainer
|
||||
trainer = GRPOTrainer(
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
@@ -308,17 +274,14 @@ trainer = GRPOTrainer(
|
||||
],
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
peft_config=peft_config, # Remove for full fine-tuning
|
||||
peft_config=peft_config, # Remove for full fine-tuning
|
||||
)
|
||||
|
||||
# Train
|
||||
trainer.train()
|
||||
|
||||
# Save
|
||||
trainer.save_model("final_model")
|
||||
```
|
||||
|
||||
**Unsloth Setup (2-3x Faster)**
|
||||
**Unsloth setup (2–3× faster)**
|
||||
```python
|
||||
from unsloth import FastLanguageModel
|
||||
|
||||
@@ -339,28 +302,26 @@ model = FastLanguageModel.get_peft_model(
|
||||
use_gradient_checkpointing="unsloth",
|
||||
)
|
||||
|
||||
# Rest is identical to standard setup
|
||||
# Rest is identical to the standard setup
|
||||
trainer = GRPOTrainer(model=model, ...)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
---
|
||||
## Critical training insights
|
||||
|
||||
## Critical Training Insights
|
||||
### 1. Loss behavior (EXPECTED pattern)
|
||||
- **Loss starts near 0 and INCREASES during training** — this is CORRECT
|
||||
- Loss measures KL divergence from initial policy; the model is learning (diverging from original behavior to optimize rewards)
|
||||
- **Monitor reward metrics, not loss, for progress**
|
||||
|
||||
### 1. Loss Behavior (EXPECTED PATTERN)
|
||||
- **Loss starts near 0 and INCREASES during training**
|
||||
- This is CORRECT - loss measures KL divergence from initial policy
|
||||
- Model is learning (diverging from original behavior to optimize rewards)
|
||||
- Monitor reward metrics instead of loss for progress
|
||||
### 2. Reward tracking
|
||||
|
||||
### 2. Reward Tracking
|
||||
Key metrics to watch:
|
||||
- `reward`: Average across all completions
|
||||
- `reward_std`: Diversity within groups (should remain > 0)
|
||||
- `kl`: KL divergence from reference (should grow moderately)
|
||||
- `reward` — average across all completions
|
||||
- `reward_std` — diversity within groups (should remain > 0)
|
||||
- `kl` — KL divergence from reference (should grow moderately)
|
||||
|
||||
**Healthy Training Pattern:**
|
||||
**Healthy pattern:**
|
||||
```
|
||||
Step Reward Reward_Std KL
|
||||
100 0.5 0.3 0.02
|
||||
@@ -369,12 +330,12 @@ Step Reward Reward_Std KL
|
||||
400 1.5 0.15 0.12
|
||||
```
|
||||
|
||||
**Warning Signs:**
|
||||
- Reward std → 0 (model collapsing to single response)
|
||||
- KL exploding (> 0.5) (diverging too much, reduce LR)
|
||||
- Reward stuck (reward functions too harsh or model capacity issue)
|
||||
**Warning signs:**
|
||||
- `reward_std` → 0 (model collapsing to a single response)
|
||||
- `kl` exploding (> 0.5) — diverging too much, reduce LR
|
||||
- Reward stuck — reward functions too harsh or model capacity issue
|
||||
|
||||
### 3. Common Pitfalls and Solutions
|
||||
### 3. Common pitfalls and solutions
|
||||
|
||||
| Problem | Symptom | Solution |
|
||||
|---------|---------|----------|
|
||||
@@ -384,15 +345,14 @@ Step Reward Reward_Std KL
|
||||
| **Slow training** | < 1 it/s | Enable `use_vllm=True`, use Unsloth, reduce seq length |
|
||||
| **Format ignored** | Model doesn't follow structure | Increase format reward weight, add incremental rewards |
|
||||
|
||||
---
|
||||
## Advanced patterns
|
||||
|
||||
## Advanced Patterns
|
||||
### 1. Multi-stage training
|
||||
|
||||
### 1. Multi-Stage Training
|
||||
For complex tasks, train in stages:
|
||||
|
||||
```python
|
||||
# Stage 1: Format compliance (epochs=1)
|
||||
# Stage 1: Format compliance
|
||||
trainer_stage1 = GRPOTrainer(
|
||||
model=model,
|
||||
reward_funcs=[incremental_format_reward, format_reward],
|
||||
@@ -400,7 +360,7 @@ trainer_stage1 = GRPOTrainer(
|
||||
)
|
||||
trainer_stage1.train()
|
||||
|
||||
# Stage 2: Correctness (epochs=1)
|
||||
# Stage 2: Correctness
|
||||
trainer_stage2 = GRPOTrainer(
|
||||
model=model,
|
||||
reward_funcs=[format_reward, correctness_reward],
|
||||
@@ -409,7 +369,8 @@ trainer_stage2 = GRPOTrainer(
|
||||
trainer_stage2.train()
|
||||
```
|
||||
|
||||
### 2. Adaptive Reward Scaling
|
||||
### 2. Adaptive reward scaling
|
||||
|
||||
```python
|
||||
class AdaptiveReward:
|
||||
def __init__(self, base_reward_func, initial_weight=1.0):
|
||||
@@ -428,148 +389,116 @@ class AdaptiveReward:
|
||||
self.weight *= 0.9
|
||||
```
|
||||
|
||||
### 3. Custom Dataset Integration
|
||||
### 3. Custom dataset integration
|
||||
|
||||
```python
|
||||
def load_custom_knowledge_base(csv_path):
|
||||
"""Example: School communication platform docs."""
|
||||
import pandas as pd
|
||||
df = pd.read_csv(csv_path)
|
||||
|
||||
dataset = Dataset.from_pandas(df).map(lambda x: {
|
||||
return Dataset.from_pandas(df).map(lambda x: {
|
||||
'prompt': [
|
||||
{'role': 'system', 'content': CUSTOM_SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': x['question']}
|
||||
],
|
||||
'answer': x['expert_answer']
|
||||
})
|
||||
return dataset
|
||||
```
|
||||
|
||||
---
|
||||
## Deployment and inference
|
||||
|
||||
## Deployment and Inference
|
||||
|
||||
### Save and Merge LoRA
|
||||
### Save and merge LoRA
|
||||
```python
|
||||
# Merge LoRA adapters into base model
|
||||
if hasattr(trainer.model, 'merge_and_unload'):
|
||||
merged_model = trainer.model.merge_and_unload()
|
||||
merged_model.save_pretrained("production_model")
|
||||
tokenizer.save_pretrained("production_model")
|
||||
```
|
||||
|
||||
### Inference Example
|
||||
### Inference
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
generator = pipeline(
|
||||
"text-generation",
|
||||
model="production_model",
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
generator = pipeline("text-generation", model="production_model", tokenizer=tokenizer)
|
||||
|
||||
result = generator(
|
||||
[
|
||||
{'role': 'system', 'content': SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': "What is 15 + 27?"}
|
||||
{'role': 'user', 'content': "What is 15 + 27?"},
|
||||
],
|
||||
max_new_tokens=256,
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
top_p=0.9
|
||||
top_p=0.9,
|
||||
)
|
||||
print(result[0]['generated_text'])
|
||||
```
|
||||
|
||||
---
|
||||
## Best practices checklist
|
||||
|
||||
## Best Practices Checklist
|
||||
|
||||
**Before Training:**
|
||||
**Before training:**
|
||||
- [ ] Validate dataset format (prompts as List[Dict])
|
||||
- [ ] Test reward functions on sample data
|
||||
- [ ] Calculate expected max_prompt_length from data
|
||||
- [ ] Choose appropriate num_generations based on GPU memory
|
||||
- [ ] Calculate expected `max_prompt_length` from data
|
||||
- [ ] Choose `num_generations` based on GPU memory
|
||||
- [ ] Set up logging (wandb recommended)
|
||||
|
||||
**During Training:**
|
||||
**During training:**
|
||||
- [ ] Monitor reward progression (should increase)
|
||||
- [ ] Check reward_std (should stay > 0.1)
|
||||
- [ ] Check `reward_std` (should stay > 0.1)
|
||||
- [ ] Watch for OOM errors (reduce batch size if needed)
|
||||
- [ ] Sample generations every 50-100 steps
|
||||
- [ ] Sample generations every 50–100 steps
|
||||
- [ ] Validate format compliance on holdout set
|
||||
|
||||
**After Training:**
|
||||
**After training:**
|
||||
- [ ] Merge LoRA weights if using PEFT
|
||||
- [ ] Test on diverse prompts
|
||||
- [ ] Compare to baseline model
|
||||
- [ ] Document reward weights and hyperparameters
|
||||
- [ ] Save reproducibility config
|
||||
|
||||
---
|
||||
## Troubleshooting
|
||||
|
||||
## Troubleshooting Guide
|
||||
### Debugging workflow
|
||||
1. **Isolate reward functions** — test each independently
|
||||
2. **Check data distribution** — ensure diversity in prompts
|
||||
3. **Reduce complexity** — start with single reward, add gradually
|
||||
4. **Monitor generations** — print samples every N steps
|
||||
5. **Validate extraction logic** — ensure answer parsing works
|
||||
|
||||
### Debugging Workflow
|
||||
1. **Isolate reward functions** - Test each independently
|
||||
2. **Check data distribution** - Ensure diversity in prompts
|
||||
3. **Reduce complexity** - Start with single reward, add gradually
|
||||
4. **Monitor generations** - Print samples every N steps
|
||||
5. **Validate extraction logic** - Ensure answer parsing works
|
||||
|
||||
### Quick Fixes
|
||||
### Quick debug reward
|
||||
```python
|
||||
# Debug reward function
|
||||
def debug_reward(completions, **kwargs):
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
for i, r in enumerate(responses[:2]): # Print first 2
|
||||
for i, r in enumerate(responses[:2]):
|
||||
print(f"Response {i}: {r[:200]}...")
|
||||
return [1.0] * len(responses) # Dummy rewards
|
||||
return [1.0] * len(responses)
|
||||
|
||||
# Test without training
|
||||
trainer = GRPOTrainer(..., reward_funcs=[debug_reward])
|
||||
trainer.generate_completions(dataset[:1]) # Generate without updating
|
||||
trainer.generate_completions(dataset[:1])
|
||||
```
|
||||
|
||||
---
|
||||
## Template
|
||||
|
||||
## References and Resources
|
||||
A production-ready training script lives at **`../templates/basic_grpo_training.py`**. It uses Qwen 2.5-1.5B-Instruct with LoRA and three reward functions (incremental format, strict format, correctness) on GSM8K. Copy and adapt:
|
||||
1. `get_dataset()` — swap in your data loader
|
||||
2. Reward functions — tune to your task
|
||||
3. `SYSTEM_PROMPT` — match your output format
|
||||
4. `GRPOConfig` — adjust hyperparameters for your GPU
|
||||
|
||||
## References and resources
|
||||
|
||||
**Official Documentation:**
|
||||
- TRL GRPO Trainer: https://huggingface.co/docs/trl/grpo_trainer
|
||||
- DeepSeek R1 Paper: https://arxiv.org/abs/2501.12948
|
||||
- Unsloth Docs: https://docs.unsloth.ai/
|
||||
|
||||
**Example Repositories:**
|
||||
- Open R1 Implementation: https://github.com/huggingface/open-r1
|
||||
- TRL Examples: https://github.com/huggingface/trl/tree/main/examples
|
||||
|
||||
**Recommended Reading:**
|
||||
- Progressive Disclosure Pattern for agent instructions
|
||||
- Reward shaping in RL (Ng et al.)
|
||||
- LoRA paper (Hu et al., 2021)
|
||||
|
||||
---
|
||||
|
||||
## Usage Instructions for Agents
|
||||
|
||||
When this skill is loaded:
|
||||
|
||||
1. **Read this entire file** before implementing GRPO training
|
||||
2. **Start with the simplest reward function** (e.g., length-based) to validate setup
|
||||
3. **Use the templates** in `templates/` directory as starting points
|
||||
4. **Reference examples** in `examples/` for task-specific implementations
|
||||
5. **Follow the workflow** sequentially (don't skip steps)
|
||||
6. **Debug incrementally** - add one reward function at a time
|
||||
|
||||
**Critical Reminders:**
|
||||
- Always use multiple reward functions (3-5 is optimal)
|
||||
- Monitor reward metrics, not loss
|
||||
- Test reward functions before training
|
||||
- Start small (num_generations=4), scale up gradually
|
||||
- Save checkpoints frequently (every 100 steps)
|
||||
|
||||
This skill is designed for **expert-level implementation**. Beginners should start with supervised fine-tuning before attempting GRPO.
|
||||
|
||||
- GRPO paper (DeepSeek): https://arxiv.org/abs/2402.03300
|
||||
- DeepSeek R1 paper: https://arxiv.org/abs/2501.12948
|
||||
- Open R1 implementation: https://github.com/huggingface/open-r1
|
||||
- TRL examples: https://github.com/huggingface/trl/tree/main/examples
|
||||
- Unsloth (faster training): https://docs.unsloth.ai/
|
||||
|
||||
## Critical reminders
|
||||
|
||||
- **Loss goes UP during training** — this is normal (it's KL divergence)
|
||||
- **Use 3–5 reward functions** — single rewards often fail
|
||||
- **Test rewards before training** — debug each function independently
|
||||
- **Monitor `reward_std`** — should stay > 0.1 (avoid mode collapse)
|
||||
- **Start with `num_generations=4–8`** — scale up if GPU allows
|
||||
+59
-12
@@ -42,9 +42,10 @@ class TestToolProgressCallback:
|
||||
def test_emits_tool_call_start(self, mock_conn, event_loop_fixture):
|
||||
"""Tool progress should emit a ToolCallStart update."""
|
||||
tool_call_ids = {}
|
||||
tool_call_meta = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
|
||||
# Run callback in the event loop context
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
@@ -66,9 +67,10 @@ class TestToolProgressCallback:
|
||||
def test_handles_string_args(self, mock_conn, event_loop_fixture):
|
||||
"""If args is a JSON string, it should be parsed."""
|
||||
tool_call_ids = {}
|
||||
tool_call_meta = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
@@ -82,9 +84,10 @@ class TestToolProgressCallback:
|
||||
def test_handles_non_dict_args(self, mock_conn, event_loop_fixture):
|
||||
"""If args is not a dict, it should be wrapped."""
|
||||
tool_call_ids = {}
|
||||
tool_call_meta = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
@@ -98,10 +101,11 @@ class TestToolProgressCallback:
|
||||
def test_duplicate_same_name_tool_calls_use_fifo_ids(self, mock_conn, event_loop_fixture):
|
||||
"""Multiple same-name tool calls should be tracked independently in order."""
|
||||
tool_call_ids = {}
|
||||
tool_call_meta = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
progress_cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
step_cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
progress_cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
step_cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
@@ -163,7 +167,7 @@ class TestStepCallback:
|
||||
tool_call_ids = {"terminal": "tc-abc123"}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
@@ -181,7 +185,7 @@ class TestStepCallback:
|
||||
tool_call_ids = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
cb(1, [{"name": "unknown_tool", "result": "ok"}])
|
||||
@@ -193,7 +197,7 @@ class TestStepCallback:
|
||||
tool_call_ids = {"read_file": "tc-def456"}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
@@ -212,7 +216,7 @@ class TestStepCallback:
|
||||
tool_call_ids = {"terminal": deque(["tc-xyz789"])}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \
|
||||
patch("acp_adapter.events.build_tool_complete") as mock_btc:
|
||||
@@ -224,7 +228,7 @@ class TestStepCallback:
|
||||
cb(1, [{"name": "terminal", "result": '{"output": "hello"}'}])
|
||||
|
||||
mock_btc.assert_called_once_with(
|
||||
"tc-xyz789", "terminal", result='{"output": "hello"}'
|
||||
"tc-xyz789", "terminal", result='{"output": "hello"}', function_args=None, snapshot=None
|
||||
)
|
||||
|
||||
def test_none_result_passed_through(self, mock_conn, event_loop_fixture):
|
||||
@@ -234,7 +238,7 @@ class TestStepCallback:
|
||||
tool_call_ids = {"web_search": deque(["tc-aaa"])}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \
|
||||
patch("acp_adapter.events.build_tool_complete") as mock_btc:
|
||||
@@ -244,7 +248,50 @@ class TestStepCallback:
|
||||
|
||||
cb(1, [{"name": "web_search", "result": None}])
|
||||
|
||||
mock_btc.assert_called_once_with("tc-aaa", "web_search", result=None)
|
||||
mock_btc.assert_called_once_with("tc-aaa", "web_search", result=None, function_args=None, snapshot=None)
|
||||
|
||||
def test_step_callback_passes_arguments_and_snapshot(self, mock_conn, event_loop_fixture):
|
||||
from collections import deque
|
||||
|
||||
tool_call_ids = {"write_file": deque(["tc-write"])}
|
||||
tool_call_meta = {"tc-write": {"args": {"path": "fallback.txt"}, "snapshot": "snap"}}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \
|
||||
patch("acp_adapter.events.build_tool_complete") as mock_btc:
|
||||
future = MagicMock(spec=Future)
|
||||
future.result.return_value = None
|
||||
mock_rcts.return_value = future
|
||||
|
||||
cb(1, [{"name": "write_file", "result": '{"bytes_written": 23}', "arguments": {"path": "diff-test.txt"}}])
|
||||
|
||||
mock_btc.assert_called_once_with(
|
||||
"tc-write",
|
||||
"write_file",
|
||||
result='{"bytes_written": 23}',
|
||||
function_args={"path": "diff-test.txt"},
|
||||
snapshot="snap",
|
||||
)
|
||||
|
||||
def test_tool_progress_captures_snapshot_metadata(self, mock_conn, event_loop_fixture):
|
||||
tool_call_ids = {}
|
||||
tool_call_meta = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
with patch("acp_adapter.events.make_tool_call_id", return_value="tc-meta"), \
|
||||
patch("acp_adapter.events._send_update") as mock_send, \
|
||||
patch("agent.display.capture_local_edit_snapshot", return_value="snapshot"):
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
cb("tool.started", "write_file", None, {"path": "diff-test.txt", "content": "hello"})
|
||||
|
||||
assert list(tool_call_ids["write_file"]) == ["tc-meta"]
|
||||
assert tool_call_meta["tc-meta"] == {
|
||||
"args": {"path": "diff-test.txt", "content": "hello"},
|
||||
"snapshot": "snapshot",
|
||||
}
|
||||
mock_send.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -29,6 +29,7 @@ from acp.schema import (
|
||||
|
||||
from acp_adapter.server import HermesACPAgent
|
||||
from acp_adapter.session import SessionManager
|
||||
from acp_adapter.tools import build_tool_start
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -181,6 +182,25 @@ class TestMcpRegistrationE2E:
|
||||
assert complete_event.raw_output is not None
|
||||
assert "hello" in str(complete_event.raw_output)
|
||||
|
||||
def test_patch_mode_tool_start_emits_diff_blocks_for_v4a_patch(self):
|
||||
update = build_tool_start(
|
||||
"tc-1",
|
||||
"patch",
|
||||
{
|
||||
"mode": "patch",
|
||||
"patch": "*** Begin Patch\n*** Update File: src/app.py\n@@\n-old line\n+new line\n*** Add File: src/new.py\n+hello\n*** End Patch",
|
||||
},
|
||||
)
|
||||
|
||||
assert len(update.content) == 2
|
||||
assert update.content[0].type == "diff"
|
||||
assert update.content[0].path == "src/app.py"
|
||||
assert update.content[0].old_text == "old line"
|
||||
assert update.content[0].new_text == "new line"
|
||||
assert update.content[1].type == "diff"
|
||||
assert update.content[1].path == "src/new.py"
|
||||
assert update.content[1].new_text == "hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_tool_results_paired_by_call_id(self, acp_agent, mock_manager):
|
||||
"""The ToolCallUpdate's toolCallId must match the ToolCallStart's."""
|
||||
|
||||
@@ -20,7 +20,9 @@ from acp.schema import (
|
||||
NewSessionResponse,
|
||||
PromptResponse,
|
||||
ResumeSessionResponse,
|
||||
SessionModelState,
|
||||
SetSessionConfigOptionResponse,
|
||||
SetSessionModelResponse,
|
||||
SetSessionModeResponse,
|
||||
SessionInfo,
|
||||
TextContentBlock,
|
||||
@@ -127,6 +129,25 @@ class TestSessionOps:
|
||||
assert state is not None
|
||||
assert state.cwd == "/home/user/project"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_session_returns_model_state(self):
|
||||
manager = SessionManager(
|
||||
agent_factory=lambda: SimpleNamespace(model="gpt-5.4", provider="openai-codex")
|
||||
)
|
||||
acp_agent = HermesACPAgent(session_manager=manager)
|
||||
|
||||
with patch(
|
||||
"hermes_cli.models.curated_models_for_provider",
|
||||
return_value=[("gpt-5.4", "recommended"), ("gpt-5.4-mini", "")],
|
||||
):
|
||||
resp = await acp_agent.new_session(cwd="/tmp")
|
||||
|
||||
assert isinstance(resp.models, SessionModelState)
|
||||
assert resp.models.current_model_id == "openai-codex:gpt-5.4"
|
||||
assert resp.models.available_models[0].model_id == "openai-codex:gpt-5.4"
|
||||
assert resp.models.available_models[0].description is not None
|
||||
assert "Provider:" in resp.models.available_models[0].description
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_available_commands_include_help(self, agent):
|
||||
help_cmd = next(
|
||||
@@ -204,6 +225,33 @@ class TestListAndFork:
|
||||
assert fork_resp.session_id
|
||||
assert fork_resp.session_id != new_resp.session_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_sessions_includes_title_and_updated_at(self, agent):
|
||||
with patch.object(
|
||||
agent.session_manager,
|
||||
"list_sessions",
|
||||
return_value=[
|
||||
{
|
||||
"session_id": "session-1",
|
||||
"cwd": "/tmp/project",
|
||||
"title": "Fix Zed session history",
|
||||
"updated_at": 123.0,
|
||||
}
|
||||
],
|
||||
):
|
||||
resp = await agent.list_sessions(cwd="/tmp/project")
|
||||
|
||||
assert isinstance(resp.sessions[0], SessionInfo)
|
||||
assert resp.sessions[0].title == "Fix Zed session history"
|
||||
assert resp.sessions[0].updated_at == "123.0"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_sessions_passes_cwd_filter(self, agent):
|
||||
with patch.object(agent.session_manager, "list_sessions", return_value=[]) as mock_list:
|
||||
await agent.list_sessions(cwd="/mnt/e/Projects/AI/browser-link-3")
|
||||
|
||||
mock_list.assert_called_once_with(cwd="/mnt/e/Projects/AI/browser-link-3")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# session configuration / model routing
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -257,6 +305,53 @@ class TestSessionConfiguration:
|
||||
assert result == {}
|
||||
assert state.model == "gpt-5.4"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_session_model_accepts_provider_prefixed_choice(self, tmp_path, monkeypatch):
|
||||
runtime_calls = []
|
||||
|
||||
def fake_resolve_runtime_provider(requested=None, **kwargs):
|
||||
runtime_calls.append(requested)
|
||||
provider = requested or "openrouter"
|
||||
return {
|
||||
"provider": provider,
|
||||
"api_mode": "anthropic_messages" if provider == "anthropic" else "chat_completions",
|
||||
"base_url": f"https://{provider}.example/v1",
|
||||
"api_key": f"{provider}-key",
|
||||
"command": None,
|
||||
"args": [],
|
||||
}
|
||||
|
||||
def fake_agent(**kwargs):
|
||||
return SimpleNamespace(
|
||||
model=kwargs.get("model"),
|
||||
provider=kwargs.get("provider"),
|
||||
base_url=kwargs.get("base_url"),
|
||||
api_mode=kwargs.get("api_mode"),
|
||||
)
|
||||
|
||||
monkeypatch.setattr("hermes_cli.config.load_config", lambda: {
|
||||
"model": {"provider": "openrouter", "default": "openrouter/gpt-5"}
|
||||
})
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
fake_resolve_runtime_provider,
|
||||
)
|
||||
manager = SessionManager(db=SessionDB(tmp_path / "state.db"))
|
||||
|
||||
with patch("run_agent.AIAgent", side_effect=fake_agent):
|
||||
acp_agent = HermesACPAgent(session_manager=manager)
|
||||
state = manager.create_session(cwd="/tmp")
|
||||
result = await acp_agent.set_session_model(
|
||||
model_id="anthropic:claude-sonnet-4-6",
|
||||
session_id=state.session_id,
|
||||
)
|
||||
|
||||
assert isinstance(result, SetSessionModelResponse)
|
||||
assert state.model == "claude-sonnet-4-6"
|
||||
assert state.agent.provider == "anthropic"
|
||||
assert state.agent.base_url == "https://anthropic.example/v1"
|
||||
assert runtime_calls[-1] == "anthropic"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# prompt
|
||||
@@ -354,6 +449,31 @@ class TestPrompt:
|
||||
update = last_call[1].get("update") or last_call[0][1]
|
||||
assert update.session_update == "agent_message_chunk"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_auto_titles_session(self, agent):
|
||||
new_resp = await agent.new_session(cwd=".")
|
||||
state = agent.session_manager.get_session(new_resp.session_id)
|
||||
state.agent.run_conversation = MagicMock(return_value={
|
||||
"final_response": "Here is the fix.",
|
||||
"messages": [
|
||||
{"role": "user", "content": "fix the broken ACP history"},
|
||||
{"role": "assistant", "content": "Here is the fix."},
|
||||
],
|
||||
})
|
||||
|
||||
mock_conn = MagicMock(spec=acp.Client)
|
||||
mock_conn.session_update = AsyncMock()
|
||||
agent._conn = mock_conn
|
||||
|
||||
with patch("agent.title_generator.maybe_auto_title") as mock_title:
|
||||
prompt = [TextContentBlock(type="text", text="fix the broken ACP history")]
|
||||
await agent.prompt(prompt=prompt, session_id=new_resp.session_id)
|
||||
|
||||
mock_title.assert_called_once()
|
||||
assert mock_title.call_args.args[1] == new_resp.session_id
|
||||
assert mock_title.call_args.args[2] == "fix the broken ACP history"
|
||||
assert mock_title.call_args.args[3] == "Here is the fix."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_populates_usage_from_top_level_run_conversation_fields(self, agent):
|
||||
"""ACP should map top-level token fields into PromptResponse.usage."""
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import contextlib
|
||||
import io
|
||||
import json
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
@@ -100,15 +101,23 @@ class TestListAndCleanup:
|
||||
def test_list_sessions_returns_created(self, manager):
|
||||
s1 = manager.create_session(cwd="/a")
|
||||
s2 = manager.create_session(cwd="/b")
|
||||
s1.history.append({"role": "user", "content": "hello from a"})
|
||||
s2.history.append({"role": "user", "content": "hello from b"})
|
||||
listing = manager.list_sessions()
|
||||
ids = {s["session_id"] for s in listing}
|
||||
assert s1.session_id in ids
|
||||
assert s2.session_id in ids
|
||||
assert len(listing) == 2
|
||||
|
||||
def test_list_sessions_hides_empty_threads(self, manager):
|
||||
manager.create_session(cwd="/empty")
|
||||
assert manager.list_sessions() == []
|
||||
|
||||
def test_cleanup_clears_all(self, manager):
|
||||
manager.create_session()
|
||||
manager.create_session()
|
||||
s1 = manager.create_session()
|
||||
s2 = manager.create_session()
|
||||
s1.history.append({"role": "user", "content": "one"})
|
||||
s2.history.append({"role": "user", "content": "two"})
|
||||
assert len(manager.list_sessions()) == 2
|
||||
manager.cleanup()
|
||||
assert manager.list_sessions() == []
|
||||
@@ -194,6 +203,8 @@ class TestPersistence:
|
||||
def test_list_sessions_includes_db_only(self, manager):
|
||||
"""Sessions only in DB (not in memory) appear in list_sessions."""
|
||||
state = manager.create_session(cwd="/db-only")
|
||||
state.history.append({"role": "user", "content": "database only thread"})
|
||||
manager.save_session(state.session_id)
|
||||
sid = state.session_id
|
||||
|
||||
# Drop from memory.
|
||||
@@ -204,6 +215,53 @@ class TestPersistence:
|
||||
ids = {s["session_id"] for s in listing}
|
||||
assert sid in ids
|
||||
|
||||
def test_list_sessions_filters_by_cwd(self, manager):
|
||||
keep = manager.create_session(cwd="/keep")
|
||||
drop = manager.create_session(cwd="/drop")
|
||||
keep.history.append({"role": "user", "content": "keep me"})
|
||||
drop.history.append({"role": "user", "content": "drop me"})
|
||||
|
||||
listing = manager.list_sessions(cwd="/keep")
|
||||
ids = {s["session_id"] for s in listing}
|
||||
assert keep.session_id in ids
|
||||
assert drop.session_id not in ids
|
||||
|
||||
def test_list_sessions_matches_windows_and_wsl_paths(self, manager):
|
||||
state = manager.create_session(cwd="/mnt/e/Projects/AI/browser-link-3")
|
||||
state.history.append({"role": "user", "content": "same project from WSL"})
|
||||
|
||||
listing = manager.list_sessions(cwd=r"E:\Projects\AI\browser-link-3")
|
||||
ids = {s["session_id"] for s in listing}
|
||||
assert state.session_id in ids
|
||||
|
||||
def test_list_sessions_prefers_title_then_preview(self, manager):
|
||||
state = manager.create_session(cwd="/named")
|
||||
state.history.append({"role": "user", "content": "Investigate broken ACP history in Zed"})
|
||||
manager.save_session(state.session_id)
|
||||
db = manager._get_db()
|
||||
db.set_session_title(state.session_id, "Fix Zed ACP history")
|
||||
|
||||
listing = manager.list_sessions(cwd="/named")
|
||||
assert listing[0]["title"] == "Fix Zed ACP history"
|
||||
|
||||
db.set_session_title(state.session_id, "")
|
||||
listing = manager.list_sessions(cwd="/named")
|
||||
assert listing[0]["title"].startswith("Investigate broken ACP history")
|
||||
|
||||
def test_list_sessions_sorted_by_most_recent_activity(self, manager):
|
||||
older = manager.create_session(cwd="/ordered")
|
||||
older.history.append({"role": "user", "content": "older"})
|
||||
manager.save_session(older.session_id)
|
||||
time.sleep(0.02)
|
||||
newer = manager.create_session(cwd="/ordered")
|
||||
newer.history.append({"role": "user", "content": "newer"})
|
||||
manager.save_session(newer.session_id)
|
||||
|
||||
listing = manager.list_sessions(cwd="/ordered")
|
||||
assert [item["session_id"] for item in listing[:2]] == [newer.session_id, older.session_id]
|
||||
assert listing[0]["updated_at"]
|
||||
assert listing[1]["updated_at"]
|
||||
|
||||
def test_fork_restores_source_from_db(self, manager):
|
||||
"""Forking a session that is only in DB should work."""
|
||||
original = manager.create_session()
|
||||
|
||||
@@ -215,6 +215,46 @@ class TestBuildToolComplete:
|
||||
assert len(display_text) < 6000
|
||||
assert "truncated" in display_text
|
||||
|
||||
def test_build_tool_complete_for_patch_uses_diff_blocks(self):
|
||||
"""Completed patch calls should keep structured diff content for Zed."""
|
||||
patch_result = (
|
||||
'{"success": true, "diff": "--- a/README.md\\n+++ b/README.md\\n@@ -1 +1,2 @@\\n old line\\n+new line\\n", '
|
||||
'"files_modified": ["README.md"]}'
|
||||
)
|
||||
result = build_tool_complete("tc-p1", "patch", patch_result)
|
||||
assert isinstance(result, ToolCallProgress)
|
||||
assert len(result.content) == 1
|
||||
diff_item = result.content[0]
|
||||
assert isinstance(diff_item, FileEditToolCallContent)
|
||||
assert diff_item.path == "README.md"
|
||||
assert diff_item.old_text == "old line"
|
||||
assert diff_item.new_text == "old line\nnew line"
|
||||
|
||||
def test_build_tool_complete_for_patch_falls_back_to_text_when_no_diff(self):
|
||||
result = build_tool_complete("tc-p2", "patch", '{"success": true}')
|
||||
assert isinstance(result, ToolCallProgress)
|
||||
assert isinstance(result.content[0], ContentToolCallContent)
|
||||
|
||||
def test_build_tool_complete_for_write_file_uses_snapshot_diff(self, tmp_path):
|
||||
target = tmp_path / "diff-test.txt"
|
||||
snapshot = type("Snapshot", (), {"paths": [target], "before": {str(target): None}})()
|
||||
target.write_text("hello from hermes\n", encoding="utf-8")
|
||||
|
||||
result = build_tool_complete(
|
||||
"tc-wf1",
|
||||
"write_file",
|
||||
'{"bytes_written": 18, "dirs_created": false}',
|
||||
function_args={"path": str(target), "content": "hello from hermes\n"},
|
||||
snapshot=snapshot,
|
||||
)
|
||||
assert isinstance(result, ToolCallProgress)
|
||||
assert len(result.content) == 1
|
||||
diff_item = result.content[0]
|
||||
assert isinstance(diff_item, FileEditToolCallContent)
|
||||
assert diff_item.path.endswith("diff-test.txt")
|
||||
assert diff_item.old_text is None
|
||||
assert diff_item.new_text == "hello from hermes"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# extract_locations
|
||||
|
||||
@@ -696,6 +696,95 @@ class TestIsConnectionError:
|
||||
assert _is_connection_error(err) is False
|
||||
|
||||
|
||||
class TestKimiForCodingTemperature:
|
||||
"""kimi-for-coding now requires temperature=0.6 exactly."""
|
||||
|
||||
def test_build_call_kwargs_forces_fixed_temperature(self):
|
||||
from agent.auxiliary_client import _build_call_kwargs
|
||||
|
||||
kwargs = _build_call_kwargs(
|
||||
provider="kimi-coding",
|
||||
model="kimi-for-coding",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
assert kwargs["temperature"] == 0.6
|
||||
|
||||
def test_build_call_kwargs_injects_temperature_when_missing(self):
|
||||
from agent.auxiliary_client import _build_call_kwargs
|
||||
|
||||
kwargs = _build_call_kwargs(
|
||||
provider="kimi-coding",
|
||||
model="kimi-for-coding",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
temperature=None,
|
||||
)
|
||||
|
||||
assert kwargs["temperature"] == 0.6
|
||||
|
||||
def test_auto_routed_kimi_for_coding_sync_call_uses_fixed_temperature(self):
|
||||
client = MagicMock()
|
||||
client.base_url = "https://api.kimi.com/coding/v1"
|
||||
response = MagicMock()
|
||||
client.chat.completions.create.return_value = response
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client._get_cached_client",
|
||||
return_value=(client, "kimi-for-coding"),
|
||||
), patch(
|
||||
"agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("auto", "kimi-for-coding", None, None, None),
|
||||
):
|
||||
result = call_llm(
|
||||
task="session_search",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
assert result is response
|
||||
kwargs = client.chat.completions.create.call_args.kwargs
|
||||
assert kwargs["model"] == "kimi-for-coding"
|
||||
assert kwargs["temperature"] == 0.6
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_routed_kimi_for_coding_async_call_uses_fixed_temperature(self):
|
||||
client = MagicMock()
|
||||
client.base_url = "https://api.kimi.com/coding/v1"
|
||||
response = MagicMock()
|
||||
client.chat.completions.create = AsyncMock(return_value=response)
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client._get_cached_client",
|
||||
return_value=(client, "kimi-for-coding"),
|
||||
), patch(
|
||||
"agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("auto", "kimi-for-coding", None, None, None),
|
||||
):
|
||||
result = await async_call_llm(
|
||||
task="session_search",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
assert result is response
|
||||
kwargs = client.chat.completions.create.call_args.kwargs
|
||||
assert kwargs["model"] == "kimi-for-coding"
|
||||
assert kwargs["temperature"] == 0.6
|
||||
|
||||
def test_non_kimi_model_still_preserves_temperature(self):
|
||||
from agent.auxiliary_client import _build_call_kwargs
|
||||
|
||||
kwargs = _build_call_kwargs(
|
||||
provider="kimi-coding",
|
||||
model="kimi-k2.5",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
assert kwargs["temperature"] == 0.3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# async_call_llm payment / connection fallback (#7512 bug 2)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,311 @@
|
||||
"""Regression tests for the ``auto`` → main-model-first policy.
|
||||
|
||||
Prior to this change, aggregator users (OpenRouter / Nous Portal) had aux
|
||||
tasks routed through a cheap provider-side default (Gemini Flash) while
|
||||
non-aggregator users got their main model. This made behavior inconsistent
|
||||
and surprising — users picked Claude but got Gemini Flash summaries.
|
||||
|
||||
The current policy: ``auto`` means "use my main chat model" for every user,
|
||||
regardless of provider type. Explicit per-task overrides in ``config.yaml``
|
||||
(``auxiliary.<task>.provider``) still win. The cheap fallback chain only
|
||||
runs when the main provider has no working client.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ── Text aux tasks — _resolve_auto ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestResolveAutoMainFirst:
|
||||
"""_resolve_auto() must prefer main provider + main model for every user."""
|
||||
|
||||
def test_openrouter_main_uses_main_model_for_aux(self, monkeypatch):
|
||||
"""OpenRouter main user → aux uses their picked OR model, not Gemini Flash."""
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-test-key")
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider",
|
||||
return_value="openrouter",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model",
|
||||
return_value="anthropic/claude-sonnet-4.6",
|
||||
), patch(
|
||||
"agent.auxiliary_client.resolve_provider_client"
|
||||
) as mock_resolve:
|
||||
mock_client = MagicMock()
|
||||
mock_resolve.return_value = (mock_client, "anthropic/claude-sonnet-4.6")
|
||||
|
||||
from agent.auxiliary_client import _resolve_auto
|
||||
|
||||
client, model = _resolve_auto()
|
||||
|
||||
assert client is mock_client
|
||||
assert model == "anthropic/claude-sonnet-4.6"
|
||||
# Verify it asked resolve_provider_client for the MAIN provider+model,
|
||||
# not a fallback-chain provider
|
||||
mock_resolve.assert_called_once()
|
||||
assert mock_resolve.call_args.args[0] == "openrouter"
|
||||
assert mock_resolve.call_args.args[1] == "anthropic/claude-sonnet-4.6"
|
||||
|
||||
def test_nous_main_uses_main_model_for_aux(self, monkeypatch):
|
||||
"""Nous Portal main user → aux uses their picked Nous model, not free-tier MiMo."""
|
||||
# No OPENROUTER_API_KEY → ensures if main failed we'd fall to chain
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider", return_value="nous",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model",
|
||||
return_value="anthropic/claude-opus-4.6",
|
||||
), patch(
|
||||
"agent.auxiliary_client.resolve_provider_client"
|
||||
) as mock_resolve:
|
||||
mock_client = MagicMock()
|
||||
mock_resolve.return_value = (mock_client, "anthropic/claude-opus-4.6")
|
||||
|
||||
from agent.auxiliary_client import _resolve_auto
|
||||
|
||||
client, model = _resolve_auto()
|
||||
|
||||
assert client is mock_client
|
||||
assert model == "anthropic/claude-opus-4.6"
|
||||
assert mock_resolve.call_args.args[0] == "nous"
|
||||
|
||||
def test_non_aggregator_main_still_uses_main(self, monkeypatch):
|
||||
"""Non-aggregator main (DeepSeek) → unchanged behavior, main model used."""
|
||||
monkeypatch.setenv("DEEPSEEK_API_KEY", "ds-test")
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider", return_value="deepseek",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model", return_value="deepseek-chat",
|
||||
), patch(
|
||||
"agent.auxiliary_client.resolve_provider_client"
|
||||
) as mock_resolve:
|
||||
mock_client = MagicMock()
|
||||
mock_resolve.return_value = (mock_client, "deepseek-chat")
|
||||
|
||||
from agent.auxiliary_client import _resolve_auto
|
||||
|
||||
client, model = _resolve_auto()
|
||||
|
||||
assert client is mock_client
|
||||
assert model == "deepseek-chat"
|
||||
assert mock_resolve.call_args.args[0] == "deepseek"
|
||||
|
||||
def test_main_unavailable_falls_through_to_chain(self, monkeypatch):
|
||||
"""Main provider with no working client → fall back to aux chain."""
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
|
||||
chain_client = MagicMock()
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider", return_value="anthropic",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model", return_value="claude-opus",
|
||||
), patch(
|
||||
"agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(None, None), # main provider has no client
|
||||
), patch(
|
||||
"agent.auxiliary_client._try_openrouter",
|
||||
return_value=(chain_client, "google/gemini-3-flash-preview"),
|
||||
):
|
||||
from agent.auxiliary_client import _resolve_auto
|
||||
|
||||
client, model = _resolve_auto()
|
||||
|
||||
assert client is chain_client
|
||||
assert model == "google/gemini-3-flash-preview"
|
||||
|
||||
def test_no_main_config_uses_chain_directly(self):
|
||||
"""No main provider configured → skip step 1, use chain (no regression)."""
|
||||
chain_client = MagicMock()
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider", return_value="",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model", return_value="",
|
||||
), patch(
|
||||
"agent.auxiliary_client._try_openrouter",
|
||||
return_value=(chain_client, "google/gemini-3-flash-preview"),
|
||||
):
|
||||
from agent.auxiliary_client import _resolve_auto
|
||||
|
||||
client, model = _resolve_auto()
|
||||
|
||||
assert client is chain_client
|
||||
|
||||
def test_runtime_override_wins_over_config(self, monkeypatch):
|
||||
"""main_runtime kwarg overrides config-read main provider/model."""
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider",
|
||||
return_value="openrouter",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model", return_value="config-model",
|
||||
), patch(
|
||||
"agent.auxiliary_client.resolve_provider_client"
|
||||
) as mock_resolve:
|
||||
mock_resolve.return_value = (MagicMock(), "runtime-model")
|
||||
|
||||
from agent.auxiliary_client import _resolve_auto
|
||||
|
||||
_resolve_auto(main_runtime={
|
||||
"provider": "anthropic",
|
||||
"model": "runtime-model",
|
||||
"base_url": "",
|
||||
"api_key": "",
|
||||
"api_mode": "",
|
||||
})
|
||||
|
||||
# Runtime override wins
|
||||
assert mock_resolve.call_args.args[0] == "anthropic"
|
||||
assert mock_resolve.call_args.args[1] == "runtime-model"
|
||||
|
||||
|
||||
# ── Vision — resolve_vision_provider_client ─────────────────────────────────
|
||||
|
||||
|
||||
class TestResolveVisionMainFirst:
|
||||
"""Vision auto-detection prefers main provider + main model first."""
|
||||
|
||||
def test_openrouter_main_vision_uses_main_model(self, monkeypatch):
|
||||
"""OpenRouter main with vision-capable model → aux vision uses main model."""
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider", return_value="openrouter",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model",
|
||||
return_value="anthropic/claude-sonnet-4.6",
|
||||
), patch(
|
||||
"agent.auxiliary_client.resolve_provider_client"
|
||||
) as mock_resolve, patch(
|
||||
"agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("auto", None, None, None, None),
|
||||
):
|
||||
mock_client = MagicMock()
|
||||
mock_resolve.return_value = (mock_client, "anthropic/claude-sonnet-4.6")
|
||||
|
||||
from agent.auxiliary_client import resolve_vision_provider_client
|
||||
|
||||
provider, client, model = resolve_vision_provider_client()
|
||||
|
||||
assert provider == "openrouter"
|
||||
assert client is mock_client
|
||||
assert model == "anthropic/claude-sonnet-4.6"
|
||||
# Verify it did NOT call the strict vision backend for OpenRouter
|
||||
# (which would have used a cheap gemini-flash-preview default)
|
||||
mock_resolve.assert_called_once()
|
||||
assert mock_resolve.call_args.args[0] == "openrouter"
|
||||
assert mock_resolve.call_args.args[1] == "anthropic/claude-sonnet-4.6"
|
||||
|
||||
def test_nous_main_vision_uses_main_model(self):
|
||||
"""Nous Portal main → aux vision uses main model, not free-tier MiMo-V2-Omni."""
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider", return_value="nous",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model",
|
||||
return_value="openai/gpt-5",
|
||||
), patch(
|
||||
"agent.auxiliary_client.resolve_provider_client"
|
||||
) as mock_resolve, patch(
|
||||
"agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("auto", None, None, None, None),
|
||||
):
|
||||
mock_client = MagicMock()
|
||||
mock_resolve.return_value = (mock_client, "openai/gpt-5")
|
||||
|
||||
from agent.auxiliary_client import resolve_vision_provider_client
|
||||
|
||||
provider, client, model = resolve_vision_provider_client()
|
||||
|
||||
assert provider == "nous"
|
||||
assert model == "openai/gpt-5"
|
||||
|
||||
def test_exotic_provider_with_vision_override_preserved(self):
|
||||
"""xiaomi → mimo-v2-omni override still wins over main_model."""
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider", return_value="xiaomi",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model",
|
||||
return_value="mimo-v2-pro", # text model
|
||||
), patch(
|
||||
"agent.auxiliary_client.resolve_provider_client"
|
||||
) as mock_resolve, patch(
|
||||
"agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("auto", None, None, None, None),
|
||||
):
|
||||
mock_resolve.return_value = (MagicMock(), "mimo-v2-omni")
|
||||
|
||||
from agent.auxiliary_client import resolve_vision_provider_client
|
||||
|
||||
provider, client, model = resolve_vision_provider_client()
|
||||
|
||||
assert provider == "xiaomi"
|
||||
# Should use mimo-v2-omni (vision override), not mimo-v2-pro (text main)
|
||||
assert mock_resolve.call_args.args[1] == "mimo-v2-omni"
|
||||
|
||||
def test_main_unavailable_vision_falls_through_to_aggregators(self):
|
||||
"""Main provider fails → fall back to OpenRouter/Nous strict backends."""
|
||||
fallback_client = MagicMock()
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider", return_value="deepseek",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model", return_value="deepseek-chat",
|
||||
), patch(
|
||||
"agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(None, None),
|
||||
), patch(
|
||||
"agent.auxiliary_client._resolve_strict_vision_backend",
|
||||
return_value=(fallback_client, "google/gemini-3-flash-preview"),
|
||||
), patch(
|
||||
"agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("auto", None, None, None, None),
|
||||
):
|
||||
from agent.auxiliary_client import resolve_vision_provider_client
|
||||
|
||||
provider, client, model = resolve_vision_provider_client()
|
||||
|
||||
assert client is fallback_client
|
||||
assert provider in ("openrouter", "nous")
|
||||
|
||||
def test_explicit_provider_override_still_wins(self):
|
||||
"""Explicit config override bypasses main-first policy."""
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider", return_value="openrouter",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model",
|
||||
return_value="anthropic/claude-opus-4.6",
|
||||
), patch(
|
||||
"agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("nous", None, None, None, None), # explicit override
|
||||
), patch(
|
||||
"agent.auxiliary_client._resolve_strict_vision_backend"
|
||||
) as mock_strict:
|
||||
mock_strict.return_value = (MagicMock(), "nous-default-model")
|
||||
|
||||
from agent.auxiliary_client import resolve_vision_provider_client
|
||||
|
||||
provider, client, model = resolve_vision_provider_client()
|
||||
|
||||
# Explicit "nous" override → uses strict backend, NOT main model path
|
||||
assert provider == "nous"
|
||||
mock_strict.assert_called_once_with("nous")
|
||||
|
||||
|
||||
# ── Constant cleanup ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_aggregator_providers_constant_removed():
|
||||
"""The dead _AGGREGATOR_PROVIDERS constant should no longer live in the module.
|
||||
|
||||
Removed when the main-first policy made the aggregator-skip guard obsolete.
|
||||
"""
|
||||
import agent.auxiliary_client as aux_mod
|
||||
|
||||
assert not hasattr(aux_mod, "_AGGREGATOR_PROVIDERS"), (
|
||||
"_AGGREGATOR_PROVIDERS was removed when _resolve_auto stopped "
|
||||
"treating aggregators specially. If you re-added it, the main-first "
|
||||
"policy may have regressed."
|
||||
)
|
||||
@@ -826,6 +826,160 @@ class TestGeminiCloudCodeClient:
|
||||
finally:
|
||||
client.close()
|
||||
|
||||
|
||||
class TestGeminiHttpErrorParsing:
|
||||
"""Regression coverage for _gemini_http_error Google-envelope parsing.
|
||||
|
||||
These are the paths that users actually hit during Google-side throttling
|
||||
(April 2026: gemini-2.5-pro MODEL_CAPACITY_EXHAUSTED, gemma-4-26b-it
|
||||
returning 404). The error needs to carry status_code + response so the
|
||||
main loop's error_classifier and Retry-After logic work.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _fake_response(status: int, body: dict | str = "", headers=None):
|
||||
"""Minimal httpx.Response stand-in (duck-typed for _gemini_http_error)."""
|
||||
class _FakeResponse:
|
||||
def __init__(self):
|
||||
self.status_code = status
|
||||
if isinstance(body, dict):
|
||||
self.text = json.dumps(body)
|
||||
else:
|
||||
self.text = body
|
||||
self.headers = headers or {}
|
||||
return _FakeResponse()
|
||||
|
||||
def test_model_capacity_exhausted_produces_friendly_message(self):
|
||||
from agent.gemini_cloudcode_adapter import _gemini_http_error
|
||||
|
||||
body = {
|
||||
"error": {
|
||||
"code": 429,
|
||||
"message": "Resource has been exhausted (e.g. check quota).",
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{
|
||||
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
|
||||
"reason": "MODEL_CAPACITY_EXHAUSTED",
|
||||
"domain": "googleapis.com",
|
||||
"metadata": {"model": "gemini-2.5-pro"},
|
||||
},
|
||||
{
|
||||
"@type": "type.googleapis.com/google.rpc.RetryInfo",
|
||||
"retryDelay": "30s",
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
err = _gemini_http_error(self._fake_response(429, body))
|
||||
assert err.status_code == 429
|
||||
assert err.code == "code_assist_capacity_exhausted"
|
||||
assert err.retry_after == 30.0
|
||||
assert err.details["reason"] == "MODEL_CAPACITY_EXHAUSTED"
|
||||
# Message must be user-friendly, not a raw JSON dump.
|
||||
message = str(err)
|
||||
assert "gemini-2.5-pro" in message
|
||||
assert "capacity exhausted" in message.lower()
|
||||
assert "30s" in message
|
||||
# response attr is preserved for run_agent's Retry-After header path.
|
||||
assert err.response is not None
|
||||
|
||||
def test_resource_exhausted_without_reason(self):
|
||||
from agent.gemini_cloudcode_adapter import _gemini_http_error
|
||||
|
||||
body = {
|
||||
"error": {
|
||||
"code": 429,
|
||||
"message": "Quota exceeded for requests per minute.",
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
}
|
||||
}
|
||||
err = _gemini_http_error(self._fake_response(429, body))
|
||||
assert err.status_code == 429
|
||||
assert err.code == "code_assist_rate_limited"
|
||||
message = str(err)
|
||||
assert "quota" in message.lower()
|
||||
|
||||
def test_404_model_not_found_produces_model_retired_message(self):
|
||||
from agent.gemini_cloudcode_adapter import _gemini_http_error
|
||||
|
||||
body = {
|
||||
"error": {
|
||||
"code": 404,
|
||||
"message": "models/gemma-4-26b-it is not found for API version v1internal",
|
||||
"status": "NOT_FOUND",
|
||||
}
|
||||
}
|
||||
err = _gemini_http_error(self._fake_response(404, body))
|
||||
assert err.status_code == 404
|
||||
message = str(err)
|
||||
assert "not available" in message.lower() or "retired" in message.lower()
|
||||
# Error message should reference the actual model text from Google.
|
||||
assert "gemma-4-26b-it" in message
|
||||
|
||||
def test_unauthorized_preserves_status_code(self):
|
||||
from agent.gemini_cloudcode_adapter import _gemini_http_error
|
||||
|
||||
err = _gemini_http_error(self._fake_response(
|
||||
401, {"error": {"code": 401, "message": "Invalid token", "status": "UNAUTHENTICATED"}},
|
||||
))
|
||||
assert err.status_code == 401
|
||||
assert err.code == "code_assist_unauthorized"
|
||||
|
||||
def test_retry_after_header_fallback(self):
|
||||
"""If the body has no RetryInfo detail, fall back to Retry-After header."""
|
||||
from agent.gemini_cloudcode_adapter import _gemini_http_error
|
||||
|
||||
resp = self._fake_response(
|
||||
429,
|
||||
{"error": {"code": 429, "message": "Rate limited", "status": "RESOURCE_EXHAUSTED"}},
|
||||
headers={"Retry-After": "45"},
|
||||
)
|
||||
err = _gemini_http_error(resp)
|
||||
assert err.retry_after == 45.0
|
||||
|
||||
def test_malformed_body_still_produces_structured_error(self):
|
||||
"""Non-JSON body must not swallow status_code — we still want the classifier path."""
|
||||
from agent.gemini_cloudcode_adapter import _gemini_http_error
|
||||
|
||||
err = _gemini_http_error(self._fake_response(500, "<html>internal error</html>"))
|
||||
assert err.status_code == 500
|
||||
# Raw body snippet must still be there for debugging.
|
||||
assert "500" in str(err)
|
||||
|
||||
def test_status_code_flows_through_error_classifier(self):
|
||||
"""End-to-end: CodeAssistError from a 429 must classify as rate_limit.
|
||||
|
||||
This is the whole point of adding status_code to CodeAssistError —
|
||||
_extract_status_code must see it and FailoverReason.rate_limit must
|
||||
fire, so the main loop triggers fallback_providers.
|
||||
"""
|
||||
from agent.gemini_cloudcode_adapter import _gemini_http_error
|
||||
from agent.error_classifier import classify_api_error, FailoverReason
|
||||
|
||||
body = {
|
||||
"error": {
|
||||
"code": 429,
|
||||
"message": "Resource has been exhausted",
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{
|
||||
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
|
||||
"reason": "MODEL_CAPACITY_EXHAUSTED",
|
||||
"metadata": {"model": "gemini-2.5-pro"},
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
err = _gemini_http_error(self._fake_response(429, body))
|
||||
|
||||
classified = classify_api_error(
|
||||
err, provider="google-gemini-cli", model="gemini-2.5-pro",
|
||||
)
|
||||
assert classified.status_code == 429
|
||||
assert classified.reason == FailoverReason.rate_limit
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Provider registration
|
||||
# =============================================================================
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user