Compare commits
347 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 50c8198918 | |||
| 6fb69229ca | |||
| 2edebedc9e | |||
| f9667331e5 | |||
| 9527707f80 | |||
| cf012a05d8 | |||
| 3b69b2fd61 | |||
| 8826d9c197 | |||
| a2c9f5d0a7 | |||
| 8322b42c6c | |||
| 285bb2b915 | |||
| 54e0eb24c0 | |||
| 73bccc94c7 | |||
| 598cba62ad | |||
| 5ff65dbf68 | |||
| c20e236b71 | |||
| 994faacce8 | |||
| 8a59f8a9ed | |||
| 1c352f6b1d | |||
| 11a89cc032 | |||
| 45acd9beb5 | |||
| c5c0bb9a73 | |||
| 20f2258f34 | |||
| 607be54a24 | |||
| e5333e793c | |||
| 148459716c | |||
| 53e4a2f2c6 | |||
| 07db20c72d | |||
| 38436eb4e3 | |||
| 86fd0f846d | |||
| 4459913f40 | |||
| d7ef562a05 | |||
| 47010e0757 | |||
| 213e39463b | |||
| 2297c5f5ce | |||
| c7fece1f9d | |||
| c096a6935f | |||
| a155b4a159 | |||
| b449a0e049 | |||
| 85cdb04bd4 | |||
| 9b14b76eb3 | |||
| 2992802b35 | |||
| 04a0c3cb95 | |||
| 8444f66890 | |||
| bb85404b16 | |||
| 8ab1aa2efc | |||
| 511ed4dacc | |||
| d465fc5869 | |||
| 016ae5c334 | |||
| 304fb921bf | |||
| 64b354719f | |||
| e9b8ece103 | |||
| 3f43aec15d | |||
| aa583cb14e | |||
| 0a83187801 | |||
| 2b60478fc2 | |||
| c6fd2619f7 | |||
| d2206c69cc | |||
| 103beea7a6 | |||
| 287d3e12c7 | |||
| 6fd58e1e4a | |||
| 235e6ecc0e | |||
| 1648e41c17 | |||
| c4cdf3b861 | |||
| 02f5e3dc27 | |||
| b7d330211a | |||
| a5f4d652d3 | |||
| 6358501915 | |||
| 31e7276474 | |||
| 036dacf659 | |||
| 3207b9bda0 | |||
| eb07c05646 | |||
| f362083c64 | |||
| 3b569ff576 | |||
| bd09e42eac | |||
| cc3aa76675 | |||
| 2ff1ef6ae6 | |||
| 1229d8855c | |||
| d49126b987 | |||
| cb883f9e97 | |||
| d5b9db8b4a | |||
| 6a37802476 | |||
| d0e1388ca9 | |||
| 78a74bb097 | |||
| bedbeebbc8 | |||
| f53250b5e1 | |||
| 00591e3801 | |||
| be768db627 | |||
| 42721dbe1c | |||
| 8f553a55b2 | |||
| a82097e7a2 | |||
| 0dd5055d59 | |||
| 5b386ced71 | |||
| 0219da9626 | |||
| 1f37ef2fd1 | |||
| 6ea7386a6f | |||
| 8dcd08d8bb | |||
| 3a0ec1d935 | |||
| e105b7ac93 | |||
| 4b1567f425 | |||
| cedc95c100 | |||
| c7334b4a50 | |||
| 3f3d8a7b24 | |||
| 32a694ad5f | |||
| f5dc4e905d | |||
| 93fe4b357d | |||
| 8d7b7feb0d | |||
| fc04f83062 | |||
| fe0e7edd27 | |||
| 86f02d8d71 | |||
| 5fbe16635b | |||
| fdf42d62a0 | |||
| f64241ed90 | |||
| b46db048c3 | |||
| f696b4745a | |||
| 5ca52bae5b | |||
| c60b6dc317 | |||
| 47a0dd1024 | |||
| 91e7aff219 | |||
| d404849351 | |||
| ee95822e07 | |||
| e5b880264b | |||
| 541a3e27d7 | |||
| 0741f22463 | |||
| 7d888ab49c | |||
| 0f4403346d | |||
| d7fb435e0e | |||
| 13f2d997b0 | |||
| 9deeee7bb7 | |||
| 08930a65ea | |||
| 6ee65b4d61 | |||
| 498fc6780e | |||
| 4ed6e4c1a5 | |||
| 649f38390c | |||
| 678b69ec1b | |||
| 53da34a4fc | |||
| 24342813fe | |||
| ca03e80348 | |||
| 504e7eb9e5 | |||
| b594b30de4 | |||
| 995177d542 | |||
| 590c9964e1 | |||
| a97b08e30c | |||
| aca81ac7bb | |||
| 9039273ff0 | |||
| 29d5d36b14 | |||
| eabe14af1c | |||
| ef37aa7cce | |||
| a448e7a04d | |||
| 0231f8882b | |||
| 7c932c5aa4 | |||
| f268215019 | |||
| 8b312248dc | |||
| 82969615bb | |||
| 902d6b97d6 | |||
| 5d929caa59 | |||
| efa6c9f715 | |||
| 5435287dec | |||
| 41d3d7afb7 | |||
| 39231f29c6 | |||
| c730ab8ad7 | |||
| c74017f405 | |||
| 40f2368875 | |||
| 319aabbb80 | |||
| 26f3a05c9c | |||
| 15096903c7 | |||
| 26859e3fcb | |||
| aedc767c66 | |||
| 23212d6b40 | |||
| 7ffefc2d6c | |||
| 2812bfe5b9 | |||
| ca30803d89 | |||
| 7f1204840d | |||
| dd2ec6bfa0 | |||
| 3746c60439 | |||
| 727f0eaf74 | |||
| 275256cdb4 | |||
| 9503896aa2 | |||
| 04e36851b7 | |||
| a8e0a1148f | |||
| 842a122964 | |||
| 2d693c865c | |||
| f3920fec0b | |||
| c6ed61430a | |||
| cb2a737bc8 | |||
| 18840bcff8 | |||
| 0478266831 | |||
| beccd1bc04 | |||
| 68ecdb6e26 | |||
| fc0623f0af | |||
| 9c71f3a6ea | |||
| c4b9750bc1 | |||
| 39b1336d1f | |||
| f81dba0da2 | |||
| 8e06db56fd | |||
| cb31732c4f | |||
| 097702c8a7 | |||
| 72aebfbb24 | |||
| c9f78d110a | |||
| baa0de7649 | |||
| 57e4b61155 | |||
| 53a024a941 | |||
| cb7b740e32 | |||
| 4b4b4d47bc | |||
| 46cef4b7fa | |||
| 9931d1d814 | |||
| cc15b55bb9 | |||
| 371166fe26 | |||
| 33c615504d | |||
| 561cea0d4a | |||
| 496bfb3c59 | |||
| 99d859ce4a | |||
| 4cbf54fb33 | |||
| 77cd5bf565 | |||
| bf54f1fb2f | |||
| 3bc661ea29 | |||
| 52c11d172a | |||
| 9804aa7443 | |||
| 7aed09e1ba | |||
| dd2b0b4775 | |||
| ea2d5754ab | |||
| 9a3a2925ed | |||
| c189d5e98b | |||
| 6bbac046a7 | |||
| bbc7316007 | |||
| 35dbb1da3f | |||
| 6d6b3b03ac | |||
| 1b573b7b21 | |||
| 7e4dd6ea02 | |||
| aeb53131f3 | |||
| 783c6b6ed6 | |||
| 4a260b51fe | |||
| ebe3270430 | |||
| 77b97b810a | |||
| 9db94e8521 | |||
| cac1b1b724 | |||
| 56524bb1d9 | |||
| 0642b6cc53 | |||
| eec1db36f7 | |||
| 713a614ea8 | |||
| a27167fb30 | |||
| a2c0597ae4 | |||
| 0fd33a98cd | |||
| ddb0871769 | |||
| e03bef684e | |||
| 4b026d6761 | |||
| 8efd3db1b4 | |||
| ef51bb0091 | |||
| 3bf0f39337 | |||
| 690d62a6d1 | |||
| 2aea75e91e | |||
| 5552e1ffe1 | |||
| 90890f8f04 | |||
| 8e0df1d532 | |||
| 29721fcc58 | |||
| a1d2a0c0fd | |||
| ec553fdb49 | |||
| 24a498eb90 | |||
| 9ccb490cf3 | |||
| 32302c37dd | |||
| 5e5e65f6d5 | |||
| acbf1794f2 | |||
| e2ea8934d4 | |||
| 7e7f78f86c | |||
| 5fb6a4418b | |||
| bf6af95ff5 | |||
| 3fd5cf6e3c | |||
| b04248f4d5 | |||
| 7803d21bcc | |||
| 8760faf991 | |||
| cab6447d58 | |||
| 57e8d44af8 | |||
| cb79018977 | |||
| 90f0aa174d | |||
| 304f1463a9 | |||
| 294c377c0c | |||
| 660379637a | |||
| bc80848e49 | |||
| 658cd2dd4c | |||
| 8c1ba639c6 | |||
| 17a9c47178 | |||
| e1df13cf20 | |||
| 4fe78d5b88 | |||
| aa5b697a9d | |||
| aca479c1ae | |||
| b85ff282bc | |||
| f805323517 | |||
| 4406b4b100 | |||
| 17ecdce936 | |||
| 7e813a30e0 | |||
| 6e24b9947e | |||
| 99fd3b518d | |||
| c5511bbc5a | |||
| b7d4ea1550 | |||
| 74241328f0 | |||
| df5874c119 | |||
| 21afb3fa3c | |||
| 31b2c12f0f | |||
| 405c1b4e84 | |||
| 5ff96551d5 | |||
| 2b4272ef5b | |||
| 670dcea8f4 | |||
| 17f13013eb | |||
| 00e1d42b9e | |||
| b2ea9b4176 | |||
| 0d7c19a42f | |||
| 8755b9dfc0 | |||
| 54bd25ff4a | |||
| b66550ed08 | |||
| c49bbbe8c2 | |||
| 9d8f9765c1 | |||
| f226e6be10 | |||
| a435c7274a | |||
| b597123489 | |||
| af0f4a52fe | |||
| b50d81f212 | |||
| a9fa054df9 | |||
| 31cb23890a | |||
| a3cfb1de86 | |||
| 371efafc46 | |||
| ebd2d83ef2 | |||
| af077b2c0d | |||
| 2d884ff12d | |||
| b397c91d4a | |||
| 9c2c9e3a3e | |||
| c3eeb03e26 | |||
| d9d0ac06b9 | |||
| 29f2610e4b | |||
| dcb97f7465 | |||
| 86308b6de4 | |||
| 2d349bbf7a | |||
| 39878aff00 | |||
| afd670a36f | |||
| e2b3b1c5e4 | |||
| 4c7d5ec778 | |||
| f116c59071 | |||
| 0f556a17f5 | |||
| ee92460763 | |||
| 2893e9df71 | |||
| 5a5d90c85a | |||
| 56a69e519b | |||
| fab4d8d470 | |||
| 1218994992 | |||
| f4bf57ff7a | |||
| bbba9ed4f2 | |||
| 2818dd8611 | |||
| 2ea5345a7b |
@@ -1 +1,5 @@
|
||||
watch_file pyproject.toml uv.lock
|
||||
watch_file ui-tui/package-lock.json ui-tui/package.json
|
||||
watch_file flake.nix flake.lock nix/devShell.nix nix/tui.nix nix/package.nix nix/python.nix
|
||||
|
||||
use flake
|
||||
|
||||
@@ -16,13 +16,8 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: test (${{ matrix.group }}/4)
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
group: [1, 2, 3, 4]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
@@ -42,11 +37,10 @@ jobs:
|
||||
source .venv/bin/activate
|
||||
uv pip install -e ".[all,dev]"
|
||||
|
||||
- name: Run tests (shard ${{ matrix.group }}/4)
|
||||
- name: Run tests
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
python -m pytest tests/ -q --ignore=tests/integration --ignore=tests/e2e --tb=short \
|
||||
--splits 4 --group ${{ matrix.group }}
|
||||
python -m pytest tests/ -q --ignore=tests/integration --ignore=tests/e2e --tb=short -n auto
|
||||
env:
|
||||
# Ensure tests don't accidentally call real APIs
|
||||
OPENROUTER_API_KEY: ""
|
||||
|
||||
@@ -60,5 +60,6 @@ mini-swe-agent/
|
||||
|
||||
# Nix
|
||||
.direnv/
|
||||
.nix-stamps/
|
||||
result
|
||||
website/static/api/skills-index.json
|
||||
|
||||
@@ -56,6 +56,19 @@ hermes-agent/
|
||||
│ ├── run.py # Main loop, slash commands, message dispatch
|
||||
│ ├── session.py # SessionStore — conversation persistence
|
||||
│ └── platforms/ # Adapters: telegram, discord, slack, whatsapp, homeassistant, signal, qqbot
|
||||
├── ui-tui/ # Ink (React) terminal UI — `hermes --tui`
|
||||
│ ├── src/entry.tsx # TTY gate + render()
|
||||
│ ├── src/app.tsx # Main state machine and UI
|
||||
│ ├── src/gatewayClient.ts # Child process + JSON-RPC bridge
|
||||
│ ├── src/app/ # Decomposed app logic (event handler, slash handler, stores, hooks)
|
||||
│ ├── src/components/ # Ink components (branding, markdown, prompts, pickers, etc.)
|
||||
│ ├── src/hooks/ # useCompletion, useInputHistory, useQueue, useVirtualHistory
|
||||
│ └── src/lib/ # Pure helpers (history, osc52, text, rpc, messages)
|
||||
├── tui_gateway/ # Python JSON-RPC backend for the TUI
|
||||
│ ├── entry.py # stdio entrypoint
|
||||
│ ├── server.py # RPC handlers and session logic
|
||||
│ ├── render.py # Optional rich/ANSI bridge
|
||||
│ └── slash_worker.py # Persistent HermesCLI subprocess for slash commands
|
||||
├── acp_adapter/ # ACP server (VS Code / Zed / JetBrains integration)
|
||||
├── cron/ # Scheduler (jobs.py, scheduler.py)
|
||||
├── environments/ # RL training environments (Atropos)
|
||||
@@ -179,6 +192,59 @@ if canonical == "mycommand":
|
||||
|
||||
---
|
||||
|
||||
## TUI Architecture (ui-tui + tui_gateway)
|
||||
|
||||
The TUI is a full replacement for the classic (prompt_toolkit) CLI, activated via `hermes --tui` or `HERMES_TUI=1`.
|
||||
|
||||
### Process Model
|
||||
|
||||
```
|
||||
hermes --tui
|
||||
└─ Node (Ink) ──stdio JSON-RPC── Python (tui_gateway)
|
||||
│ └─ AIAgent + tools + sessions
|
||||
└─ renders transcript, composer, prompts, activity
|
||||
```
|
||||
|
||||
TypeScript owns the screen. Python owns sessions, tools, model calls, and slash command logic.
|
||||
|
||||
### Transport
|
||||
|
||||
Newline-delimited JSON-RPC over stdio. Requests from Ink, events from Python. See `tui_gateway/server.py` for the full method/event catalog.
|
||||
|
||||
### Key Surfaces
|
||||
|
||||
| Surface | Ink component | Gateway method |
|
||||
|---------|---------------|----------------|
|
||||
| Chat streaming | `app.tsx` + `messageLine.tsx` | `prompt.submit` → `message.delta/complete` |
|
||||
| Tool activity | `thinking.tsx` | `tool.start/progress/complete` |
|
||||
| Approvals | `prompts.tsx` | `approval.respond` ← `approval.request` |
|
||||
| Clarify/sudo/secret | `prompts.tsx`, `maskedPrompt.tsx` | `clarify/sudo/secret.respond` |
|
||||
| Session picker | `sessionPicker.tsx` | `session.list/resume` |
|
||||
| Slash commands | Local handler + fallthrough | `slash.exec` → `_SlashWorker`, `command.dispatch` |
|
||||
| Completions | `useCompletion` hook | `complete.slash`, `complete.path` |
|
||||
| Theming | `theme.ts` + `branding.tsx` | `gateway.ready` with skin data |
|
||||
|
||||
### Slash Command Flow
|
||||
|
||||
1. Built-in client commands (`/help`, `/quit`, `/clear`, `/resume`, `/copy`, `/paste`, etc.) handled locally in `app.tsx`
|
||||
2. Everything else → `slash.exec` (runs in persistent `_SlashWorker` subprocess) → `command.dispatch` fallback
|
||||
|
||||
### Dev Commands
|
||||
|
||||
```bash
|
||||
cd ui-tui
|
||||
npm install # first time
|
||||
npm run dev # watch mode (rebuilds hermes-ink + tsx --watch)
|
||||
npm start # production
|
||||
npm run build # full build (hermes-ink + tsc)
|
||||
npm run type-check # typecheck only (tsc --noEmit)
|
||||
npm run lint # eslint
|
||||
npm run fmt # prettier
|
||||
npm test # vitest
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Adding New Tools
|
||||
|
||||
Requires changes in **2 files**:
|
||||
@@ -458,13 +524,45 @@ def profile_env(tmp_path, monkeypatch):
|
||||
|
||||
## Testing
|
||||
|
||||
**ALWAYS use `scripts/run_tests.sh`** — do not call `pytest` directly. The script enforces
|
||||
hermetic environment parity with CI (unset credential vars, TZ=UTC, LANG=C.UTF-8,
|
||||
4 xdist workers matching GHA ubuntu-latest). Direct `pytest` on a 16+ core
|
||||
developer machine with API keys set diverges from CI in ways that have caused
|
||||
multiple "works locally, fails in CI" incidents (and the reverse).
|
||||
|
||||
```bash
|
||||
source venv/bin/activate
|
||||
python -m pytest tests/ -q # Full suite (~3000 tests, ~3 min)
|
||||
python -m pytest tests/test_model_tools.py -q # Toolset resolution
|
||||
python -m pytest tests/test_cli_init.py -q # CLI config loading
|
||||
python -m pytest tests/gateway/ -q # Gateway tests
|
||||
python -m pytest tests/tools/ -q # Tool-level tests
|
||||
scripts/run_tests.sh # full suite, CI-parity
|
||||
scripts/run_tests.sh tests/gateway/ # one directory
|
||||
scripts/run_tests.sh tests/agent/test_foo.py::test_x # one test
|
||||
scripts/run_tests.sh -v --tb=long # pass-through pytest flags
|
||||
```
|
||||
|
||||
### Why the wrapper (and why the old "just call pytest" doesn't work)
|
||||
|
||||
Five real sources of local-vs-CI drift the script closes:
|
||||
|
||||
| | Without wrapper | With wrapper |
|
||||
|---|---|---|
|
||||
| Provider API keys | Whatever is in your env (auto-detects pool) | All `*_API_KEY`/`*_TOKEN`/etc. unset |
|
||||
| HOME / `~/.hermes/` | Your real config+auth.json | Temp dir per test |
|
||||
| Timezone | Local TZ (PDT etc.) | UTC |
|
||||
| Locale | Whatever is set | C.UTF-8 |
|
||||
| xdist workers | `-n auto` = all cores (20+ on a workstation) | `-n 4` matching CI |
|
||||
|
||||
`tests/conftest.py` also enforces points 1-4 as an autouse fixture so ANY pytest
|
||||
invocation (including IDE integrations) gets hermetic behavior — but the wrapper
|
||||
is belt-and-suspenders.
|
||||
|
||||
### Running without the wrapper (only if you must)
|
||||
|
||||
If you can't use the wrapper (e.g. on Windows or inside an IDE that shells
|
||||
pytest directly), at minimum activate the venv and pass `-n 4`:
|
||||
|
||||
```bash
|
||||
source venv/bin/activate
|
||||
python -m pytest tests/ -q -n 4
|
||||
```
|
||||
|
||||
Worker count above 4 will surface test-ordering flakes that CI never sees.
|
||||
|
||||
Always run the full suite before pushing changes.
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
|
||||
**The self-improving AI agent built by [Nous Research](https://nousresearch.com).** It's the only agent with a built-in learning loop — it creates skills from experience, improves them during use, nudges itself to persist knowledge, searches its own past conversations, and builds a deepening model of who you are across sessions. Run it on a $5 VPS, a GPU cluster, or serverless infrastructure that costs nearly nothing when idle. It's not tied to your laptop — talk to it from Telegram while it works on a cloud VM.
|
||||
|
||||
Use any model you want — [Nous Portal](https://portal.nousresearch.com), [OpenRouter](https://openrouter.ai) (200+ models), [Xiaomi MiMo](https://platform.xiaomimimo.com), [z.ai/GLM](https://z.ai), [Kimi/Moonshot](https://platform.moonshot.ai), [MiniMax](https://www.minimax.io), [Hugging Face](https://huggingface.co), OpenAI, or your own endpoint. Switch with `hermes model` — no code changes, no lock-in.
|
||||
Use any model you want — [Nous Portal](https://portal.nousresearch.com), [OpenRouter](https://openrouter.ai) (200+ models), [NVIDIA NIM](https://build.nvidia.com) (Nemotron), [Xiaomi MiMo](https://platform.xiaomimimo.com), [z.ai/GLM](https://z.ai), [Kimi/Moonshot](https://platform.moonshot.ai), [MiniMax](https://www.minimax.io), [Hugging Face](https://huggingface.co), OpenAI, or your own endpoint. Switch with `hermes model` — no code changes, no lock-in.
|
||||
|
||||
<table>
|
||||
<tr><td><b>A real terminal interface</b></td><td>Full TUI with multiline editing, slash-command autocomplete, conversation history, interrupt-and-redirect, and streaming tool output.</td></tr>
|
||||
@@ -141,11 +141,18 @@ See `hermes claw migrate --help` for all options, or use the `openclaw-migration
|
||||
|
||||
We welcome contributions! See the [Contributing Guide](https://hermes-agent.nousresearch.com/docs/developer-guide/contributing) for development setup, code style, and PR process.
|
||||
|
||||
Quick start for contributors:
|
||||
Quick start for contributors — clone and go with `setup-hermes.sh`:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/NousResearch/hermes-agent.git
|
||||
cd hermes-agent
|
||||
./setup-hermes.sh # installs uv, creates venv, installs .[all], symlinks ~/.local/bin/hermes
|
||||
./hermes # auto-detects the venv, no need to `source` first
|
||||
```
|
||||
|
||||
Manual path (equivalent to the above):
|
||||
|
||||
```bash
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
uv venv venv --python 3.11
|
||||
source venv/bin/activate
|
||||
|
||||
+20
-1
@@ -49,6 +49,7 @@ def make_tool_progress_cb(
|
||||
session_id: str,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
tool_call_ids: Dict[str, Deque[str]],
|
||||
tool_call_meta: Dict[str, Dict[str, Any]],
|
||||
) -> Callable:
|
||||
"""Create a ``tool_progress_callback`` for AIAgent.
|
||||
|
||||
@@ -84,6 +85,16 @@ def make_tool_progress_cb(
|
||||
tool_call_ids[name] = queue
|
||||
queue.append(tc_id)
|
||||
|
||||
snapshot = None
|
||||
if name in {"write_file", "patch", "skill_manage"}:
|
||||
try:
|
||||
from agent.display import capture_local_edit_snapshot
|
||||
|
||||
snapshot = capture_local_edit_snapshot(name, args)
|
||||
except Exception:
|
||||
logger.debug("Failed to capture ACP edit snapshot for %s", name, exc_info=True)
|
||||
tool_call_meta[tc_id] = {"args": args, "snapshot": snapshot}
|
||||
|
||||
update = build_tool_start(tc_id, name, args)
|
||||
_send_update(conn, session_id, loop, update)
|
||||
|
||||
@@ -119,6 +130,7 @@ def make_step_cb(
|
||||
session_id: str,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
tool_call_ids: Dict[str, Deque[str]],
|
||||
tool_call_meta: Dict[str, Dict[str, Any]],
|
||||
) -> Callable:
|
||||
"""Create a ``step_callback`` for AIAgent.
|
||||
|
||||
@@ -132,10 +144,12 @@ def make_step_cb(
|
||||
for tool_info in prev_tools:
|
||||
tool_name = None
|
||||
result = None
|
||||
function_args = None
|
||||
|
||||
if isinstance(tool_info, dict):
|
||||
tool_name = tool_info.get("name") or tool_info.get("function_name")
|
||||
result = tool_info.get("result") or tool_info.get("output")
|
||||
function_args = tool_info.get("arguments") or tool_info.get("args")
|
||||
elif isinstance(tool_info, str):
|
||||
tool_name = tool_info
|
||||
|
||||
@@ -145,8 +159,13 @@ def make_step_cb(
|
||||
tool_call_ids[tool_name] = queue
|
||||
if tool_name and queue:
|
||||
tc_id = queue.popleft()
|
||||
meta = tool_call_meta.pop(tc_id, {})
|
||||
update = build_tool_complete(
|
||||
tc_id, tool_name, result=str(result) if result is not None else None
|
||||
tc_id,
|
||||
tool_name,
|
||||
result=str(result) if result is not None else None,
|
||||
function_args=function_args or meta.get("args"),
|
||||
snapshot=meta.get("snapshot"),
|
||||
)
|
||||
_send_update(conn, session_id, loop, update)
|
||||
if not queue:
|
||||
|
||||
+148
-30
@@ -26,6 +26,7 @@ from acp.schema import (
|
||||
McpServerHttp,
|
||||
McpServerSse,
|
||||
McpServerStdio,
|
||||
ModelInfo,
|
||||
NewSessionResponse,
|
||||
PromptResponse,
|
||||
ResumeSessionResponse,
|
||||
@@ -36,6 +37,7 @@ from acp.schema import (
|
||||
SessionCapabilities,
|
||||
SessionForkCapabilities,
|
||||
SessionListCapabilities,
|
||||
SessionModelState,
|
||||
SessionResumeCapabilities,
|
||||
SessionInfo,
|
||||
TextContentBlock,
|
||||
@@ -147,6 +149,98 @@ class HermesACPAgent(acp.Agent):
|
||||
self._conn = conn
|
||||
logger.info("ACP client connected")
|
||||
|
||||
@staticmethod
|
||||
def _encode_model_choice(provider: str | None, model: str | None) -> str:
|
||||
"""Encode a model selection so ACP clients can keep provider context."""
|
||||
raw_model = str(model or "").strip()
|
||||
if not raw_model:
|
||||
return ""
|
||||
raw_provider = str(provider or "").strip().lower()
|
||||
if not raw_provider:
|
||||
return raw_model
|
||||
return f"{raw_provider}:{raw_model}"
|
||||
|
||||
def _build_model_state(self, state: SessionState) -> SessionModelState | None:
|
||||
"""Return the ACP model selector payload for editors like Zed."""
|
||||
model = str(state.model or getattr(state.agent, "model", "") or "").strip()
|
||||
provider = getattr(state.agent, "provider", None) or detect_provider() or "openrouter"
|
||||
|
||||
try:
|
||||
from hermes_cli.models import curated_models_for_provider, normalize_provider, provider_label
|
||||
|
||||
normalized_provider = normalize_provider(provider)
|
||||
provider_name = provider_label(normalized_provider)
|
||||
available_models: list[ModelInfo] = []
|
||||
seen_ids: set[str] = set()
|
||||
|
||||
for model_id, description in curated_models_for_provider(normalized_provider):
|
||||
rendered_model = str(model_id or "").strip()
|
||||
if not rendered_model:
|
||||
continue
|
||||
choice_id = self._encode_model_choice(normalized_provider, rendered_model)
|
||||
if choice_id in seen_ids:
|
||||
continue
|
||||
desc_parts = [f"Provider: {provider_name}"]
|
||||
if description:
|
||||
desc_parts.append(str(description).strip())
|
||||
if rendered_model == model:
|
||||
desc_parts.append("current")
|
||||
available_models.append(
|
||||
ModelInfo(
|
||||
model_id=choice_id,
|
||||
name=rendered_model,
|
||||
description=" • ".join(part for part in desc_parts if part),
|
||||
)
|
||||
)
|
||||
seen_ids.add(choice_id)
|
||||
|
||||
current_model_id = self._encode_model_choice(normalized_provider, model)
|
||||
if current_model_id and current_model_id not in seen_ids:
|
||||
available_models.insert(
|
||||
0,
|
||||
ModelInfo(
|
||||
model_id=current_model_id,
|
||||
name=model,
|
||||
description=f"Provider: {provider_name} • current",
|
||||
),
|
||||
)
|
||||
|
||||
if available_models:
|
||||
return SessionModelState(
|
||||
available_models=available_models,
|
||||
current_model_id=current_model_id or available_models[0].model_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Could not build ACP model state", exc_info=True)
|
||||
|
||||
if not model:
|
||||
return None
|
||||
|
||||
fallback_choice = self._encode_model_choice(provider, model)
|
||||
return SessionModelState(
|
||||
available_models=[ModelInfo(model_id=fallback_choice, name=model)],
|
||||
current_model_id=fallback_choice,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_model_selection(raw_model: str, current_provider: str) -> tuple[str, str]:
|
||||
"""Resolve ``provider:model`` input into the provider and normalized model id."""
|
||||
target_provider = current_provider
|
||||
new_model = raw_model.strip()
|
||||
|
||||
try:
|
||||
from hermes_cli.models import detect_provider_for_model, parse_model_input
|
||||
|
||||
target_provider, new_model = parse_model_input(new_model, current_provider)
|
||||
if target_provider == current_provider:
|
||||
detected = detect_provider_for_model(new_model, current_provider)
|
||||
if detected:
|
||||
target_provider, new_model = detected
|
||||
except Exception:
|
||||
logger.debug("Provider detection failed, using model as-is", exc_info=True)
|
||||
|
||||
return target_provider, new_model
|
||||
|
||||
async def _register_session_mcp_servers(
|
||||
self,
|
||||
state: SessionState,
|
||||
@@ -273,7 +367,10 @@ class HermesACPAgent(acp.Agent):
|
||||
await self._register_session_mcp_servers(state, mcp_servers)
|
||||
logger.info("New session %s (cwd=%s)", state.session_id, cwd)
|
||||
self._schedule_available_commands_update(state.session_id)
|
||||
return NewSessionResponse(session_id=state.session_id)
|
||||
return NewSessionResponse(
|
||||
session_id=state.session_id,
|
||||
models=self._build_model_state(state),
|
||||
)
|
||||
|
||||
async def load_session(
|
||||
self,
|
||||
@@ -289,7 +386,7 @@ class HermesACPAgent(acp.Agent):
|
||||
await self._register_session_mcp_servers(state, mcp_servers)
|
||||
logger.info("Loaded session %s", session_id)
|
||||
self._schedule_available_commands_update(session_id)
|
||||
return LoadSessionResponse()
|
||||
return LoadSessionResponse(models=self._build_model_state(state))
|
||||
|
||||
async def resume_session(
|
||||
self,
|
||||
@@ -305,7 +402,7 @@ class HermesACPAgent(acp.Agent):
|
||||
await self._register_session_mcp_servers(state, mcp_servers)
|
||||
logger.info("Resumed session %s", state.session_id)
|
||||
self._schedule_available_commands_update(state.session_id)
|
||||
return ResumeSessionResponse()
|
||||
return ResumeSessionResponse(models=self._build_model_state(state))
|
||||
|
||||
async def cancel(self, session_id: str, **kwargs: Any) -> None:
|
||||
state = self.session_manager.get_session(session_id)
|
||||
@@ -340,11 +437,20 @@ class HermesACPAgent(acp.Agent):
|
||||
cwd: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ListSessionsResponse:
|
||||
infos = self.session_manager.list_sessions()
|
||||
sessions = [
|
||||
SessionInfo(session_id=s["session_id"], cwd=s["cwd"])
|
||||
for s in infos
|
||||
]
|
||||
infos = self.session_manager.list_sessions(cwd=cwd)
|
||||
sessions = []
|
||||
for s in infos:
|
||||
updated_at = s.get("updated_at")
|
||||
if updated_at is not None and not isinstance(updated_at, str):
|
||||
updated_at = str(updated_at)
|
||||
sessions.append(
|
||||
SessionInfo(
|
||||
session_id=s["session_id"],
|
||||
cwd=s["cwd"],
|
||||
title=s.get("title"),
|
||||
updated_at=updated_at,
|
||||
)
|
||||
)
|
||||
return ListSessionsResponse(sessions=sessions)
|
||||
|
||||
# ---- Prompt (core) ------------------------------------------------------
|
||||
@@ -389,12 +495,13 @@ class HermesACPAgent(acp.Agent):
|
||||
state.cancel_event.clear()
|
||||
|
||||
tool_call_ids: dict[str, Deque[str]] = defaultdict(deque)
|
||||
tool_call_meta: dict[str, dict[str, Any]] = {}
|
||||
previous_approval_cb = None
|
||||
|
||||
if conn:
|
||||
tool_progress_cb = make_tool_progress_cb(conn, session_id, loop, tool_call_ids)
|
||||
tool_progress_cb = make_tool_progress_cb(conn, session_id, loop, tool_call_ids, tool_call_meta)
|
||||
thinking_cb = make_thinking_cb(conn, session_id, loop)
|
||||
step_cb = make_step_cb(conn, session_id, loop, tool_call_ids)
|
||||
step_cb = make_step_cb(conn, session_id, loop, tool_call_ids, tool_call_meta)
|
||||
message_cb = make_message_cb(conn, session_id, loop)
|
||||
approval_cb = make_approval_callback(conn.request_permission, loop, session_id)
|
||||
else:
|
||||
@@ -449,6 +556,19 @@ class HermesACPAgent(acp.Agent):
|
||||
self.session_manager.save_session(session_id)
|
||||
|
||||
final_response = result.get("final_response", "")
|
||||
if final_response:
|
||||
try:
|
||||
from agent.title_generator import maybe_auto_title
|
||||
|
||||
maybe_auto_title(
|
||||
self.session_manager._get_db(),
|
||||
session_id,
|
||||
user_text,
|
||||
final_response,
|
||||
state.history,
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to auto-title ACP session %s", session_id, exc_info=True)
|
||||
if final_response and conn:
|
||||
update = acp.update_agent_message_text(final_response)
|
||||
await conn.session_update(session_id, update)
|
||||
@@ -556,27 +676,15 @@ class HermesACPAgent(acp.Agent):
|
||||
provider = getattr(state.agent, "provider", None) or "auto"
|
||||
return f"Current model: {model}\nProvider: {provider}"
|
||||
|
||||
new_model = args.strip()
|
||||
target_provider = None
|
||||
current_provider = getattr(state.agent, "provider", None) or "openrouter"
|
||||
|
||||
# Auto-detect provider for the requested model
|
||||
try:
|
||||
from hermes_cli.models import parse_model_input, detect_provider_for_model
|
||||
target_provider, new_model = parse_model_input(new_model, current_provider)
|
||||
if target_provider == current_provider:
|
||||
detected = detect_provider_for_model(new_model, current_provider)
|
||||
if detected:
|
||||
target_provider, new_model = detected
|
||||
except Exception:
|
||||
logger.debug("Provider detection failed, using model as-is", exc_info=True)
|
||||
target_provider, new_model = self._resolve_model_selection(args, current_provider)
|
||||
|
||||
state.model = new_model
|
||||
state.agent = self.session_manager._make_agent(
|
||||
session_id=state.session_id,
|
||||
cwd=state.cwd,
|
||||
model=new_model,
|
||||
requested_provider=target_provider or current_provider,
|
||||
requested_provider=target_provider,
|
||||
)
|
||||
self.session_manager.save_session(state.session_id)
|
||||
provider_label = getattr(state.agent, "provider", None) or target_provider or current_provider
|
||||
@@ -678,20 +786,30 @@ class HermesACPAgent(acp.Agent):
|
||||
"""Switch the model for a session (called by ACP protocol)."""
|
||||
state = self.session_manager.get_session(session_id)
|
||||
if state:
|
||||
state.model = model_id
|
||||
current_provider = getattr(state.agent, "provider", None)
|
||||
current_base_url = getattr(state.agent, "base_url", None)
|
||||
current_api_mode = getattr(state.agent, "api_mode", None)
|
||||
requested_provider, resolved_model = self._resolve_model_selection(
|
||||
model_id,
|
||||
current_provider or "openrouter",
|
||||
)
|
||||
state.model = resolved_model
|
||||
provider_changed = bool(current_provider and requested_provider != current_provider)
|
||||
current_base_url = None if provider_changed else getattr(state.agent, "base_url", None)
|
||||
current_api_mode = None if provider_changed else getattr(state.agent, "api_mode", None)
|
||||
state.agent = self.session_manager._make_agent(
|
||||
session_id=session_id,
|
||||
cwd=state.cwd,
|
||||
model=model_id,
|
||||
requested_provider=current_provider,
|
||||
model=resolved_model,
|
||||
requested_provider=requested_provider,
|
||||
base_url=current_base_url,
|
||||
api_mode=current_api_mode,
|
||||
)
|
||||
self.session_manager.save_session(session_id)
|
||||
logger.info("Session %s: model switched to %s", session_id, model_id)
|
||||
logger.info(
|
||||
"Session %s: model switched to %s via provider %s",
|
||||
session_id,
|
||||
resolved_model,
|
||||
requested_provider,
|
||||
)
|
||||
return SetSessionModelResponse()
|
||||
logger.warning("Session %s: model switch requested for missing session", session_id)
|
||||
return None
|
||||
|
||||
+127
-34
@@ -13,8 +13,12 @@ from hermes_constants import get_hermes_home
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Lock
|
||||
from typing import Any, Dict, List, Optional
|
||||
@@ -22,6 +26,64 @@ from typing import Any, Dict, List, Optional
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _normalize_cwd_for_compare(cwd: str | None) -> str:
|
||||
raw = str(cwd or ".").strip()
|
||||
if not raw:
|
||||
raw = "."
|
||||
expanded = os.path.expanduser(raw)
|
||||
|
||||
# Normalize Windows drive paths into the equivalent WSL mount form so
|
||||
# ACP history filters match the same workspace across Windows and WSL.
|
||||
match = re.match(r"^([A-Za-z]):[\\/](.*)$", expanded)
|
||||
if match:
|
||||
drive = match.group(1).lower()
|
||||
tail = match.group(2).replace("\\", "/")
|
||||
expanded = f"/mnt/{drive}/{tail}"
|
||||
elif re.match(r"^/mnt/[A-Za-z]/", expanded):
|
||||
expanded = f"/mnt/{expanded[5].lower()}/{expanded[7:]}"
|
||||
|
||||
return os.path.normpath(expanded)
|
||||
|
||||
|
||||
def _build_session_title(title: Any, preview: Any, cwd: str | None) -> str:
|
||||
explicit = str(title or "").strip()
|
||||
if explicit:
|
||||
return explicit
|
||||
preview_text = str(preview or "").strip()
|
||||
if preview_text:
|
||||
return preview_text
|
||||
leaf = os.path.basename(str(cwd or "").rstrip("/\\"))
|
||||
return leaf or "New thread"
|
||||
|
||||
|
||||
def _format_updated_at(value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value
|
||||
try:
|
||||
return datetime.fromtimestamp(float(value), tz=timezone.utc).isoformat()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _updated_at_sort_key(value: Any) -> float:
|
||||
if value is None:
|
||||
return float("-inf")
|
||||
if isinstance(value, (int, float)):
|
||||
return float(value)
|
||||
raw = str(value).strip()
|
||||
if not raw:
|
||||
return float("-inf")
|
||||
try:
|
||||
return datetime.fromisoformat(raw.replace("Z", "+00:00")).timestamp()
|
||||
except Exception:
|
||||
try:
|
||||
return float(raw)
|
||||
except Exception:
|
||||
return float("-inf")
|
||||
|
||||
|
||||
def _acp_stderr_print(*args, **kwargs) -> None:
|
||||
"""Best-effort human-readable output sink for ACP stdio sessions.
|
||||
|
||||
@@ -162,47 +224,78 @@ class SessionManager:
|
||||
logger.info("Forked ACP session %s -> %s", session_id, new_id)
|
||||
return state
|
||||
|
||||
def list_sessions(self) -> List[Dict[str, Any]]:
|
||||
def list_sessions(self, cwd: str | None = None) -> List[Dict[str, Any]]:
|
||||
"""Return lightweight info dicts for all sessions (memory + database)."""
|
||||
normalized_cwd = _normalize_cwd_for_compare(cwd) if cwd else None
|
||||
db = self._get_db()
|
||||
persisted_rows: dict[str, dict[str, Any]] = {}
|
||||
|
||||
if db is not None:
|
||||
try:
|
||||
for row in db.list_sessions_rich(source="acp", limit=1000):
|
||||
persisted_rows[str(row["id"])] = dict(row)
|
||||
except Exception:
|
||||
logger.debug("Failed to load ACP sessions from DB", exc_info=True)
|
||||
|
||||
# Collect in-memory sessions first.
|
||||
with self._lock:
|
||||
seen_ids = set(self._sessions.keys())
|
||||
results = [
|
||||
{
|
||||
"session_id": s.session_id,
|
||||
"cwd": s.cwd,
|
||||
"model": s.model,
|
||||
"history_len": len(s.history),
|
||||
}
|
||||
for s in self._sessions.values()
|
||||
]
|
||||
results = []
|
||||
for s in self._sessions.values():
|
||||
history_len = len(s.history)
|
||||
if history_len <= 0:
|
||||
continue
|
||||
if normalized_cwd and _normalize_cwd_for_compare(s.cwd) != normalized_cwd:
|
||||
continue
|
||||
persisted = persisted_rows.get(s.session_id, {})
|
||||
preview = next(
|
||||
(
|
||||
str(msg.get("content") or "").strip()
|
||||
for msg in s.history
|
||||
if msg.get("role") == "user" and str(msg.get("content") or "").strip()
|
||||
),
|
||||
persisted.get("preview") or "",
|
||||
)
|
||||
results.append(
|
||||
{
|
||||
"session_id": s.session_id,
|
||||
"cwd": s.cwd,
|
||||
"model": s.model,
|
||||
"history_len": history_len,
|
||||
"title": _build_session_title(persisted.get("title"), preview, s.cwd),
|
||||
"updated_at": _format_updated_at(
|
||||
persisted.get("last_active") or persisted.get("started_at") or time.time()
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# Merge any persisted sessions not currently in memory.
|
||||
db = self._get_db()
|
||||
if db is not None:
|
||||
try:
|
||||
rows = db.search_sessions(source="acp", limit=1000)
|
||||
for row in rows:
|
||||
sid = row["id"]
|
||||
if sid in seen_ids:
|
||||
continue
|
||||
# Extract cwd from model_config JSON.
|
||||
cwd = "."
|
||||
mc = row.get("model_config")
|
||||
if mc:
|
||||
try:
|
||||
cwd = json.loads(mc).get("cwd", ".")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
results.append({
|
||||
"session_id": sid,
|
||||
"cwd": cwd,
|
||||
"model": row.get("model") or "",
|
||||
"history_len": row.get("message_count") or 0,
|
||||
})
|
||||
except Exception:
|
||||
logger.debug("Failed to list ACP sessions from DB", exc_info=True)
|
||||
for sid, row in persisted_rows.items():
|
||||
if sid in seen_ids:
|
||||
continue
|
||||
message_count = int(row.get("message_count") or 0)
|
||||
if message_count <= 0:
|
||||
continue
|
||||
# Extract cwd from model_config JSON.
|
||||
session_cwd = "."
|
||||
mc = row.get("model_config")
|
||||
if mc:
|
||||
try:
|
||||
session_cwd = json.loads(mc).get("cwd", ".")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
if normalized_cwd and _normalize_cwd_for_compare(session_cwd) != normalized_cwd:
|
||||
continue
|
||||
results.append({
|
||||
"session_id": sid,
|
||||
"cwd": session_cwd,
|
||||
"model": row.get("model") or "",
|
||||
"history_len": message_count,
|
||||
"title": _build_session_title(row.get("title"), row.get("preview"), session_cwd),
|
||||
"updated_at": _format_updated_at(row.get("last_active") or row.get("started_at")),
|
||||
})
|
||||
|
||||
results.sort(key=lambda item: _updated_at_sort_key(item.get("updated_at")), reverse=True)
|
||||
return results
|
||||
|
||||
def update_cwd(self, session_id: str, cwd: str) -> Optional[SessionState]:
|
||||
|
||||
+174
-9
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
@@ -96,6 +97,170 @@ def build_tool_title(tool_name: str, args: Dict[str, Any]) -> str:
|
||||
return tool_name
|
||||
|
||||
|
||||
def _build_patch_mode_content(patch_text: str) -> List[Any]:
|
||||
"""Parse V4A patch mode input into ACP diff blocks when possible."""
|
||||
if not patch_text:
|
||||
return [acp.tool_content(acp.text_block(""))]
|
||||
|
||||
try:
|
||||
from tools.patch_parser import OperationType, parse_v4a_patch
|
||||
|
||||
operations, error = parse_v4a_patch(patch_text)
|
||||
if error or not operations:
|
||||
return [acp.tool_content(acp.text_block(patch_text))]
|
||||
|
||||
content: List[Any] = []
|
||||
for op in operations:
|
||||
if op.operation == OperationType.UPDATE:
|
||||
old_chunks: list[str] = []
|
||||
new_chunks: list[str] = []
|
||||
for hunk in op.hunks:
|
||||
old_lines = [line.content for line in hunk.lines if line.prefix in (" ", "-")]
|
||||
new_lines = [line.content for line in hunk.lines if line.prefix in (" ", "+")]
|
||||
if old_lines or new_lines:
|
||||
old_chunks.append("\n".join(old_lines))
|
||||
new_chunks.append("\n".join(new_lines))
|
||||
|
||||
old_text = "\n...\n".join(chunk for chunk in old_chunks if chunk)
|
||||
new_text = "\n...\n".join(chunk for chunk in new_chunks if chunk)
|
||||
if old_text or new_text:
|
||||
content.append(
|
||||
acp.tool_diff_content(
|
||||
path=op.file_path,
|
||||
old_text=old_text or None,
|
||||
new_text=new_text or "",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if op.operation == OperationType.ADD:
|
||||
added_lines = [line.content for hunk in op.hunks for line in hunk.lines if line.prefix == "+"]
|
||||
content.append(
|
||||
acp.tool_diff_content(
|
||||
path=op.file_path,
|
||||
new_text="\n".join(added_lines),
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if op.operation == OperationType.DELETE:
|
||||
content.append(
|
||||
acp.tool_diff_content(
|
||||
path=op.file_path,
|
||||
old_text=f"Delete file: {op.file_path}",
|
||||
new_text="",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if op.operation == OperationType.MOVE:
|
||||
content.append(
|
||||
acp.tool_content(acp.text_block(f"Move file: {op.file_path} -> {op.new_path}"))
|
||||
)
|
||||
|
||||
return content or [acp.tool_content(acp.text_block(patch_text))]
|
||||
except Exception:
|
||||
return [acp.tool_content(acp.text_block(patch_text))]
|
||||
|
||||
|
||||
def _strip_diff_prefix(path: str) -> str:
|
||||
raw = str(path or "").strip()
|
||||
if raw.startswith(("a/", "b/")):
|
||||
return raw[2:]
|
||||
return raw
|
||||
|
||||
|
||||
def _parse_unified_diff_content(diff_text: str) -> List[Any]:
|
||||
"""Convert unified diff text into ACP diff content blocks."""
|
||||
if not diff_text:
|
||||
return []
|
||||
|
||||
content: List[Any] = []
|
||||
current_old_path: Optional[str] = None
|
||||
current_new_path: Optional[str] = None
|
||||
old_lines: list[str] = []
|
||||
new_lines: list[str] = []
|
||||
|
||||
def _flush() -> None:
|
||||
nonlocal current_old_path, current_new_path, old_lines, new_lines
|
||||
if current_old_path is None and current_new_path is None:
|
||||
return
|
||||
path = current_new_path if current_new_path and current_new_path != "/dev/null" else current_old_path
|
||||
if not path or path == "/dev/null":
|
||||
current_old_path = None
|
||||
current_new_path = None
|
||||
old_lines = []
|
||||
new_lines = []
|
||||
return
|
||||
content.append(
|
||||
acp.tool_diff_content(
|
||||
path=_strip_diff_prefix(path),
|
||||
old_text="\n".join(old_lines) if old_lines else None,
|
||||
new_text="\n".join(new_lines),
|
||||
)
|
||||
)
|
||||
current_old_path = None
|
||||
current_new_path = None
|
||||
old_lines = []
|
||||
new_lines = []
|
||||
|
||||
for line in diff_text.splitlines():
|
||||
if line.startswith("--- "):
|
||||
_flush()
|
||||
current_old_path = line[4:].strip()
|
||||
continue
|
||||
if line.startswith("+++ "):
|
||||
current_new_path = line[4:].strip()
|
||||
continue
|
||||
if line.startswith("@@"):
|
||||
continue
|
||||
if current_old_path is None and current_new_path is None:
|
||||
continue
|
||||
if line.startswith("+"):
|
||||
new_lines.append(line[1:])
|
||||
elif line.startswith("-"):
|
||||
old_lines.append(line[1:])
|
||||
elif line.startswith(" "):
|
||||
shared = line[1:]
|
||||
old_lines.append(shared)
|
||||
new_lines.append(shared)
|
||||
|
||||
_flush()
|
||||
return content
|
||||
|
||||
|
||||
def _build_tool_complete_content(
|
||||
tool_name: str,
|
||||
result: Optional[str],
|
||||
*,
|
||||
function_args: Optional[Dict[str, Any]] = None,
|
||||
snapshot: Any = None,
|
||||
) -> List[Any]:
|
||||
"""Build structured ACP completion content, falling back to plain text."""
|
||||
display_result = result or ""
|
||||
if len(display_result) > 5000:
|
||||
display_result = display_result[:4900] + f"\n... ({len(result)} chars total, truncated)"
|
||||
|
||||
if tool_name in {"write_file", "patch", "skill_manage"}:
|
||||
try:
|
||||
from agent.display import extract_edit_diff
|
||||
|
||||
diff_text = extract_edit_diff(
|
||||
tool_name,
|
||||
result,
|
||||
function_args=function_args,
|
||||
snapshot=snapshot,
|
||||
)
|
||||
if isinstance(diff_text, str) and diff_text.strip():
|
||||
diff_content = _parse_unified_diff_content(diff_text)
|
||||
if diff_content:
|
||||
return diff_content
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return [acp.tool_content(acp.text_block(display_result))]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Build ACP content objects for tool-call events
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -119,9 +284,8 @@ def build_tool_start(
|
||||
new = arguments.get("new_string", "")
|
||||
content = [acp.tool_diff_content(path=path, new_text=new, old_text=old)]
|
||||
else:
|
||||
# Patch mode — show the patch content as text
|
||||
patch_text = arguments.get("patch", "")
|
||||
content = [acp.tool_content(acp.text_block(patch_text))]
|
||||
content = _build_patch_mode_content(patch_text)
|
||||
return acp.start_tool_call(
|
||||
tool_call_id, title, kind=kind, content=content, locations=locations,
|
||||
raw_input=arguments,
|
||||
@@ -178,16 +342,17 @@ def build_tool_complete(
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
result: Optional[str] = None,
|
||||
function_args: Optional[Dict[str, Any]] = None,
|
||||
snapshot: Any = None,
|
||||
) -> ToolCallProgress:
|
||||
"""Create a ToolCallUpdate (progress) event for a completed tool call."""
|
||||
kind = get_tool_kind(tool_name)
|
||||
|
||||
# Truncate very large results for the UI
|
||||
display_result = result or ""
|
||||
if len(display_result) > 5000:
|
||||
display_result = display_result[:4900] + f"\n... ({len(result)} chars total, truncated)"
|
||||
|
||||
content = [acp.tool_content(acp.text_block(display_result))]
|
||||
content = _build_tool_complete_content(
|
||||
tool_name,
|
||||
result,
|
||||
function_args=function_args,
|
||||
snapshot=snapshot,
|
||||
)
|
||||
return acp.update_tool_call(
|
||||
tool_call_id,
|
||||
kind=kind,
|
||||
|
||||
+78
-32
@@ -94,6 +94,17 @@ def _normalize_aux_provider(provider: Optional[str]) -> str:
|
||||
return "custom"
|
||||
return _PROVIDER_ALIASES.get(normalized, normalized)
|
||||
|
||||
|
||||
_FIXED_TEMPERATURE_MODELS: Dict[str, float] = {
|
||||
"kimi-for-coding": 0.6,
|
||||
}
|
||||
|
||||
|
||||
def _fixed_temperature_for_model(model: Optional[str]) -> Optional[float]:
|
||||
"""Return a required temperature override for models with strict contracts."""
|
||||
normalized = (model or "").strip().lower()
|
||||
return _FIXED_TEMPERATURE_MODELS.get(normalized)
|
||||
|
||||
# Default auxiliary models for direct API-key providers (cheap/fast for side tasks)
|
||||
_API_KEY_PROVIDER_AUX_MODELS: Dict[str, str] = {
|
||||
"gemini": "gemini-3-flash-preview",
|
||||
@@ -734,6 +745,15 @@ def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
from hermes_cli.models import copilot_default_headers
|
||||
|
||||
extra["default_headers"] = copilot_default_headers()
|
||||
elif "generativelanguage.googleapis.com" in base_url.lower():
|
||||
# Google's OpenAI-compatible endpoint only accepts x-goog-api-key.
|
||||
# Passing api_key= causes the SDK to inject Authorization: Bearer,
|
||||
# which Google rejects with HTTP 400 "Multiple authentication
|
||||
# credentials received". Use a placeholder for api_key and pass
|
||||
# the real key via x-goog-api-key header instead.
|
||||
# Fixes: https://github.com/NousResearch/hermes-agent/issues/7893
|
||||
extra["default_headers"] = {"x-goog-api-key": api_key}
|
||||
api_key = "not-used"
|
||||
return OpenAI(api_key=api_key, base_url=base_url, **extra), model
|
||||
|
||||
creds = resolve_api_key_provider_credentials(provider_id)
|
||||
@@ -755,6 +775,15 @@ def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
from hermes_cli.models import copilot_default_headers
|
||||
|
||||
extra["default_headers"] = copilot_default_headers()
|
||||
elif "generativelanguage.googleapis.com" in base_url.lower():
|
||||
# Google's OpenAI-compatible endpoint only accepts x-goog-api-key.
|
||||
# Passing api_key= causes the SDK to inject Authorization: Bearer,
|
||||
# which Google rejects with HTTP 400 "Multiple authentication
|
||||
# credentials received". Use a placeholder for api_key and pass
|
||||
# the real key via x-goog-api-key header instead.
|
||||
# Fixes: https://github.com/NousResearch/hermes-agent/issues/7893
|
||||
extra["default_headers"] = {"x-goog-api-key": api_key}
|
||||
api_key = "not-used"
|
||||
return OpenAI(api_key=api_key, base_url=base_url, **extra), model
|
||||
|
||||
return None, None
|
||||
@@ -1064,8 +1093,6 @@ _AUTO_PROVIDER_LABELS = {
|
||||
"_resolve_api_key_provider": "api-key",
|
||||
}
|
||||
|
||||
_AGGREGATOR_PROVIDERS = frozenset({"openrouter", "nous"})
|
||||
|
||||
_MAIN_RUNTIME_FIELDS = ("provider", "model", "base_url", "api_key", "api_mode")
|
||||
|
||||
|
||||
@@ -1196,11 +1223,15 @@ def _resolve_auto(main_runtime: Optional[Dict[str, Any]] = None) -> Tuple[Option
|
||||
"""Full auto-detection chain.
|
||||
|
||||
Priority:
|
||||
1. If the user's main provider is NOT an aggregator (OpenRouter / Nous),
|
||||
use their main provider + main model directly. This ensures users on
|
||||
Alibaba, DeepSeek, ZAI, etc. get auxiliary tasks handled by the same
|
||||
provider they already have credentials for — no OpenRouter key needed.
|
||||
2. OpenRouter → Nous → custom → Codex → API-key providers (original chain).
|
||||
1. User's main provider + main model, regardless of provider type.
|
||||
This means auxiliary tasks (compression, vision, web extraction,
|
||||
session search, etc.) use the same model the user configured for
|
||||
chat. Users on OpenRouter/Nous get their chosen chat model; users
|
||||
on DeepSeek/ZAI/Alibaba get theirs; etc. Running aux tasks on the
|
||||
user's picked model keeps behavior predictable — no surprise
|
||||
switches to a cheap fallback model for side tasks.
|
||||
2. OpenRouter → Nous → custom → Codex → API-key providers (fallback
|
||||
chain, only used when the main provider has no working client).
|
||||
"""
|
||||
global auxiliary_is_nous, _stale_base_url_warned
|
||||
auxiliary_is_nous = False # Reset — _try_nous() will set True if it wins
|
||||
@@ -1230,11 +1261,16 @@ def _resolve_auto(main_runtime: Optional[Dict[str, Any]] = None) -> Tuple[Option
|
||||
)
|
||||
_stale_base_url_warned = True
|
||||
|
||||
# ── Step 1: non-aggregator main provider → use main model directly ──
|
||||
# ── Step 1: main provider + main model → use them directly ──
|
||||
#
|
||||
# This is the primary aux backend for every user. "auto" means
|
||||
# "use my main chat model for side tasks as well" — including users
|
||||
# on aggregators (OpenRouter, Nous) who previously got routed to a
|
||||
# cheap provider-side default. Explicit per-task overrides set via
|
||||
# config.yaml (auxiliary.<task>.provider) still win over this.
|
||||
main_provider = runtime_provider or _read_main_provider()
|
||||
main_model = runtime_model or _read_main_model()
|
||||
if (main_provider and main_model
|
||||
and main_provider not in _AGGREGATOR_PROVIDERS
|
||||
and main_provider not in ("auto", "")):
|
||||
resolved_provider = main_provider
|
||||
explicit_base_url = None
|
||||
@@ -1593,6 +1629,15 @@ def resolve_provider_client(
|
||||
from hermes_cli.models import copilot_default_headers
|
||||
|
||||
headers.update(copilot_default_headers())
|
||||
elif "generativelanguage.googleapis.com" in base_url.lower():
|
||||
# Google's OpenAI-compatible endpoint only accepts x-goog-api-key.
|
||||
# Passing api_key= causes the OpenAI SDK to inject Authorization: Bearer,
|
||||
# which Google rejects with HTTP 400 "Multiple authentication credentials
|
||||
# received". Use a placeholder for api_key and pass the real key via
|
||||
# x-goog-api-key header instead.
|
||||
# Fixes: https://github.com/NousResearch/hermes-agent/issues/7893
|
||||
headers["x-goog-api-key"] = api_key
|
||||
api_key = "not-used"
|
||||
|
||||
client = OpenAI(api_key=api_key, base_url=base_url,
|
||||
**({"default_headers": headers} if headers else {}))
|
||||
@@ -1817,34 +1862,31 @@ def resolve_vision_provider_client(
|
||||
|
||||
if requested == "auto":
|
||||
# Vision auto-detection order:
|
||||
# 1. Active provider + model (user's main chat config)
|
||||
# 2. OpenRouter (known vision-capable default model)
|
||||
# 3. Nous Portal (known vision-capable default model)
|
||||
# 1. User's main provider + main model (including aggregators).
|
||||
# _PROVIDER_VISION_MODELS provides per-provider vision model
|
||||
# overrides when the provider has a dedicated multimodal model
|
||||
# that differs from the chat model (e.g. xiaomi → mimo-v2-omni,
|
||||
# zai → glm-5v-turbo).
|
||||
# 2. OpenRouter (vision-capable aggregator fallback)
|
||||
# 3. Nous Portal (vision-capable aggregator fallback)
|
||||
# 4. Stop
|
||||
main_provider = _read_main_provider()
|
||||
main_model = _read_main_model()
|
||||
if main_provider and main_provider not in ("auto", ""):
|
||||
if main_provider in _VISION_AUTO_PROVIDER_ORDER:
|
||||
# Known strict backend — use its defaults.
|
||||
sync_client, default_model = _resolve_strict_vision_backend(main_provider)
|
||||
if sync_client is not None:
|
||||
return _finalize(main_provider, sync_client, default_model)
|
||||
else:
|
||||
# Exotic provider (DeepSeek, Alibaba, Xiaomi, named custom, etc.)
|
||||
# Use provider-specific vision model if available, otherwise main model.
|
||||
vision_model = _PROVIDER_VISION_MODELS.get(main_provider, main_model)
|
||||
rpc_client, rpc_model = resolve_provider_client(
|
||||
main_provider, vision_model,
|
||||
api_mode=resolved_api_mode)
|
||||
if rpc_client is not None:
|
||||
logger.info(
|
||||
"Vision auto-detect: using active provider %s (%s)",
|
||||
main_provider, rpc_model or vision_model,
|
||||
)
|
||||
return _finalize(
|
||||
main_provider, rpc_client, rpc_model or vision_model)
|
||||
vision_model = _PROVIDER_VISION_MODELS.get(main_provider, main_model)
|
||||
rpc_client, rpc_model = resolve_provider_client(
|
||||
main_provider, vision_model,
|
||||
api_mode=resolved_api_mode)
|
||||
if rpc_client is not None:
|
||||
logger.info(
|
||||
"Vision auto-detect: using main provider %s (%s)",
|
||||
main_provider, rpc_model or vision_model,
|
||||
)
|
||||
return _finalize(
|
||||
main_provider, rpc_client, rpc_model or vision_model)
|
||||
|
||||
# Fall back through aggregators.
|
||||
# Fall back through aggregators (uses their dedicated vision model,
|
||||
# not the user's main model) when main provider has no client.
|
||||
for candidate in _VISION_AUTO_PROVIDER_ORDER:
|
||||
if candidate == main_provider:
|
||||
continue # already tried above
|
||||
@@ -2293,6 +2335,10 @@ def _build_call_kwargs(
|
||||
"timeout": timeout,
|
||||
}
|
||||
|
||||
fixed_temperature = _fixed_temperature_for_model(model)
|
||||
if fixed_temperature is not None:
|
||||
temperature = fixed_temperature
|
||||
|
||||
# Opus 4.7+ rejects any non-default temperature/top_p/top_k — silently
|
||||
# drop here so auxiliary callers that hardcode temperature (e.g. 0.3 on
|
||||
# flush_memories, 0 on structured-JSON extraction) don't 400 the moment
|
||||
|
||||
@@ -1130,6 +1130,14 @@ def _seed_from_singletons(provider: str, entries: List[PooledCredential]) -> Tup
|
||||
state = _load_provider_state(auth_store, "nous")
|
||||
if state:
|
||||
active_sources.add("device_code")
|
||||
# Prefer a user-supplied label embedded in the singleton state
|
||||
# (set by persist_nous_credentials(label=...) when the user ran
|
||||
# `hermes auth add nous --label <name>`). Fall back to the
|
||||
# auto-derived token fingerprint for logins that didn't supply one.
|
||||
custom_label = str(state.get("label") or "").strip()
|
||||
seeded_label = custom_label or label_from_token(
|
||||
state.get("access_token", ""), "device_code"
|
||||
)
|
||||
changed |= _upsert_entry(
|
||||
entries,
|
||||
provider,
|
||||
@@ -1148,7 +1156,7 @@ def _seed_from_singletons(provider: str, entries: List[PooledCredential]) -> Tup
|
||||
"agent_key": state.get("agent_key"),
|
||||
"agent_key_expires_at": state.get("agent_key_expires_at"),
|
||||
"tls": state.get("tls") if isinstance(state.get("tls"), dict) else None,
|
||||
"label": label_from_token(state.get("access_token", ""), "device_code"),
|
||||
"label": seeded_label,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1208,6 +1216,19 @@ def _seed_from_singletons(provider: str, entries: List[PooledCredential]) -> Tup
|
||||
logger.debug("Qwen OAuth token seed failed: %s", exc)
|
||||
|
||||
elif provider == "openai-codex":
|
||||
# Respect user suppression — `hermes auth remove openai-codex` marks
|
||||
# the device_code source as suppressed so it won't be re-seeded from
|
||||
# either the Hermes auth store or ~/.codex/auth.json. Without this
|
||||
# gate the removal is instantly undone on the next load_pool() call.
|
||||
codex_suppressed = False
|
||||
try:
|
||||
from hermes_cli.auth import is_source_suppressed
|
||||
codex_suppressed = is_source_suppressed(provider, "device_code")
|
||||
except ImportError:
|
||||
pass
|
||||
if codex_suppressed:
|
||||
return changed, active_sources
|
||||
|
||||
state = _load_provider_state(auth_store, "openai-codex")
|
||||
tokens = state.get("tokens") if isinstance(state, dict) else None
|
||||
# Fallback: import from Codex CLI (~/.codex/auth.json) if Hermes auth
|
||||
|
||||
@@ -747,18 +747,149 @@ class GeminiCloudCodeClient:
|
||||
|
||||
|
||||
def _gemini_http_error(response: httpx.Response) -> CodeAssistError:
|
||||
"""Translate an httpx response into a CodeAssistError with rich metadata.
|
||||
|
||||
Parses Google's error envelope (``{"error": {"code", "message", "status",
|
||||
"details": [...]}}``) so the agent's error classifier can reason about
|
||||
the failure — ``status_code`` enables the rate_limit / auth classification
|
||||
paths, and ``response`` lets the main loop honor ``Retry-After`` just
|
||||
like it does for OpenAI SDK exceptions.
|
||||
|
||||
Also lifts a few recognizable Google conditions into human-readable
|
||||
messages so the user sees something better than a 500-char JSON dump:
|
||||
|
||||
MODEL_CAPACITY_EXHAUSTED → "Gemini model capacity exhausted for
|
||||
<model>. This is a Google-side throttle..."
|
||||
RESOURCE_EXHAUSTED w/o reason → quota-style message
|
||||
404 → "Model <name> not found at cloudcode-pa..."
|
||||
"""
|
||||
status = response.status_code
|
||||
|
||||
# Parse the body once, surviving any weird encodings.
|
||||
body_text = ""
|
||||
body_json: Dict[str, Any] = {}
|
||||
try:
|
||||
body = response.text[:500]
|
||||
body_text = response.text
|
||||
except Exception:
|
||||
body = ""
|
||||
# Let run_agent's retry logic see auth errors as rotatable via `api_key`
|
||||
body_text = ""
|
||||
if body_text:
|
||||
try:
|
||||
parsed = json.loads(body_text)
|
||||
if isinstance(parsed, dict):
|
||||
body_json = parsed
|
||||
except (ValueError, TypeError):
|
||||
body_json = {}
|
||||
|
||||
# Dig into Google's error envelope. Shape is:
|
||||
# {"error": {"code": 429, "message": "...", "status": "RESOURCE_EXHAUSTED",
|
||||
# "details": [{"@type": ".../ErrorInfo", "reason": "MODEL_CAPACITY_EXHAUSTED",
|
||||
# "metadata": {...}},
|
||||
# {"@type": ".../RetryInfo", "retryDelay": "30s"}]}}
|
||||
err_obj = body_json.get("error") if isinstance(body_json, dict) else None
|
||||
if not isinstance(err_obj, dict):
|
||||
err_obj = {}
|
||||
err_status = str(err_obj.get("status") or "").strip()
|
||||
err_message = str(err_obj.get("message") or "").strip()
|
||||
err_details_list = err_obj.get("details") if isinstance(err_obj.get("details"), list) else []
|
||||
|
||||
# Extract google.rpc.ErrorInfo reason + metadata. There may be more
|
||||
# than one ErrorInfo (rare), so we pick the first one with a reason.
|
||||
error_reason = ""
|
||||
error_metadata: Dict[str, Any] = {}
|
||||
retry_delay_seconds: Optional[float] = None
|
||||
for detail in err_details_list:
|
||||
if not isinstance(detail, dict):
|
||||
continue
|
||||
type_url = str(detail.get("@type") or "")
|
||||
if not error_reason and type_url.endswith("/google.rpc.ErrorInfo"):
|
||||
reason = detail.get("reason")
|
||||
if isinstance(reason, str) and reason:
|
||||
error_reason = reason
|
||||
md = detail.get("metadata")
|
||||
if isinstance(md, dict):
|
||||
error_metadata = md
|
||||
elif retry_delay_seconds is None and type_url.endswith("/google.rpc.RetryInfo"):
|
||||
# retryDelay is a google.protobuf.Duration string like "30s" or "1.5s".
|
||||
delay_raw = detail.get("retryDelay")
|
||||
if isinstance(delay_raw, str) and delay_raw.endswith("s"):
|
||||
try:
|
||||
retry_delay_seconds = float(delay_raw[:-1])
|
||||
except ValueError:
|
||||
pass
|
||||
elif isinstance(delay_raw, (int, float)):
|
||||
retry_delay_seconds = float(delay_raw)
|
||||
|
||||
# Fall back to the Retry-After header if the body didn't include RetryInfo.
|
||||
if retry_delay_seconds is None:
|
||||
try:
|
||||
header_val = response.headers.get("Retry-After") or response.headers.get("retry-after")
|
||||
except Exception:
|
||||
header_val = None
|
||||
if header_val:
|
||||
try:
|
||||
retry_delay_seconds = float(header_val)
|
||||
except (TypeError, ValueError):
|
||||
retry_delay_seconds = None
|
||||
|
||||
# Classify the error code. ``code_assist_rate_limited`` stays the default
|
||||
# for 429s; a more specific reason tag helps downstream callers (e.g. tests,
|
||||
# logs) without changing the rate_limit classification path.
|
||||
code = f"code_assist_http_{status}"
|
||||
if status == 401:
|
||||
code = "code_assist_unauthorized"
|
||||
elif status == 429:
|
||||
code = "code_assist_rate_limited"
|
||||
if error_reason == "MODEL_CAPACITY_EXHAUSTED":
|
||||
code = "code_assist_capacity_exhausted"
|
||||
|
||||
# Build a human-readable message. Keep the status + a raw-body tail for
|
||||
# debugging, but lead with a friendlier summary when we recognize the
|
||||
# Google signal.
|
||||
model_hint = ""
|
||||
if isinstance(error_metadata, dict):
|
||||
model_hint = str(error_metadata.get("model") or error_metadata.get("modelId") or "").strip()
|
||||
|
||||
if status == 429 and error_reason == "MODEL_CAPACITY_EXHAUSTED":
|
||||
target = model_hint or "this Gemini model"
|
||||
message = (
|
||||
f"Gemini capacity exhausted for {target} (Google-side throttle, "
|
||||
f"not a Hermes issue). Try a different Gemini model or set a "
|
||||
f"fallback_providers entry to a non-Gemini provider."
|
||||
)
|
||||
if retry_delay_seconds is not None:
|
||||
message += f" Google suggests retrying in {retry_delay_seconds:g}s."
|
||||
elif status == 429 and err_status == "RESOURCE_EXHAUSTED":
|
||||
message = (
|
||||
f"Gemini quota exhausted ({err_message or 'RESOURCE_EXHAUSTED'}). "
|
||||
f"Check /gquota for remaining daily requests."
|
||||
)
|
||||
if retry_delay_seconds is not None:
|
||||
message += f" Retry suggested in {retry_delay_seconds:g}s."
|
||||
elif status == 404:
|
||||
# Google returns 404 when a model has been retired or renamed.
|
||||
target = model_hint or (err_message or "model")
|
||||
message = (
|
||||
f"Code Assist 404: {target} is not available at "
|
||||
f"cloudcode-pa.googleapis.com. It may have been renamed or "
|
||||
f"retired. Check hermes_cli/models.py for the current list."
|
||||
)
|
||||
elif err_message:
|
||||
# Generic fallback with the parsed message.
|
||||
message = f"Code Assist HTTP {status} ({err_status or 'error'}): {err_message}"
|
||||
else:
|
||||
# Last-ditch fallback — raw body snippet.
|
||||
message = f"Code Assist returned HTTP {status}: {body_text[:500]}"
|
||||
|
||||
return CodeAssistError(
|
||||
f"Code Assist returned HTTP {status}: {body}",
|
||||
message,
|
||||
code=code,
|
||||
status_code=status,
|
||||
response=response,
|
||||
retry_after=retry_delay_seconds,
|
||||
details={
|
||||
"status": err_status,
|
||||
"reason": error_reason,
|
||||
"metadata": error_metadata,
|
||||
"message": err_message,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -68,9 +68,45 @@ _ONBOARDING_POLL_INTERVAL_SECONDS = 5.0
|
||||
|
||||
|
||||
class CodeAssistError(RuntimeError):
|
||||
def __init__(self, message: str, *, code: str = "code_assist_error") -> None:
|
||||
"""Exception raised by the Code Assist (``cloudcode-pa``) integration.
|
||||
|
||||
Carries HTTP status / response / retry-after metadata so the agent's
|
||||
``error_classifier._extract_status_code`` and the main loop's Retry-After
|
||||
handling (which walks ``error.response.headers``) pick up the right
|
||||
signals. Without these, 429s from the OAuth path look like opaque
|
||||
``RuntimeError`` and skip the rate-limit path.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
code: str = "code_assist_error",
|
||||
status_code: Optional[int] = None,
|
||||
response: Any = None,
|
||||
retry_after: Optional[float] = None,
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.code = code
|
||||
# ``status_code`` is picked up by ``agent.error_classifier._extract_status_code``
|
||||
# so a 429 from Code Assist classifies as FailoverReason.rate_limit and
|
||||
# triggers the main loop's fallback_providers chain the same way SDK
|
||||
# errors do.
|
||||
self.status_code = status_code
|
||||
# ``response`` is the underlying ``httpx.Response`` (or a shim with a
|
||||
# ``.headers`` mapping and ``.json()`` method). The main loop reads
|
||||
# ``error.response.headers["Retry-After"]`` to honor Google's retry
|
||||
# hints when the backend throttles us.
|
||||
self.response = response
|
||||
# Parsed ``Retry-After`` seconds (kept separately for convenience —
|
||||
# Google returns retry hints in both the header and the error body's
|
||||
# ``google.rpc.RetryInfo`` details, and we pick whichever we found).
|
||||
self.retry_after = retry_after
|
||||
# Parsed structured error details from the Google error envelope
|
||||
# (e.g. ``{"reason": "MODEL_CAPACITY_EXHAUSTED", "status": "RESOURCE_EXHAUSTED"}``).
|
||||
# Useful for logging and for tests that want to assert on specifics.
|
||||
self.details = details or {}
|
||||
|
||||
|
||||
class ProjectIdRequiredError(CodeAssistError):
|
||||
|
||||
@@ -38,6 +38,7 @@ _PROVIDER_PREFIXES: frozenset[str] = frozenset({
|
||||
"mimo", "xiaomi-mimo",
|
||||
"arcee-ai", "arceeai",
|
||||
"xai", "x-ai", "x.ai", "grok",
|
||||
"nvidia", "nim", "nvidia-nim", "nemotron",
|
||||
"qwen-portal",
|
||||
})
|
||||
|
||||
@@ -124,7 +125,6 @@ DEFAULT_CONTEXT_LENGTHS = {
|
||||
"gemini": 1048576,
|
||||
# Gemma (open models served via AI Studio)
|
||||
"gemma-4-31b": 256000,
|
||||
"gemma-4-26b": 256000,
|
||||
"gemma-3": 131072,
|
||||
"gemma": 8192, # fallback for older gemma models
|
||||
# DeepSeek
|
||||
@@ -158,6 +158,8 @@ DEFAULT_CONTEXT_LENGTHS = {
|
||||
"grok": 131072, # catch-all (grok-beta, unknown grok-*)
|
||||
# Kimi
|
||||
"kimi": 262144,
|
||||
# Nemotron — NVIDIA's open-weights series (128K context across all sizes)
|
||||
"nemotron": 131072,
|
||||
# Arcee
|
||||
"trinity": 262144,
|
||||
# OpenRouter
|
||||
@@ -240,6 +242,7 @@ _URL_TO_PROVIDER: Dict[str, str] = {
|
||||
"api.fireworks.ai": "fireworks",
|
||||
"opencode.ai": "opencode-go",
|
||||
"api.x.ai": "xai",
|
||||
"integrate.api.nvidia.com": "nvidia",
|
||||
"api.xiaomimimo.com": "xiaomi",
|
||||
"xiaomimimo.com": "xiaomi",
|
||||
"ollama.com": "ollama-cloud",
|
||||
|
||||
@@ -654,7 +654,7 @@ def build_skills_system_prompt(
|
||||
):
|
||||
continue
|
||||
skills_by_category.setdefault(category, []).append(
|
||||
(skill_name, entry.get("description", ""))
|
||||
(frontmatter_name, entry.get("description", ""))
|
||||
)
|
||||
category_descriptions = {
|
||||
str(k): str(v)
|
||||
@@ -679,7 +679,7 @@ def build_skills_system_prompt(
|
||||
):
|
||||
continue
|
||||
skills_by_category.setdefault(entry["category"], []).append(
|
||||
(skill_name, entry["description"])
|
||||
(entry["frontmatter_name"], entry["description"])
|
||||
)
|
||||
|
||||
# Read category-level DESCRIPTION.md files
|
||||
@@ -722,9 +722,10 @@ def build_skills_system_prompt(
|
||||
continue
|
||||
entry = _build_snapshot_entry(skill_file, ext_dir, frontmatter, desc)
|
||||
skill_name = entry["skill_name"]
|
||||
if skill_name in seen_skill_names:
|
||||
frontmatter_name = entry["frontmatter_name"]
|
||||
if frontmatter_name in seen_skill_names:
|
||||
continue
|
||||
if entry["frontmatter_name"] in disabled or skill_name in disabled:
|
||||
if frontmatter_name in disabled or skill_name in disabled:
|
||||
continue
|
||||
if not _skill_should_show(
|
||||
extract_skill_conditions(frontmatter),
|
||||
@@ -732,9 +733,9 @@ def build_skills_system_prompt(
|
||||
available_toolsets,
|
||||
):
|
||||
continue
|
||||
seen_skill_names.add(skill_name)
|
||||
seen_skill_names.add(frontmatter_name)
|
||||
skills_by_category.setdefault(entry["category"], []).append(
|
||||
(skill_name, entry["description"])
|
||||
(frontmatter_name, entry["description"])
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Error reading external skill %s: %s", skill_file, e)
|
||||
|
||||
@@ -24,6 +24,7 @@ model:
|
||||
# "minimax" - MiniMax global (requires: MINIMAX_API_KEY)
|
||||
# "minimax-cn" - MiniMax China (requires: MINIMAX_CN_API_KEY)
|
||||
# "huggingface" - Hugging Face Inference (requires: HF_TOKEN)
|
||||
# "nvidia" - NVIDIA NIM / build.nvidia.com (requires: NVIDIA_API_KEY)
|
||||
# "xiaomi" - Xiaomi MiMo (requires: XIAOMI_API_KEY)
|
||||
# "arcee" - Arcee AI Trinity models (requires: ARCEEAI_API_KEY)
|
||||
# "ollama-cloud" - Ollama Cloud (requires: OLLAMA_API_KEY — https://ollama.com/settings)
|
||||
|
||||
@@ -18,6 +18,8 @@ import os
|
||||
import shutil
|
||||
import sys
|
||||
import json
|
||||
import re
|
||||
import base64
|
||||
import atexit
|
||||
import tempfile
|
||||
import time
|
||||
@@ -78,6 +80,42 @@ _project_env = Path(__file__).parent / '.env'
|
||||
load_hermes_dotenv(hermes_home=_hermes_home, project_env=_project_env)
|
||||
|
||||
|
||||
_REASONING_TAGS = (
|
||||
"REASONING_SCRATCHPAD",
|
||||
"think",
|
||||
"reasoning",
|
||||
"THINKING",
|
||||
"thinking",
|
||||
)
|
||||
|
||||
|
||||
def _strip_reasoning_tags(text: str) -> str:
|
||||
cleaned = text
|
||||
for tag in _REASONING_TAGS:
|
||||
cleaned = re.sub(rf"<{tag}>.*?</{tag}>\s*", "", cleaned, flags=re.DOTALL)
|
||||
cleaned = re.sub(rf"<{tag}>.*$", "", cleaned, flags=re.DOTALL)
|
||||
return cleaned.strip()
|
||||
|
||||
|
||||
def _assistant_content_as_text(content: Any) -> str:
|
||||
if content is None:
|
||||
return ""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = [
|
||||
str(part.get("text", ""))
|
||||
for part in content
|
||||
if isinstance(part, dict) and part.get("type") == "text"
|
||||
]
|
||||
return "\n".join(p for p in parts if p)
|
||||
return str(content)
|
||||
|
||||
|
||||
def _assistant_copy_text(content: Any) -> str:
|
||||
return _strip_reasoning_tags(_assistant_content_as_text(content))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Configuration Loading
|
||||
# =============================================================================
|
||||
@@ -1172,6 +1210,10 @@ def _resolve_attachment_path(raw_path: str) -> Path | None:
|
||||
return None
|
||||
|
||||
expanded = os.path.expandvars(os.path.expanduser(token))
|
||||
if os.name != "nt":
|
||||
normalized = expanded.replace("\\", "/")
|
||||
if len(normalized) >= 3 and normalized[1] == ":" and normalized[2] == "/" and normalized[0].isalpha():
|
||||
expanded = f"/mnt/{normalized[0].lower()}/{normalized[3:]}"
|
||||
path = Path(expanded)
|
||||
if not path.is_absolute():
|
||||
base_dir = Path(os.getenv("TERMINAL_CWD", os.getcwd()))
|
||||
@@ -1254,10 +1296,12 @@ def _detect_file_drop(user_input: str) -> "dict | None":
|
||||
or stripped.startswith("~")
|
||||
or stripped.startswith("./")
|
||||
or stripped.startswith("../")
|
||||
or (len(stripped) >= 3 and stripped[1] == ":" and stripped[2] in ("\\", "/") and stripped[0].isalpha())
|
||||
or stripped.startswith('"/')
|
||||
or stripped.startswith('"~')
|
||||
or stripped.startswith("'/")
|
||||
or stripped.startswith("'~")
|
||||
or (len(stripped) >= 4 and stripped[0] in ("'", '"') and stripped[2] == ":" and stripped[3] in ("\\", "/") and stripped[1].isalpha())
|
||||
)
|
||||
if not starts_like_path:
|
||||
return None
|
||||
@@ -3125,21 +3169,6 @@ class HermesCLI:
|
||||
MAX_ASST_LEN = 200 # truncate assistant text
|
||||
MAX_ASST_LINES = 3 # max lines of assistant text
|
||||
|
||||
def _strip_reasoning(text: str) -> str:
|
||||
"""Remove <REASONING_SCRATCHPAD>...</REASONING_SCRATCHPAD> blocks
|
||||
from displayed text (reasoning model internal thoughts)."""
|
||||
import re
|
||||
cleaned = re.sub(
|
||||
r"<REASONING_SCRATCHPAD>.*?</REASONING_SCRATCHPAD>\s*",
|
||||
"", text, flags=re.DOTALL,
|
||||
)
|
||||
# Also strip unclosed reasoning tags at the end
|
||||
cleaned = re.sub(
|
||||
r"<REASONING_SCRATCHPAD>.*$",
|
||||
"", cleaned, flags=re.DOTALL,
|
||||
)
|
||||
return cleaned.strip()
|
||||
|
||||
# Collect displayable entries (skip system, tool-result messages)
|
||||
entries = [] # list of (role, display_text)
|
||||
_last_asst_idx = None # index of last assistant entry
|
||||
@@ -3171,7 +3200,7 @@ class HermesCLI:
|
||||
|
||||
elif role == "assistant":
|
||||
text = "" if content is None else str(content)
|
||||
text = _strip_reasoning(text)
|
||||
text = _strip_reasoning_tags(text)
|
||||
parts = []
|
||||
full_parts = [] # un-truncated version
|
||||
if text:
|
||||
@@ -3510,6 +3539,26 @@ class HermesCLI:
|
||||
killed = process_registry.kill_all()
|
||||
print(f" ✅ Stopped {killed} process(es).")
|
||||
|
||||
def _handle_agents_command(self):
|
||||
"""Handle /agents — show background processes and agent status."""
|
||||
from tools.process_registry import format_uptime_short, process_registry
|
||||
|
||||
processes = process_registry.list_sessions()
|
||||
running = [p for p in processes if p.get("status") == "running"]
|
||||
finished = [p for p in processes if p.get("status") != "running"]
|
||||
|
||||
_cprint(f" Running processes: {len(running)}")
|
||||
for p in running:
|
||||
cmd = p.get("command", "")[:80]
|
||||
up = format_uptime_short(p.get("uptime_seconds", 0))
|
||||
_cprint(f" {p.get('session_id', '?')} · {up} · {cmd}")
|
||||
|
||||
if finished:
|
||||
_cprint(f" Recently finished: {len(finished)}")
|
||||
|
||||
agent_running = getattr(self, "_agent_running", False)
|
||||
_cprint(f" Agent: {'running' if agent_running else 'idle'}")
|
||||
|
||||
def _handle_paste_command(self):
|
||||
"""Handle /paste — explicitly check clipboard for an image.
|
||||
|
||||
@@ -3535,6 +3584,61 @@ class HermesCLI:
|
||||
else:
|
||||
_cprint(f" {_DIM}(._.) No image found in clipboard{_RST}")
|
||||
|
||||
def _write_osc52_clipboard(self, text: str) -> None:
|
||||
"""Copy *text* to terminal clipboard via OSC 52."""
|
||||
payload = base64.b64encode(text.encode("utf-8")).decode("ascii")
|
||||
seq = f"\x1b]52;c;{payload}\x07"
|
||||
out = getattr(self, "_app", None)
|
||||
output = getattr(out, "output", None) if out else None
|
||||
if output and hasattr(output, "write_raw"):
|
||||
output.write_raw(seq)
|
||||
output.flush()
|
||||
return
|
||||
if output and hasattr(output, "write"):
|
||||
output.write(seq)
|
||||
output.flush()
|
||||
return
|
||||
sys.stdout.write(seq)
|
||||
sys.stdout.flush()
|
||||
|
||||
def _handle_copy_command(self, cmd_original: str) -> None:
|
||||
"""Handle /copy [number] — copy assistant output to clipboard."""
|
||||
parts = cmd_original.split(maxsplit=1)
|
||||
arg = parts[1].strip() if len(parts) > 1 else ""
|
||||
|
||||
assistant = [m for m in self.conversation_history if m.get("role") == "assistant"]
|
||||
if not assistant:
|
||||
_cprint(" Nothing to copy yet.")
|
||||
return
|
||||
|
||||
if arg:
|
||||
try:
|
||||
idx = int(arg) - 1
|
||||
except ValueError:
|
||||
_cprint(" Usage: /copy [number]")
|
||||
return
|
||||
if idx < 0 or idx >= len(assistant):
|
||||
_cprint(f" Invalid response number. Use 1-{len(assistant)}.")
|
||||
return
|
||||
else:
|
||||
idx = len(assistant) - 1
|
||||
while idx >= 0 and not _assistant_copy_text(assistant[idx].get("content")):
|
||||
idx -= 1
|
||||
if idx < 0:
|
||||
_cprint(" Nothing to copy in assistant responses yet.")
|
||||
return
|
||||
|
||||
text = _assistant_copy_text(assistant[idx].get("content"))
|
||||
if not text:
|
||||
_cprint(" Nothing to copy in that assistant response.")
|
||||
return
|
||||
|
||||
try:
|
||||
self._write_osc52_clipboard(text)
|
||||
_cprint(f" Copied assistant response #{idx + 1} to clipboard")
|
||||
except Exception as e:
|
||||
_cprint(f" Clipboard copy failed: {e}")
|
||||
|
||||
def _handle_image_command(self, cmd_original: str):
|
||||
"""Handle /image <path> — attach a local image file for the next prompt."""
|
||||
raw_args = (cmd_original.split(None, 1)[1].strip() if " " in cmd_original else "")
|
||||
@@ -3671,7 +3775,7 @@ class HermesCLI:
|
||||
skin = get_active_skin()
|
||||
separator_color = skin.get_color("banner_dim", "#B8860B")
|
||||
accent_color = skin.get_color("ui_accent", "#FFBF00")
|
||||
label_color = skin.get_color("ui_label", "#4dd0e1")
|
||||
label_color = skin.get_color("ui_label", "#DAA520")
|
||||
except Exception:
|
||||
separator_color, accent_color, label_color = "#B8860B", "#FFBF00", "cyan"
|
||||
toolsets_info = ""
|
||||
@@ -4514,6 +4618,34 @@ class HermesCLI:
|
||||
self._restore_modal_input_snapshot()
|
||||
self._invalidate(min_interval=0.0)
|
||||
|
||||
@staticmethod
|
||||
def _compute_model_picker_viewport(
|
||||
selected: int,
|
||||
scroll_offset: int,
|
||||
n: int,
|
||||
term_rows: int,
|
||||
reserved_below: int = 6,
|
||||
panel_chrome: int = 6,
|
||||
min_visible: int = 3,
|
||||
) -> tuple[int, int]:
|
||||
"""Resolve (scroll_offset, visible) for the /model picker viewport.
|
||||
|
||||
``reserved_below`` matches the approval / clarify panels — input area,
|
||||
status bar, and separators below the panel. ``panel_chrome`` covers
|
||||
this panel's own borders + blanks + hint row. The remaining rows hold
|
||||
the scrollable list, with the offset slid to keep ``selected`` on screen.
|
||||
"""
|
||||
max_visible = max(min_visible, term_rows - reserved_below - panel_chrome)
|
||||
if n <= max_visible:
|
||||
return 0, n
|
||||
visible = max_visible
|
||||
if selected < scroll_offset:
|
||||
scroll_offset = selected
|
||||
elif selected >= scroll_offset + visible:
|
||||
scroll_offset = selected - visible + 1
|
||||
scroll_offset = max(0, min(scroll_offset, n - visible))
|
||||
return scroll_offset, visible
|
||||
|
||||
def _apply_model_switch_result(self, result, persist_global: bool) -> None:
|
||||
if not result.success:
|
||||
_cprint(f" ✗ {result.error_message}")
|
||||
@@ -5525,6 +5657,8 @@ class HermesCLI:
|
||||
self._show_usage()
|
||||
elif canonical == "insights":
|
||||
self._show_insights(cmd_original)
|
||||
elif canonical == "copy":
|
||||
self._handle_copy_command(cmd_original)
|
||||
elif canonical == "debug":
|
||||
self._handle_debug_command()
|
||||
elif canonical == "paste":
|
||||
@@ -5568,6 +5702,8 @@ class HermesCLI:
|
||||
self._handle_snapshot_command(cmd_original)
|
||||
elif canonical == "stop":
|
||||
self._handle_stop_command()
|
||||
elif canonical == "agents":
|
||||
self._handle_agents_command()
|
||||
elif canonical == "background":
|
||||
self._handle_background_command(cmd_original)
|
||||
elif canonical == "btw":
|
||||
@@ -5584,6 +5720,30 @@ class HermesCLI:
|
||||
_cprint(f" Queued for the next turn: {payload[:80]}{'...' if len(payload) > 80 else ''}")
|
||||
else:
|
||||
_cprint(f" Queued: {payload[:80]}{'...' if len(payload) > 80 else ''}")
|
||||
elif canonical == "steer":
|
||||
# Inject a message after the next tool call without interrupting.
|
||||
# If the agent is actively running, push the text into the agent's
|
||||
# pending_steer slot — the drain hook in _execute_tool_calls_*
|
||||
# will append it to the next tool result's content. If no agent
|
||||
# is running, fall back to queue semantics (same as /queue).
|
||||
parts = cmd_original.split(None, 1)
|
||||
payload = parts[1].strip() if len(parts) > 1 else ""
|
||||
if not payload:
|
||||
_cprint(" Usage: /steer <prompt>")
|
||||
elif self._agent_running and self.agent is not None and hasattr(self.agent, "steer"):
|
||||
try:
|
||||
accepted = self.agent.steer(payload)
|
||||
except Exception as exc:
|
||||
_cprint(f" Steer failed: {exc}")
|
||||
else:
|
||||
if accepted:
|
||||
_cprint(f" ⏩ Steer queued — arrives after the next tool call: {payload[:80]}{'...' if len(payload) > 80 else ''}")
|
||||
else:
|
||||
_cprint(" Steer rejected (empty payload).")
|
||||
else:
|
||||
# No active run — treat as a normal next-turn message.
|
||||
self._pending_input.put(payload)
|
||||
_cprint(f" No agent running; queued as next turn: {payload[:80]}{'...' if len(payload) > 80 else ''}")
|
||||
elif canonical == "skin":
|
||||
self._handle_skin_command(cmd_original)
|
||||
elif canonical == "voice":
|
||||
@@ -6881,8 +7041,7 @@ class HermesCLI:
|
||||
)
|
||||
raise RuntimeError(
|
||||
"Voice mode requires sounddevice and numpy.\n"
|
||||
"Install with: pip install sounddevice numpy\n"
|
||||
"Or: pip install hermes-agent[voice]"
|
||||
f"Install with: {sys.executable} -m pip install sounddevice numpy"
|
||||
)
|
||||
if not reqs.get("stt_available", reqs.get("stt_key_set")):
|
||||
raise RuntimeError(
|
||||
@@ -7158,8 +7317,7 @@ class HermesCLI:
|
||||
_cprint(f" {_DIM}Then install/update the Termux:API Android app for microphone capture{_RST}")
|
||||
_cprint(f" {_BOLD}Option 2: pkg install python-numpy portaudio && python -m pip install sounddevice{_RST}")
|
||||
else:
|
||||
_cprint(f"\n {_BOLD}Install: pip install {' '.join(reqs['missing_packages'])}{_RST}")
|
||||
_cprint(f" {_DIM}Or: pip install hermes-agent[voice]{_RST}")
|
||||
_cprint(f"\n {_BOLD}Install: {sys.executable} -m pip install {' '.join(reqs['missing_packages'])}{_RST}")
|
||||
return
|
||||
|
||||
with self._voice_lock:
|
||||
@@ -8110,7 +8268,15 @@ class HermesCLI:
|
||||
else:
|
||||
print(f"\n⚡ Sending after interrupt: '{preview}'")
|
||||
self._pending_input.put(combined)
|
||||
|
||||
|
||||
# If a /steer was left over (agent finished before another tool
|
||||
# batch could absorb it), deliver it as the next user turn.
|
||||
_leftover_steer = result.get("pending_steer") if result else None
|
||||
if _leftover_steer and hasattr(self, '_pending_input'):
|
||||
preview = _leftover_steer[:60] + ("..." if len(_leftover_steer) > 60 else "")
|
||||
print(f"\n⏩ Delivering leftover /steer as next turn: '{preview}'")
|
||||
self._pending_input.put(_leftover_steer)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
@@ -8528,6 +8694,7 @@ class HermesCLI:
|
||||
# --- /model picker modal ---
|
||||
if self._model_picker_state:
|
||||
self._handle_model_picker_selection()
|
||||
event.app.current_buffer.reset()
|
||||
event.app.invalidate()
|
||||
return
|
||||
|
||||
@@ -8693,6 +8860,13 @@ class HermesCLI:
|
||||
state["selected"] = min(max_idx, state.get("selected", 0) + 1)
|
||||
event.app.invalidate()
|
||||
|
||||
@kb.add('escape', filter=Condition(lambda: bool(self._model_picker_state)), eager=True)
|
||||
def model_picker_escape(event):
|
||||
"""ESC closes the /model picker."""
|
||||
self._close_model_picker()
|
||||
event.app.current_buffer.reset()
|
||||
event.app.invalidate()
|
||||
|
||||
# --- History navigation: up/down browse history in normal input mode ---
|
||||
# The TextArea is multiline, so by default up/down only move the cursor.
|
||||
# Buffer.auto_up/auto_down handle both: cursor movement when multi-line,
|
||||
@@ -9494,6 +9668,22 @@ class HermesCLI:
|
||||
|
||||
box_width = _panel_box_width(title, [hint] + choices, min_width=46, max_width=84)
|
||||
inner_text_width = max(8, box_width - 6)
|
||||
selected = state.get("selected", 0)
|
||||
|
||||
# Scrolling viewport: the panel renders into a Window with no max
|
||||
# height, so without limiting visible items the bottom border and
|
||||
# any items past the available terminal rows get clipped on long
|
||||
# provider catalogs (e.g. Ollama Cloud's 36+ models).
|
||||
try:
|
||||
from prompt_toolkit.application import get_app
|
||||
term_rows = get_app().output.get_size().rows
|
||||
except Exception:
|
||||
term_rows = shutil.get_terminal_size((100, 24)).lines
|
||||
scroll_offset, visible = HermesCLI._compute_model_picker_viewport(
|
||||
selected, state.get("_scroll_offset", 0), len(choices), term_rows,
|
||||
)
|
||||
state["_scroll_offset"] = scroll_offset
|
||||
|
||||
lines = []
|
||||
lines.append(('class:clarify-border', '╭─ '))
|
||||
lines.append(('class:clarify-title', title))
|
||||
@@ -9501,8 +9691,8 @@ class HermesCLI:
|
||||
_append_blank_panel_line(lines, 'class:clarify-border', box_width)
|
||||
_append_panel_line(lines, 'class:clarify-border', 'class:clarify-hint', hint, box_width)
|
||||
_append_blank_panel_line(lines, 'class:clarify-border', box_width)
|
||||
selected = state.get("selected", 0)
|
||||
for idx, choice in enumerate(choices):
|
||||
for idx in range(scroll_offset, scroll_offset + visible):
|
||||
choice = choices[idx]
|
||||
style = 'class:clarify-selected' if idx == selected else 'class:clarify-choice'
|
||||
prefix = '❯ ' if idx == selected else ' '
|
||||
for wrapped in _wrap_panel_text(prefix + choice, inner_text_width, subsequent_indent=' '):
|
||||
@@ -9907,8 +10097,36 @@ class HermesCLI:
|
||||
|
||||
# Register signal handlers for graceful shutdown on SSH disconnect / SIGTERM
|
||||
def _signal_handler(signum, frame):
|
||||
"""Handle SIGHUP/SIGTERM by triggering graceful cleanup."""
|
||||
"""Handle SIGHUP/SIGTERM by triggering graceful cleanup.
|
||||
|
||||
Calls ``self.agent.interrupt()`` first so the agent daemon
|
||||
thread's poll loop sees the per-thread interrupt and kills the
|
||||
tool's subprocess group via ``_kill_process`` (os.killpg).
|
||||
Without this, the main thread dies from KeyboardInterrupt and
|
||||
the daemon thread is killed with it — before it can run one
|
||||
more poll iteration to clean up the subprocess, which was
|
||||
spawned with ``os.setsid`` and therefore survives as an orphan
|
||||
with PPID=1.
|
||||
|
||||
Grace window (``HERMES_SIGTERM_GRACE``, default 1.5 s) gives
|
||||
the daemon time to: detect the interrupt (next 200 ms poll) →
|
||||
call _kill_process (SIGTERM + 1 s wait + SIGKILL if needed) →
|
||||
return from _wait_for_process. ``time.sleep`` releases the
|
||||
GIL so the daemon actually runs during the window.
|
||||
"""
|
||||
logger.debug("Received signal %s, triggering graceful shutdown", signum)
|
||||
try:
|
||||
if getattr(self, "agent", None) and getattr(self, "_agent_running", False):
|
||||
self.agent.interrupt(f"received signal {signum}")
|
||||
import time as _t
|
||||
try:
|
||||
_grace = float(os.getenv("HERMES_SIGTERM_GRACE", "1.5"))
|
||||
except (TypeError, ValueError):
|
||||
_grace = 1.5
|
||||
if _grace > 0:
|
||||
_t.sleep(_grace)
|
||||
except Exception:
|
||||
pass # never block signal handling
|
||||
raise KeyboardInterrupt()
|
||||
|
||||
try:
|
||||
@@ -10211,6 +10429,45 @@ def main(
|
||||
|
||||
# Register cleanup for single-query mode (interactive mode registers in run())
|
||||
atexit.register(_run_cleanup)
|
||||
|
||||
# Also install signal handlers in single-query / `-q` mode. Interactive
|
||||
# mode registers its own inside HermesCLI.run(), but `-q` runs
|
||||
# cli.agent.run_conversation() below and AIAgent spawns worker threads
|
||||
# for tools — so when SIGTERM arrives on the main thread, raising
|
||||
# KeyboardInterrupt only unwinds the main thread, not the worker
|
||||
# running _wait_for_process. Python then exits, the child subprocess
|
||||
# (spawned with os.setsid, its own process group) is reparented to
|
||||
# init and keeps running as an orphan.
|
||||
#
|
||||
# Fix: route SIGTERM/SIGHUP through agent.interrupt() which sets the
|
||||
# per-thread interrupt flag the worker's poll loop checks every 200 ms.
|
||||
# Give the worker a grace window to call _kill_process (SIGTERM to the
|
||||
# process group, then SIGKILL after 1 s), then raise KeyboardInterrupt
|
||||
# so main unwinds normally. HERMES_SIGTERM_GRACE overrides the 1.5 s
|
||||
# default for debugging.
|
||||
def _signal_handler_q(signum, frame):
|
||||
logger.debug("Received signal %s in single-query mode", signum)
|
||||
try:
|
||||
_agent = getattr(cli, "agent", None)
|
||||
if _agent is not None:
|
||||
_agent.interrupt(f"received signal {signum}")
|
||||
import time as _t
|
||||
try:
|
||||
_grace = float(os.getenv("HERMES_SIGTERM_GRACE", "1.5"))
|
||||
except (TypeError, ValueError):
|
||||
_grace = 1.5
|
||||
if _grace > 0:
|
||||
_t.sleep(_grace)
|
||||
except Exception:
|
||||
pass # never block signal handling
|
||||
raise KeyboardInterrupt()
|
||||
try:
|
||||
import signal as _signal
|
||||
_signal.signal(_signal.SIGTERM, _signal_handler_q)
|
||||
if hasattr(_signal, "SIGHUP"):
|
||||
_signal.signal(_signal.SIGHUP, _signal_handler_q)
|
||||
except Exception:
|
||||
pass # signal handler may fail in restricted environments
|
||||
|
||||
# Handle single query mode
|
||||
if query or image:
|
||||
|
||||
+183
-106
@@ -27,7 +27,7 @@ except ImportError:
|
||||
except ImportError:
|
||||
msvcrt = None
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
# Add parent directory to path for imports BEFORE repo-level imports.
|
||||
# Without this, standalone invocations (e.g. after `hermes update` reloads
|
||||
@@ -49,6 +49,33 @@ _KNOWN_DELIVERY_PLATFORMS = frozenset({
|
||||
"qqbot",
|
||||
})
|
||||
|
||||
# Platforms that support a configured cron/notification home target, mapped to
|
||||
# the environment variable used by gateway setup/runtime config.
|
||||
_HOME_TARGET_ENV_VARS = {
|
||||
"matrix": "MATRIX_HOME_ROOM",
|
||||
"telegram": "TELEGRAM_HOME_CHANNEL",
|
||||
"discord": "DISCORD_HOME_CHANNEL",
|
||||
"slack": "SLACK_HOME_CHANNEL",
|
||||
"signal": "SIGNAL_HOME_CHANNEL",
|
||||
"mattermost": "MATTERMOST_HOME_CHANNEL",
|
||||
"sms": "SMS_HOME_CHANNEL",
|
||||
"email": "EMAIL_HOME_ADDRESS",
|
||||
"dingtalk": "DINGTALK_HOME_CHANNEL",
|
||||
"feishu": "FEISHU_HOME_CHANNEL",
|
||||
"wecom": "WECOM_HOME_CHANNEL",
|
||||
"weixin": "WEIXIN_HOME_CHANNEL",
|
||||
"bluebubbles": "BLUEBUBBLES_HOME_CHANNEL",
|
||||
"qqbot": "QQBOT_HOME_CHANNEL",
|
||||
}
|
||||
|
||||
# Legacy env var names kept for back-compat. Each entry is the current
|
||||
# primary env var → the previous name. _get_home_target_chat_id falls
|
||||
# back to the legacy name if the primary is unset, so users who set the
|
||||
# old name before the rename keep working until they migrate.
|
||||
_LEGACY_HOME_TARGET_ENV_VARS = {
|
||||
"QQBOT_HOME_CHANNEL": "QQ_HOME_CHANNEL",
|
||||
}
|
||||
|
||||
from cron.jobs import get_due_jobs, mark_job_run, save_job_output, advance_next_run
|
||||
|
||||
# Sentinel: when a cron agent has nothing new to report, it can start its
|
||||
@@ -76,15 +103,28 @@ def _resolve_origin(job: dict) -> Optional[dict]:
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_delivery_target(job: dict) -> Optional[dict]:
|
||||
"""Resolve the concrete auto-delivery target for a cron job, if any."""
|
||||
deliver = job.get("deliver", "local")
|
||||
def _get_home_target_chat_id(platform_name: str) -> str:
|
||||
"""Return the configured home target chat/room ID for a delivery platform."""
|
||||
env_var = _HOME_TARGET_ENV_VARS.get(platform_name.lower())
|
||||
if not env_var:
|
||||
return ""
|
||||
value = os.getenv(env_var, "")
|
||||
if not value:
|
||||
legacy = _LEGACY_HOME_TARGET_ENV_VARS.get(env_var)
|
||||
if legacy:
|
||||
value = os.getenv(legacy, "")
|
||||
return value
|
||||
|
||||
|
||||
def _resolve_single_delivery_target(job: dict, deliver_value: str) -> Optional[dict]:
|
||||
"""Resolve one concrete auto-delivery target for a cron job."""
|
||||
|
||||
origin = _resolve_origin(job)
|
||||
|
||||
if deliver == "local":
|
||||
if deliver_value == "local":
|
||||
return None
|
||||
|
||||
if deliver == "origin":
|
||||
if deliver_value == "origin":
|
||||
if origin:
|
||||
return {
|
||||
"platform": origin["platform"],
|
||||
@@ -93,8 +133,8 @@ def _resolve_delivery_target(job: dict) -> Optional[dict]:
|
||||
}
|
||||
# Origin missing (e.g. job created via API/script) — try each
|
||||
# platform's home channel as a fallback instead of silently dropping.
|
||||
for platform_name in ("matrix", "telegram", "discord", "slack", "bluebubbles"):
|
||||
chat_id = os.getenv(f"{platform_name.upper()}_HOME_CHANNEL", "")
|
||||
for platform_name in _HOME_TARGET_ENV_VARS:
|
||||
chat_id = _get_home_target_chat_id(platform_name)
|
||||
if chat_id:
|
||||
logger.info(
|
||||
"Job '%s' has deliver=origin but no origin; falling back to %s home channel",
|
||||
@@ -108,8 +148,8 @@ def _resolve_delivery_target(job: dict) -> Optional[dict]:
|
||||
}
|
||||
return None
|
||||
|
||||
if ":" in deliver:
|
||||
platform_name, rest = deliver.split(":", 1)
|
||||
if ":" in deliver_value:
|
||||
platform_name, rest = deliver_value.split(":", 1)
|
||||
platform_key = platform_name.lower()
|
||||
|
||||
from tools.send_message_tool import _parse_target_ref
|
||||
@@ -139,7 +179,7 @@ def _resolve_delivery_target(job: dict) -> Optional[dict]:
|
||||
"thread_id": thread_id,
|
||||
}
|
||||
|
||||
platform_name = deliver
|
||||
platform_name = deliver_value
|
||||
if origin and origin.get("platform") == platform_name:
|
||||
return {
|
||||
"platform": platform_name,
|
||||
@@ -149,7 +189,7 @@ def _resolve_delivery_target(job: dict) -> Optional[dict]:
|
||||
|
||||
if platform_name.lower() not in _KNOWN_DELIVERY_PLATFORMS:
|
||||
return None
|
||||
chat_id = os.getenv(f"{platform_name.upper()}_HOME_CHANNEL", "")
|
||||
chat_id = _get_home_target_chat_id(platform_name)
|
||||
if not chat_id:
|
||||
return None
|
||||
|
||||
@@ -160,6 +200,30 @@ def _resolve_delivery_target(job: dict) -> Optional[dict]:
|
||||
}
|
||||
|
||||
|
||||
def _resolve_delivery_targets(job: dict) -> List[dict]:
|
||||
"""Resolve all concrete auto-delivery targets for a cron job (supports comma-separated deliver)."""
|
||||
deliver = job.get("deliver", "local")
|
||||
if deliver == "local":
|
||||
return []
|
||||
parts = [p.strip() for p in str(deliver).split(",") if p.strip()]
|
||||
seen = set()
|
||||
targets = []
|
||||
for part in parts:
|
||||
target = _resolve_single_delivery_target(job, part)
|
||||
if target:
|
||||
key = (target["platform"].lower(), str(target["chat_id"]), target.get("thread_id"))
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
targets.append(target)
|
||||
return targets
|
||||
|
||||
|
||||
def _resolve_delivery_target(job: dict) -> Optional[dict]:
|
||||
"""Resolve the concrete auto-delivery target for a cron job, if any."""
|
||||
targets = _resolve_delivery_targets(job)
|
||||
return targets[0] if targets else None
|
||||
|
||||
|
||||
# Media extension sets — keep in sync with gateway/platforms/base.py:_process_message_background
|
||||
_AUDIO_EXTS = frozenset({'.ogg', '.opus', '.mp3', '.wav', '.m4a'})
|
||||
_VIDEO_EXTS = frozenset({'.mp4', '.mov', '.avi', '.mkv', '.webm', '.3gp'})
|
||||
@@ -200,7 +264,7 @@ def _send_media_via_adapter(adapter, chat_id: str, media_files: list, metadata:
|
||||
|
||||
def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> Optional[str]:
|
||||
"""
|
||||
Deliver job output to the configured target (origin chat, specific platform, etc.).
|
||||
Deliver job output to the configured target(s) (origin chat, specific platform, etc.).
|
||||
|
||||
When ``adapters`` and ``loop`` are provided (gateway is running), tries to
|
||||
use the live adapter first — this supports E2EE rooms (e.g. Matrix) where
|
||||
@@ -209,33 +273,14 @@ def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> Option
|
||||
|
||||
Returns None on success, or an error string on failure.
|
||||
"""
|
||||
target = _resolve_delivery_target(job)
|
||||
if not target:
|
||||
targets = _resolve_delivery_targets(job)
|
||||
if not targets:
|
||||
if job.get("deliver", "local") != "local":
|
||||
msg = f"no delivery target resolved for deliver={job.get('deliver', 'local')}"
|
||||
logger.warning("Job '%s': %s", job["id"], msg)
|
||||
return msg
|
||||
return None # local-only jobs don't deliver — not a failure
|
||||
|
||||
platform_name = target["platform"]
|
||||
chat_id = target["chat_id"]
|
||||
thread_id = target.get("thread_id")
|
||||
|
||||
# Diagnostic: log thread_id for topic-aware delivery debugging
|
||||
origin = job.get("origin") or {}
|
||||
origin_thread = origin.get("thread_id")
|
||||
if origin_thread and not thread_id:
|
||||
logger.warning(
|
||||
"Job '%s': origin has thread_id=%s but delivery target lost it "
|
||||
"(deliver=%s, target=%s)",
|
||||
job["id"], origin_thread, job.get("deliver", "local"), target,
|
||||
)
|
||||
elif thread_id:
|
||||
logger.debug(
|
||||
"Job '%s': delivering to %s:%s thread_id=%s",
|
||||
job["id"], platform_name, chat_id, thread_id,
|
||||
)
|
||||
|
||||
from tools.send_message_tool import _send_to_platform
|
||||
from gateway.config import load_gateway_config, Platform
|
||||
|
||||
@@ -258,24 +303,6 @@ def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> Option
|
||||
"bluebubbles": Platform.BLUEBUBBLES,
|
||||
"qqbot": Platform.QQBOT,
|
||||
}
|
||||
platform = platform_map.get(platform_name.lower())
|
||||
if not platform:
|
||||
msg = f"unknown platform '{platform_name}'"
|
||||
logger.warning("Job '%s': %s", job["id"], msg)
|
||||
return msg
|
||||
|
||||
try:
|
||||
config = load_gateway_config()
|
||||
except Exception as e:
|
||||
msg = f"failed to load gateway config: {e}"
|
||||
logger.error("Job '%s': %s", job["id"], msg)
|
||||
return msg
|
||||
|
||||
pconfig = config.platforms.get(platform)
|
||||
if not pconfig or not pconfig.enabled:
|
||||
msg = f"platform '{platform_name}' not configured/enabled"
|
||||
logger.warning("Job '%s': %s", job["id"], msg)
|
||||
return msg
|
||||
|
||||
# Optionally wrap the content with a header/footer so the user knows this
|
||||
# is a cron delivery. Wrapping is on by default; set cron.wrap_response: false
|
||||
@@ -304,67 +331,117 @@ def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> Option
|
||||
from gateway.platforms.base import BasePlatformAdapter
|
||||
media_files, cleaned_delivery_content = BasePlatformAdapter.extract_media(delivery_content)
|
||||
|
||||
# Prefer the live adapter when the gateway is running — this supports E2EE
|
||||
# rooms (e.g. Matrix) where the standalone HTTP path cannot encrypt.
|
||||
runtime_adapter = (adapters or {}).get(platform)
|
||||
if runtime_adapter is not None and loop is not None and getattr(loop, "is_running", lambda: False)():
|
||||
send_metadata = {"thread_id": thread_id} if thread_id else None
|
||||
try:
|
||||
# Send cleaned text (MEDIA tags stripped) — not the raw content
|
||||
text_to_send = cleaned_delivery_content.strip()
|
||||
adapter_ok = True
|
||||
if text_to_send:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
runtime_adapter.send(chat_id, text_to_send, metadata=send_metadata),
|
||||
loop,
|
||||
)
|
||||
send_result = future.result(timeout=60)
|
||||
if send_result and not getattr(send_result, "success", True):
|
||||
err = getattr(send_result, "error", "unknown")
|
||||
logger.warning(
|
||||
"Job '%s': live adapter send to %s:%s failed (%s), falling back to standalone",
|
||||
job["id"], platform_name, chat_id, err,
|
||||
)
|
||||
adapter_ok = False # fall through to standalone path
|
||||
try:
|
||||
config = load_gateway_config()
|
||||
except Exception as e:
|
||||
msg = f"failed to load gateway config: {e}"
|
||||
logger.error("Job '%s': %s", job["id"], msg)
|
||||
return msg
|
||||
|
||||
# Send extracted media files as native attachments via the live adapter
|
||||
if adapter_ok and media_files:
|
||||
_send_media_via_adapter(runtime_adapter, chat_id, media_files, send_metadata, loop, job)
|
||||
delivery_errors = []
|
||||
|
||||
if adapter_ok:
|
||||
logger.info("Job '%s': delivered to %s:%s via live adapter", job["id"], platform_name, chat_id)
|
||||
return None
|
||||
except Exception as e:
|
||||
for target in targets:
|
||||
platform_name = target["platform"]
|
||||
chat_id = target["chat_id"]
|
||||
thread_id = target.get("thread_id")
|
||||
|
||||
# Diagnostic: log thread_id for topic-aware delivery debugging
|
||||
origin = job.get("origin") or {}
|
||||
origin_thread = origin.get("thread_id")
|
||||
if origin_thread and not thread_id:
|
||||
logger.warning(
|
||||
"Job '%s': live adapter delivery to %s:%s failed (%s), falling back to standalone",
|
||||
job["id"], platform_name, chat_id, e,
|
||||
"Job '%s': origin has thread_id=%s but delivery target lost it "
|
||||
"(deliver=%s, target=%s)",
|
||||
job["id"], origin_thread, job.get("deliver", "local"), target,
|
||||
)
|
||||
elif thread_id:
|
||||
logger.debug(
|
||||
"Job '%s': delivering to %s:%s thread_id=%s",
|
||||
job["id"], platform_name, chat_id, thread_id,
|
||||
)
|
||||
|
||||
# Standalone path: run the async send in a fresh event loop (safe from any thread)
|
||||
coro = _send_to_platform(platform, pconfig, chat_id, cleaned_delivery_content, thread_id=thread_id, media_files=media_files)
|
||||
try:
|
||||
result = asyncio.run(coro)
|
||||
except RuntimeError:
|
||||
# asyncio.run() checks for a running loop before awaiting the coroutine;
|
||||
# when it raises, the original coro was never started — close it to
|
||||
# prevent "coroutine was never awaited" RuntimeWarning, then retry in a
|
||||
# fresh thread that has no running loop.
|
||||
coro.close()
|
||||
import concurrent.futures
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
future = pool.submit(asyncio.run, _send_to_platform(platform, pconfig, chat_id, cleaned_delivery_content, thread_id=thread_id, media_files=media_files))
|
||||
result = future.result(timeout=30)
|
||||
except Exception as e:
|
||||
msg = f"delivery to {platform_name}:{chat_id} failed: {e}"
|
||||
logger.error("Job '%s': %s", job["id"], msg)
|
||||
return msg
|
||||
platform = platform_map.get(platform_name.lower())
|
||||
if not platform:
|
||||
msg = f"unknown platform '{platform_name}'"
|
||||
logger.warning("Job '%s': %s", job["id"], msg)
|
||||
delivery_errors.append(msg)
|
||||
continue
|
||||
|
||||
if result and result.get("error"):
|
||||
msg = f"delivery error: {result['error']}"
|
||||
logger.error("Job '%s': %s", job["id"], msg)
|
||||
return msg
|
||||
# Prefer the live adapter when the gateway is running — this supports E2EE
|
||||
# rooms (e.g. Matrix) where the standalone HTTP path cannot encrypt.
|
||||
runtime_adapter = (adapters or {}).get(platform)
|
||||
delivered = False
|
||||
if runtime_adapter is not None and loop is not None and getattr(loop, "is_running", lambda: False)():
|
||||
send_metadata = {"thread_id": thread_id} if thread_id else None
|
||||
try:
|
||||
# Send cleaned text (MEDIA tags stripped) — not the raw content
|
||||
text_to_send = cleaned_delivery_content.strip()
|
||||
adapter_ok = True
|
||||
if text_to_send:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
runtime_adapter.send(chat_id, text_to_send, metadata=send_metadata),
|
||||
loop,
|
||||
)
|
||||
send_result = future.result(timeout=60)
|
||||
if send_result and not getattr(send_result, "success", True):
|
||||
err = getattr(send_result, "error", "unknown")
|
||||
logger.warning(
|
||||
"Job '%s': live adapter send to %s:%s failed (%s), falling back to standalone",
|
||||
job["id"], platform_name, chat_id, err,
|
||||
)
|
||||
adapter_ok = False # fall through to standalone path
|
||||
|
||||
logger.info("Job '%s': delivered to %s:%s", job["id"], platform_name, chat_id)
|
||||
# Send extracted media files as native attachments via the live adapter
|
||||
if adapter_ok and media_files:
|
||||
_send_media_via_adapter(runtime_adapter, chat_id, media_files, send_metadata, loop, job)
|
||||
|
||||
if adapter_ok:
|
||||
logger.info("Job '%s': delivered to %s:%s via live adapter", job["id"], platform_name, chat_id)
|
||||
delivered = True
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Job '%s': live adapter delivery to %s:%s failed (%s), falling back to standalone",
|
||||
job["id"], platform_name, chat_id, e,
|
||||
)
|
||||
|
||||
if not delivered:
|
||||
pconfig = config.platforms.get(platform)
|
||||
if not pconfig or not pconfig.enabled:
|
||||
msg = f"platform '{platform_name}' not configured/enabled"
|
||||
logger.warning("Job '%s': %s", job["id"], msg)
|
||||
delivery_errors.append(msg)
|
||||
continue
|
||||
|
||||
# Standalone path: run the async send in a fresh event loop (safe from any thread)
|
||||
coro = _send_to_platform(platform, pconfig, chat_id, cleaned_delivery_content, thread_id=thread_id, media_files=media_files)
|
||||
try:
|
||||
result = asyncio.run(coro)
|
||||
except RuntimeError:
|
||||
# asyncio.run() checks for a running loop before awaiting the coroutine;
|
||||
# when it raises, the original coro was never started — close it to
|
||||
# prevent "coroutine was never awaited" RuntimeWarning, then retry in a
|
||||
# fresh thread that has no running loop.
|
||||
coro.close()
|
||||
import concurrent.futures
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
future = pool.submit(asyncio.run, _send_to_platform(platform, pconfig, chat_id, cleaned_delivery_content, thread_id=thread_id, media_files=media_files))
|
||||
result = future.result(timeout=30)
|
||||
except Exception as e:
|
||||
msg = f"delivery to {platform_name}:{chat_id} failed: {e}"
|
||||
logger.error("Job '%s': %s", job["id"], msg)
|
||||
delivery_errors.append(msg)
|
||||
continue
|
||||
|
||||
if result and result.get("error"):
|
||||
msg = f"delivery error: {result['error']}"
|
||||
logger.error("Job '%s': %s", job["id"], msg)
|
||||
delivery_errors.append(msg)
|
||||
continue
|
||||
|
||||
logger.info("Job '%s': delivered to %s:%s", job["id"], platform_name, chat_id)
|
||||
|
||||
if delivery_errors:
|
||||
return "; ".join(delivery_errors)
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
# Ink Gateway TUI Migration — Post-mortem
|
||||
|
||||
Planned: 2026-04-01 · Delivered: 2026-04 · Status: shipped, classic (prompt_toolkit) CLI still present
|
||||
|
||||
## What Shipped
|
||||
|
||||
Three layers, same repo, Python runtime unchanged.
|
||||
|
||||
```
|
||||
ui-tui (Node/TS) ──stdio JSON-RPC──▶ tui_gateway (Py) ──▶ AIAgent (run_agent.py)
|
||||
```
|
||||
|
||||
### Backend — `tui_gateway/`
|
||||
|
||||
```
|
||||
tui_gateway/
|
||||
├── entry.py # subprocess entrypoint, stdio read/write loop
|
||||
├── server.py # everything: sessions dict, @method handlers, _emit
|
||||
├── render.py # stream renderer, diff rendering, message rendering
|
||||
├── slash_worker.py # subprocess that runs hermes_cli slash commands
|
||||
└── __init__.py
|
||||
```
|
||||
|
||||
`server.py` owns the full runtime-control surface: session store (`_sessions: dict[str, dict]`), method registry (`@method("…")` decorator), event emitter (`_emit`), agent lifecycle (`_make_agent`, `_init_session`, `_wire_callbacks`), approval/sudo/clarify round-trips, and JSON-RPC dispatch.
|
||||
|
||||
Protocol methods (`@method(...)` in `server.py`):
|
||||
|
||||
- session: `session.{create, resume, list, close, interrupt, usage, history, compress, branch, title, save, undo}`
|
||||
- prompt: `prompt.{submit, background, btw}`
|
||||
- tools: `tools.{list, show, configure}`
|
||||
- slash: `slash.exec`, `command.{dispatch, resolve}`, `commands.catalog`, `complete.{path, slash}`
|
||||
- approvals: `approval.respond`, `sudo.respond`, `clarify.respond`, `secret.respond`
|
||||
- config/state: `config.{get, set, show}`, `model.options`, `reload.mcp`
|
||||
- ops: `shell.exec`, `cli.exec`, `terminal.resize`, `input.detect_drop`, `clipboard.paste`, `paste.collapse`, `image.attach`, `process.stop`
|
||||
- misc: `agents.list`, `skills.manage`, `plugins.list`, `cron.manage`, `insights.get`, `rollback.{list, diff, restore}`, `browser.manage`
|
||||
|
||||
Protocol events (`_emit(…)` → handled in `ui-tui/src/app/createGatewayEventHandler.ts`):
|
||||
|
||||
- lifecycle: `gateway.{ready, stderr}`, `session.info`, `skin.changed`
|
||||
- stream: `message.{start, delta, complete}`, `thinking.delta`, `reasoning.{delta, available}`, `status.update`
|
||||
- tools: `tool.{start, progress, complete, generating}`, `subagent.{start, thinking, tool, progress, complete}`
|
||||
- interactive: `approval.request`, `sudo.request`, `clarify.request`, `secret.request`
|
||||
- async: `background.complete`, `btw.complete`, `error`
|
||||
|
||||
### Frontend — `ui-tui/src/`
|
||||
|
||||
```
|
||||
src/
|
||||
├── entry.tsx # node bootstrap: bootBanner → spawn python → dynamic-import Ink → render(<App/>)
|
||||
├── app.tsx # <GatewayProvider> wraps <AppLayout>
|
||||
├── bootBanner.ts # raw-ANSI banner to stdout in ~2ms, pre-React
|
||||
├── gatewayClient.ts # JSON-RPC client over child_process stdio
|
||||
├── gatewayTypes.ts # typed RPC responses + GatewayEvent union
|
||||
├── theme.ts # DEFAULT_THEME + fromSkin
|
||||
│
|
||||
├── app/ # hooks + stores — the orchestration layer
|
||||
│ ├── uiStore.ts # nanostore: sid, info, busy, usage, theme, status…
|
||||
│ ├── turnStore.ts # nanostore: per-turn activity / reasoning / tools
|
||||
│ ├── turnController.ts # imperative singleton for stream-time operations
|
||||
│ ├── overlayStore.ts # nanostore: modal/overlay state
|
||||
│ ├── useMainApp.ts # top-level composition hook
|
||||
│ ├── useSessionLifecycle.ts # session.create/resume/close/reset
|
||||
│ ├── useSubmission.ts # shell/slash/prompt dispatch + interpolation
|
||||
│ ├── useConfigSync.ts # config.get + mtime poll
|
||||
│ ├── useComposerState.ts # input buffer, paste snippets, editor mode
|
||||
│ ├── useInputHandlers.ts # key bindings
|
||||
│ ├── createGatewayEventHandler.ts # event-stream dispatcher
|
||||
│ ├── createSlashHandler.ts # slash command router (registry + python fallback)
|
||||
│ └── slash/commands/ # core.ts, ops.ts, session.ts — TS-owned slash commands
|
||||
│
|
||||
├── components/ # AppLayout, AppChrome, AppOverlays, MessageLine, Thinking, Markdown, pickers, prompts, Banner, SessionPanel
|
||||
├── config/ # env, limits, timing constants
|
||||
├── content/ # charms, faces, fortunes, hotkeys, placeholders, verbs
|
||||
├── domain/ # details, messages, paths, roles, slash, usage, viewport
|
||||
├── protocol/ # interpolation, paste regex
|
||||
├── hooks/ # useCompletion, useInputHistory, useQueue, useVirtualHistory
|
||||
└── lib/ # history, messages, osc52, rpc, text
|
||||
```
|
||||
|
||||
### CLI entry points — `hermes_cli/main.py`
|
||||
|
||||
- `hermes --tui` → `node dist/entry.js` (auto-builds when `.ts`/`.tsx` newer than `dist/entry.js`)
|
||||
- `hermes --tui --dev` → `tsx src/entry.tsx` (skip build)
|
||||
- `HERMES_TUI_DIR=…` → external prebuilt dist (nix, distro packaging)
|
||||
|
||||
## Diverged From Original Plan
|
||||
|
||||
| Plan | Reality | Why |
|
||||
|---|---|---|
|
||||
| `tui_gateway/{controller,session_state,events,protocol}.py` | all collapsed into `server.py` | no second consumer ever emerged, keeping one file cheaper than four |
|
||||
| `ui-tui/src/main.tsx` | split into `entry.tsx` (bootstrap) + `app.tsx` (shell) | boot banner + early python spawn wanted a pre-React moment |
|
||||
| `ui-tui/src/state/store.ts` | three nanostores (`uiStore`, `turnStore`, `overlayStore`) | separate lifetimes: ui persists, turn resets per reply, overlay is modal |
|
||||
| `approval.requested` / `sudo.requested` / `clarify.requested` | `*.request` (no `-ed`) | cosmetic |
|
||||
| `session.cancel` | dropped | `session.interrupt` covers it |
|
||||
| `HERMES_EXPERIMENTAL_TUI=1`, `display.experimental_tui: true`, `/tui on/off/status` | none shipped | `--tui` went from opt-in to first-class without an experimental phase |
|
||||
|
||||
## Post-migration Additions (not in original plan)
|
||||
|
||||
- **Async `session.create`** — returns sid in ~1ms, agent builds on a background thread, `session.info` broadcasts when ready; `_wait_agent()` gates every agent-touching handler via `_sess`
|
||||
- **`bootBanner`** — raw-ANSI logo painted to stdout at T≈2ms, before Ink loads; `<AlternateScreen>` wipes it seamlessly when React mounts
|
||||
- **Selection uniform bg** — `theme.color.selectionBg` wired via `useSelection().setSelectionBgColor`; replaces SGR-inverse per-cell swap that fragmented over amber/gold fg
|
||||
- **Slash command registry** — TS-owned commands in `app/slash/commands/{core,ops,session}.ts`, everything else falls through to `slash.exec` (python worker)
|
||||
- **Turn store + controller split** — imperative singleton (`turnController`) holds refs/timers, nanostore (`turnStore`) holds render-visible state
|
||||
|
||||
## What's Still Open
|
||||
|
||||
- **Classic CLI not deleted.** `cli.py` still has ~80 `prompt_toolkit` references; classic REPL is still the default when `--tui` is absent. The original plan's "Cut 4 · prompt_toolkit removal later" hasn't happened.
|
||||
- **No config-file opt-in.** `HERMES_EXPERIMENTAL_TUI` and `display.experimental_tui` were never built; only the CLI flag exists. Fine for now — if we want "default to TUI", a single line in `main.py` flips it.
|
||||
@@ -6,6 +6,11 @@
|
||||
# All fields are optional — missing values inherit from the default skin.
|
||||
# Activate with: /skin <name> or display.skin: <name> in config.yaml
|
||||
#
|
||||
# Keys are marked:
|
||||
# (both) — applies to both the classic CLI and the TUI
|
||||
# (classic) — classic CLI only (see hermes --tui in user-guide/tui.md)
|
||||
# (tui) — TUI only
|
||||
#
|
||||
# See hermes_cli/skin_engine.py for the full schema reference.
|
||||
# ============================================================================
|
||||
|
||||
@@ -14,43 +19,47 @@ name: example
|
||||
description: An example custom skin — copy and modify this template
|
||||
|
||||
# ── Colors ──────────────────────────────────────────────────────────────────
|
||||
# Hex color values for Rich markup. These control the CLI's visual palette.
|
||||
# Hex color values. These control the visual palette.
|
||||
colors:
|
||||
# Banner panel (the startup welcome box)
|
||||
# Banner panel (the startup welcome box) — (both)
|
||||
banner_border: "#CD7F32" # Panel border
|
||||
banner_title: "#FFD700" # Panel title text
|
||||
banner_accent: "#FFBF00" # Section headers (Available Tools, Skills, etc.)
|
||||
banner_dim: "#B8860B" # Dim/muted text (separators, model info)
|
||||
banner_text: "#FFF8DC" # Body text (tool names, skill names)
|
||||
|
||||
# UI elements
|
||||
ui_accent: "#FFBF00" # General accent color
|
||||
# UI elements — (both)
|
||||
ui_accent: "#FFBF00" # General accent (falls back to banner_accent)
|
||||
ui_label: "#4dd0e1" # Labels
|
||||
ui_ok: "#4caf50" # Success indicators
|
||||
ui_error: "#ef5350" # Error indicators
|
||||
ui_warn: "#ffa726" # Warning indicators
|
||||
|
||||
# Input area
|
||||
prompt: "#FFF8DC" # Prompt text color
|
||||
input_rule: "#CD7F32" # Horizontal rule around input
|
||||
prompt: "#FFF8DC" # Prompt text / `❯` glyph color (both)
|
||||
input_rule: "#CD7F32" # Horizontal rule above input (classic)
|
||||
|
||||
# Response box
|
||||
response_border: "#FFD700" # Response box border (ANSI color)
|
||||
# Response box — (classic)
|
||||
response_border: "#FFD700" # Response box border
|
||||
|
||||
# Session display
|
||||
session_label: "#DAA520" # Session label
|
||||
session_border: "#8B8682" # Session ID dim color
|
||||
# Session display — (both)
|
||||
session_label: "#DAA520" # "Session: " label
|
||||
session_border: "#8B8682" # Session ID text
|
||||
|
||||
# TUI surfaces
|
||||
status_bar_bg: "#1a1a2e" # Status / usage bar background
|
||||
voice_status_bg: "#1a1a2e" # Voice-mode badge background
|
||||
completion_menu_bg: "#1a1a2e" # Completion list background
|
||||
completion_menu_current_bg: "#333355" # Active completion row background
|
||||
completion_menu_meta_bg: "#1a1a2e" # Completion meta column background
|
||||
completion_menu_meta_current_bg: "#333355" # Active completion meta background
|
||||
# TUI / CLI surfaces — (classic: status bar, voice badge, completion meta)
|
||||
status_bar_bg: "#1a1a2e" # Status / usage bar background (classic)
|
||||
voice_status_bg: "#1a1a2e" # Voice-mode badge background (classic)
|
||||
completion_menu_bg: "#1a1a2e" # Completion list background (both)
|
||||
completion_menu_current_bg: "#333355" # Active completion row background (both)
|
||||
completion_menu_meta_bg: "#1a1a2e" # Completion meta column bg (classic)
|
||||
completion_menu_meta_current_bg: "#333355" # Active meta bg (classic)
|
||||
|
||||
# Drag-to-select background — (tui)
|
||||
selection_bg: "#3a3a55" # Uniform selection highlight in the TUI
|
||||
|
||||
# ── Spinner ─────────────────────────────────────────────────────────────────
|
||||
# Customize the animated spinner shown during API calls and tool execution.
|
||||
# (classic) — the TUI uses its own animated indicators; spinner config here
|
||||
# is only read by the classic prompt_toolkit CLI.
|
||||
spinner:
|
||||
# Faces shown while waiting for the API response
|
||||
waiting_faces:
|
||||
@@ -78,17 +87,17 @@ spinner:
|
||||
# - ["⟪▲", "▲⟫"]
|
||||
|
||||
# ── Branding ────────────────────────────────────────────────────────────────
|
||||
# Text strings used throughout the CLI interface.
|
||||
# Text strings used throughout the interface.
|
||||
branding:
|
||||
agent_name: "Hermes Agent" # Banner title, about display
|
||||
welcome: "Welcome! Type your message or /help for commands."
|
||||
goodbye: "Goodbye! ⚕" # Exit message
|
||||
response_label: " ⚕ Hermes " # Response box header label
|
||||
prompt_symbol: "❯ " # Input prompt symbol
|
||||
help_header: "(^_^)? Available Commands" # /help header text
|
||||
agent_name: "Hermes Agent" # (both) Banner title, about display
|
||||
welcome: "Welcome! Type your message or /help for commands." # (both)
|
||||
goodbye: "Goodbye! ⚕" # (both) Exit message
|
||||
response_label: " ⚕ Hermes " # (classic) Response box header label
|
||||
prompt_symbol: "❯ " # (both) Input prompt glyph
|
||||
help_header: "(^_^)? Available Commands" # (both) /help overlay title
|
||||
|
||||
# ── Tool Output ─────────────────────────────────────────────────────────────
|
||||
# Character used as the prefix for tool output lines.
|
||||
# Character used as the prefix for tool output lines. (both)
|
||||
# Default is "┊" (thin dotted vertical line). Some alternatives:
|
||||
# "╎" (light triple dash vertical)
|
||||
# "▏" (left one-eighth block)
|
||||
|
||||
Generated
+21
@@ -36,6 +36,26 @@
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"npm-lockfile-fix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1775903712,
|
||||
"narHash": "sha256-2GV79U6iVH4gKAPWYrxUReB0S41ty/Y3dBLquU8AlaA=",
|
||||
"owner": "jeslie0",
|
||||
"repo": "npm-lockfile-fix",
|
||||
"rev": "c6093acb0c0548e0f9b8b3d82918823721930fe8",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "jeslie0",
|
||||
"repo": "npm-lockfile-fix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"pyproject-build-systems": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
@@ -124,6 +144,7 @@
|
||||
"inputs": {
|
||||
"flake-parts": "flake-parts",
|
||||
"nixpkgs": "nixpkgs",
|
||||
"npm-lockfile-fix": "npm-lockfile-fix",
|
||||
"pyproject-build-systems": "pyproject-build-systems",
|
||||
"pyproject-nix": "pyproject-nix_2",
|
||||
"uv2nix": "uv2nix_2"
|
||||
|
||||
@@ -19,11 +19,20 @@
|
||||
url = "github:pyproject-nix/build-system-pkgs";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
npm-lockfile-fix = {
|
||||
url = "github:jeslie0/npm-lockfile-fix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
};
|
||||
|
||||
outputs = inputs:
|
||||
outputs =
|
||||
inputs:
|
||||
inputs.flake-parts.lib.mkFlake { inherit inputs; } {
|
||||
systems = [ "x86_64-linux" "aarch64-linux" "aarch64-darwin" ];
|
||||
systems = [
|
||||
"x86_64-linux"
|
||||
"aarch64-linux"
|
||||
"aarch64-darwin"
|
||||
];
|
||||
|
||||
imports = [
|
||||
./nix/packages.nix
|
||||
|
||||
@@ -100,7 +100,7 @@ def build_channel_directory(adapters: Dict[Any, Any]) -> Dict[str, Any]:
|
||||
|
||||
|
||||
def _build_discord(adapter) -> List[Dict[str, str]]:
|
||||
"""Enumerate all text channels the Discord bot can see."""
|
||||
"""Enumerate all text channels and forum channels the Discord bot can see."""
|
||||
channels = []
|
||||
client = getattr(adapter, "_client", None)
|
||||
if not client:
|
||||
@@ -119,6 +119,15 @@ def _build_discord(adapter) -> List[Dict[str, str]]:
|
||||
"guild": guild.name,
|
||||
"type": "channel",
|
||||
})
|
||||
# Forum channels (type 15) — creating a message auto-spawns a thread post.
|
||||
forums = getattr(guild, "forum_channels", None) or []
|
||||
for ch in forums:
|
||||
channels.append({
|
||||
"id": str(ch.id),
|
||||
"name": ch.name,
|
||||
"guild": guild.name,
|
||||
"type": "forum",
|
||||
})
|
||||
# Also include DM-capable users we've interacted with is not
|
||||
# feasible via guild enumeration; those come from sessions.
|
||||
|
||||
@@ -191,6 +200,15 @@ def load_directory() -> Dict[str, Any]:
|
||||
return {"updated_at": None, "platforms": {}}
|
||||
|
||||
|
||||
def lookup_channel_type(platform_name: str, chat_id: str) -> Optional[str]:
|
||||
"""Return the channel ``type`` string (e.g. ``"channel"``, ``"forum"``) for *chat_id*, or *None* if unknown."""
|
||||
directory = load_directory()
|
||||
for ch in directory.get("platforms", {}).get(platform_name, []):
|
||||
if ch.get("id") == chat_id:
|
||||
return ch.get("type")
|
||||
return None
|
||||
|
||||
|
||||
def resolve_channel_name(platform_name: str, name: str) -> Optional[str]:
|
||||
"""
|
||||
Resolve a human-friendly channel name to a numeric ID.
|
||||
|
||||
+89
-2
@@ -258,6 +258,13 @@ class GatewayConfig:
|
||||
# Streaming configuration
|
||||
streaming: StreamingConfig = field(default_factory=StreamingConfig)
|
||||
|
||||
# Session store pruning: drop SessionEntry records older than this many
|
||||
# days from the in-memory dict and sessions.json. Keeps the store from
|
||||
# growing unbounded in gateways serving many chats/threads/users over
|
||||
# months. Pruning is invisible to users — if they resume, they get a
|
||||
# fresh session exactly as if the reset policy had fired. 0 = disabled.
|
||||
session_store_max_age_days: int = 90
|
||||
|
||||
def get_connected_platforms(self) -> List[Platform]:
|
||||
"""Return list of platforms that are enabled and configured."""
|
||||
connected = []
|
||||
@@ -307,6 +314,14 @@ class GatewayConfig:
|
||||
# QQBot uses extra dict for app credentials
|
||||
elif platform == Platform.QQBOT and config.extra.get("app_id") and config.extra.get("client_secret"):
|
||||
connected.append(platform)
|
||||
# DingTalk uses client_id/client_secret from config.extra or env vars
|
||||
elif platform == Platform.DINGTALK and (
|
||||
config.extra.get("client_id") or os.getenv("DINGTALK_CLIENT_ID")
|
||||
) and (
|
||||
config.extra.get("client_secret") or os.getenv("DINGTALK_CLIENT_SECRET")
|
||||
):
|
||||
connected.append(platform)
|
||||
|
||||
return connected
|
||||
|
||||
def get_home_channel(self, platform: Platform) -> Optional[HomeChannel]:
|
||||
@@ -357,6 +372,7 @@ class GatewayConfig:
|
||||
"thread_sessions_per_user": self.thread_sessions_per_user,
|
||||
"unauthorized_dm_behavior": self.unauthorized_dm_behavior,
|
||||
"streaming": self.streaming.to_dict(),
|
||||
"session_store_max_age_days": self.session_store_max_age_days,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -404,6 +420,13 @@ class GatewayConfig:
|
||||
"pair",
|
||||
)
|
||||
|
||||
try:
|
||||
session_store_max_age_days = int(data.get("session_store_max_age_days", 90))
|
||||
if session_store_max_age_days < 0:
|
||||
session_store_max_age_days = 0
|
||||
except (TypeError, ValueError):
|
||||
session_store_max_age_days = 90
|
||||
|
||||
return cls(
|
||||
platforms=platforms,
|
||||
default_reset_policy=default_policy,
|
||||
@@ -418,6 +441,7 @@ class GatewayConfig:
|
||||
thread_sessions_per_user=_coerce_bool(thread_sessions_per_user, False),
|
||||
unauthorized_dm_behavior=unauthorized_dm_behavior,
|
||||
streaming=StreamingConfig.from_dict(data.get("streaming", {})),
|
||||
session_store_max_age_days=session_store_max_age_days,
|
||||
)
|
||||
|
||||
def get_unauthorized_dm_behavior(self, platform: Optional[Platform] = None) -> str:
|
||||
@@ -617,6 +641,20 @@ def load_gateway_config() -> GatewayConfig:
|
||||
if isinstance(ntc, list):
|
||||
ntc = ",".join(str(v) for v in ntc)
|
||||
os.environ["DISCORD_NO_THREAD_CHANNELS"] = str(ntc)
|
||||
# allow_mentions: granular control over what the bot can ping.
|
||||
# Safe defaults (no @everyone/roles) are applied in the adapter;
|
||||
# these YAML keys only override when set and let users opt back
|
||||
# into unsafe modes (e.g. roles=true) if they actually want it.
|
||||
allow_mentions_cfg = discord_cfg.get("allow_mentions")
|
||||
if isinstance(allow_mentions_cfg, dict):
|
||||
for yaml_key, env_key in (
|
||||
("everyone", "DISCORD_ALLOW_MENTION_EVERYONE"),
|
||||
("roles", "DISCORD_ALLOW_MENTION_ROLES"),
|
||||
("users", "DISCORD_ALLOW_MENTION_USERS"),
|
||||
("replied_user", "DISCORD_ALLOW_MENTION_REPLIED_USER"),
|
||||
):
|
||||
if yaml_key in allow_mentions_cfg and not os.getenv(env_key):
|
||||
os.environ[env_key] = str(allow_mentions_cfg[yaml_key]).lower()
|
||||
|
||||
# Telegram settings → env vars (env vars take precedence)
|
||||
telegram_cfg = yaml_cfg.get("telegram", {})
|
||||
@@ -663,6 +701,24 @@ def load_gateway_config() -> GatewayConfig:
|
||||
frc = ",".join(str(v) for v in frc)
|
||||
os.environ["WHATSAPP_FREE_RESPONSE_CHATS"] = str(frc)
|
||||
|
||||
# DingTalk settings → env vars (env vars take precedence)
|
||||
dingtalk_cfg = yaml_cfg.get("dingtalk", {})
|
||||
if isinstance(dingtalk_cfg, dict):
|
||||
if "require_mention" in dingtalk_cfg and not os.getenv("DINGTALK_REQUIRE_MENTION"):
|
||||
os.environ["DINGTALK_REQUIRE_MENTION"] = str(dingtalk_cfg["require_mention"]).lower()
|
||||
if "mention_patterns" in dingtalk_cfg and not os.getenv("DINGTALK_MENTION_PATTERNS"):
|
||||
os.environ["DINGTALK_MENTION_PATTERNS"] = json.dumps(dingtalk_cfg["mention_patterns"])
|
||||
frc = dingtalk_cfg.get("free_response_chats")
|
||||
if frc is not None and not os.getenv("DINGTALK_FREE_RESPONSE_CHATS"):
|
||||
if isinstance(frc, list):
|
||||
frc = ",".join(str(v) for v in frc)
|
||||
os.environ["DINGTALK_FREE_RESPONSE_CHATS"] = str(frc)
|
||||
allowed = dingtalk_cfg.get("allowed_users")
|
||||
if allowed is not None and not os.getenv("DINGTALK_ALLOWED_USERS"):
|
||||
if isinstance(allowed, list):
|
||||
allowed = ",".join(str(v) for v in allowed)
|
||||
os.environ["DINGTALK_ALLOWED_USERS"] = str(allowed)
|
||||
|
||||
# Matrix settings → env vars (env vars take precedence)
|
||||
matrix_cfg = yaml_cfg.get("matrix", {})
|
||||
if isinstance(matrix_cfg, dict):
|
||||
@@ -1006,6 +1062,25 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
if webhook_secret:
|
||||
config.platforms[Platform.WEBHOOK].extra["secret"] = webhook_secret
|
||||
|
||||
# DingTalk
|
||||
dingtalk_client_id = os.getenv("DINGTALK_CLIENT_ID")
|
||||
dingtalk_client_secret = os.getenv("DINGTALK_CLIENT_SECRET")
|
||||
if dingtalk_client_id and dingtalk_client_secret:
|
||||
if Platform.DINGTALK not in config.platforms:
|
||||
config.platforms[Platform.DINGTALK] = PlatformConfig()
|
||||
config.platforms[Platform.DINGTALK].enabled = True
|
||||
config.platforms[Platform.DINGTALK].extra.update({
|
||||
"client_id": dingtalk_client_id,
|
||||
"client_secret": dingtalk_client_secret,
|
||||
})
|
||||
dingtalk_home = os.getenv("DINGTALK_HOME_CHANNEL")
|
||||
if dingtalk_home:
|
||||
config.platforms[Platform.DINGTALK].home_channel = HomeChannel(
|
||||
platform=Platform.DINGTALK,
|
||||
chat_id=dingtalk_home,
|
||||
name=os.getenv("DINGTALK_HOME_CHANNEL_NAME", "Home"),
|
||||
)
|
||||
|
||||
# Feishu / Lark
|
||||
feishu_app_id = os.getenv("FEISHU_APP_ID")
|
||||
feishu_app_secret = os.getenv("FEISHU_APP_SECRET")
|
||||
@@ -1154,12 +1229,24 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
qq_group_allowed = os.getenv("QQ_GROUP_ALLOWED_USERS", "").strip()
|
||||
if qq_group_allowed:
|
||||
extra["group_allow_from"] = qq_group_allowed
|
||||
qq_home = os.getenv("QQ_HOME_CHANNEL", "").strip()
|
||||
qq_home = os.getenv("QQBOT_HOME_CHANNEL", "").strip()
|
||||
qq_home_name_env = "QQBOT_HOME_CHANNEL_NAME"
|
||||
if not qq_home:
|
||||
# Back-compat: accept the pre-rename name and log a one-time warning.
|
||||
legacy_home = os.getenv("QQ_HOME_CHANNEL", "").strip()
|
||||
if legacy_home:
|
||||
qq_home = legacy_home
|
||||
qq_home_name_env = "QQ_HOME_CHANNEL_NAME"
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(
|
||||
"QQ_HOME_CHANNEL is deprecated; rename to QQBOT_HOME_CHANNEL "
|
||||
"in your .env for consistency with the platform key."
|
||||
)
|
||||
if qq_home:
|
||||
config.platforms[Platform.QQBOT].home_channel = HomeChannel(
|
||||
platform=Platform.QQBOT,
|
||||
chat_id=qq_home,
|
||||
name=os.getenv("QQ_HOME_CHANNEL_NAME", "Home"),
|
||||
name=os.getenv("QQBOT_HOME_CHANNEL_NAME") or os.getenv(qq_home_name_env, "Home"),
|
||||
)
|
||||
|
||||
# Session settings
|
||||
|
||||
@@ -669,6 +669,15 @@ class MessageEvent:
|
||||
# Original platform data
|
||||
raw_message: Any = None
|
||||
message_id: Optional[str] = None
|
||||
|
||||
# Platform-specific update identifier. For Telegram this is the
|
||||
# ``update_id`` from the PTB Update wrapper; other platforms currently
|
||||
# ignore it. Used by ``/restart`` to record the triggering update so the
|
||||
# new gateway can advance the Telegram offset past it and avoid processing
|
||||
# the same ``/restart`` twice if PTB's graceful-shutdown ACK times out
|
||||
# ("Error while calling `get_updates` one more time to mark all fetched
|
||||
# updates" in gateway.log).
|
||||
platform_update_id: Optional[int] = None
|
||||
|
||||
# Media attachments
|
||||
# media_urls: local file paths (for vision tool access)
|
||||
@@ -1045,16 +1054,40 @@ class BasePlatformAdapter(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
# Default: the adapter treats ``finalize=True`` on edit_message as a
|
||||
# no-op and is happy to have the stream consumer skip redundant final
|
||||
# edits. Subclasses that *require* an explicit finalize call to close
|
||||
# out the message lifecycle (e.g. rich card / AI assistant surfaces
|
||||
# such as DingTalk AI Cards) override this to True (class attribute or
|
||||
# property) so the stream consumer knows not to short-circuit.
|
||||
REQUIRES_EDIT_FINALIZE: bool = False
|
||||
|
||||
async def edit_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
content: str,
|
||||
*,
|
||||
finalize: bool = False,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Edit a previously sent message. Optional — platforms that don't
|
||||
support editing return success=False and callers fall back to
|
||||
sending a new message.
|
||||
|
||||
``finalize`` signals that this is the last edit in a streaming
|
||||
sequence. Most platforms (Telegram, Slack, Discord, Matrix,
|
||||
etc.) treat it as a no-op because their edit APIs have no notion
|
||||
of message lifecycle state — an edit is an edit. Platforms that
|
||||
render streaming updates with a distinct "in progress" state and
|
||||
require explicit closure (e.g. rich card / AI assistant surfaces
|
||||
such as DingTalk AI Cards) use it to finalize the message and
|
||||
transition the UI out of the streaming indicator — those should
|
||||
also set ``REQUIRES_EDIT_FINALIZE = True`` so callers route a
|
||||
final edit through even when content is unchanged. Callers
|
||||
should set ``finalize=True`` on the final edit of a streamed
|
||||
response (typically when ``got_done`` fires in the stream
|
||||
consumer) and leave it ``False`` on intermediate edits.
|
||||
"""
|
||||
return SendResult(success=False, error="Not supported")
|
||||
|
||||
@@ -1579,7 +1612,9 @@ class BasePlatformAdapter(ABC):
|
||||
# session lifecycle and its cleanup races with the running task
|
||||
# (see PR #4926).
|
||||
cmd = event.get_command()
|
||||
if cmd in ("approve", "deny", "status", "stop", "new", "reset", "background", "restart", "queue", "q"):
|
||||
from hermes_cli.commands import should_bypass_active_session
|
||||
|
||||
if should_bypass_active_session(cmd):
|
||||
logger.debug(
|
||||
"[%s] Command '/%s' bypassing active-session guard for %s",
|
||||
self.name, cmd, session_key,
|
||||
@@ -1991,6 +2026,7 @@ class BasePlatformAdapter(ABC):
|
||||
chat_topic: Optional[str] = None,
|
||||
user_id_alt: Optional[str] = None,
|
||||
chat_id_alt: Optional[str] = None,
|
||||
is_bot: bool = False,
|
||||
) -> SessionSource:
|
||||
"""Helper to build a SessionSource for this platform."""
|
||||
# Normalize empty topic to None
|
||||
@@ -2007,6 +2043,7 @@ class BasePlatformAdapter(ABC):
|
||||
chat_topic=chat_topic.strip() if chat_topic else None,
|
||||
user_id_alt=user_id_alt,
|
||||
chat_id_alt=chat_id_alt,
|
||||
is_bot=is_bot,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
|
||||
+1084
-69
File diff suppressed because it is too large
Load Diff
+548
-91
@@ -51,7 +51,9 @@ from gateway.platforms.base import (
|
||||
ProcessingOutcome,
|
||||
SendResult,
|
||||
cache_image_from_url,
|
||||
cache_image_from_bytes,
|
||||
cache_audio_from_url,
|
||||
cache_audio_from_bytes,
|
||||
cache_document_from_bytes,
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
)
|
||||
@@ -80,6 +82,41 @@ def check_discord_requirements() -> bool:
|
||||
return DISCORD_AVAILABLE
|
||||
|
||||
|
||||
def _build_allowed_mentions():
|
||||
"""Build Discord ``AllowedMentions`` with safe defaults, overridable via env.
|
||||
|
||||
Discord bots default to parsing ``@everyone``, ``@here``, role pings, and
|
||||
user pings when ``allowed_mentions`` is unset on the client — any LLM
|
||||
output or echoed user content that contains ``@everyone`` would therefore
|
||||
ping the whole server. We explicitly deny ``@everyone`` and role pings
|
||||
by default and keep user / replied-user pings enabled so normal
|
||||
conversation still works.
|
||||
|
||||
Override via environment variables (or ``discord.allow_mentions.*`` in
|
||||
config.yaml):
|
||||
|
||||
DISCORD_ALLOW_MENTION_EVERYONE default false — @everyone + @here
|
||||
DISCORD_ALLOW_MENTION_ROLES default false — @role pings
|
||||
DISCORD_ALLOW_MENTION_USERS default true — @user pings
|
||||
DISCORD_ALLOW_MENTION_REPLIED_USER default true — reply-ping author
|
||||
"""
|
||||
if not DISCORD_AVAILABLE:
|
||||
return None
|
||||
|
||||
def _b(name: str, default: bool) -> bool:
|
||||
raw = os.getenv(name, "").strip().lower()
|
||||
if not raw:
|
||||
return default
|
||||
return raw in ("true", "1", "yes", "on")
|
||||
|
||||
return discord.AllowedMentions(
|
||||
everyone=_b("DISCORD_ALLOW_MENTION_EVERYONE", False),
|
||||
roles=_b("DISCORD_ALLOW_MENTION_ROLES", False),
|
||||
users=_b("DISCORD_ALLOW_MENTION_USERS", True),
|
||||
replied_user=_b("DISCORD_ALLOW_MENTION_REPLIED_USER", True),
|
||||
)
|
||||
|
||||
|
||||
class VoiceReceiver:
|
||||
"""Captures and decodes voice audio from a Discord voice channel.
|
||||
|
||||
@@ -458,6 +495,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
self._client: Optional[commands.Bot] = None
|
||||
self._ready_event = asyncio.Event()
|
||||
self._allowed_user_ids: set = set() # For button approval authorization
|
||||
self._allowed_role_ids: set = set() # For DISCORD_ALLOWED_ROLES filtering
|
||||
# Voice channel state (per-guild)
|
||||
self._voice_clients: Dict[int, Any] = {} # guild_id -> VoiceClient
|
||||
# Text batching: merge rapid successive messages (Telegram-style)
|
||||
@@ -536,6 +574,15 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if uid.strip()
|
||||
}
|
||||
|
||||
# Parse DISCORD_ALLOWED_ROLES — comma-separated role IDs.
|
||||
# Users with ANY of these roles can interact with the bot.
|
||||
roles_env = os.getenv("DISCORD_ALLOWED_ROLES", "")
|
||||
if roles_env:
|
||||
self._allowed_role_ids = {
|
||||
int(rid.strip()) for rid in roles_env.split(",")
|
||||
if rid.strip().isdigit()
|
||||
}
|
||||
|
||||
# Set up intents.
|
||||
# Message Content is required for normal text replies.
|
||||
# Server Members is only needed when the allowlist contains usernames
|
||||
@@ -547,7 +594,10 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
intents.message_content = True
|
||||
intents.dm_messages = True
|
||||
intents.guild_messages = True
|
||||
intents.members = any(not entry.isdigit() for entry in self._allowed_user_ids)
|
||||
intents.members = (
|
||||
any(not entry.isdigit() for entry in self._allowed_user_ids)
|
||||
or bool(self._allowed_role_ids) # Need members intent for role lookup
|
||||
)
|
||||
intents.voice_states = True
|
||||
|
||||
# Resolve proxy (DISCORD_PROXY > generic env vars > macOS system proxy)
|
||||
@@ -556,10 +606,15 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if proxy_url:
|
||||
logger.info("[%s] Using proxy for Discord: %s", self.name, proxy_url)
|
||||
|
||||
# Create bot — proxy= for HTTP, connector= for SOCKS
|
||||
# Create bot — proxy= for HTTP, connector= for SOCKS.
|
||||
# allowed_mentions is set with safe defaults (no @everyone/roles)
|
||||
# so LLM output or echoed user content can't ping the whole
|
||||
# server; override per DISCORD_ALLOW_MENTION_* env vars or the
|
||||
# discord.allow_mentions.* block in config.yaml.
|
||||
self._client = commands.Bot(
|
||||
command_prefix="!", # Not really used, we handle raw messages
|
||||
intents=intents,
|
||||
allowed_mentions=_build_allowed_mentions(),
|
||||
**proxy_kwargs_for_bot(proxy_url),
|
||||
)
|
||||
adapter_self = self # capture for closure
|
||||
@@ -594,14 +649,13 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if message.type not in (discord.MessageType.default, discord.MessageType.reply):
|
||||
return
|
||||
|
||||
# Check if the message author is in the allowed user list
|
||||
if not self._is_allowed_user(str(message.author.id)):
|
||||
return
|
||||
|
||||
# Bot message filtering (DISCORD_ALLOW_BOTS):
|
||||
# "none" — ignore all other bots (default)
|
||||
# "mentions" — accept bot messages only when they @mention us
|
||||
# "all" — accept all bot messages
|
||||
# Must run BEFORE the user allowlist check so that bots
|
||||
# permitted by DISCORD_ALLOW_BOTS are not rejected for
|
||||
# not being in DISCORD_ALLOWED_USERS (fixes #4466).
|
||||
if getattr(message.author, "bot", False):
|
||||
allow_bots = os.getenv("DISCORD_ALLOW_BOTS", "none").lower().strip()
|
||||
if allow_bots == "none":
|
||||
@@ -609,7 +663,12 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
elif allow_bots == "mentions":
|
||||
if not self._client.user or self._client.user not in message.mentions:
|
||||
return
|
||||
# "all" falls through to handle_message
|
||||
# "all" falls through; bot is permitted — skip the
|
||||
# human-user allowlist below (bots aren't in it).
|
||||
else:
|
||||
# Non-bot: enforce the configured user/role allowlists.
|
||||
if not self._is_allowed_user(str(message.author.id), message.author):
|
||||
return
|
||||
|
||||
# Multi-agent filtering: if the message mentions specific bots
|
||||
# but NOT this bot, the sender is talking to another agent —
|
||||
@@ -798,6 +857,9 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
|
||||
When metadata contains a thread_id, the message is sent to that
|
||||
thread instead of the parent channel identified by chat_id.
|
||||
|
||||
Forum channels (type 15) reject direct messages — a thread post is
|
||||
created automatically.
|
||||
"""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
@@ -823,6 +885,10 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
# Forum channels reject channel.send() — create a thread post instead.
|
||||
if self._is_forum_parent(channel):
|
||||
return await self._send_to_forum(channel, content)
|
||||
|
||||
# Format and split message if needed
|
||||
formatted = self.format_message(content)
|
||||
chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH)
|
||||
@@ -833,7 +899,10 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if reply_to and self._reply_to_mode != "off":
|
||||
try:
|
||||
ref_msg = await channel.fetch_message(int(reply_to))
|
||||
reference = ref_msg
|
||||
if hasattr(ref_msg, "to_reference"):
|
||||
reference = ref_msg.to_reference(fail_if_not_exists=False)
|
||||
else:
|
||||
reference = ref_msg
|
||||
except Exception as e:
|
||||
logger.debug("Could not fetch reply-to message: %s", e)
|
||||
|
||||
@@ -851,14 +920,20 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
err_text = str(e)
|
||||
if (
|
||||
chunk_reference is not None
|
||||
and "error code: 50035" in err_text
|
||||
and "Cannot reply to a system message" in err_text
|
||||
and (
|
||||
(
|
||||
"error code: 50035" in err_text
|
||||
and "Cannot reply to a system message" in err_text
|
||||
)
|
||||
or "error code: 10008" in err_text
|
||||
)
|
||||
):
|
||||
logger.warning(
|
||||
"[%s] Reply target %s is a Discord system message; retrying send without reply reference",
|
||||
"[%s] Reply target %s rejected the reply reference; retrying send without reply reference",
|
||||
self.name,
|
||||
reply_to,
|
||||
)
|
||||
reference = None
|
||||
msg = await channel.send(
|
||||
content=chunk,
|
||||
reference=None,
|
||||
@@ -877,6 +952,120 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
logger.error("[%s] Failed to send Discord message: %s", self.name, e, exc_info=True)
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def _send_to_forum(self, forum_channel: Any, content: str) -> SendResult:
|
||||
"""Create a thread post in a forum channel with the message as starter content.
|
||||
|
||||
Forum channels (type 15) don't support direct messages. Instead we
|
||||
POST to /channels/{forum_id}/threads with a thread name derived from
|
||||
the first line of the message. Any follow-up chunk failures are
|
||||
reported in ``raw_response['warnings']`` so the caller can surface
|
||||
partial-send issues.
|
||||
"""
|
||||
from tools.send_message_tool import _derive_forum_thread_name
|
||||
|
||||
formatted = self.format_message(content)
|
||||
chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH)
|
||||
|
||||
thread_name = _derive_forum_thread_name(content)
|
||||
|
||||
starter_content = chunks[0] if chunks else thread_name
|
||||
|
||||
try:
|
||||
thread = await forum_channel.create_thread(
|
||||
name=thread_name,
|
||||
content=starter_content,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("[%s] Failed to create forum thread in %s: %s", self.name, forum_channel.id, e)
|
||||
return SendResult(success=False, error=f"Forum thread creation failed: {e}")
|
||||
|
||||
thread_channel = thread if hasattr(thread, "send") else getattr(thread, "thread", None)
|
||||
thread_id = str(getattr(thread_channel, "id", getattr(thread, "id", "")))
|
||||
starter_msg = getattr(thread, "message", None)
|
||||
message_id = str(getattr(starter_msg, "id", thread_id)) if starter_msg else thread_id
|
||||
|
||||
# Send remaining chunks into the newly created thread. Track any
|
||||
# per-chunk failures so the caller sees partial-send outcomes.
|
||||
message_ids = [message_id]
|
||||
warnings: list[str] = []
|
||||
for chunk in chunks[1:]:
|
||||
try:
|
||||
msg = await thread_channel.send(content=chunk)
|
||||
message_ids.append(str(msg.id))
|
||||
except Exception as e:
|
||||
warning = f"Failed to send follow-up chunk to forum thread {thread_id}: {e}"
|
||||
logger.warning("[%s] %s", self.name, warning)
|
||||
warnings.append(warning)
|
||||
|
||||
raw_response: Dict[str, Any] = {"message_ids": message_ids, "thread_id": thread_id}
|
||||
if warnings:
|
||||
raw_response["warnings"] = warnings
|
||||
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=message_ids[0],
|
||||
raw_response=raw_response,
|
||||
)
|
||||
|
||||
async def _forum_post_file(
|
||||
self,
|
||||
forum_channel: Any,
|
||||
*,
|
||||
thread_name: Optional[str] = None,
|
||||
content: str = "",
|
||||
file: Any = None,
|
||||
files: Optional[list] = None,
|
||||
) -> SendResult:
|
||||
"""Create a forum thread whose starter message carries file attachments.
|
||||
|
||||
Used by the send_voice / send_image_file / send_document paths when
|
||||
the target channel is a forum (type 15). ``create_thread`` on a
|
||||
ForumChannel accepts the same file/files/content kwargs as
|
||||
``channel.send``, creating the thread and starter message atomically.
|
||||
"""
|
||||
from tools.send_message_tool import _derive_forum_thread_name
|
||||
|
||||
if not thread_name:
|
||||
# Prefer the text content, fall back to the first attached
|
||||
# filename, fall back to the generic default.
|
||||
hint = content or ""
|
||||
if not hint.strip():
|
||||
if file is not None:
|
||||
hint = getattr(file, "filename", "") or ""
|
||||
elif files:
|
||||
hint = getattr(files[0], "filename", "") or ""
|
||||
thread_name = _derive_forum_thread_name(hint) if hint.strip() else "New Post"
|
||||
|
||||
kwargs: Dict[str, Any] = {"name": thread_name}
|
||||
if content:
|
||||
kwargs["content"] = content
|
||||
if file is not None:
|
||||
kwargs["file"] = file
|
||||
if files:
|
||||
kwargs["files"] = files
|
||||
|
||||
try:
|
||||
thread = await forum_channel.create_thread(**kwargs)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"[%s] Failed to create forum thread with file in %s: %s",
|
||||
self.name,
|
||||
getattr(forum_channel, "id", "?"),
|
||||
e,
|
||||
)
|
||||
return SendResult(success=False, error=f"Forum thread creation failed: {e}")
|
||||
|
||||
thread_channel = thread if hasattr(thread, "send") else getattr(thread, "thread", None)
|
||||
thread_id = str(getattr(thread_channel, "id", getattr(thread, "id", "")))
|
||||
starter_msg = getattr(thread, "message", None)
|
||||
message_id = str(getattr(starter_msg, "id", thread_id)) if starter_msg else thread_id
|
||||
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=message_id,
|
||||
raw_response={"thread_id": thread_id},
|
||||
)
|
||||
|
||||
async def edit_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
@@ -907,7 +1096,11 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
caption: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send a local file as a Discord attachment."""
|
||||
"""Send a local file as a Discord attachment.
|
||||
|
||||
Forum channels (type 15) get a new thread whose starter message
|
||||
carries the file — they reject direct POST /messages.
|
||||
"""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
@@ -920,6 +1113,12 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
filename = file_name or os.path.basename(file_path)
|
||||
with open(file_path, "rb") as fh:
|
||||
file = discord.File(fh, filename=filename)
|
||||
if self._is_forum_parent(channel):
|
||||
return await self._forum_post_file(
|
||||
channel,
|
||||
content=(caption or "").strip(),
|
||||
file=file,
|
||||
)
|
||||
msg = await channel.send(content=caption if caption else None, file=file)
|
||||
return SendResult(success=True, message_id=str(msg.id))
|
||||
|
||||
@@ -968,6 +1167,18 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
with open(audio_path, "rb") as f:
|
||||
file_data = f.read()
|
||||
|
||||
# Forum channels (type 15) reject direct POST /messages — the
|
||||
# native voice flag path also targets /messages so it would fail
|
||||
# too. Create a thread post with the audio as the starter
|
||||
# attachment instead.
|
||||
if self._is_forum_parent(channel):
|
||||
forum_file = discord.File(io.BytesIO(file_data), filename=filename)
|
||||
return await self._forum_post_file(
|
||||
channel,
|
||||
content=(caption or "").strip(),
|
||||
file=forum_file,
|
||||
)
|
||||
|
||||
# Try sending as a native voice message via raw API (flags=8192).
|
||||
try:
|
||||
import base64
|
||||
@@ -1310,11 +1521,48 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _is_allowed_user(self, user_id: str) -> bool:
|
||||
"""Check if user is in DISCORD_ALLOWED_USERS."""
|
||||
if not self._allowed_user_ids:
|
||||
def _is_allowed_user(self, user_id: str, author=None) -> bool:
|
||||
"""Check if user is allowed via DISCORD_ALLOWED_USERS or DISCORD_ALLOWED_ROLES.
|
||||
|
||||
Uses OR semantics: if the user matches EITHER allowlist, they're allowed.
|
||||
If both allowlists are empty, everyone is allowed (backwards compatible).
|
||||
When author is a Member, checks .roles directly; otherwise falls back
|
||||
to scanning the bot's mutual guilds for a Member record.
|
||||
"""
|
||||
# ``getattr`` fallbacks here guard against test fixtures that build
|
||||
# an adapter via ``object.__new__(DiscordAdapter)`` and skip __init__
|
||||
# (see AGENTS.md pitfall #17 — same pattern as gateway.run).
|
||||
allowed_users = getattr(self, "_allowed_user_ids", set())
|
||||
allowed_roles = getattr(self, "_allowed_role_ids", set())
|
||||
has_users = bool(allowed_users)
|
||||
has_roles = bool(allowed_roles)
|
||||
if not has_users and not has_roles:
|
||||
return True
|
||||
return user_id in self._allowed_user_ids
|
||||
# Check user ID allowlist
|
||||
if has_users and user_id in allowed_users:
|
||||
return True
|
||||
# Check role allowlist
|
||||
if has_roles:
|
||||
# Try direct role check from Member object
|
||||
direct_roles = getattr(author, "roles", None) if author is not None else None
|
||||
if direct_roles:
|
||||
if any(getattr(r, "id", None) in allowed_roles for r in direct_roles):
|
||||
return True
|
||||
# Fallback: scan mutual guilds for member's roles
|
||||
if self._client is not None:
|
||||
try:
|
||||
uid_int = int(user_id)
|
||||
except (TypeError, ValueError):
|
||||
uid_int = None
|
||||
if uid_int is not None:
|
||||
for guild in self._client.guilds:
|
||||
m = guild.get_member(uid_int)
|
||||
if m is None:
|
||||
continue
|
||||
m_roles = getattr(m, "roles", None) or []
|
||||
if any(getattr(r, "id", None) in allowed_roles for r in m_roles):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def send_image_file(
|
||||
self,
|
||||
@@ -1383,6 +1631,13 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
import io
|
||||
file = discord.File(io.BytesIO(image_data), filename=f"image.{ext}")
|
||||
|
||||
if self._is_forum_parent(channel):
|
||||
return await self._forum_post_file(
|
||||
channel,
|
||||
content=(caption or "").strip(),
|
||||
file=file,
|
||||
)
|
||||
|
||||
msg = await channel.send(
|
||||
content=caption if caption else None,
|
||||
file=file,
|
||||
@@ -1445,6 +1700,13 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
import io
|
||||
file = discord.File(io.BytesIO(animation_data), filename="animation.gif")
|
||||
|
||||
if self._is_forum_parent(channel):
|
||||
return await self._forum_post_file(
|
||||
channel,
|
||||
content=(caption or "").strip(),
|
||||
file=file,
|
||||
)
|
||||
|
||||
msg = await channel.send(
|
||||
content=caption if caption else None,
|
||||
file=file,
|
||||
@@ -1732,6 +1994,11 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
async def slash_stop(interaction: discord.Interaction):
|
||||
await self._run_simple_slash(interaction, "/stop", "Stop requested~")
|
||||
|
||||
@tree.command(name="steer", description="Inject a message after the next tool call (no interrupt)")
|
||||
@discord.app_commands.describe(prompt="Text to inject into the agent's next tool result")
|
||||
async def slash_steer(interaction: discord.Interaction, prompt: str):
|
||||
await self._run_simple_slash(interaction, f"/steer {prompt}".strip())
|
||||
|
||||
@tree.command(name="compress", description="Compress conversation context")
|
||||
async def slash_compress(interaction: discord.Interaction):
|
||||
await self._run_simple_slash(interaction, "/compress")
|
||||
@@ -1904,12 +2171,23 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
self._register_skill_group(tree)
|
||||
|
||||
def _register_skill_group(self, tree) -> None:
|
||||
"""Register a ``/skill`` command group with category subcommand groups.
|
||||
"""Register a single ``/skill`` command with autocomplete on the name.
|
||||
|
||||
Skills are organized by their directory category under ``SKILLS_DIR``.
|
||||
Each category becomes a subcommand group; root-level skills become
|
||||
direct subcommands. Discord supports 25 subcommand groups × 25
|
||||
subcommands each = 625 skills — well beyond the old 100-command cap.
|
||||
Discord enforces an ~8000-byte per-command payload limit. The older
|
||||
nested layout (``/skill <category> <name>``) registered one giant
|
||||
command whose serialized payload grew linearly with the skill
|
||||
catalog — with the default ~75 skills the payload was ~14 KB and
|
||||
``tree.sync()`` rejected the entire slash-command batch (issues
|
||||
#11321, #10259, #11385, #10261, #10214).
|
||||
|
||||
Autocomplete options are fetched dynamically by Discord when the
|
||||
user types — they do NOT count against the per-command registration
|
||||
budget. So we register ONE flat ``/skill`` command with
|
||||
``name: str`` (autocompleted) and ``args: str = ""``. This scales
|
||||
to thousands of skills with no size math, no splitting, and no
|
||||
hidden skills. The slash picker also becomes more discoverable —
|
||||
Discord live-filters by the user's typed prefix against both the
|
||||
skill name and its description.
|
||||
"""
|
||||
try:
|
||||
from hermes_cli.commands import discord_skill_commands_by_category
|
||||
@@ -1920,68 +2198,97 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Reuse the existing collector for consistent filtering
|
||||
# (per-platform disabled, hub-excluded, name clamping), then
|
||||
# flatten — the category grouping was only useful for the
|
||||
# nested layout.
|
||||
categories, uncategorized, hidden = discord_skill_commands_by_category(
|
||||
reserved_names=existing_names,
|
||||
)
|
||||
entries: list[tuple[str, str, str]] = list(uncategorized)
|
||||
for cat_skills in categories.values():
|
||||
entries.extend(cat_skills)
|
||||
|
||||
if not categories and not uncategorized:
|
||||
if not entries:
|
||||
return
|
||||
|
||||
skill_group = discord.app_commands.Group(
|
||||
# Stable alphabetical order so the autocomplete suggestion
|
||||
# list is predictable across restarts.
|
||||
entries.sort(key=lambda t: t[0])
|
||||
|
||||
# name -> (description, cmd_key) — used by both the autocomplete
|
||||
# callback and the handler for O(1) dispatch.
|
||||
skill_lookup: dict[str, tuple[str, str]] = {
|
||||
n: (d, k) for n, d, k in entries
|
||||
}
|
||||
|
||||
async def _autocomplete_name(
|
||||
interaction: "discord.Interaction", current: str,
|
||||
) -> list:
|
||||
"""Filter skills by the user's typed prefix.
|
||||
|
||||
Matches both the skill name and its description so
|
||||
"/skill pdf" surfaces skills whose description mentions
|
||||
PDFs even if the name doesn't. Discord caps this list at
|
||||
25 entries per query.
|
||||
"""
|
||||
q = (current or "").strip().lower()
|
||||
choices: list = []
|
||||
for name, desc, _key in entries:
|
||||
if not q or q in name.lower() or (desc and q in desc.lower()):
|
||||
if desc:
|
||||
label = f"{name} — {desc}"
|
||||
else:
|
||||
label = name
|
||||
# Discord's Choice.name is capped at 100 chars.
|
||||
if len(label) > 100:
|
||||
label = label[:97] + "..."
|
||||
choices.append(
|
||||
discord.app_commands.Choice(name=label, value=name)
|
||||
)
|
||||
if len(choices) >= 25:
|
||||
break
|
||||
return choices
|
||||
|
||||
@discord.app_commands.describe(
|
||||
name="Which skill to run",
|
||||
args="Optional arguments for the skill",
|
||||
)
|
||||
@discord.app_commands.autocomplete(name=_autocomplete_name)
|
||||
async def _skill_handler(
|
||||
interaction: "discord.Interaction", name: str, args: str = "",
|
||||
):
|
||||
entry = skill_lookup.get(name)
|
||||
if not entry:
|
||||
await interaction.response.send_message(
|
||||
f"Unknown skill: `{name}`. Start typing for "
|
||||
f"autocomplete suggestions.",
|
||||
ephemeral=True,
|
||||
)
|
||||
return
|
||||
_desc, cmd_key = entry
|
||||
await self._run_simple_slash(
|
||||
interaction, f"{cmd_key} {args}".strip()
|
||||
)
|
||||
|
||||
cmd = discord.app_commands.Command(
|
||||
name="skill",
|
||||
description="Run a Hermes skill",
|
||||
callback=_skill_handler,
|
||||
)
|
||||
tree.add_command(cmd)
|
||||
|
||||
# ── Helper: build a callback for a skill command key ──
|
||||
def _make_handler(_key: str):
|
||||
@discord.app_commands.describe(args="Optional arguments for the skill")
|
||||
async def _handler(interaction: discord.Interaction, args: str = ""):
|
||||
await self._run_simple_slash(interaction, f"{_key} {args}".strip())
|
||||
_handler.__name__ = f"skill_{_key.lstrip('/').replace('-', '_')}"
|
||||
return _handler
|
||||
|
||||
# ── Uncategorized (root-level) skills → direct subcommands ──
|
||||
for discord_name, description, cmd_key in uncategorized:
|
||||
cmd = discord.app_commands.Command(
|
||||
name=discord_name,
|
||||
description=description or f"Run the {discord_name} skill",
|
||||
callback=_make_handler(cmd_key),
|
||||
)
|
||||
skill_group.add_command(cmd)
|
||||
|
||||
# ── Category subcommand groups ──
|
||||
for cat_name in sorted(categories):
|
||||
cat_desc = f"{cat_name.replace('-', ' ').title()} skills"
|
||||
if len(cat_desc) > 100:
|
||||
cat_desc = cat_desc[:97] + "..."
|
||||
cat_group = discord.app_commands.Group(
|
||||
name=cat_name,
|
||||
description=cat_desc,
|
||||
parent=skill_group,
|
||||
)
|
||||
for discord_name, description, cmd_key in categories[cat_name]:
|
||||
cmd = discord.app_commands.Command(
|
||||
name=discord_name,
|
||||
description=description or f"Run the {discord_name} skill",
|
||||
callback=_make_handler(cmd_key),
|
||||
)
|
||||
cat_group.add_command(cmd)
|
||||
|
||||
tree.add_command(skill_group)
|
||||
|
||||
total = sum(len(v) for v in categories.values()) + len(uncategorized)
|
||||
logger.info(
|
||||
"[%s] Registered /skill group: %d skill(s) across %d categories"
|
||||
" + %d uncategorized",
|
||||
self.name, total, len(categories), len(uncategorized),
|
||||
"[%s] Registered /skill command with %d skill(s) via autocomplete",
|
||||
self.name, len(entries),
|
||||
)
|
||||
if hidden:
|
||||
logger.warning(
|
||||
"[%s] %d skill(s) not registered (Discord subcommand limits)",
|
||||
logger.info(
|
||||
"[%s] %d skill(s) filtered out of /skill (name clamp / reserved)",
|
||||
self.name, hidden,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("[%s] Failed to register /skill group: %s", self.name, exc)
|
||||
logger.warning("[%s] Failed to register /skill command: %s", self.name, exc)
|
||||
|
||||
def _build_slash_event(self, interaction: discord.Interaction, text: str) -> MessageEvent:
|
||||
"""Build a MessageEvent from a Discord slash command interaction."""
|
||||
@@ -2140,6 +2447,26 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
from gateway.platforms.base import resolve_channel_prompt
|
||||
return resolve_channel_prompt(self.config.extra, channel_id, parent_id)
|
||||
|
||||
def _discord_require_mention(self) -> bool:
|
||||
"""Return whether Discord channel messages require a bot mention."""
|
||||
configured = self.config.extra.get("require_mention")
|
||||
if configured is not None:
|
||||
if isinstance(configured, str):
|
||||
return configured.lower() not in ("false", "0", "no", "off")
|
||||
return bool(configured)
|
||||
return os.getenv("DISCORD_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no", "off")
|
||||
|
||||
def _discord_free_response_channels(self) -> set:
|
||||
"""Return Discord channel IDs where no bot mention is required."""
|
||||
raw = self.config.extra.get("free_response_channels")
|
||||
if raw is None:
|
||||
raw = os.getenv("DISCORD_FREE_RESPONSE_CHANNELS", "")
|
||||
if isinstance(raw, list):
|
||||
return {str(part).strip() for part in raw if str(part).strip()}
|
||||
if isinstance(raw, str) and raw.strip():
|
||||
return {part.strip() for part in raw.split(",") if part.strip()}
|
||||
return set()
|
||||
|
||||
def _thread_parent_channel(self, channel: Any) -> Any:
|
||||
"""Return the parent text channel when invoked from a thread."""
|
||||
return getattr(channel, "parent", None) or channel
|
||||
@@ -2242,8 +2569,15 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
|
||||
Returns the created thread object, or ``None`` on failure.
|
||||
"""
|
||||
# Build a short thread name from the message
|
||||
# Build a short thread name from the message. Strip Discord mention
|
||||
# syntax (users / roles / channels) so thread titles don't end up
|
||||
# showing raw <@id>, <@&id>, or <#id> markers — the ID isn't
|
||||
# meaningful to humans glancing at the thread list (#6336).
|
||||
content = (message.content or "").strip()
|
||||
# <@123>, <@!123>, <@&123>, <#123> — collapse to empty; normalize spaces.
|
||||
content = re.sub(r"<@[!&]?\d+>", "", content)
|
||||
content = re.sub(r"<#\d+>", "", content)
|
||||
content = re.sub(r"\s+", " ", content).strip()
|
||||
thread_name = content[:80] if content else "Hermes"
|
||||
if len(content) > 80:
|
||||
thread_name = thread_name[:77] + "..."
|
||||
@@ -2251,9 +2585,25 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
try:
|
||||
thread = await message.create_thread(name=thread_name, auto_archive_duration=1440)
|
||||
return thread
|
||||
except Exception as e:
|
||||
logger.warning("[%s] Auto-thread creation failed: %s", self.name, e)
|
||||
return None
|
||||
except Exception as direct_error:
|
||||
display_name = getattr(getattr(message, "author", None), "display_name", None) or "unknown user"
|
||||
reason = f"Auto-threaded from mention by {display_name}"
|
||||
try:
|
||||
seed_msg = await message.channel.send(f"\U0001f9f5 Thread created by Hermes: **{thread_name}**")
|
||||
thread = await seed_msg.create_thread(
|
||||
name=thread_name,
|
||||
auto_archive_duration=1440,
|
||||
reason=reason,
|
||||
)
|
||||
return thread
|
||||
except Exception as fallback_error:
|
||||
logger.warning(
|
||||
"[%s] Auto-thread creation failed. Direct error: %s. Fallback error: %s",
|
||||
self.name,
|
||||
direct_error,
|
||||
fallback_error,
|
||||
)
|
||||
return None
|
||||
|
||||
async def send_exec_approval(
|
||||
self, chat_id: str, command: str, session_key: str,
|
||||
@@ -2440,6 +2790,124 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
return f"{parent_name} / {thread_name}"
|
||||
return thread_name
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Attachment download helpers
|
||||
#
|
||||
# Discord attachments (images / audio / documents) are fetched via the
|
||||
# authenticated bot session whenever the Attachment object exposes
|
||||
# ``read()``. That sidesteps two classes of bug that hit the older
|
||||
# plain-HTTP path:
|
||||
#
|
||||
# 1. ``cdn.discordapp.com`` URLs increasingly require bot auth on
|
||||
# download — unauthenticated httpx sees 403 Forbidden.
|
||||
# (issue #8242)
|
||||
# 2. Some user environments (VPNs, corporate DNS, tunnels) resolve
|
||||
# ``cdn.discordapp.com`` to private-looking IPs that our
|
||||
# ``is_safe_url`` guard classifies as SSRF risks. Routing the
|
||||
# fetch through discord.py's own HTTP client handles DNS
|
||||
# internally so our guard isn't consulted for the attachment
|
||||
# path. (issue #6587)
|
||||
#
|
||||
# If ``att.read()`` is unavailable (unexpected object shape / test
|
||||
# stub) or the bot session fetch fails, we fall back to the existing
|
||||
# SSRF-gated URL downloaders. The fallback keeps defense-in-depth
|
||||
# against any future Discord payload-schema drift that could slip a
|
||||
# non-CDN URL into the ``att.url`` field. (issue #11345)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _read_attachment_bytes(self, att) -> Optional[bytes]:
|
||||
"""Read an attachment via discord.py's authenticated bot session.
|
||||
|
||||
Returns the raw bytes on success, or ``None`` if ``att`` doesn't
|
||||
expose a callable ``read()`` or the read itself fails. Callers
|
||||
should treat ``None`` as a signal to fall back to the URL-based
|
||||
downloaders.
|
||||
"""
|
||||
reader = getattr(att, "read", None)
|
||||
if reader is None or not callable(reader):
|
||||
return None
|
||||
try:
|
||||
return await reader()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"[Discord] Authenticated attachment read failed for %s: %s",
|
||||
getattr(att, "filename", None) or getattr(att, "url", "<unknown>"),
|
||||
e,
|
||||
)
|
||||
return None
|
||||
|
||||
async def _cache_discord_image(self, att, ext: str) -> str:
|
||||
"""Cache a Discord image attachment to local disk.
|
||||
|
||||
Primary path: ``att.read()`` + ``cache_image_from_bytes``
|
||||
(authenticated, no SSRF gate).
|
||||
|
||||
Fallback: ``cache_image_from_url`` (plain httpx, SSRF-gated).
|
||||
"""
|
||||
raw_bytes = await self._read_attachment_bytes(att)
|
||||
if raw_bytes is not None:
|
||||
try:
|
||||
return cache_image_from_bytes(raw_bytes, ext=ext)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"[Discord] cache_image_from_bytes rejected att.read() data; falling back to URL: %s",
|
||||
e,
|
||||
)
|
||||
return await cache_image_from_url(att.url, ext=ext)
|
||||
|
||||
async def _cache_discord_audio(self, att, ext: str) -> str:
|
||||
"""Cache a Discord audio attachment to local disk.
|
||||
|
||||
Primary path: ``att.read()`` + ``cache_audio_from_bytes``
|
||||
(authenticated, no SSRF gate).
|
||||
|
||||
Fallback: ``cache_audio_from_url`` (plain httpx, SSRF-gated).
|
||||
"""
|
||||
raw_bytes = await self._read_attachment_bytes(att)
|
||||
if raw_bytes is not None:
|
||||
try:
|
||||
return cache_audio_from_bytes(raw_bytes, ext=ext)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"[Discord] cache_audio_from_bytes failed; falling back to URL: %s",
|
||||
e,
|
||||
)
|
||||
return await cache_audio_from_url(att.url, ext=ext)
|
||||
|
||||
async def _cache_discord_document(self, att, ext: str) -> bytes:
|
||||
"""Download a Discord document attachment and return the raw bytes.
|
||||
|
||||
Primary path: ``att.read()`` (authenticated, no SSRF gate).
|
||||
|
||||
Fallback: SSRF-gated ``aiohttp`` download. This closes the gap
|
||||
where the old document path made raw ``aiohttp.ClientSession``
|
||||
requests with no safety check (#11345). The caller is responsible
|
||||
for passing the returned bytes to ``cache_document_from_bytes``
|
||||
(and, where applicable, for injecting text content).
|
||||
"""
|
||||
raw_bytes = await self._read_attachment_bytes(att)
|
||||
if raw_bytes is not None:
|
||||
return raw_bytes
|
||||
|
||||
# Fallback: SSRF-gated URL download.
|
||||
if not is_safe_url(att.url):
|
||||
raise ValueError(
|
||||
f"Blocked unsafe attachment URL (SSRF protection): {att.url}"
|
||||
)
|
||||
import aiohttp
|
||||
from gateway.platforms.base import resolve_proxy_url, proxy_kwargs_for_aiohttp
|
||||
_proxy = resolve_proxy_url(platform_env_var="DISCORD_PROXY")
|
||||
_sess_kw, _req_kw = proxy_kwargs_for_aiohttp(_proxy)
|
||||
async with aiohttp.ClientSession(**_sess_kw) as session:
|
||||
async with session.get(
|
||||
att.url,
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
**_req_kw,
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"HTTP {resp.status}")
|
||||
return await resp.read()
|
||||
|
||||
async def _handle_message(self, message: DiscordMessage) -> None:
|
||||
"""Handle incoming Discord messages."""
|
||||
# In server channels (not DMs), require the bot to be @mentioned
|
||||
@@ -2482,12 +2950,11 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
logger.debug("[%s] Ignoring message in ignored channel: %s", self.name, channel_ids)
|
||||
return
|
||||
|
||||
free_channels_raw = os.getenv("DISCORD_FREE_RESPONSE_CHANNELS", "")
|
||||
free_channels = {ch.strip() for ch in free_channels_raw.split(",") if ch.strip()}
|
||||
free_channels = self._discord_free_response_channels()
|
||||
if parent_channel_id:
|
||||
channel_ids.add(parent_channel_id)
|
||||
|
||||
require_mention = os.getenv("DISCORD_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no")
|
||||
require_mention = self._discord_require_mention()
|
||||
# Voice-linked text channels act as free-response while voice is active.
|
||||
# Only the exact bound channel gets the exemption, not sibling threads.
|
||||
voice_linked_ids = {str(ch_id) for ch_id in self._voice_text_channels.values()}
|
||||
@@ -2515,9 +2982,10 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if not is_thread and not isinstance(message.channel, discord.DMChannel):
|
||||
no_thread_channels_raw = os.getenv("DISCORD_NO_THREAD_CHANNELS", "")
|
||||
no_thread_channels = {ch.strip() for ch in no_thread_channels_raw.split(",") if ch.strip()}
|
||||
skip_thread = bool(channel_ids & no_thread_channels)
|
||||
skip_thread = bool(channel_ids & no_thread_channels) or is_free_channel
|
||||
auto_thread = os.getenv("DISCORD_AUTO_THREAD", "true").lower() in ("true", "1", "yes")
|
||||
if auto_thread and not skip_thread and not is_voice_linked_channel:
|
||||
is_reply_message = getattr(message, "type", None) == discord.MessageType.reply
|
||||
if auto_thread and not skip_thread and not is_voice_linked_channel and not is_reply_message:
|
||||
thread = await self._auto_create_thread(message)
|
||||
if thread:
|
||||
is_thread = True
|
||||
@@ -2578,6 +3046,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
user_name=message.author.display_name,
|
||||
thread_id=thread_id,
|
||||
chat_topic=chat_topic,
|
||||
is_bot=getattr(message.author, "bot", False),
|
||||
)
|
||||
|
||||
# Build media URLs -- download image attachments to local cache so the
|
||||
@@ -2593,7 +3062,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
ext = "." + content_type.split("/")[-1].split(";")[0]
|
||||
if ext not in (".jpg", ".jpeg", ".png", ".gif", ".webp"):
|
||||
ext = ".jpg"
|
||||
cached_path = await cache_image_from_url(att.url, ext=ext)
|
||||
cached_path = await self._cache_discord_image(att, ext)
|
||||
media_urls.append(cached_path)
|
||||
media_types.append(content_type)
|
||||
print(f"[Discord] Cached user image: {cached_path}", flush=True)
|
||||
@@ -2607,7 +3076,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
ext = "." + content_type.split("/")[-1].split(";")[0]
|
||||
if ext not in (".ogg", ".mp3", ".wav", ".webm", ".m4a"):
|
||||
ext = ".ogg"
|
||||
cached_path = await cache_audio_from_url(att.url, ext=ext)
|
||||
cached_path = await self._cache_discord_audio(att, ext)
|
||||
media_urls.append(cached_path)
|
||||
media_types.append(content_type)
|
||||
print(f"[Discord] Cached user audio: {cached_path}", flush=True)
|
||||
@@ -2638,19 +3107,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
)
|
||||
else:
|
||||
try:
|
||||
import aiohttp
|
||||
from gateway.platforms.base import resolve_proxy_url, proxy_kwargs_for_aiohttp
|
||||
_proxy = resolve_proxy_url(platform_env_var="DISCORD_PROXY")
|
||||
_sess_kw, _req_kw = proxy_kwargs_for_aiohttp(_proxy)
|
||||
async with aiohttp.ClientSession(**_sess_kw) as session:
|
||||
async with session.get(
|
||||
att.url,
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
**_req_kw,
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"HTTP {resp.status}")
|
||||
raw_bytes = await resp.read()
|
||||
raw_bytes = await self._cache_discord_document(att, ext)
|
||||
cached_path = cache_document_from_bytes(
|
||||
raw_bytes, att.filename or f"document{ext}"
|
||||
)
|
||||
|
||||
@@ -1228,6 +1228,10 @@ class FeishuAdapter(BasePlatformAdapter):
|
||||
.register_p2_im_chat_member_bot_deleted_v1(self._on_bot_removed_from_chat)
|
||||
.register_p2_im_chat_access_event_bot_p2p_chat_entered_v1(self._on_p2p_chat_entered)
|
||||
.register_p2_im_message_recalled_v1(self._on_message_recalled)
|
||||
.register_p2_customized_event(
|
||||
"drive.notice.comment_add_v1",
|
||||
self._on_drive_comment_event,
|
||||
)
|
||||
.build()
|
||||
)
|
||||
|
||||
@@ -1965,6 +1969,25 @@ class FeishuAdapter(BasePlatformAdapter):
|
||||
def _on_message_recalled(self, data: Any) -> None:
|
||||
logger.debug("[Feishu] Message recalled by user")
|
||||
|
||||
def _on_drive_comment_event(self, data: Any) -> None:
|
||||
"""Handle drive document comment notification (drive.notice.comment_add_v1).
|
||||
|
||||
Delegates to :mod:`gateway.platforms.feishu_comment` for parsing,
|
||||
logging, and reaction. Scheduling follows the same
|
||||
``run_coroutine_threadsafe`` pattern used by ``_on_message_event``.
|
||||
"""
|
||||
from gateway.platforms.feishu_comment import handle_drive_comment_event
|
||||
|
||||
loop = self._loop
|
||||
if not self._loop_accepts_callbacks(loop):
|
||||
logger.warning("[Feishu] Dropping drive comment event before adapter loop is ready")
|
||||
return
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
handle_drive_comment_event(self._client, data, self_open_id=self._bot_open_id),
|
||||
loop,
|
||||
)
|
||||
future.add_done_callback(self._log_background_failure)
|
||||
|
||||
def _on_reaction_event(self, event_type: str, data: Any) -> None:
|
||||
"""Route user reactions on bot messages as synthetic text events."""
|
||||
event = getattr(data, "event", None)
|
||||
@@ -2590,6 +2613,8 @@ class FeishuAdapter(BasePlatformAdapter):
|
||||
self._on_reaction_event(event_type, data)
|
||||
elif event_type == "card.action.trigger":
|
||||
self._on_card_action_trigger(data)
|
||||
elif event_type == "drive.notice.comment_add_v1":
|
||||
self._on_drive_comment_event(data)
|
||||
else:
|
||||
logger.debug("[Feishu] Ignoring webhook event type: %s", event_type or "unknown")
|
||||
return web.json_response({"code": 0, "msg": "ok"})
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,429 @@
|
||||
"""
|
||||
Feishu document comment access-control rules.
|
||||
|
||||
3-tier rule resolution: exact doc > wildcard "*" > top-level > code defaults.
|
||||
Each field (enabled/policy/allow_from) falls back independently.
|
||||
Config: ~/.hermes/feishu_comment_rules.json (mtime-cached, hot-reload).
|
||||
Pairing store: ~/.hermes/feishu_comment_pairing.json.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Paths
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Uses the canonical ``get_hermes_home()`` helper (HERMES_HOME-aware and
|
||||
# profile-safe). Resolved at import time; this module is lazy-imported by
|
||||
# the Feishu comment event handler, which runs long after profile overrides
|
||||
# have been applied, so freezing paths here is safe.
|
||||
|
||||
RULES_FILE = get_hermes_home() / "feishu_comment_rules.json"
|
||||
PAIRING_FILE = get_hermes_home() / "feishu_comment_pairing.json"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_VALID_POLICIES = ("allowlist", "pairing")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CommentDocumentRule:
|
||||
"""Per-document rule. ``None`` means 'inherit from lower tier'."""
|
||||
enabled: Optional[bool] = None
|
||||
policy: Optional[str] = None
|
||||
allow_from: Optional[frozenset] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CommentsConfig:
|
||||
"""Top-level comment access config."""
|
||||
enabled: bool = True
|
||||
policy: str = "pairing"
|
||||
allow_from: frozenset = field(default_factory=frozenset)
|
||||
documents: Dict[str, CommentDocumentRule] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ResolvedCommentRule:
|
||||
"""Fully resolved rule after field-by-field fallback."""
|
||||
enabled: bool
|
||||
policy: str
|
||||
allow_from: frozenset
|
||||
match_source: str # e.g. "exact:docx:xxx" | "wildcard" | "top" | "default"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mtime-cached file loading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _MtimeCache:
|
||||
"""Generic mtime-based file cache. ``stat()`` per access, re-read only on change."""
|
||||
|
||||
def __init__(self, path: Path):
|
||||
self._path = path
|
||||
self._mtime: float = 0.0
|
||||
self._data: Optional[dict] = None
|
||||
|
||||
def load(self) -> dict:
|
||||
try:
|
||||
st = self._path.stat()
|
||||
mtime = st.st_mtime
|
||||
except FileNotFoundError:
|
||||
self._mtime = 0.0
|
||||
self._data = {}
|
||||
return {}
|
||||
|
||||
if mtime == self._mtime and self._data is not None:
|
||||
return self._data
|
||||
|
||||
try:
|
||||
with open(self._path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if not isinstance(data, dict):
|
||||
data = {}
|
||||
except (json.JSONDecodeError, OSError):
|
||||
logger.warning("[Feishu-Rules] Failed to read %s, using empty config", self._path)
|
||||
data = {}
|
||||
|
||||
self._mtime = mtime
|
||||
self._data = data
|
||||
return data
|
||||
|
||||
|
||||
_rules_cache = _MtimeCache(RULES_FILE)
|
||||
_pairing_cache = _MtimeCache(PAIRING_FILE)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _parse_frozenset(raw: Any) -> Optional[frozenset]:
|
||||
"""Parse a list of strings into a frozenset; return None if key absent."""
|
||||
if raw is None:
|
||||
return None
|
||||
if isinstance(raw, (list, tuple)):
|
||||
return frozenset(str(u).strip() for u in raw if str(u).strip())
|
||||
return None
|
||||
|
||||
|
||||
def _parse_document_rule(raw: dict) -> CommentDocumentRule:
|
||||
enabled = raw.get("enabled")
|
||||
if enabled is not None:
|
||||
enabled = bool(enabled)
|
||||
policy = raw.get("policy")
|
||||
if policy is not None:
|
||||
policy = str(policy).strip().lower()
|
||||
if policy not in _VALID_POLICIES:
|
||||
policy = None
|
||||
allow_from = _parse_frozenset(raw.get("allow_from"))
|
||||
return CommentDocumentRule(enabled=enabled, policy=policy, allow_from=allow_from)
|
||||
|
||||
|
||||
def load_config() -> CommentsConfig:
|
||||
"""Load comment rules from disk (mtime-cached)."""
|
||||
raw = _rules_cache.load()
|
||||
if not raw:
|
||||
return CommentsConfig()
|
||||
|
||||
documents: Dict[str, CommentDocumentRule] = {}
|
||||
raw_docs = raw.get("documents", {})
|
||||
if isinstance(raw_docs, dict):
|
||||
for key, rule_raw in raw_docs.items():
|
||||
if isinstance(rule_raw, dict):
|
||||
documents[str(key)] = _parse_document_rule(rule_raw)
|
||||
|
||||
policy = str(raw.get("policy", "pairing")).strip().lower()
|
||||
if policy not in _VALID_POLICIES:
|
||||
policy = "pairing"
|
||||
|
||||
return CommentsConfig(
|
||||
enabled=raw.get("enabled", True),
|
||||
policy=policy,
|
||||
allow_from=_parse_frozenset(raw.get("allow_from")) or frozenset(),
|
||||
documents=documents,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Rule resolution (§8.4 field-by-field fallback)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def has_wiki_keys(cfg: CommentsConfig) -> bool:
|
||||
"""Check if any document rule key starts with 'wiki:'."""
|
||||
return any(k.startswith("wiki:") for k in cfg.documents)
|
||||
|
||||
|
||||
def resolve_rule(
|
||||
cfg: CommentsConfig,
|
||||
file_type: str,
|
||||
file_token: str,
|
||||
wiki_token: str = "",
|
||||
) -> ResolvedCommentRule:
|
||||
"""Resolve effective rule: exact doc → wiki key → wildcard → top-level → defaults."""
|
||||
exact_key = f"{file_type}:{file_token}"
|
||||
|
||||
exact = cfg.documents.get(exact_key)
|
||||
exact_src = f"exact:{exact_key}"
|
||||
if exact is None and wiki_token:
|
||||
wiki_key = f"wiki:{wiki_token}"
|
||||
exact = cfg.documents.get(wiki_key)
|
||||
exact_src = f"exact:{wiki_key}"
|
||||
|
||||
wildcard = cfg.documents.get("*")
|
||||
|
||||
layers = []
|
||||
if exact is not None:
|
||||
layers.append((exact, exact_src))
|
||||
if wildcard is not None:
|
||||
layers.append((wildcard, "wildcard"))
|
||||
|
||||
def _pick(field_name: str):
|
||||
for layer, source in layers:
|
||||
val = getattr(layer, field_name)
|
||||
if val is not None:
|
||||
return val, source
|
||||
return getattr(cfg, field_name), "top"
|
||||
|
||||
enabled, en_src = _pick("enabled")
|
||||
policy, pol_src = _pick("policy")
|
||||
allow_from, _ = _pick("allow_from")
|
||||
|
||||
# match_source = highest-priority tier that contributed any field
|
||||
priority_order = {"exact": 0, "wildcard": 1, "top": 2}
|
||||
best_src = min(
|
||||
[en_src, pol_src],
|
||||
key=lambda s: priority_order.get(s.split(":")[0], 3),
|
||||
)
|
||||
|
||||
return ResolvedCommentRule(
|
||||
enabled=enabled,
|
||||
policy=policy,
|
||||
allow_from=allow_from,
|
||||
match_source=best_src,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pairing store
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _load_pairing_approved() -> set:
|
||||
"""Return set of approved user open_ids (mtime-cached)."""
|
||||
data = _pairing_cache.load()
|
||||
approved = data.get("approved", {})
|
||||
if isinstance(approved, dict):
|
||||
return set(approved.keys())
|
||||
if isinstance(approved, list):
|
||||
return set(str(u) for u in approved if u)
|
||||
return set()
|
||||
|
||||
|
||||
def _save_pairing(data: dict) -> None:
|
||||
PAIRING_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = PAIRING_FILE.with_suffix(".tmp")
|
||||
with open(tmp, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
tmp.replace(PAIRING_FILE)
|
||||
# Invalidate cache so next load picks up change
|
||||
_pairing_cache._mtime = 0.0
|
||||
_pairing_cache._data = None
|
||||
|
||||
|
||||
def pairing_add(user_open_id: str) -> bool:
|
||||
"""Add a user to the pairing-approved list. Returns True if newly added."""
|
||||
data = _pairing_cache.load()
|
||||
approved = data.get("approved", {})
|
||||
if not isinstance(approved, dict):
|
||||
approved = {}
|
||||
if user_open_id in approved:
|
||||
return False
|
||||
approved[user_open_id] = {"approved_at": time.time()}
|
||||
data["approved"] = approved
|
||||
_save_pairing(data)
|
||||
return True
|
||||
|
||||
|
||||
def pairing_remove(user_open_id: str) -> bool:
|
||||
"""Remove a user from the pairing-approved list. Returns True if removed."""
|
||||
data = _pairing_cache.load()
|
||||
approved = data.get("approved", {})
|
||||
if not isinstance(approved, dict):
|
||||
return False
|
||||
if user_open_id not in approved:
|
||||
return False
|
||||
del approved[user_open_id]
|
||||
data["approved"] = approved
|
||||
_save_pairing(data)
|
||||
return True
|
||||
|
||||
|
||||
def pairing_list() -> Dict[str, Any]:
|
||||
"""Return the approved dict {user_open_id: {approved_at: ...}}."""
|
||||
data = _pairing_cache.load()
|
||||
approved = data.get("approved", {})
|
||||
return dict(approved) if isinstance(approved, dict) else {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Access check (public API for feishu_comment.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def is_user_allowed(rule: ResolvedCommentRule, user_open_id: str) -> bool:
|
||||
"""Check if user passes the resolved rule's policy gate."""
|
||||
if user_open_id in rule.allow_from:
|
||||
return True
|
||||
if rule.policy == "pairing":
|
||||
return user_open_id in _load_pairing_approved()
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _print_status() -> None:
|
||||
cfg = load_config()
|
||||
print(f"Rules file: {RULES_FILE}")
|
||||
print(f" exists: {RULES_FILE.exists()}")
|
||||
print(f"Pairing file: {PAIRING_FILE}")
|
||||
print(f" exists: {PAIRING_FILE.exists()}")
|
||||
print()
|
||||
print(f"Top-level:")
|
||||
print(f" enabled: {cfg.enabled}")
|
||||
print(f" policy: {cfg.policy}")
|
||||
print(f" allow_from: {sorted(cfg.allow_from) if cfg.allow_from else '[]'}")
|
||||
print()
|
||||
if cfg.documents:
|
||||
print(f"Document rules ({len(cfg.documents)}):")
|
||||
for key, rule in sorted(cfg.documents.items()):
|
||||
parts = []
|
||||
if rule.enabled is not None:
|
||||
parts.append(f"enabled={rule.enabled}")
|
||||
if rule.policy is not None:
|
||||
parts.append(f"policy={rule.policy}")
|
||||
if rule.allow_from is not None:
|
||||
parts.append(f"allow_from={sorted(rule.allow_from)}")
|
||||
print(f" [{key}] {', '.join(parts) if parts else '(empty — inherits all)'}")
|
||||
else:
|
||||
print("Document rules: (none)")
|
||||
print()
|
||||
approved = pairing_list()
|
||||
print(f"Pairing approved ({len(approved)}):")
|
||||
for uid, meta in sorted(approved.items()):
|
||||
ts = meta.get("approved_at", 0)
|
||||
print(f" {uid} (approved_at={ts})")
|
||||
|
||||
|
||||
def _do_check(doc_key: str, user_open_id: str) -> None:
|
||||
cfg = load_config()
|
||||
parts = doc_key.split(":", 1)
|
||||
if len(parts) != 2:
|
||||
print(f"Error: doc_key must be 'fileType:fileToken', got '{doc_key}'")
|
||||
return
|
||||
file_type, file_token = parts
|
||||
rule = resolve_rule(cfg, file_type, file_token)
|
||||
allowed = is_user_allowed(rule, user_open_id)
|
||||
print(f"Document: {doc_key}")
|
||||
print(f"User: {user_open_id}")
|
||||
print(f"Resolved rule:")
|
||||
print(f" enabled: {rule.enabled}")
|
||||
print(f" policy: {rule.policy}")
|
||||
print(f" allow_from: {sorted(rule.allow_from) if rule.allow_from else '[]'}")
|
||||
print(f" match_source: {rule.match_source}")
|
||||
print(f"Result: {'ALLOWED' if allowed else 'DENIED'}")
|
||||
|
||||
|
||||
def _main() -> int:
|
||||
import sys
|
||||
|
||||
try:
|
||||
from hermes_cli.env_loader import load_hermes_dotenv
|
||||
load_hermes_dotenv()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
usage = (
|
||||
"Usage: python -m gateway.platforms.feishu_comment_rules <command> [args]\n"
|
||||
"\n"
|
||||
"Commands:\n"
|
||||
" status Show rules config and pairing state\n"
|
||||
" check <fileType:token> <user> Simulate access check\n"
|
||||
" pairing add <user_open_id> Add user to pairing-approved list\n"
|
||||
" pairing remove <user_open_id> Remove user from pairing-approved list\n"
|
||||
" pairing list List pairing-approved users\n"
|
||||
"\n"
|
||||
f"Rules config file: {RULES_FILE}\n"
|
||||
" Edit this JSON file directly to configure policies and document rules.\n"
|
||||
" Changes take effect on the next comment event (no restart needed).\n"
|
||||
)
|
||||
|
||||
args = sys.argv[1:]
|
||||
if not args:
|
||||
print(usage)
|
||||
return 1
|
||||
|
||||
cmd = args[0]
|
||||
|
||||
if cmd == "status":
|
||||
_print_status()
|
||||
|
||||
elif cmd == "check":
|
||||
if len(args) < 3:
|
||||
print("Usage: check <fileType:fileToken> <user_open_id>")
|
||||
return 1
|
||||
_do_check(args[1], args[2])
|
||||
|
||||
elif cmd == "pairing":
|
||||
if len(args) < 2:
|
||||
print("Usage: pairing <add|remove|list> [args]")
|
||||
return 1
|
||||
sub = args[1]
|
||||
if sub == "add":
|
||||
if len(args) < 3:
|
||||
print("Usage: pairing add <user_open_id>")
|
||||
return 1
|
||||
if pairing_add(args[2]):
|
||||
print(f"Added: {args[2]}")
|
||||
else:
|
||||
print(f"Already approved: {args[2]}")
|
||||
elif sub == "remove":
|
||||
if len(args) < 3:
|
||||
print("Usage: pairing remove <user_open_id>")
|
||||
return 1
|
||||
if pairing_remove(args[2]):
|
||||
print(f"Removed: {args[2]}")
|
||||
else:
|
||||
print(f"Not in approved list: {args[2]}")
|
||||
elif sub == "list":
|
||||
approved = pairing_list()
|
||||
if not approved:
|
||||
print("(no approved users)")
|
||||
for uid, meta in sorted(approved.items()):
|
||||
print(f" {uid} approved_at={meta.get('approved_at', '?')}")
|
||||
else:
|
||||
print(f"Unknown pairing subcommand: {sub}")
|
||||
return 1
|
||||
else:
|
||||
print(f"Unknown command: {cmd}\n")
|
||||
print(usage)
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(_main())
|
||||
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
QQBot platform package.
|
||||
|
||||
Re-exports the main adapter symbols from ``adapter.py`` (the original
|
||||
``qqbot.py``) so that **all existing import paths remain unchanged**::
|
||||
|
||||
from gateway.platforms.qqbot import QQAdapter # works
|
||||
from gateway.platforms.qqbot import check_qq_requirements # works
|
||||
|
||||
New modules:
|
||||
- ``constants`` — shared constants (API URLs, timeouts, message types)
|
||||
- ``utils`` — User-Agent builder, config helpers
|
||||
- ``crypto`` — AES-256-GCM key generation and decryption
|
||||
- ``onboard`` — QR-code scan-to-configure flow
|
||||
"""
|
||||
|
||||
# -- Adapter (original qqbot.py) ------------------------------------------
|
||||
from .adapter import ( # noqa: F401
|
||||
QQAdapter,
|
||||
QQCloseError,
|
||||
check_qq_requirements,
|
||||
_coerce_list,
|
||||
_ssrf_redirect_guard,
|
||||
)
|
||||
|
||||
# -- Onboard (QR-code scan-to-configure) -----------------------------------
|
||||
from .onboard import ( # noqa: F401
|
||||
BindStatus,
|
||||
create_bind_task,
|
||||
poll_bind_result,
|
||||
build_connect_url,
|
||||
)
|
||||
from .crypto import decrypt_secret, generate_bind_key # noqa: F401
|
||||
|
||||
# -- Utils -----------------------------------------------------------------
|
||||
from .utils import build_user_agent, get_api_headers, coerce_list # noqa: F401
|
||||
|
||||
__all__ = [
|
||||
# adapter
|
||||
"QQAdapter",
|
||||
"QQCloseError",
|
||||
"check_qq_requirements",
|
||||
"_coerce_list",
|
||||
"_ssrf_redirect_guard",
|
||||
# onboard
|
||||
"BindStatus",
|
||||
"create_bind_task",
|
||||
"poll_bind_result",
|
||||
"build_connect_url",
|
||||
# crypto
|
||||
"decrypt_secret",
|
||||
"generate_bind_key",
|
||||
# utils
|
||||
"build_user_agent",
|
||||
"get_api_headers",
|
||||
"coerce_list",
|
||||
]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,74 @@
|
||||
"""QQBot package-level constants shared across adapter, onboard, and other modules."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# QQBot adapter version — bump on functional changes to the adapter package.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
QQBOT_VERSION = "1.1.0"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# API endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# The portal domain is configurable via QQ_API_HOST for corporate proxies
|
||||
# or test environments. Default: q.qq.com (production).
|
||||
PORTAL_HOST = os.getenv("QQ_PORTAL_HOST", "q.qq.com")
|
||||
|
||||
API_BASE = "https://api.sgroup.qq.com"
|
||||
TOKEN_URL = "https://bots.qq.com/app/getAppAccessToken"
|
||||
GATEWAY_URL_PATH = "/gateway"
|
||||
|
||||
# QR-code onboard endpoints (on the portal host)
|
||||
ONBOARD_CREATE_PATH = "/lite/create_bind_task"
|
||||
ONBOARD_POLL_PATH = "/lite/poll_bind_result"
|
||||
QR_URL_TEMPLATE = (
|
||||
"https://q.qq.com/qqbot/openclaw/connect.html"
|
||||
"?task_id={task_id}&_wv=2&source=hermes"
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Timeouts & retry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DEFAULT_API_TIMEOUT = 30.0
|
||||
FILE_UPLOAD_TIMEOUT = 120.0
|
||||
CONNECT_TIMEOUT_SECONDS = 20.0
|
||||
|
||||
RECONNECT_BACKOFF = [2, 5, 10, 30, 60]
|
||||
MAX_RECONNECT_ATTEMPTS = 100
|
||||
RATE_LIMIT_DELAY = 60 # seconds
|
||||
QUICK_DISCONNECT_THRESHOLD = 5.0 # seconds
|
||||
MAX_QUICK_DISCONNECT_COUNT = 3
|
||||
|
||||
ONBOARD_POLL_INTERVAL = 2.0 # seconds between poll_bind_result calls
|
||||
ONBOARD_API_TIMEOUT = 10.0
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message limits
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
MAX_MESSAGE_LENGTH = 4000
|
||||
DEDUP_WINDOW_SECONDS = 300
|
||||
DEDUP_MAX_SIZE = 1000
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# QQ Bot message types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
MSG_TYPE_TEXT = 0
|
||||
MSG_TYPE_MARKDOWN = 2
|
||||
MSG_TYPE_MEDIA = 7
|
||||
MSG_TYPE_INPUT_NOTIFY = 6
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# QQ Bot file media types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
MEDIA_TYPE_IMAGE = 1
|
||||
MEDIA_TYPE_VIDEO = 2
|
||||
MEDIA_TYPE_VOICE = 3
|
||||
MEDIA_TYPE_FILE = 4
|
||||
@@ -0,0 +1,45 @@
|
||||
"""AES-256-GCM utilities for QQBot scan-to-configure credential decryption."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import os
|
||||
|
||||
|
||||
def generate_bind_key() -> str:
|
||||
"""Generate a 256-bit random AES key and return it as base64.
|
||||
|
||||
The key is passed to ``create_bind_task`` so the server can encrypt
|
||||
the bot's *client_secret* before returning it. Only this CLI holds
|
||||
the key, ensuring the secret never travels in plaintext.
|
||||
"""
|
||||
return base64.b64encode(os.urandom(32)).decode()
|
||||
|
||||
|
||||
def decrypt_secret(encrypted_base64: str, key_base64: str) -> str:
|
||||
"""Decrypt a base64-encoded AES-256-GCM ciphertext.
|
||||
|
||||
Ciphertext layout (after base64-decoding)::
|
||||
|
||||
IV (12 bytes) ‖ ciphertext (N bytes) ‖ AuthTag (16 bytes)
|
||||
|
||||
Args:
|
||||
encrypted_base64: The ``bot_encrypt_secret`` value from
|
||||
``poll_bind_result``.
|
||||
key_base64: The base64 AES key generated by
|
||||
:func:`generate_bind_key`.
|
||||
|
||||
Returns:
|
||||
The decrypted *client_secret* as a UTF-8 string.
|
||||
"""
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
|
||||
key = base64.b64decode(key_base64)
|
||||
raw = base64.b64decode(encrypted_base64)
|
||||
|
||||
iv = raw[:12]
|
||||
ciphertext_with_tag = raw[12:] # AESGCM expects ciphertext + tag concatenated
|
||||
|
||||
aesgcm = AESGCM(key)
|
||||
plaintext = aesgcm.decrypt(iv, ciphertext_with_tag, None)
|
||||
return plaintext.decode("utf-8")
|
||||
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
QQBot scan-to-configure (QR code onboard) module.
|
||||
|
||||
Calls the ``q.qq.com`` ``create_bind_task`` / ``poll_bind_result`` APIs to
|
||||
generate a QR-code URL and poll for scan completion. On success the caller
|
||||
receives the bot's *app_id*, *client_secret* (decrypted locally), and the
|
||||
scanner's *user_openid* — enough to fully configure the QQBot gateway.
|
||||
|
||||
Reference: https://bot.q.qq.com/wiki/develop/api-v2/
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from enum import IntEnum
|
||||
from typing import Tuple
|
||||
from urllib.parse import quote
|
||||
|
||||
from .constants import (
|
||||
ONBOARD_API_TIMEOUT,
|
||||
ONBOARD_CREATE_PATH,
|
||||
ONBOARD_POLL_PATH,
|
||||
PORTAL_HOST,
|
||||
QR_URL_TEMPLATE,
|
||||
)
|
||||
from .crypto import generate_bind_key
|
||||
from .utils import get_api_headers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bind status
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class BindStatus(IntEnum):
|
||||
"""Status codes returned by ``poll_bind_result``."""
|
||||
|
||||
NONE = 0
|
||||
PENDING = 1
|
||||
COMPLETED = 2
|
||||
EXPIRED = 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def create_bind_task(
|
||||
timeout: float = ONBOARD_API_TIMEOUT,
|
||||
) -> Tuple[str, str]:
|
||||
"""Create a bind task and return *(task_id, aes_key_base64)*.
|
||||
|
||||
The AES key is generated locally and sent to the server so it can
|
||||
encrypt the bot credentials before returning them.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the API returns a non-zero ``retcode``.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
url = f"https://{PORTAL_HOST}{ONBOARD_CREATE_PATH}"
|
||||
key = generate_bind_key()
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, follow_redirects=True) as client:
|
||||
resp = await client.post(url, json={"key": key}, headers=get_api_headers())
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
if data.get("retcode") != 0:
|
||||
raise RuntimeError(data.get("msg", "create_bind_task failed"))
|
||||
|
||||
task_id = data.get("data", {}).get("task_id")
|
||||
if not task_id:
|
||||
raise RuntimeError("create_bind_task: missing task_id in response")
|
||||
|
||||
logger.debug("create_bind_task ok: task_id=%s", task_id)
|
||||
return task_id, key
|
||||
|
||||
|
||||
async def poll_bind_result(
|
||||
task_id: str,
|
||||
timeout: float = ONBOARD_API_TIMEOUT,
|
||||
) -> Tuple[BindStatus, str, str, str]:
|
||||
"""Poll the bind result for *task_id*.
|
||||
|
||||
Returns:
|
||||
A 4-tuple of ``(status, bot_appid, bot_encrypt_secret, user_openid)``.
|
||||
|
||||
* ``bot_encrypt_secret`` is AES-256-GCM encrypted — decrypt it with
|
||||
:func:`~gateway.platforms.qqbot.crypto.decrypt_secret` using the
|
||||
key from :func:`create_bind_task`.
|
||||
* ``user_openid`` is the OpenID of the person who scanned the code
|
||||
(available when ``status == COMPLETED``).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the API returns a non-zero ``retcode``.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
url = f"https://{PORTAL_HOST}{ONBOARD_POLL_PATH}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, follow_redirects=True) as client:
|
||||
resp = await client.post(url, json={"task_id": task_id}, headers=get_api_headers())
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
if data.get("retcode") != 0:
|
||||
raise RuntimeError(data.get("msg", "poll_bind_result failed"))
|
||||
|
||||
d = data.get("data", {})
|
||||
return (
|
||||
BindStatus(d.get("status", 0)),
|
||||
str(d.get("bot_appid", "")),
|
||||
d.get("bot_encrypt_secret", ""),
|
||||
d.get("user_openid", ""),
|
||||
)
|
||||
|
||||
|
||||
def build_connect_url(task_id: str) -> str:
|
||||
"""Build the QR-code target URL for a given *task_id*."""
|
||||
return QR_URL_TEMPLATE.format(task_id=quote(task_id))
|
||||
@@ -0,0 +1,71 @@
|
||||
"""QQBot shared utilities — User-Agent, HTTP helpers, config coercion."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import platform
|
||||
import sys
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from .constants import QQBOT_VERSION
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User-Agent
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _get_hermes_version() -> str:
|
||||
"""Return the hermes-agent package version, or 'dev' if unavailable."""
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
return version("hermes-agent")
|
||||
except Exception:
|
||||
return "dev"
|
||||
|
||||
|
||||
def build_user_agent() -> str:
|
||||
"""Build a descriptive User-Agent string.
|
||||
|
||||
Format::
|
||||
|
||||
QQBotAdapter/<qqbot_version> (Python/<py_version>; <os>; Hermes/<hermes_version>)
|
||||
|
||||
Example::
|
||||
|
||||
QQBotAdapter/1.0.0 (Python/3.11.15; darwin; Hermes/0.9.0)
|
||||
"""
|
||||
py_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
|
||||
os_name = platform.system().lower()
|
||||
hermes_version = _get_hermes_version()
|
||||
return f"QQBotAdapter/{QQBOT_VERSION} (Python/{py_version}; {os_name}; Hermes/{hermes_version})"
|
||||
|
||||
|
||||
def get_api_headers() -> Dict[str, str]:
|
||||
"""Return standard HTTP headers for QQBot API requests.
|
||||
|
||||
Includes ``Content-Type``, ``Accept``, and a dynamic ``User-Agent``.
|
||||
``q.qq.com`` requires ``Accept: application/json`` — without it,
|
||||
the server returns a JavaScript anti-bot challenge page.
|
||||
"""
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"User-Agent": build_user_agent(),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def coerce_list(value: Any) -> List[str]:
|
||||
"""Coerce config values into a trimmed string list.
|
||||
|
||||
Accepts comma-separated strings, lists, tuples, sets, or single values.
|
||||
"""
|
||||
if value is None:
|
||||
return []
|
||||
if isinstance(value, str):
|
||||
return [item.strip() for item in value.split(",") if item.strip()]
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
return [str(item).strip() for item in value if str(item).strip()]
|
||||
return [str(value).strip()] if str(value).strip() else []
|
||||
@@ -160,6 +160,14 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
self._sse_task: Optional[asyncio.Task] = None
|
||||
self._health_monitor_task: Optional[asyncio.Task] = None
|
||||
self._typing_tasks: Dict[str, asyncio.Task] = {}
|
||||
# Per-chat typing-indicator backoff. When signal-cli reports
|
||||
# NETWORK_FAILURE (recipient offline / unroutable), base.py's
|
||||
# _keep_typing refresh loop would otherwise hammer sendTyping every
|
||||
# ~2s indefinitely, producing WARNING-level log spam and pointless
|
||||
# RPC traffic. We track consecutive failures per chat and skip the
|
||||
# RPC during a cooldown window instead.
|
||||
self._typing_failures: Dict[str, int] = {}
|
||||
self._typing_skip_until: Dict[str, float] = {}
|
||||
self._running = False
|
||||
self._last_sse_activity = 0.0
|
||||
self._sse_response: Optional[httpx.Response] = None
|
||||
@@ -548,8 +556,22 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
# JSON-RPC Communication
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _rpc(self, method: str, params: dict, rpc_id: str = None) -> Any:
|
||||
"""Send a JSON-RPC 2.0 request to signal-cli daemon."""
|
||||
async def _rpc(
|
||||
self,
|
||||
method: str,
|
||||
params: dict,
|
||||
rpc_id: str = None,
|
||||
*,
|
||||
log_failures: bool = True,
|
||||
) -> Any:
|
||||
"""Send a JSON-RPC 2.0 request to signal-cli daemon.
|
||||
|
||||
When ``log_failures=False``, error and exception paths log at DEBUG
|
||||
instead of WARNING — used by the typing-indicator path to silence
|
||||
repeated NETWORK_FAILURE spam for unreachable recipients while
|
||||
still preserving visibility for the first occurrence and for
|
||||
unrelated RPCs.
|
||||
"""
|
||||
if not self.client:
|
||||
logger.warning("Signal: RPC called but client not connected")
|
||||
return None
|
||||
@@ -574,13 +596,19 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
data = resp.json()
|
||||
|
||||
if "error" in data:
|
||||
logger.warning("Signal RPC error (%s): %s", method, data["error"])
|
||||
if log_failures:
|
||||
logger.warning("Signal RPC error (%s): %s", method, data["error"])
|
||||
else:
|
||||
logger.debug("Signal RPC error (%s): %s", method, data["error"])
|
||||
return None
|
||||
|
||||
return data.get("result")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Signal RPC %s failed: %s", method, e)
|
||||
if log_failures:
|
||||
logger.warning("Signal RPC %s failed: %s", method, e)
|
||||
else:
|
||||
logger.debug("Signal RPC %s failed: %s", method, e)
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@@ -627,7 +655,28 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
self._recent_sent_timestamps.pop()
|
||||
|
||||
async def send_typing(self, chat_id: str, metadata=None) -> None:
|
||||
"""Send a typing indicator."""
|
||||
"""Send a typing indicator.
|
||||
|
||||
base.py's ``_keep_typing`` refresh loop calls this every ~2s while
|
||||
the agent is processing. If signal-cli returns NETWORK_FAILURE for
|
||||
this recipient (offline, unroutable, group membership lost, etc.)
|
||||
the unmitigated behaviour is: a WARNING log every 2 seconds for as
|
||||
long as the agent keeps running. Instead we:
|
||||
|
||||
- silence the WARNING after the first consecutive failure (subsequent
|
||||
attempts log at DEBUG) so transport issues are still visible once
|
||||
but don't flood the log,
|
||||
- skip the RPC entirely during an exponential cooldown window once
|
||||
three consecutive failures have happened, so we stop hammering
|
||||
signal-cli with requests it can't deliver.
|
||||
|
||||
A successful sendTyping clears the counters.
|
||||
"""
|
||||
now = time.monotonic()
|
||||
skip_until = self._typing_skip_until.get(chat_id, 0.0)
|
||||
if now < skip_until:
|
||||
return
|
||||
|
||||
params: Dict[str, Any] = {
|
||||
"account": self.account,
|
||||
}
|
||||
@@ -637,7 +686,26 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
else:
|
||||
params["recipient"] = [chat_id]
|
||||
|
||||
await self._rpc("sendTyping", params, rpc_id="typing")
|
||||
fails = self._typing_failures.get(chat_id, 0)
|
||||
result = await self._rpc(
|
||||
"sendTyping",
|
||||
params,
|
||||
rpc_id="typing",
|
||||
log_failures=(fails == 0),
|
||||
)
|
||||
|
||||
if result is None:
|
||||
fails += 1
|
||||
self._typing_failures[chat_id] = fails
|
||||
# After 3 consecutive failures, back off exponentially (16s,
|
||||
# 32s, 60s cap) to stop spamming signal-cli for a recipient
|
||||
# that clearly isn't reachable right now.
|
||||
if fails >= 3:
|
||||
backoff = min(60.0, 16.0 * (2 ** (fails - 3)))
|
||||
self._typing_skip_until[chat_id] = now + backoff
|
||||
else:
|
||||
self._typing_failures.pop(chat_id, None)
|
||||
self._typing_skip_until.pop(chat_id, None)
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
@@ -789,6 +857,10 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
# Reset per-chat typing backoff state so the next agent turn starts
|
||||
# fresh rather than inheriting a cooldown from a prior conversation.
|
||||
self._typing_failures.pop(chat_id, None)
|
||||
self._typing_skip_until.pop(chat_id, None)
|
||||
|
||||
async def stop_typing(self, chat_id: str) -> None:
|
||||
"""Public interface for stopping typing — called by base adapter's
|
||||
|
||||
@@ -118,6 +118,84 @@ def _strip_mdv2(text: str) -> str:
|
||||
return cleaned
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Markdown table → code block conversion
|
||||
# ---------------------------------------------------------------------------
|
||||
# Telegram's MarkdownV2 has no table syntax — '|' is just an escaped literal,
|
||||
# so pipe tables render as noisy backslash-pipe text with no alignment.
|
||||
# Wrapping the table in a fenced code block makes Telegram render it as
|
||||
# monospace preformatted text with columns intact.
|
||||
|
||||
# Matches a GFM table delimiter row: optional outer pipes, cells containing
|
||||
# only dashes (with optional leading/trailing colons for alignment) separated
|
||||
# by '|'. Requires at least one internal '|' so lone '---' horizontal rules
|
||||
# are NOT matched.
|
||||
_TABLE_SEPARATOR_RE = re.compile(
|
||||
r'^\s*\|?\s*:?-+:?\s*(?:\|\s*:?-+:?\s*){1,}\|?\s*$'
|
||||
)
|
||||
|
||||
|
||||
def _is_table_row(line: str) -> bool:
|
||||
"""Return True if *line* could plausibly be a table data row."""
|
||||
stripped = line.strip()
|
||||
return bool(stripped) and '|' in stripped
|
||||
|
||||
|
||||
def _wrap_markdown_tables(text: str) -> str:
|
||||
"""Wrap GFM-style pipe tables in ``` fences so Telegram renders them.
|
||||
|
||||
Detected by a row containing '|' immediately followed by a delimiter
|
||||
row matching :data:`_TABLE_SEPARATOR_RE`. Subsequent pipe-containing
|
||||
non-blank lines are consumed as the table body and included in the
|
||||
wrapped block. Tables inside existing fenced code blocks are left
|
||||
alone.
|
||||
"""
|
||||
if '|' not in text or '-' not in text:
|
||||
return text
|
||||
|
||||
lines = text.split('\n')
|
||||
out: list[str] = []
|
||||
in_fence = False
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
stripped = line.lstrip()
|
||||
|
||||
# Track existing fenced code blocks — never touch content inside.
|
||||
if stripped.startswith('```'):
|
||||
in_fence = not in_fence
|
||||
out.append(line)
|
||||
i += 1
|
||||
continue
|
||||
if in_fence:
|
||||
out.append(line)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Look for a header row (contains '|') immediately followed by a
|
||||
# delimiter row.
|
||||
if (
|
||||
'|' in line
|
||||
and i + 1 < len(lines)
|
||||
and _TABLE_SEPARATOR_RE.match(lines[i + 1])
|
||||
):
|
||||
table_block = [line, lines[i + 1]]
|
||||
j = i + 2
|
||||
while j < len(lines) and _is_table_row(lines[j]):
|
||||
table_block.append(lines[j])
|
||||
j += 1
|
||||
out.append('```')
|
||||
out.extend(table_block)
|
||||
out.append('```')
|
||||
i = j
|
||||
continue
|
||||
|
||||
out.append(line)
|
||||
i += 1
|
||||
|
||||
return '\n'.join(out)
|
||||
|
||||
|
||||
class TelegramAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
Telegram bot adapter.
|
||||
@@ -1916,6 +1994,12 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
|
||||
text = content
|
||||
|
||||
# 0) Pre-wrap GFM-style pipe tables in ``` fences. Telegram can't
|
||||
# render tables natively, but fenced code blocks render as
|
||||
# monospace preformatted text with columns intact. The wrapped
|
||||
# tables then flow through step (1) below as protected regions.
|
||||
text = _wrap_markdown_tables(text)
|
||||
|
||||
# 1) Protect fenced code blocks (``` ... ```)
|
||||
# Per MarkdownV2 spec, \ and ` inside pre/code must be escaped.
|
||||
def _protect_fenced(m):
|
||||
@@ -2242,7 +2326,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
if not self._should_process_message(update.message):
|
||||
return
|
||||
|
||||
event = self._build_message_event(update.message, MessageType.TEXT)
|
||||
event = self._build_message_event(update.message, MessageType.TEXT, update_id=update.update_id)
|
||||
event.text = self._clean_bot_trigger_text(event.text)
|
||||
self._enqueue_text_event(event)
|
||||
|
||||
@@ -2253,7 +2337,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
if not self._should_process_message(update.message, is_command=True):
|
||||
return
|
||||
|
||||
event = self._build_message_event(update.message, MessageType.COMMAND)
|
||||
event = self._build_message_event(update.message, MessageType.COMMAND, update_id=update.update_id)
|
||||
await self.handle_message(event)
|
||||
|
||||
async def _handle_location_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
@@ -2289,7 +2373,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
parts.append(f"Map: https://www.google.com/maps/search/?api=1&query={lat},{lon}")
|
||||
parts.append("Ask what they'd like to find nearby (restaurants, cafes, etc.) and any preferences.")
|
||||
|
||||
event = self._build_message_event(msg, MessageType.LOCATION)
|
||||
event = self._build_message_event(msg, MessageType.LOCATION, update_id=update.update_id)
|
||||
event.text = "\n".join(parts)
|
||||
await self.handle_message(event)
|
||||
|
||||
@@ -2440,7 +2524,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
else:
|
||||
msg_type = MessageType.DOCUMENT
|
||||
|
||||
event = self._build_message_event(msg, msg_type)
|
||||
event = self._build_message_event(msg, msg_type, update_id=update.update_id)
|
||||
|
||||
# Add caption as text
|
||||
if msg.caption:
|
||||
@@ -2779,8 +2863,19 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
self.name, cache_key, thread_id,
|
||||
)
|
||||
|
||||
def _build_message_event(self, message: Message, msg_type: MessageType) -> MessageEvent:
|
||||
"""Build a MessageEvent from a Telegram message."""
|
||||
def _build_message_event(
|
||||
self,
|
||||
message: Message,
|
||||
msg_type: MessageType,
|
||||
update_id: Optional[int] = None,
|
||||
) -> MessageEvent:
|
||||
"""Build a MessageEvent from a Telegram message.
|
||||
|
||||
``update_id`` is the ``Update.update_id`` from PTB; passing it through
|
||||
lets ``/restart`` record the triggering offset so the new gateway
|
||||
process can advance past it (prevents ``/restart`` being re-delivered
|
||||
when PTB's graceful-shutdown ACK fails).
|
||||
"""
|
||||
chat = message.chat
|
||||
user = message.from_user
|
||||
|
||||
@@ -2859,6 +2954,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
source=source,
|
||||
raw_message=message,
|
||||
message_id=str(message.message_id),
|
||||
platform_update_id=update_id,
|
||||
reply_to_message_id=reply_to_id,
|
||||
reply_to_text=reply_to_text,
|
||||
auto_skill=topic_skill,
|
||||
|
||||
+41
-10
@@ -180,6 +180,8 @@ class WeComAdapter(BasePlatformAdapter):
|
||||
self._text_batch_split_delay_seconds = float(os.getenv("HERMES_WECOM_TEXT_BATCH_SPLIT_DELAY_SECONDS", "2.0"))
|
||||
self._pending_text_batches: Dict[str, MessageEvent] = {}
|
||||
self._pending_text_batch_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._device_id = uuid.uuid4().hex
|
||||
self._last_chat_req_ids: Dict[str, str] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Connection lifecycle
|
||||
@@ -277,7 +279,11 @@ class WeComAdapter(BasePlatformAdapter):
|
||||
{
|
||||
"cmd": APP_CMD_SUBSCRIBE,
|
||||
"headers": {"req_id": req_id},
|
||||
"body": {"bot_id": self._bot_id, "secret": self._secret},
|
||||
"body": {
|
||||
"bot_id": self._bot_id,
|
||||
"secret": self._secret,
|
||||
"device_id": self._device_id,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -496,6 +502,11 @@ class WeComAdapter(BasePlatformAdapter):
|
||||
logger.debug("[%s] DM sender %s blocked by policy", self.name, sender_id)
|
||||
return
|
||||
|
||||
# Cache the inbound req_id after policy checks so proactive sends to
|
||||
# this chat can fall back to APP_CMD_RESPONSE (required for groups —
|
||||
# WeCom AI Bots cannot initiate APP_CMD_SEND in group chats).
|
||||
self._remember_chat_req_id(chat_id, self._payload_req_id(payload))
|
||||
|
||||
text, reply_text = self._extract_text(body)
|
||||
media_urls, media_types = await self._extract_media(body)
|
||||
message_type = self._derive_message_type(body, text, media_types)
|
||||
@@ -847,6 +858,23 @@ class WeComAdapter(BasePlatformAdapter):
|
||||
while len(self._reply_req_ids) > DEDUP_MAX_SIZE:
|
||||
self._reply_req_ids.pop(next(iter(self._reply_req_ids)))
|
||||
|
||||
def _remember_chat_req_id(self, chat_id: str, req_id: str) -> None:
|
||||
"""Cache the most recent inbound req_id per chat.
|
||||
|
||||
Used as a fallback reply target when we need to send into a group
|
||||
without an explicit ``reply_to`` — WeCom AI Bots are blocked from
|
||||
APP_CMD_SEND in groups and must use APP_CMD_RESPONSE bound to some
|
||||
prior req_id. Bounded like _reply_req_ids so long-running gateways
|
||||
don't leak memory across many chats.
|
||||
"""
|
||||
normalized_chat_id = str(chat_id or "").strip()
|
||||
normalized_req_id = str(req_id or "").strip()
|
||||
if not normalized_chat_id or not normalized_req_id:
|
||||
return
|
||||
self._last_chat_req_ids[normalized_chat_id] = normalized_req_id
|
||||
while len(self._last_chat_req_ids) > DEDUP_MAX_SIZE:
|
||||
self._last_chat_req_ids.pop(next(iter(self._last_chat_req_ids)))
|
||||
|
||||
def _reply_req_id_for_message(self, reply_to: Optional[str]) -> Optional[str]:
|
||||
normalized = str(reply_to or "").strip()
|
||||
if not normalized or normalized.startswith("quote:"):
|
||||
@@ -1163,19 +1191,15 @@ class WeComAdapter(BasePlatformAdapter):
|
||||
self._raise_for_wecom_error(response, "send media message")
|
||||
return response
|
||||
|
||||
async def _send_reply_stream(self, reply_req_id: str, content: str) -> Dict[str, Any]:
|
||||
async def _send_reply_markdown(self, reply_req_id: str, content: str) -> Dict[str, Any]:
|
||||
response = await self._send_reply_request(
|
||||
reply_req_id,
|
||||
{
|
||||
"msgtype": "stream",
|
||||
"stream": {
|
||||
"id": self._new_req_id("stream"),
|
||||
"finish": True,
|
||||
"content": content[:self.MAX_MESSAGE_LENGTH],
|
||||
},
|
||||
"msgtype": "markdown",
|
||||
"markdown": {"content": content[:self.MAX_MESSAGE_LENGTH]},
|
||||
},
|
||||
)
|
||||
self._raise_for_wecom_error(response, "send reply stream")
|
||||
self._raise_for_wecom_error(response, "send reply markdown")
|
||||
return response
|
||||
|
||||
async def _send_reply_media_message(
|
||||
@@ -1235,6 +1259,9 @@ class WeComAdapter(BasePlatformAdapter):
|
||||
return SendResult(success=False, error=prepared["reject_reason"])
|
||||
|
||||
reply_req_id = self._reply_req_id_for_message(reply_to)
|
||||
if not reply_req_id and chat_id in self._last_chat_req_ids:
|
||||
reply_req_id = self._last_chat_req_ids[chat_id]
|
||||
|
||||
try:
|
||||
upload_result = await self._upload_media_bytes(
|
||||
prepared["data"],
|
||||
@@ -1302,8 +1329,12 @@ class WeComAdapter(BasePlatformAdapter):
|
||||
|
||||
try:
|
||||
reply_req_id = self._reply_req_id_for_message(reply_to)
|
||||
|
||||
if not reply_req_id and chat_id in self._last_chat_req_ids:
|
||||
reply_req_id = self._last_chat_req_ids[chat_id]
|
||||
|
||||
if reply_req_id:
|
||||
response = await self._send_reply_stream(reply_req_id, content)
|
||||
response = await self._send_reply_markdown(reply_req_id, content)
|
||||
else:
|
||||
response = await self._send_request(
|
||||
APP_CMD_SEND,
|
||||
|
||||
+310
-86
@@ -28,7 +28,7 @@ import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import quote, urlparse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -96,6 +96,28 @@ MEDIA_VIDEO = 2
|
||||
MEDIA_FILE = 3
|
||||
MEDIA_VOICE = 4
|
||||
|
||||
_LIVE_ADAPTERS: Dict[str, Any] = {}
|
||||
|
||||
|
||||
def _make_ssl_connector() -> Optional["aiohttp.TCPConnector"]:
|
||||
"""Return a TCPConnector with a certifi CA bundle, or None if certifi is unavailable.
|
||||
|
||||
Tencent's iLink server (``ilinkai.weixin.qq.com``) is not verifiable against
|
||||
some system CA stores (notably Homebrew's OpenSSL on macOS Apple Silicon).
|
||||
When ``certifi`` is installed, use its Mozilla CA bundle to guarantee
|
||||
verification. Otherwise fall back to aiohttp's default (which honors
|
||||
``SSL_CERT_FILE`` env var via ``trust_env=True``).
|
||||
"""
|
||||
try:
|
||||
import ssl
|
||||
import certifi
|
||||
except ImportError:
|
||||
return None
|
||||
if not AIOHTTP_AVAILABLE:
|
||||
return None
|
||||
ssl_ctx = ssl.create_default_context(cafile=certifi.where())
|
||||
return aiohttp.TCPConnector(ssl=ssl_ctx)
|
||||
|
||||
ITEM_TEXT = 1
|
||||
ITEM_IMAGE = 2
|
||||
ITEM_VOICE = 3
|
||||
@@ -398,7 +420,12 @@ async def _send_message(
|
||||
text: str,
|
||||
context_token: Optional[str],
|
||||
client_id: str,
|
||||
) -> None:
|
||||
) -> Dict[str, Any]:
|
||||
"""Send a text message via iLink sendmessage API.
|
||||
|
||||
Returns the raw API response dict (may contain error codes like
|
||||
``errcode: -14`` for session expiry that the caller can inspect).
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
raise ValueError("_send_message: text must not be empty")
|
||||
message: Dict[str, Any] = {
|
||||
@@ -411,7 +438,7 @@ async def _send_message(
|
||||
}
|
||||
if context_token:
|
||||
message["context_token"] = context_token
|
||||
await _api_post(
|
||||
return await _api_post(
|
||||
session,
|
||||
base_url=base_url,
|
||||
endpoint=EP_SEND_MESSAGE,
|
||||
@@ -533,6 +560,39 @@ async def _download_bytes(
|
||||
return await response.read()
|
||||
|
||||
|
||||
_WEIXIN_CDN_ALLOWLIST: frozenset[str] = frozenset(
|
||||
{
|
||||
"novac2c.cdn.weixin.qq.com",
|
||||
"ilinkai.weixin.qq.com",
|
||||
"wx.qlogo.cn",
|
||||
"thirdwx.qlogo.cn",
|
||||
"res.wx.qq.com",
|
||||
"mmbiz.qpic.cn",
|
||||
"mmbiz.qlogo.cn",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _assert_weixin_cdn_url(url: str) -> None:
|
||||
"""Raise ValueError if *url* does not point at a known WeChat CDN host."""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
scheme = parsed.scheme.lower()
|
||||
host = parsed.hostname or ""
|
||||
except Exception as exc: # noqa: BLE001
|
||||
raise ValueError(f"Unparseable media URL: {url!r}") from exc
|
||||
|
||||
if scheme not in ("http", "https"):
|
||||
raise ValueError(
|
||||
f"Media URL has disallowed scheme {scheme!r}; only http/https are permitted."
|
||||
)
|
||||
if host not in _WEIXIN_CDN_ALLOWLIST:
|
||||
raise ValueError(
|
||||
f"Media URL host {host!r} is not in the WeChat CDN allowlist. "
|
||||
"Refusing to fetch to prevent SSRF."
|
||||
)
|
||||
|
||||
|
||||
def _media_reference(item: Dict[str, Any], key: str) -> Dict[str, Any]:
|
||||
return (item.get(key) or {}).get("media") or {}
|
||||
|
||||
@@ -553,6 +613,7 @@ async def _download_and_decrypt_media(
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
elif full_url:
|
||||
_assert_weixin_cdn_url(full_url)
|
||||
raw = await _download_bytes(session, url=full_url, timeout_seconds=timeout_seconds)
|
||||
else:
|
||||
raise RuntimeError("media item had neither encrypt_query_param nor full_url")
|
||||
@@ -623,42 +684,31 @@ def _rewrite_table_block_for_weixin(lines: List[str]) -> str:
|
||||
def _normalize_markdown_blocks(content: str) -> str:
|
||||
lines = content.splitlines()
|
||||
result: List[str] = []
|
||||
i = 0
|
||||
in_code_block = False
|
||||
blank_run = 0
|
||||
|
||||
while i < len(lines):
|
||||
line = lines[i].rstrip()
|
||||
fence_match = _FENCE_RE.match(line.strip())
|
||||
if fence_match:
|
||||
for raw_line in lines:
|
||||
line = raw_line.rstrip()
|
||||
if _FENCE_RE.match(line.strip()):
|
||||
in_code_block = not in_code_block
|
||||
result.append(line)
|
||||
i += 1
|
||||
blank_run = 0
|
||||
continue
|
||||
|
||||
if in_code_block:
|
||||
result.append(line)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if (
|
||||
i + 1 < len(lines)
|
||||
and "|" in lines[i]
|
||||
and _TABLE_RULE_RE.match(lines[i + 1].rstrip())
|
||||
):
|
||||
table_lines = [lines[i].rstrip(), lines[i + 1].rstrip()]
|
||||
i += 2
|
||||
while i < len(lines) and "|" in lines[i]:
|
||||
table_lines.append(lines[i].rstrip())
|
||||
i += 1
|
||||
result.append(_rewrite_table_block_for_weixin(table_lines))
|
||||
if not line.strip():
|
||||
blank_run += 1
|
||||
if blank_run <= 1:
|
||||
result.append("")
|
||||
continue
|
||||
|
||||
result.append(_MARKDOWN_LINK_RE.sub(r"\1 (\2)", _rewrite_headers_for_weixin(line)))
|
||||
i += 1
|
||||
blank_run = 0
|
||||
result.append(line)
|
||||
|
||||
normalized = "\n".join(item.rstrip() for item in result)
|
||||
normalized = re.sub(r"\n{3,}", "\n\n", normalized)
|
||||
return normalized.strip()
|
||||
return "\n".join(result).strip()
|
||||
|
||||
|
||||
def _split_markdown_blocks(content: str) -> List[str]:
|
||||
@@ -704,8 +754,8 @@ def _split_delivery_units_for_weixin(content: str) -> List[str]:
|
||||
|
||||
Weixin can render Markdown, but chat readability is better when top-level
|
||||
line breaks become separate messages. Keep fenced code blocks intact and
|
||||
attach indented continuation lines to the previous top-level line so
|
||||
transformed tables/lists do not get torn apart.
|
||||
attach indented continuation lines to the previous top-level line so nested
|
||||
list items do not get torn apart.
|
||||
"""
|
||||
units: List[str] = []
|
||||
|
||||
@@ -747,7 +797,9 @@ def _looks_like_chatty_line_for_weixin(line: str) -> bool:
|
||||
return False
|
||||
if line.startswith((" ", "\t")):
|
||||
return False
|
||||
if stripped.startswith((">", "-", "*", "【")):
|
||||
if stripped.startswith((">", "-", "*", "【", "#", "|")):
|
||||
return False
|
||||
if _TABLE_RULE_RE.match(stripped):
|
||||
return False
|
||||
if re.match(r"^\*\*[^*]+\*\*$", stripped):
|
||||
return False
|
||||
@@ -757,10 +809,12 @@ def _looks_like_chatty_line_for_weixin(line: str) -> bool:
|
||||
|
||||
|
||||
def _looks_like_heading_line_for_weixin(line: str) -> bool:
|
||||
"""Return True when a short line behaves like a plain-text heading."""
|
||||
"""Return True when a short line behaves like a heading."""
|
||||
stripped = line.strip()
|
||||
if not stripped:
|
||||
return False
|
||||
if _HEADER_RE.match(stripped):
|
||||
return True
|
||||
return len(stripped) <= 24 and stripped.endswith((":", ":"))
|
||||
|
||||
|
||||
@@ -935,7 +989,7 @@ async def qr_login(
|
||||
if not AIOHTTP_AVAILABLE:
|
||||
raise RuntimeError("aiohttp is required for Weixin QR login")
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with aiohttp.ClientSession(trust_env=True, connector=_make_ssl_connector()) as session:
|
||||
try:
|
||||
qr_resp = await _api_get(
|
||||
session,
|
||||
@@ -953,6 +1007,10 @@ async def qr_login(
|
||||
logger.error("weixin: QR response missing qrcode")
|
||||
return None
|
||||
|
||||
# qrcode_url is the full scannable liteapp URL; qrcode_value is just the hex token
|
||||
# WeChat needs to scan the full URL, not the raw hex string
|
||||
qr_scan_data = qrcode_url if qrcode_url else qrcode_value
|
||||
|
||||
print("\n请使用微信扫描以下二维码:")
|
||||
if qrcode_url:
|
||||
print(qrcode_url)
|
||||
@@ -960,11 +1018,11 @@ async def qr_login(
|
||||
import qrcode
|
||||
|
||||
qr = qrcode.QRCode()
|
||||
qr.add_data(qrcode_url or qrcode_value)
|
||||
qr.add_data(qr_scan_data)
|
||||
qr.make(fit=True)
|
||||
qr.print_ascii(invert=True)
|
||||
except Exception:
|
||||
print("(终端二维码渲染失败,请直接打开上面的二维码链接)")
|
||||
except Exception as _qr_exc:
|
||||
print(f"(终端二维码渲染失败: {_qr_exc},请直接打开上面的二维码链接)")
|
||||
|
||||
deadline = time.time() + timeout_seconds
|
||||
current_base_url = ILINK_BASE_URL
|
||||
@@ -1010,8 +1068,17 @@ async def qr_login(
|
||||
)
|
||||
qrcode_value = str(qr_resp.get("qrcode") or "")
|
||||
qrcode_url = str(qr_resp.get("qrcode_img_content") or "")
|
||||
qr_scan_data = qrcode_url if qrcode_url else qrcode_value
|
||||
if qrcode_url:
|
||||
print(qrcode_url)
|
||||
try:
|
||||
import qrcode as _qrcode
|
||||
qr = _qrcode.QRCode()
|
||||
qr.add_data(qr_scan_data)
|
||||
qr.make(fit=True)
|
||||
qr.print_ascii(invert=True)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.error("weixin: QR refresh failed: %s", exc)
|
||||
return None
|
||||
@@ -1059,7 +1126,8 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
self._hermes_home = hermes_home
|
||||
self._token_store = ContextTokenStore(hermes_home)
|
||||
self._typing_cache = TypingTicketCache()
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._poll_session: Optional[aiohttp.ClientSession] = None
|
||||
self._send_session: Optional[aiohttp.ClientSession] = None
|
||||
self._poll_task: Optional[asyncio.Task] = None
|
||||
self._dedup = MessageDeduplicator(ttl_seconds=MESSAGE_DEDUP_TTL_SECONDS)
|
||||
|
||||
@@ -1134,14 +1202,17 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
except Exception as exc:
|
||||
logger.debug("[%s] Token lock unavailable (non-fatal): %s", self.name, exc)
|
||||
|
||||
self._session = aiohttp.ClientSession(trust_env=True)
|
||||
self._poll_session = aiohttp.ClientSession(trust_env=True, connector=_make_ssl_connector())
|
||||
self._send_session = aiohttp.ClientSession(trust_env=True, connector=_make_ssl_connector())
|
||||
self._token_store.restore(self._account_id)
|
||||
self._poll_task = asyncio.create_task(self._poll_loop(), name="weixin-poll")
|
||||
self._mark_connected()
|
||||
_LIVE_ADAPTERS[self._token] = self
|
||||
logger.info("[%s] Connected account=%s base=%s", self.name, _safe_id(self._account_id), self._base_url)
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
_LIVE_ADAPTERS.pop(self._token, None)
|
||||
self._running = False
|
||||
if self._poll_task and not self._poll_task.done():
|
||||
self._poll_task.cancel()
|
||||
@@ -1150,15 +1221,18 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._poll_task = None
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
if self._poll_session and not self._poll_session.closed:
|
||||
await self._poll_session.close()
|
||||
self._poll_session = None
|
||||
if self._send_session and not self._send_session.closed:
|
||||
await self._send_session.close()
|
||||
self._send_session = None
|
||||
self._release_platform_lock()
|
||||
self._mark_disconnected()
|
||||
logger.info("[%s] Disconnected", self.name)
|
||||
|
||||
async def _poll_loop(self) -> None:
|
||||
assert self._session is not None
|
||||
assert self._poll_session is not None
|
||||
sync_buf = _load_sync_buf(self._hermes_home, self._account_id)
|
||||
timeout_ms = LONG_POLL_TIMEOUT_MS
|
||||
consecutive_failures = 0
|
||||
@@ -1166,7 +1240,7 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
while self._running:
|
||||
try:
|
||||
response = await _get_updates(
|
||||
self._session,
|
||||
self._poll_session,
|
||||
base_url=self._base_url,
|
||||
token=self._token,
|
||||
sync_buf=sync_buf,
|
||||
@@ -1223,7 +1297,7 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
logger.error("[%s] unhandled inbound error from=%s: %s", self.name, _safe_id(message.get("from_user_id")), exc, exc_info=True)
|
||||
|
||||
async def _process_message(self, message: Dict[str, Any]) -> None:
|
||||
assert self._session is not None
|
||||
assert self._poll_session is not None
|
||||
sender_id = str(message.get("from_user_id") or "").strip()
|
||||
if not sender_id:
|
||||
return
|
||||
@@ -1316,7 +1390,7 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
media = _media_reference(item, "image_item")
|
||||
try:
|
||||
data = await _download_and_decrypt_media(
|
||||
self._session,
|
||||
self._poll_session,
|
||||
cdn_base_url=self._cdn_base_url,
|
||||
encrypted_query_param=media.get("encrypt_query_param"),
|
||||
aes_key_b64=(item.get("image_item") or {}).get("aeskey")
|
||||
@@ -1334,7 +1408,7 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
media = _media_reference(item, "video_item")
|
||||
try:
|
||||
data = await _download_and_decrypt_media(
|
||||
self._session,
|
||||
self._poll_session,
|
||||
cdn_base_url=self._cdn_base_url,
|
||||
encrypted_query_param=media.get("encrypt_query_param"),
|
||||
aes_key_b64=media.get("aes_key"),
|
||||
@@ -1353,7 +1427,7 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
mime = _mime_from_filename(filename)
|
||||
try:
|
||||
data = await _download_and_decrypt_media(
|
||||
self._session,
|
||||
self._poll_session,
|
||||
cdn_base_url=self._cdn_base_url,
|
||||
encrypted_query_param=media.get("encrypt_query_param"),
|
||||
aes_key_b64=media.get("aes_key"),
|
||||
@@ -1372,7 +1446,7 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
return None
|
||||
try:
|
||||
data = await _download_and_decrypt_media(
|
||||
self._session,
|
||||
self._poll_session,
|
||||
cdn_base_url=self._cdn_base_url,
|
||||
encrypted_query_param=media.get("encrypt_query_param"),
|
||||
aes_key_b64=media.get("aes_key"),
|
||||
@@ -1385,13 +1459,13 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
return None
|
||||
|
||||
async def _maybe_fetch_typing_ticket(self, user_id: str, context_token: Optional[str]) -> None:
|
||||
if not self._session or not self._token:
|
||||
if not self._poll_session or not self._token:
|
||||
return
|
||||
if self._typing_cache.get(user_id):
|
||||
return
|
||||
try:
|
||||
response = await _get_config(
|
||||
self._session,
|
||||
self._poll_session,
|
||||
base_url=self._base_url,
|
||||
token=self._token,
|
||||
user_id=user_id,
|
||||
@@ -1416,12 +1490,19 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
context_token: Optional[str],
|
||||
client_id: str,
|
||||
) -> None:
|
||||
"""Send a single text chunk with per-chunk retry and backoff."""
|
||||
"""Send a single text chunk with per-chunk retry and backoff.
|
||||
|
||||
On session-expired errors (errcode -14), automatically retries
|
||||
*without* ``context_token`` — iLink accepts tokenless sends as a
|
||||
degraded fallback, which keeps cron-initiated push messages working
|
||||
even when no user message has refreshed the session recently.
|
||||
"""
|
||||
last_error: Optional[Exception] = None
|
||||
retried_without_token = False
|
||||
for attempt in range(self._send_chunk_retries + 1):
|
||||
try:
|
||||
await _send_message(
|
||||
self._session,
|
||||
resp = await _send_message(
|
||||
self._send_session,
|
||||
base_url=self._base_url,
|
||||
token=self._token,
|
||||
to=chat_id,
|
||||
@@ -1429,6 +1510,31 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
context_token=context_token,
|
||||
client_id=client_id,
|
||||
)
|
||||
# Check iLink response for session-expired error
|
||||
if resp and isinstance(resp, dict):
|
||||
ret = resp.get("ret")
|
||||
errcode = resp.get("errcode")
|
||||
if (ret is not None and ret not in (0,)) or (errcode is not None and errcode not in (0,)):
|
||||
is_session_expired = (
|
||||
ret == SESSION_EXPIRED_ERRCODE
|
||||
or errcode == SESSION_EXPIRED_ERRCODE
|
||||
)
|
||||
# Session expired — strip token and retry once
|
||||
if is_session_expired and not retried_without_token and context_token:
|
||||
retried_without_token = True
|
||||
context_token = None
|
||||
self._token_store._cache.pop(
|
||||
self._token_store._key(self._account_id, chat_id), None
|
||||
)
|
||||
logger.warning(
|
||||
"[%s] session expired for %s; retrying without context_token",
|
||||
self.name, _safe_id(chat_id),
|
||||
)
|
||||
continue
|
||||
errmsg = resp.get("errmsg") or resp.get("msg") or "unknown error"
|
||||
raise RuntimeError(
|
||||
f"iLink sendmessage error: ret={ret} errcode={errcode} errmsg={errmsg}"
|
||||
)
|
||||
return
|
||||
except Exception as exc:
|
||||
last_error = exc
|
||||
@@ -1456,12 +1562,48 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
if not self._session or not self._token:
|
||||
if not self._send_session or not self._token:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
context_token = self._token_store.get(self._account_id, chat_id)
|
||||
last_message_id: Optional[str] = None
|
||||
|
||||
# Extract MEDIA: tags and bare local file paths before text delivery.
|
||||
media_files, cleaned_content = self.extract_media(content)
|
||||
_, image_cleaned = self.extract_images(cleaned_content)
|
||||
local_files, final_content = self.extract_local_files(image_cleaned)
|
||||
|
||||
_AUDIO_EXTS = {".ogg", ".opus", ".mp3", ".wav", ".m4a"}
|
||||
_VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm", ".3gp"}
|
||||
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".gif"}
|
||||
|
||||
async def _deliver_media(path: str, is_voice: bool = False) -> None:
|
||||
ext = Path(path).suffix.lower()
|
||||
if is_voice or ext in _AUDIO_EXTS:
|
||||
await self.send_voice(chat_id=chat_id, audio_path=path, metadata=metadata)
|
||||
elif ext in _VIDEO_EXTS:
|
||||
await self.send_video(chat_id=chat_id, video_path=path, metadata=metadata)
|
||||
elif ext in _IMAGE_EXTS:
|
||||
await self.send_image_file(chat_id=chat_id, image_path=path, metadata=metadata)
|
||||
else:
|
||||
await self.send_document(chat_id=chat_id, file_path=path, metadata=metadata)
|
||||
|
||||
try:
|
||||
chunks = [c for c in self._split_text(self.format_message(content)) if c and c.strip()]
|
||||
# Deliver extracted MEDIA: attachments first.
|
||||
for media_path, is_voice in media_files:
|
||||
try:
|
||||
await _deliver_media(media_path, is_voice)
|
||||
except Exception as exc:
|
||||
logger.warning("[%s] media delivery failed for %s: %s", self.name, media_path, exc)
|
||||
|
||||
# Deliver bare local file paths.
|
||||
for file_path in local_files:
|
||||
try:
|
||||
await _deliver_media(file_path, is_voice=False)
|
||||
except Exception as exc:
|
||||
logger.warning("[%s] local file delivery failed for %s: %s", self.name, file_path, exc)
|
||||
|
||||
# Deliver text content.
|
||||
chunks = [c for c in self._split_text(self.format_message(final_content)) if c and c.strip()]
|
||||
for idx, chunk in enumerate(chunks):
|
||||
client_id = f"hermes-weixin-{uuid.uuid4().hex}"
|
||||
await self._send_text_chunk(
|
||||
@@ -1479,14 +1621,14 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
return SendResult(success=False, error=str(exc))
|
||||
|
||||
async def send_typing(self, chat_id: str, metadata: Optional[Dict[str, Any]] = None) -> None:
|
||||
if not self._session or not self._token:
|
||||
if not self._send_session or not self._token:
|
||||
return
|
||||
typing_ticket = self._typing_cache.get(chat_id)
|
||||
if not typing_ticket:
|
||||
return
|
||||
try:
|
||||
await _send_typing(
|
||||
self._session,
|
||||
self._send_session,
|
||||
base_url=self._base_url,
|
||||
token=self._token,
|
||||
to_user_id=chat_id,
|
||||
@@ -1497,14 +1639,14 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
logger.debug("[%s] typing start failed for %s: %s", self.name, _safe_id(chat_id), exc)
|
||||
|
||||
async def stop_typing(self, chat_id: str) -> None:
|
||||
if not self._session or not self._token:
|
||||
if not self._send_session or not self._token:
|
||||
return
|
||||
typing_ticket = self._typing_cache.get(chat_id)
|
||||
if not typing_ticket:
|
||||
return
|
||||
try:
|
||||
await _send_typing(
|
||||
self._session,
|
||||
self._send_session,
|
||||
base_url=self._base_url,
|
||||
token=self._token,
|
||||
to_user_id=chat_id,
|
||||
@@ -1542,24 +1684,35 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
async def send_image_file(
|
||||
self,
|
||||
chat_id: str,
|
||||
path: str,
|
||||
caption: str = "",
|
||||
image_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
return await self.send_document(chat_id, file_path=path, caption=caption, metadata=metadata)
|
||||
del reply_to, kwargs
|
||||
return await self.send_document(
|
||||
chat_id=chat_id,
|
||||
file_path=image_path,
|
||||
caption=caption,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
async def send_document(
|
||||
self,
|
||||
chat_id: str,
|
||||
file_path: str,
|
||||
caption: str = "",
|
||||
caption: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
if not self._session or not self._token:
|
||||
del file_name, reply_to, metadata, kwargs
|
||||
if not self._send_session or not self._token:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
try:
|
||||
message_id = await self._send_file(chat_id, file_path, caption)
|
||||
message_id = await self._send_file(chat_id, file_path, caption or "")
|
||||
return SendResult(success=True, message_id=message_id)
|
||||
except Exception as exc:
|
||||
logger.error("[%s] send_document failed to=%s: %s", self.name, _safe_id(chat_id), exc)
|
||||
@@ -1573,7 +1726,7 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
if not self._session or not self._token:
|
||||
if not self._send_session or not self._token:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
try:
|
||||
message_id = await self._send_file(chat_id, video_path, caption or "")
|
||||
@@ -1590,7 +1743,24 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
return await self.send_document(chat_id, audio_path, caption=caption or "", metadata=metadata)
|
||||
if not self._send_session or not self._token:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
# Native outbound Weixin voice bubbles are not proven-working in the
|
||||
# upstream reference implementation. Prefer a reliable file attachment
|
||||
# fallback so users at least receive playable audio, even for .silk.
|
||||
fallback_caption = caption or "[voice message as attachment]"
|
||||
try:
|
||||
message_id = await self._send_file(
|
||||
chat_id,
|
||||
audio_path,
|
||||
fallback_caption,
|
||||
force_file_attachment=True,
|
||||
)
|
||||
return SendResult(success=True, message_id=message_id)
|
||||
except Exception as exc:
|
||||
logger.error("[%s] send_voice failed to=%s: %s", self.name, _safe_id(chat_id), exc)
|
||||
return SendResult(success=False, error=str(exc))
|
||||
|
||||
async def _download_remote_media(self, url: str) -> str:
|
||||
from tools.url_safety import is_safe_url
|
||||
@@ -1598,8 +1768,8 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
if not is_safe_url(url):
|
||||
raise ValueError(f"Blocked unsafe URL (SSRF protection): {url}")
|
||||
|
||||
assert self._session is not None
|
||||
async with self._session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as response:
|
||||
assert self._send_session is not None
|
||||
async with self._send_session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as response:
|
||||
response.raise_for_status()
|
||||
data = await response.read()
|
||||
suffix = Path(url.split("?", 1)[0]).suffix or ".bin"
|
||||
@@ -1607,16 +1777,22 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
handle.write(data)
|
||||
return handle.name
|
||||
|
||||
async def _send_file(self, chat_id: str, path: str, caption: str) -> str:
|
||||
assert self._session is not None and self._token is not None
|
||||
async def _send_file(
|
||||
self,
|
||||
chat_id: str,
|
||||
path: str,
|
||||
caption: str,
|
||||
force_file_attachment: bool = False,
|
||||
) -> str:
|
||||
assert self._send_session is not None and self._token is not None
|
||||
plaintext = Path(path).read_bytes()
|
||||
media_type, item_builder = self._outbound_media_builder(path)
|
||||
media_type, item_builder = self._outbound_media_builder(path, force_file_attachment=force_file_attachment)
|
||||
filekey = secrets.token_hex(16)
|
||||
aes_key = secrets.token_bytes(16)
|
||||
rawsize = len(plaintext)
|
||||
rawfilemd5 = hashlib.md5(plaintext).hexdigest()
|
||||
upload_response = await _get_upload_url(
|
||||
self._session,
|
||||
self._send_session,
|
||||
base_url=self._base_url,
|
||||
token=self._token,
|
||||
to_user_id=chat_id,
|
||||
@@ -1642,30 +1818,34 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
raise RuntimeError(f"getUploadUrl returned neither upload_param nor upload_full_url: {upload_response}")
|
||||
|
||||
encrypted_query_param = await _upload_ciphertext(
|
||||
self._session,
|
||||
self._send_session,
|
||||
ciphertext=ciphertext,
|
||||
upload_url=upload_url,
|
||||
)
|
||||
|
||||
context_token = self._token_store.get(self._account_id, chat_id)
|
||||
# The iLink API expects aes_key as base64(hex_string), not base64(raw_bytes).
|
||||
# Sending base64(raw_bytes) causes images to show as grey boxes on the
|
||||
# receiver side because the decryption key doesn't match.
|
||||
aes_key_for_api = base64.b64encode(aes_key.hex().encode("ascii")).decode("ascii")
|
||||
media_item = item_builder(
|
||||
encrypt_query_param=encrypted_query_param,
|
||||
aes_key_for_api=aes_key_for_api,
|
||||
ciphertext_size=len(ciphertext),
|
||||
plaintext_size=rawsize,
|
||||
filename=Path(path).name,
|
||||
rawfilemd5=rawfilemd5,
|
||||
)
|
||||
item_kwargs = {
|
||||
"encrypt_query_param": encrypted_query_param,
|
||||
"aes_key_for_api": aes_key_for_api,
|
||||
"ciphertext_size": len(ciphertext),
|
||||
"plaintext_size": rawsize,
|
||||
"filename": Path(path).name,
|
||||
"rawfilemd5": rawfilemd5,
|
||||
}
|
||||
if media_type == MEDIA_VOICE and path.endswith(".silk"):
|
||||
item_kwargs["encode_type"] = 6
|
||||
item_kwargs["sample_rate"] = 24000
|
||||
item_kwargs["bits_per_sample"] = 16
|
||||
media_item = item_builder(**item_kwargs)
|
||||
|
||||
last_message_id = None
|
||||
if caption:
|
||||
last_message_id = f"hermes-weixin-{uuid.uuid4().hex}"
|
||||
await _send_message(
|
||||
self._session,
|
||||
self._send_session,
|
||||
base_url=self._base_url,
|
||||
token=self._token,
|
||||
to=chat_id,
|
||||
@@ -1676,7 +1856,7 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
|
||||
last_message_id = f"hermes-weixin-{uuid.uuid4().hex}"
|
||||
await _api_post(
|
||||
self._session,
|
||||
self._send_session,
|
||||
base_url=self._base_url,
|
||||
endpoint=EP_SEND_MESSAGE,
|
||||
payload={
|
||||
@@ -1695,7 +1875,7 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
)
|
||||
return last_message_id
|
||||
|
||||
def _outbound_media_builder(self, path: str):
|
||||
def _outbound_media_builder(self, path: str, force_file_attachment: bool = False):
|
||||
mime = mimetypes.guess_type(path)[0] or "application/octet-stream"
|
||||
if mime.startswith("image/"):
|
||||
return MEDIA_IMAGE, lambda **kw: {
|
||||
@@ -1723,7 +1903,7 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
"video_md5": kw.get("rawfilemd5", ""),
|
||||
},
|
||||
}
|
||||
if mime.startswith("audio/") or path.endswith(".silk"):
|
||||
if path.endswith(".silk") and not force_file_attachment:
|
||||
return MEDIA_VOICE, lambda **kw: {
|
||||
"type": ITEM_VOICE,
|
||||
"voice_item": {
|
||||
@@ -1732,9 +1912,25 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
"aes_key": kw["aes_key_for_api"],
|
||||
"encrypt_type": 1,
|
||||
},
|
||||
"encode_type": kw.get("encode_type"),
|
||||
"bits_per_sample": kw.get("bits_per_sample"),
|
||||
"sample_rate": kw.get("sample_rate"),
|
||||
"playtime": kw.get("playtime", 0),
|
||||
},
|
||||
}
|
||||
if mime.startswith("audio/"):
|
||||
return MEDIA_FILE, lambda **kw: {
|
||||
"type": ITEM_FILE,
|
||||
"file_item": {
|
||||
"media": {
|
||||
"encrypt_query_param": kw["encrypt_query_param"],
|
||||
"aes_key": kw["aes_key_for_api"],
|
||||
"encrypt_type": 1,
|
||||
},
|
||||
"file_name": kw["filename"],
|
||||
"len": str(kw["plaintext_size"]),
|
||||
},
|
||||
}
|
||||
return MEDIA_FILE, lambda **kw: {
|
||||
"type": ITEM_FILE,
|
||||
"file_item": {
|
||||
@@ -1784,7 +1980,34 @@ async def send_weixin_direct(
|
||||
token_store.restore(account_id)
|
||||
context_token = token_store.get(account_id, chat_id)
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
live_adapter = _LIVE_ADAPTERS.get(resolved_token)
|
||||
send_session = getattr(live_adapter, '_send_session', None)
|
||||
if live_adapter is not None and send_session is not None and not send_session.closed:
|
||||
last_result: Optional[SendResult] = None
|
||||
cleaned = live_adapter.format_message(message)
|
||||
if cleaned:
|
||||
last_result = await live_adapter.send(chat_id, cleaned)
|
||||
if not last_result.success:
|
||||
return {"error": f"Weixin send failed: {last_result.error}"}
|
||||
|
||||
for media_path, _is_voice in media_files or []:
|
||||
ext = Path(media_path).suffix.lower()
|
||||
if ext in {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"}:
|
||||
last_result = await live_adapter.send_image_file(chat_id, media_path)
|
||||
else:
|
||||
last_result = await live_adapter.send_document(chat_id, media_path)
|
||||
if not last_result.success:
|
||||
return {"error": f"Weixin media send failed: {last_result.error}"}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"platform": "weixin",
|
||||
"chat_id": chat_id,
|
||||
"message_id": last_result.message_id if last_result else None,
|
||||
"context_token_used": bool(context_token),
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True, connector=_make_ssl_connector()) as session:
|
||||
adapter = WeixinAdapter(
|
||||
PlatformConfig(
|
||||
enabled=True,
|
||||
@@ -1797,6 +2020,7 @@ async def send_weixin_direct(
|
||||
},
|
||||
)
|
||||
)
|
||||
adapter._send_session = session
|
||||
adapter._session = session
|
||||
adapter._token = resolved_token
|
||||
adapter._account_id = account_id
|
||||
|
||||
+622
-29
@@ -24,11 +24,20 @@ import signal
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from contextvars import copy_context
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional, Any, List
|
||||
|
||||
# --- Agent cache tuning ---------------------------------------------------
|
||||
# Bounds the per-session AIAgent cache to prevent unbounded growth in
|
||||
# long-lived gateways (each AIAgent holds LLM clients, tool schemas,
|
||||
# memory providers, etc.). LRU order + idle TTL eviction are enforced
|
||||
# from _enforce_agent_cache_cap() and _session_expiry_watcher() below.
|
||||
_AGENT_CACHE_MAX_SIZE = 128
|
||||
_AGENT_CACHE_IDLE_TTL_SECS = 3600.0 # evict agents idle for >1h
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SSL certificate auto-detection for NixOS and other non-standard systems.
|
||||
# Must run BEFORE any HTTP library (discord, aiohttp, etc.) is imported.
|
||||
@@ -622,8 +631,13 @@ class GatewayRunner:
|
||||
# system prompt (including memory) every turn — breaking prefix cache
|
||||
# and costing ~10x more on providers with prompt caching (Anthropic).
|
||||
# Key: session_key, Value: (AIAgent, config_signature_str)
|
||||
#
|
||||
# OrderedDict so _enforce_agent_cache_cap() can pop the least-recently-
|
||||
# used entry (move_to_end() on cache hits, popitem(last=False) for
|
||||
# eviction). Hard cap via _AGENT_CACHE_MAX_SIZE, idle TTL enforced
|
||||
# from _session_expiry_watcher().
|
||||
import threading as _threading
|
||||
self._agent_cache: Dict[str, tuple] = {}
|
||||
self._agent_cache: "OrderedDict[str, tuple]" = OrderedDict()
|
||||
self._agent_cache_lock = _threading.Lock()
|
||||
|
||||
# Per-session model overrides from /model command.
|
||||
@@ -2102,6 +2116,11 @@ class GatewayRunner:
|
||||
_cached_agent = self._running_agents.get(key)
|
||||
if _cached_agent and _cached_agent is not _AGENT_PENDING_SENTINEL:
|
||||
self._cleanup_agent_resources(_cached_agent)
|
||||
# Drop the cache entry so the AIAgent (and its LLM
|
||||
# clients, tool schemas, memory provider refs) can
|
||||
# be garbage-collected. Otherwise the cache grows
|
||||
# unbounded across the gateway's lifetime.
|
||||
self._evict_cached_agent(key)
|
||||
# Mark as flushed and persist to disk so the flag
|
||||
# survives gateway restarts.
|
||||
with self.session_store._lock:
|
||||
@@ -2145,6 +2164,44 @@ class GatewayRunner:
|
||||
logger.info(
|
||||
"Session expiry done: %d flushed", _flushed,
|
||||
)
|
||||
|
||||
# Sweep agents that have been idle beyond the TTL regardless
|
||||
# of session reset policy. This catches sessions with very
|
||||
# long / "never" reset windows, whose cached AIAgents would
|
||||
# otherwise pin memory for the gateway's entire lifetime.
|
||||
try:
|
||||
_idle_evicted = self._sweep_idle_cached_agents()
|
||||
if _idle_evicted:
|
||||
logger.info(
|
||||
"Agent cache idle sweep: evicted %d agent(s)",
|
||||
_idle_evicted,
|
||||
)
|
||||
except Exception as _e:
|
||||
logger.debug("Idle agent sweep failed: %s", _e)
|
||||
|
||||
# Periodically prune stale SessionStore entries. The
|
||||
# in-memory dict (and sessions.json) would otherwise grow
|
||||
# unbounded in gateways serving many rotating chats /
|
||||
# threads / users over long time windows. Pruning is
|
||||
# invisible to users — a resumed session just gets a
|
||||
# fresh session_id, exactly as if the reset policy fired.
|
||||
_last_prune_ts = getattr(self, "_last_session_store_prune_ts", 0.0)
|
||||
_prune_interval = 3600.0 # once per hour
|
||||
if time.time() - _last_prune_ts > _prune_interval:
|
||||
try:
|
||||
_max_age = int(
|
||||
getattr(self.config, "session_store_max_age_days", 0) or 0
|
||||
)
|
||||
if _max_age > 0:
|
||||
_pruned = self.session_store.prune_old_entries(_max_age)
|
||||
if _pruned:
|
||||
logger.info(
|
||||
"SessionStore prune: dropped %d stale entries",
|
||||
_pruned,
|
||||
)
|
||||
except Exception as _e:
|
||||
logger.debug("SessionStore prune failed: %s", _e)
|
||||
self._last_session_store_prune_ts = time.time()
|
||||
except Exception as e:
|
||||
logger.debug("Session expiry watcher error: %s", e)
|
||||
# Sleep in small increments so we can stop quickly
|
||||
@@ -2351,6 +2408,7 @@ class GatewayRunner:
|
||||
|
||||
self.adapters.clear()
|
||||
self._running_agents.clear()
|
||||
self._running_agents_ts.clear()
|
||||
self._pending_messages.clear()
|
||||
self._pending_approvals.clear()
|
||||
if hasattr(self, '_busy_ack_ts'):
|
||||
@@ -2375,6 +2433,20 @@ class GatewayRunner:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Close SQLite session DBs so the WAL write lock is released.
|
||||
# Without this, --replace and similar restart flows leave the
|
||||
# old gateway's connection holding the WAL lock until Python
|
||||
# actually exits — causing 'database is locked' errors when
|
||||
# the new gateway tries to open the same file.
|
||||
for _db_holder in (self, getattr(self, "session_store", None)):
|
||||
_db = getattr(_db_holder, "_db", None) if _db_holder else None
|
||||
if _db is None or not hasattr(_db, "close"):
|
||||
continue
|
||||
try:
|
||||
_db.close()
|
||||
except Exception as _e:
|
||||
logger.debug("SessionDB close error: %s", _e)
|
||||
|
||||
from gateway.status import remove_pid_file
|
||||
remove_pid_file()
|
||||
|
||||
@@ -2618,6 +2690,9 @@ class GatewayRunner:
|
||||
Platform.BLUEBUBBLES: "BLUEBUBBLES_ALLOWED_USERS",
|
||||
Platform.QQBOT: "QQ_ALLOWED_USERS",
|
||||
}
|
||||
platform_group_env_map = {
|
||||
Platform.QQBOT: "QQ_GROUP_ALLOWED_USERS",
|
||||
}
|
||||
platform_allow_all_map = {
|
||||
Platform.TELEGRAM: "TELEGRAM_ALLOW_ALL_USERS",
|
||||
Platform.DISCORD: "DISCORD_ALLOW_ALL_USERS",
|
||||
@@ -2642,6 +2717,28 @@ class GatewayRunner:
|
||||
if platform_allow_all_var and os.getenv(platform_allow_all_var, "").lower() in ("true", "1", "yes"):
|
||||
return True
|
||||
|
||||
# Discord bot senders that passed the DISCORD_ALLOW_BOTS platform
|
||||
# filter are already authorized at the platform level — skip the
|
||||
# user allowlist. Without this, bot messages allowed by
|
||||
# DISCORD_ALLOW_BOTS=mentions/all would be rejected here with
|
||||
# "Unauthorized user" (fixes #4466).
|
||||
if source.platform == Platform.DISCORD and getattr(source, "is_bot", False):
|
||||
allow_bots = os.getenv("DISCORD_ALLOW_BOTS", "none").lower().strip()
|
||||
if allow_bots in ("mentions", "all"):
|
||||
return True
|
||||
|
||||
# Discord role-based access (DISCORD_ALLOWED_ROLES): the adapter's
|
||||
# on_message pre-filter already verified role membership — if the
|
||||
# message reached here, the user passed that check. Authorize
|
||||
# directly to avoid the "no allowlists configured" branch below
|
||||
# rejecting role-only setups where DISCORD_ALLOWED_USERS is empty
|
||||
# (issue #7871).
|
||||
if (
|
||||
source.platform == Platform.DISCORD
|
||||
and os.getenv("DISCORD_ALLOWED_ROLES", "").strip()
|
||||
):
|
||||
return True
|
||||
|
||||
# Check pairing store (always checked, regardless of allowlists)
|
||||
platform_name = source.platform.value if source.platform else ""
|
||||
if self.pairing_store.is_approved(platform_name, user_id):
|
||||
@@ -2649,12 +2746,23 @@ class GatewayRunner:
|
||||
|
||||
# Check platform-specific and global allowlists
|
||||
platform_allowlist = os.getenv(platform_env_map.get(source.platform, ""), "").strip()
|
||||
group_allowlist = ""
|
||||
if source.chat_type == "group":
|
||||
group_allowlist = os.getenv(platform_group_env_map.get(source.platform, ""), "").strip()
|
||||
global_allowlist = os.getenv("GATEWAY_ALLOWED_USERS", "").strip()
|
||||
|
||||
if not platform_allowlist and not global_allowlist:
|
||||
if not platform_allowlist and not group_allowlist and not global_allowlist:
|
||||
# No allowlists configured -- check global allow-all flag
|
||||
return os.getenv("GATEWAY_ALLOW_ALL_USERS", "").lower() in ("true", "1", "yes")
|
||||
|
||||
# Some platforms authorize group traffic by chat ID rather than sender ID.
|
||||
if group_allowlist and source.chat_type == "group" and source.chat_id:
|
||||
allowed_group_ids = {
|
||||
chat_id.strip() for chat_id in group_allowlist.split(",") if chat_id.strip()
|
||||
}
|
||||
if "*" in allowed_group_ids or source.chat_id in allowed_group_ids:
|
||||
return True
|
||||
|
||||
# Check if user is in any allowlist
|
||||
allowed_ids = set()
|
||||
if platform_allowlist:
|
||||
@@ -2837,16 +2945,17 @@ class GatewayRunner:
|
||||
_quick_key[:30], _stale_age, _stale_idle,
|
||||
_raw_stale_timeout, _stale_detail,
|
||||
)
|
||||
del self._running_agents[_quick_key]
|
||||
self._running_agents_ts.pop(_quick_key, None)
|
||||
self._busy_ack_ts.pop(_quick_key, None)
|
||||
self._release_running_agent_state(_quick_key)
|
||||
|
||||
if _quick_key in self._running_agents:
|
||||
if event.get_command() == "status":
|
||||
return await self._handle_status_command(event)
|
||||
|
||||
# Resolve the command once for all early-intercept checks below.
|
||||
from hermes_cli.commands import resolve_command as _resolve_cmd_inner
|
||||
from hermes_cli.commands import (
|
||||
resolve_command as _resolve_cmd_inner,
|
||||
should_bypass_active_session as _should_bypass_active_inner,
|
||||
)
|
||||
_evt_cmd = event.get_command()
|
||||
_cmd_def_inner = _resolve_cmd_inner(_evt_cmd) if _evt_cmd else None
|
||||
|
||||
@@ -2867,8 +2976,7 @@ class GatewayRunner:
|
||||
if adapter and hasattr(adapter, 'get_pending_message'):
|
||||
adapter.get_pending_message(_quick_key) # consume and discard
|
||||
self._pending_messages.pop(_quick_key, None)
|
||||
if _quick_key in self._running_agents:
|
||||
del self._running_agents[_quick_key]
|
||||
self._release_running_agent_state(_quick_key)
|
||||
logger.info("STOP for session %s — agent interrupted, session lock released", _quick_key[:20])
|
||||
return "⚡ Stopped. You can continue this session."
|
||||
|
||||
@@ -2890,8 +2998,7 @@ class GatewayRunner:
|
||||
self._pending_messages.pop(_quick_key, None)
|
||||
# Clean up the running agent entry so the reset handler
|
||||
# doesn't think an agent is still active.
|
||||
if _quick_key in self._running_agents:
|
||||
del self._running_agents[_quick_key]
|
||||
self._release_running_agent_state(_quick_key)
|
||||
return await self._handle_reset_command(event)
|
||||
|
||||
# /queue <prompt> — queue without interrupting
|
||||
@@ -2912,6 +3019,54 @@ class GatewayRunner:
|
||||
adapter._pending_messages[_quick_key] = queued_event
|
||||
return "Queued for the next turn."
|
||||
|
||||
# /steer <prompt> — inject mid-run after the next tool call.
|
||||
# Unlike /queue (turn boundary), /steer lands BETWEEN tool-call
|
||||
# iterations inside the same agent run, by appending to the
|
||||
# last tool result's content. No interrupt, no new user turn,
|
||||
# no role-alternation violation.
|
||||
if _cmd_def_inner and _cmd_def_inner.name == "steer":
|
||||
steer_text = event.get_command_args().strip()
|
||||
if not steer_text:
|
||||
return "Usage: /steer <prompt>"
|
||||
running_agent = self._running_agents.get(_quick_key)
|
||||
if running_agent is _AGENT_PENDING_SENTINEL:
|
||||
# Agent hasn't started yet — queue as turn-boundary fallback.
|
||||
adapter = self.adapters.get(source.platform)
|
||||
if adapter:
|
||||
from gateway.platforms.base import MessageEvent as _ME, MessageType as _MT
|
||||
queued_event = _ME(
|
||||
text=steer_text,
|
||||
message_type=_MT.TEXT,
|
||||
source=event.source,
|
||||
message_id=event.message_id,
|
||||
channel_prompt=event.channel_prompt,
|
||||
)
|
||||
adapter._pending_messages[_quick_key] = queued_event
|
||||
return "Agent still starting — /steer queued for the next turn."
|
||||
if running_agent and hasattr(running_agent, "steer"):
|
||||
try:
|
||||
accepted = running_agent.steer(steer_text)
|
||||
except Exception as exc:
|
||||
logger.warning("Steer failed for session %s: %s", _quick_key[:20], exc)
|
||||
return f"⚠️ Steer failed: {exc}"
|
||||
if accepted:
|
||||
preview = steer_text[:60] + ("..." if len(steer_text) > 60 else "")
|
||||
return f"⏩ Steer queued — arrives after the next tool call: '{preview}'"
|
||||
return "Steer rejected (empty payload)."
|
||||
# Running agent is missing or lacks steer() — fall back to queue.
|
||||
adapter = self.adapters.get(source.platform)
|
||||
if adapter:
|
||||
from gateway.platforms.base import MessageEvent as _ME, MessageType as _MT
|
||||
queued_event = _ME(
|
||||
text=steer_text,
|
||||
message_type=_MT.TEXT,
|
||||
source=event.source,
|
||||
message_id=event.message_id,
|
||||
channel_prompt=event.channel_prompt,
|
||||
)
|
||||
adapter._pending_messages[_quick_key] = queued_event
|
||||
return "No active agent — /steer queued for the next turn."
|
||||
|
||||
# /model must not be used while the agent is running.
|
||||
if _cmd_def_inner and _cmd_def_inner.name == "model":
|
||||
return "Agent is running — wait or /stop first, then switch models."
|
||||
@@ -2925,11 +3080,29 @@ class GatewayRunner:
|
||||
return await self._handle_approve_command(event)
|
||||
return await self._handle_deny_command(event)
|
||||
|
||||
# /agents (/tasks alias) should be query-only and never interrupt.
|
||||
if _cmd_def_inner and _cmd_def_inner.name == "agents":
|
||||
return await self._handle_agents_command(event)
|
||||
|
||||
# /background must bypass the running-agent guard — it starts a
|
||||
# parallel task and must never interrupt the active conversation.
|
||||
if _cmd_def_inner and _cmd_def_inner.name == "background":
|
||||
return await self._handle_background_command(event)
|
||||
|
||||
# Gateway-handled info/control commands must never fall through to
|
||||
# the interrupt path. If they are queued as pending text, the
|
||||
# slash-command safety net discards them before the user sees any
|
||||
# response.
|
||||
if _cmd_def_inner and _should_bypass_active_inner(_cmd_def_inner.name):
|
||||
if _cmd_def_inner.name == "help":
|
||||
return await self._handle_help_command(event)
|
||||
if _cmd_def_inner.name == "commands":
|
||||
return await self._handle_commands_command(event)
|
||||
if _cmd_def_inner.name == "profile":
|
||||
return await self._handle_profile_command(event)
|
||||
if _cmd_def_inner.name == "update":
|
||||
return await self._handle_update_command(event)
|
||||
|
||||
if event.message_type == MessageType.PHOTO:
|
||||
logger.debug("PRIORITY photo follow-up for session %s — queueing without interrupt", _quick_key[:20])
|
||||
adapter = self.adapters.get(source.platform)
|
||||
@@ -2968,8 +3141,7 @@ class GatewayRunner:
|
||||
# Agent is being set up but not ready yet.
|
||||
if event.get_command() == "stop":
|
||||
# Force-clean the sentinel so the session is unlocked.
|
||||
if _quick_key in self._running_agents:
|
||||
del self._running_agents[_quick_key]
|
||||
self._release_running_agent_state(_quick_key)
|
||||
logger.info("HARD STOP (pending) for session %s — sentinel cleared", _quick_key[:20])
|
||||
return "⚡ Force-stopped. The agent was still starting — session unlocked."
|
||||
# Queue the message so it will be picked up after the
|
||||
@@ -3033,6 +3205,9 @@ class GatewayRunner:
|
||||
if canonical == "status":
|
||||
return await self._handle_status_command(event)
|
||||
|
||||
if canonical == "agents":
|
||||
return await self._handle_agents_command(event)
|
||||
|
||||
if canonical == "restart":
|
||||
return await self._handle_restart_command(event)
|
||||
|
||||
@@ -3133,6 +3308,21 @@ class GatewayRunner:
|
||||
if canonical == "btw":
|
||||
return await self._handle_btw_command(event)
|
||||
|
||||
if canonical == "steer":
|
||||
# No active agent — /steer has no tool call to inject into.
|
||||
# Strip the prefix so downstream treats it as a normal user
|
||||
# message. If the payload is empty, surface the usage hint.
|
||||
steer_payload = event.get_command_args().strip()
|
||||
if not steer_payload:
|
||||
return "Usage: /steer <prompt> (no agent is running; sending as a normal message)"
|
||||
try:
|
||||
event.text = steer_payload
|
||||
except Exception:
|
||||
pass
|
||||
# Do NOT return — fall through to _handle_message_with_agent
|
||||
# at the end of this function so the rewritten text is sent
|
||||
# to the agent as a regular user turn.
|
||||
|
||||
if canonical == "voice":
|
||||
return await self._handle_voice_command(event)
|
||||
|
||||
@@ -3285,8 +3475,13 @@ class GatewayRunner:
|
||||
# (exception, command fallthrough, etc.) the sentinel must
|
||||
# not linger or the session would be permanently locked out.
|
||||
if self._running_agents.get(_quick_key) is _AGENT_PENDING_SENTINEL:
|
||||
del self._running_agents[_quick_key]
|
||||
self._running_agents_ts.pop(_quick_key, None)
|
||||
self._release_running_agent_state(_quick_key)
|
||||
else:
|
||||
# Agent path already cleaned _running_agents; make sure
|
||||
# the paired metadata dicts are gone too.
|
||||
self._running_agents_ts.pop(_quick_key, None)
|
||||
if hasattr(self, "_busy_ack_ts"):
|
||||
self._busy_ack_ts.pop(_quick_key, None)
|
||||
|
||||
async def _prepare_inbound_message_text(
|
||||
self,
|
||||
@@ -4483,6 +4678,96 @@ class GatewayRunner:
|
||||
])
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def _handle_agents_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /agents command - list active agents and running tasks."""
|
||||
from tools.process_registry import format_uptime_short, process_registry
|
||||
|
||||
now = time.time()
|
||||
current_session_key = self._session_key_for_source(event.source)
|
||||
|
||||
running_agents: dict = getattr(self, "_running_agents", {}) or {}
|
||||
running_started: dict = getattr(self, "_running_agents_ts", {}) or {}
|
||||
|
||||
agent_rows: list[dict] = []
|
||||
for session_key, agent in running_agents.items():
|
||||
started = float(running_started.get(session_key, now))
|
||||
elapsed = max(0, int(now - started))
|
||||
is_pending = agent is _AGENT_PENDING_SENTINEL
|
||||
agent_rows.append(
|
||||
{
|
||||
"session_key": session_key,
|
||||
"elapsed": elapsed,
|
||||
"state": "starting" if is_pending else "running",
|
||||
"session_id": "" if is_pending else str(getattr(agent, "session_id", "") or ""),
|
||||
"model": "" if is_pending else str(getattr(agent, "model", "") or ""),
|
||||
}
|
||||
)
|
||||
|
||||
agent_rows.sort(key=lambda row: row["elapsed"], reverse=True)
|
||||
|
||||
running_processes: list[dict] = []
|
||||
try:
|
||||
running_processes = [
|
||||
p for p in process_registry.list_sessions()
|
||||
if p.get("status") == "running"
|
||||
]
|
||||
except Exception:
|
||||
running_processes = []
|
||||
|
||||
background_tasks = [
|
||||
t for t in (getattr(self, "_background_tasks", set()) or set())
|
||||
if hasattr(t, "done") and not t.done()
|
||||
]
|
||||
|
||||
lines = [
|
||||
"🤖 **Active Agents & Tasks**",
|
||||
"",
|
||||
f"**Active agents:** {len(agent_rows)}",
|
||||
]
|
||||
|
||||
if agent_rows:
|
||||
for idx, row in enumerate(agent_rows[:12], 1):
|
||||
current = " · this chat" if row["session_key"] == current_session_key else ""
|
||||
sid = f" · `{row['session_id']}`" if row["session_id"] else ""
|
||||
model = f" · `{row['model']}`" if row["model"] else ""
|
||||
lines.append(
|
||||
f"{idx}. `{row['session_key']}` · {row['state']} · "
|
||||
f"{format_uptime_short(row['elapsed'])}{sid}{model}{current}"
|
||||
)
|
||||
if len(agent_rows) > 12:
|
||||
lines.append(f"... and {len(agent_rows) - 12} more")
|
||||
|
||||
lines.extend(
|
||||
[
|
||||
"",
|
||||
f"**Running background processes:** {len(running_processes)}",
|
||||
]
|
||||
)
|
||||
if running_processes:
|
||||
for proc in running_processes[:12]:
|
||||
cmd = " ".join(str(proc.get("command", "")).split())
|
||||
if len(cmd) > 90:
|
||||
cmd = cmd[:87] + "..."
|
||||
lines.append(
|
||||
f"- `{proc.get('session_id', '?')}` · "
|
||||
f"{format_uptime_short(int(proc.get('uptime_seconds', 0)))} · `{cmd}`"
|
||||
)
|
||||
if len(running_processes) > 12:
|
||||
lines.append(f"... and {len(running_processes) - 12} more")
|
||||
|
||||
lines.extend(
|
||||
[
|
||||
"",
|
||||
f"**Gateway async jobs:** {len(background_tasks)}",
|
||||
]
|
||||
)
|
||||
|
||||
if not agent_rows and not running_processes and not background_tasks:
|
||||
lines.append("")
|
||||
lines.append("No active agents or running tasks.")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def _handle_stop_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /stop command - interrupt a running agent.
|
||||
@@ -4502,22 +4787,40 @@ class GatewayRunner:
|
||||
agent = self._running_agents.get(session_key)
|
||||
if agent is _AGENT_PENDING_SENTINEL:
|
||||
# Force-clean the sentinel so the session is unlocked.
|
||||
if session_key in self._running_agents:
|
||||
del self._running_agents[session_key]
|
||||
self._release_running_agent_state(session_key)
|
||||
logger.info("STOP (pending) for session %s — sentinel cleared", session_key[:20])
|
||||
return "⚡ Stopped. The agent hadn't started yet — you can continue this session."
|
||||
if agent:
|
||||
agent.interrupt("Stop requested")
|
||||
# Force-clean the session lock so a truly hung agent doesn't
|
||||
# keep it locked forever.
|
||||
if session_key in self._running_agents:
|
||||
del self._running_agents[session_key]
|
||||
self._release_running_agent_state(session_key)
|
||||
return "⚡ Stopped. You can continue this session."
|
||||
else:
|
||||
return "No active task to stop."
|
||||
|
||||
async def _handle_restart_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /restart command - drain active work, then restart the gateway."""
|
||||
# Defensive idempotency check: if the previous gateway process
|
||||
# recorded this same /restart (same platform + update_id) and the new
|
||||
# process is seeing it *again*, this is a re-delivery caused by PTB's
|
||||
# graceful-shutdown `get_updates` ACK failing on the way out ("Error
|
||||
# while calling `get_updates` one more time to mark all fetched
|
||||
# updates. Suppressing error to ensure graceful shutdown. When
|
||||
# polling for updates is restarted, updates may be received twice."
|
||||
# in gateway.log). Ignoring the stale redelivery prevents a
|
||||
# self-perpetuating restart loop where every fresh gateway
|
||||
# re-processes the same /restart command and immediately restarts
|
||||
# again.
|
||||
if self._is_stale_restart_redelivery(event):
|
||||
logger.info(
|
||||
"Ignoring redelivered /restart (platform=%s, update_id=%s) — "
|
||||
"already processed by a previous gateway instance.",
|
||||
event.source.platform.value if event.source and event.source.platform else "?",
|
||||
event.platform_update_id,
|
||||
)
|
||||
return ""
|
||||
|
||||
if self._restart_requested or self._draining:
|
||||
count = self._running_agent_count()
|
||||
if count:
|
||||
@@ -4540,6 +4843,26 @@ class GatewayRunner:
|
||||
except Exception as e:
|
||||
logger.debug("Failed to write restart notify file: %s", e)
|
||||
|
||||
# Record the triggering platform + update_id in a dedicated dedup
|
||||
# marker. Unlike .restart_notify.json (which gets unlinked once the
|
||||
# new gateway sends the "gateway restarted" notification), this
|
||||
# marker persists so the new gateway can still detect a delayed
|
||||
# /restart redelivery from Telegram. Overwritten on every /restart.
|
||||
try:
|
||||
import json as _json
|
||||
import time as _time
|
||||
dedup_data = {
|
||||
"platform": event.source.platform.value if event.source.platform else None,
|
||||
"requested_at": _time.time(),
|
||||
}
|
||||
if event.platform_update_id is not None:
|
||||
dedup_data["update_id"] = event.platform_update_id
|
||||
(_hermes_home / ".restart_last_processed.json").write_text(
|
||||
_json.dumps(dedup_data)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to write restart dedup marker: %s", e)
|
||||
|
||||
active_agents = self._running_agent_count()
|
||||
# When running under a service manager (systemd/launchd), use the
|
||||
# service restart path: exit with code 75 so the service manager
|
||||
@@ -4555,6 +4878,58 @@ class GatewayRunner:
|
||||
return f"⏳ Draining {active_agents} active agent(s) before restart..."
|
||||
return "♻ Restarting gateway. If you aren't notified within 60 seconds, restart from the console with `hermes gateway restart`."
|
||||
|
||||
def _is_stale_restart_redelivery(self, event: MessageEvent) -> bool:
|
||||
"""Return True if this /restart is a Telegram re-delivery we already handled.
|
||||
|
||||
The previous gateway wrote ``.restart_last_processed.json`` with the
|
||||
triggering platform + update_id when it processed the /restart. If
|
||||
we now see a /restart on the same platform with an update_id <= that
|
||||
recorded value AND the marker is recent (< 5 minutes), it's a
|
||||
redelivery and should be ignored.
|
||||
|
||||
Only applies to Telegram today (the only platform that exposes a
|
||||
numeric cross-session update ordering); other platforms return False.
|
||||
"""
|
||||
if event is None or event.source is None:
|
||||
return False
|
||||
if event.platform_update_id is None:
|
||||
return False
|
||||
if event.source.platform is None:
|
||||
return False
|
||||
# Only Telegram populates platform_update_id currently; be explicit
|
||||
# so future platforms aren't accidentally gated by this check.
|
||||
try:
|
||||
platform_value = event.source.platform.value
|
||||
except Exception:
|
||||
return False
|
||||
if platform_value != "telegram":
|
||||
return False
|
||||
|
||||
try:
|
||||
import json as _json
|
||||
import time as _time
|
||||
marker_path = _hermes_home / ".restart_last_processed.json"
|
||||
if not marker_path.exists():
|
||||
return False
|
||||
data = _json.loads(marker_path.read_text())
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
if data.get("platform") != platform_value:
|
||||
return False
|
||||
recorded_uid = data.get("update_id")
|
||||
if not isinstance(recorded_uid, int):
|
||||
return False
|
||||
# Staleness guard: ignore markers older than 5 minutes. A legitimately
|
||||
# old marker (e.g. crash recovery where notify never fired) should not
|
||||
# swallow a fresh /restart from the user.
|
||||
requested_at = data.get("requested_at")
|
||||
if isinstance(requested_at, (int, float)):
|
||||
if _time.time() - requested_at > 300:
|
||||
return False
|
||||
return event.platform_update_id <= recorded_uid
|
||||
|
||||
|
||||
async def _handle_help_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /help command - list available commands."""
|
||||
from hermes_cli.commands import gateway_help_lines
|
||||
@@ -5300,8 +5675,7 @@ class GatewayRunner:
|
||||
if "pynacl" in err_lower or "nacl" in err_lower or "davey" in err_lower:
|
||||
return (
|
||||
"Voice dependencies are missing (PyNaCl / davey). "
|
||||
"Install or reinstall Hermes with the messaging extra, e.g. "
|
||||
"`pip install hermes-agent[messaging]`."
|
||||
f"Install with: `{sys.executable} -m pip install PyNaCl`"
|
||||
)
|
||||
return f"Failed to join voice channel: {e}"
|
||||
|
||||
@@ -5797,7 +6171,7 @@ class GatewayRunner:
|
||||
pass
|
||||
|
||||
# Send media files
|
||||
for media_path in (media_files or []):
|
||||
for media_path, _is_voice in (media_files or []):
|
||||
try:
|
||||
await adapter.send_document(
|
||||
chat_id=source.chat_id,
|
||||
@@ -5975,7 +6349,7 @@ class GatewayRunner:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for media_path in (media_files or []):
|
||||
for media_path, _is_voice in (media_files or []):
|
||||
try:
|
||||
await adapter.send_file(chat_id=source.chat_id, file_path=media_path)
|
||||
except Exception:
|
||||
@@ -6427,8 +6801,7 @@ class GatewayRunner:
|
||||
logger.debug("Memory flush on resume failed: %s", e)
|
||||
|
||||
# Clear any running agent for this session key
|
||||
if session_key in self._running_agents:
|
||||
del self._running_agents[session_key]
|
||||
self._release_running_agent_state(session_key)
|
||||
|
||||
# Switch the session entry to point at the old session
|
||||
new_entry = self.session_store.switch_session(session_key, target_id)
|
||||
@@ -7844,6 +8217,30 @@ class GatewayRunner:
|
||||
override = self._session_model_overrides.get(session_key)
|
||||
return override is not None and override.get("model") == agent_model
|
||||
|
||||
def _release_running_agent_state(self, session_key: str) -> None:
|
||||
"""Pop ALL per-running-agent state entries for ``session_key``.
|
||||
|
||||
Replaces ad-hoc ``del self._running_agents[key]`` calls scattered
|
||||
across the gateway. Those sites had drifted: some popped only
|
||||
``_running_agents``; some also ``_running_agents_ts``; only one
|
||||
path also cleared ``_busy_ack_ts``. Each missed entry was a
|
||||
small, persistent leak — a (str_key → float) tuple per session
|
||||
per gateway lifetime.
|
||||
|
||||
Use this at every site that ends a running turn, regardless of
|
||||
cause (normal completion, /stop, /reset, /resume, sentinel
|
||||
cleanup, stale-eviction). Per-session state that PERSISTS
|
||||
across turns (``_session_model_overrides``, ``_voice_mode``,
|
||||
``_pending_approvals``, ``_update_prompt_pending``) is NOT
|
||||
touched here — those have their own lifecycles.
|
||||
"""
|
||||
if not session_key:
|
||||
return
|
||||
self._running_agents.pop(session_key, None)
|
||||
self._running_agents_ts.pop(session_key, None)
|
||||
if hasattr(self, "_busy_ack_ts"):
|
||||
self._busy_ack_ts.pop(session_key, None)
|
||||
|
||||
def _evict_cached_agent(self, session_key: str) -> None:
|
||||
"""Remove a cached agent for a session (called on /new, /model, etc)."""
|
||||
_lock = getattr(self, "_agent_cache_lock", None)
|
||||
@@ -7851,6 +8248,153 @@ class GatewayRunner:
|
||||
with _lock:
|
||||
self._agent_cache.pop(session_key, None)
|
||||
|
||||
def _release_evicted_agent_soft(self, agent: Any) -> None:
|
||||
"""Soft cleanup for cache-evicted agents — preserves session tool state.
|
||||
|
||||
Called from _enforce_agent_cache_cap and _sweep_idle_cached_agents.
|
||||
Distinct from _cleanup_agent_resources (full teardown) because a
|
||||
cache-evicted session may resume at any time — its terminal
|
||||
sandbox, browser daemon, and tracked bg processes must outlive
|
||||
the Python AIAgent instance so the next agent built for the
|
||||
same task_id inherits them.
|
||||
"""
|
||||
if agent is None:
|
||||
return
|
||||
try:
|
||||
if hasattr(agent, "release_clients"):
|
||||
agent.release_clients()
|
||||
else:
|
||||
# Older agent instance (shouldn't happen in practice) —
|
||||
# fall back to the legacy full-close path.
|
||||
self._cleanup_agent_resources(agent)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _enforce_agent_cache_cap(self) -> None:
|
||||
"""Evict oldest cached agents when cache exceeds _AGENT_CACHE_MAX_SIZE.
|
||||
|
||||
Must be called with _agent_cache_lock held. Resource cleanup
|
||||
(memory provider shutdown, tool resource close) is scheduled
|
||||
on a daemon thread so the caller doesn't block on slow teardown
|
||||
while holding the cache lock.
|
||||
|
||||
Agents currently in _running_agents are SKIPPED — their clients,
|
||||
terminal sandboxes, background processes, and child subagents
|
||||
are all in active use by the running turn. Evicting them would
|
||||
tear down those resources mid-turn and crash the request. If
|
||||
every candidate in the LRU order is active, we simply leave the
|
||||
cache over the cap; it will be re-checked on the next insert.
|
||||
"""
|
||||
_cache = getattr(self, "_agent_cache", None)
|
||||
if _cache is None:
|
||||
return
|
||||
# OrderedDict.popitem(last=False) pops oldest; plain dict lacks the
|
||||
# arg so skip enforcement if a test fixture swapped the cache type.
|
||||
if not hasattr(_cache, "move_to_end"):
|
||||
return
|
||||
|
||||
# Snapshot of agent instances that are actively mid-turn. Use id()
|
||||
# so the lookup is O(1) and doesn't depend on AIAgent.__eq__ (which
|
||||
# MagicMock overrides in tests).
|
||||
running_ids = {
|
||||
id(a)
|
||||
for a in getattr(self, "_running_agents", {}).values()
|
||||
if a is not None and a is not _AGENT_PENDING_SENTINEL
|
||||
}
|
||||
|
||||
# Walk LRU → MRU and evict excess-LRU entries that aren't mid-turn.
|
||||
# We only consider entries in the first (size - cap) LRU positions
|
||||
# as eviction candidates. If one of those slots is held by an
|
||||
# active agent, we SKIP it without compensating by evicting a
|
||||
# newer entry — that would penalise a freshly-inserted session
|
||||
# (which has no cache history to retain) while protecting an
|
||||
# already-cached long-running one. The cache may therefore stay
|
||||
# temporarily over cap; it will re-check on the next insert,
|
||||
# after active turns have finished.
|
||||
excess = max(0, len(_cache) - _AGENT_CACHE_MAX_SIZE)
|
||||
evict_plan: List[tuple] = [] # [(key, agent), ...]
|
||||
if excess > 0:
|
||||
ordered_keys = list(_cache.keys())
|
||||
for key in ordered_keys[:excess]:
|
||||
entry = _cache.get(key)
|
||||
agent = entry[0] if isinstance(entry, tuple) and entry else None
|
||||
if agent is not None and id(agent) in running_ids:
|
||||
continue # active mid-turn; don't evict, don't substitute
|
||||
evict_plan.append((key, agent))
|
||||
|
||||
for key, _ in evict_plan:
|
||||
_cache.pop(key, None)
|
||||
|
||||
remaining_over_cap = len(_cache) - _AGENT_CACHE_MAX_SIZE
|
||||
if remaining_over_cap > 0:
|
||||
logger.warning(
|
||||
"Agent cache over cap (%d > %d); %d excess slot(s) held by "
|
||||
"mid-turn agents — will re-check on next insert.",
|
||||
len(_cache), _AGENT_CACHE_MAX_SIZE, remaining_over_cap,
|
||||
)
|
||||
|
||||
for key, agent in evict_plan:
|
||||
logger.info(
|
||||
"Agent cache at cap; evicting LRU session=%s (cache_size=%d)",
|
||||
key, len(_cache),
|
||||
)
|
||||
if agent is not None:
|
||||
threading.Thread(
|
||||
target=self._release_evicted_agent_soft,
|
||||
args=(agent,),
|
||||
daemon=True,
|
||||
name=f"agent-cache-evict-{key[:24]}",
|
||||
).start()
|
||||
|
||||
def _sweep_idle_cached_agents(self) -> int:
|
||||
"""Evict cached agents whose AIAgent has been idle > _AGENT_CACHE_IDLE_TTL_SECS.
|
||||
|
||||
Safe to call from the session expiry watcher without holding the
|
||||
cache lock — acquires it internally. Returns the number of entries
|
||||
evicted. Resource cleanup is scheduled on daemon threads.
|
||||
|
||||
Agents currently in _running_agents are SKIPPED for the same reason
|
||||
as _enforce_agent_cache_cap: tearing down an active turn's clients
|
||||
mid-flight would crash the request.
|
||||
"""
|
||||
_cache = getattr(self, "_agent_cache", None)
|
||||
_lock = getattr(self, "_agent_cache_lock", None)
|
||||
if _cache is None or _lock is None:
|
||||
return 0
|
||||
now = time.time()
|
||||
to_evict: List[tuple] = []
|
||||
running_ids = {
|
||||
id(a)
|
||||
for a in getattr(self, "_running_agents", {}).values()
|
||||
if a is not None and a is not _AGENT_PENDING_SENTINEL
|
||||
}
|
||||
with _lock:
|
||||
for key, entry in list(_cache.items()):
|
||||
agent = entry[0] if isinstance(entry, tuple) and entry else None
|
||||
if agent is None:
|
||||
continue
|
||||
if id(agent) in running_ids:
|
||||
continue # mid-turn — don't tear it down
|
||||
last_activity = getattr(agent, "_last_activity_ts", None)
|
||||
if last_activity is None:
|
||||
continue
|
||||
if (now - last_activity) > _AGENT_CACHE_IDLE_TTL_SECS:
|
||||
to_evict.append((key, agent))
|
||||
for key, _ in to_evict:
|
||||
_cache.pop(key, None)
|
||||
for key, agent in to_evict:
|
||||
logger.info(
|
||||
"Agent cache idle-TTL evict: session=%s (idle=%.0fs)",
|
||||
key, now - getattr(agent, "_last_activity_ts", now),
|
||||
)
|
||||
threading.Thread(
|
||||
target=self._release_evicted_agent_soft,
|
||||
args=(agent,),
|
||||
daemon=True,
|
||||
name=f"agent-cache-idle-{key[:24]}",
|
||||
).start()
|
||||
return len(to_evict)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Proxy mode: forward messages to a remote Hermes API server
|
||||
# ------------------------------------------------------------------
|
||||
@@ -8618,6 +9162,13 @@ class GatewayRunner:
|
||||
cached = _cache.get(session_key)
|
||||
if cached and cached[1] == _sig:
|
||||
agent = cached[0]
|
||||
# Refresh LRU order so the cap enforcement evicts
|
||||
# truly-oldest entries, not the one we just used.
|
||||
if hasattr(_cache, "move_to_end"):
|
||||
try:
|
||||
_cache.move_to_end(session_key)
|
||||
except KeyError:
|
||||
pass
|
||||
# Reset activity timestamp so the inactivity timeout
|
||||
# handler doesn't see stale idle time from the previous
|
||||
# turn and immediately kill this agent. (#9051)
|
||||
@@ -8656,6 +9207,7 @@ class GatewayRunner:
|
||||
if _cache_lock and _cache is not None:
|
||||
with _cache_lock:
|
||||
_cache[session_key] = (agent, _sig)
|
||||
self._enforce_agent_cache_cap()
|
||||
logger.debug("Created new agent for session %s (sig=%s)", session_key, _sig)
|
||||
|
||||
# Per-message state — callbacks and reasoning config change every
|
||||
@@ -9524,10 +10076,8 @@ class GatewayRunner:
|
||||
|
||||
# Clean up tracking
|
||||
tracking_task.cancel()
|
||||
if session_key and session_key in self._running_agents:
|
||||
del self._running_agents[session_key]
|
||||
if session_key:
|
||||
self._running_agents_ts.pop(session_key, None)
|
||||
self._release_running_agent_state(session_key)
|
||||
if self._draining:
|
||||
self._update_runtime_status("draining")
|
||||
|
||||
@@ -9656,6 +10206,16 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool =
|
||||
"Replacing existing gateway instance (PID %d) with --replace.",
|
||||
existing_pid,
|
||||
)
|
||||
# Record a takeover marker so the target's shutdown handler
|
||||
# recognises its SIGTERM as a planned takeover and exits 0
|
||||
# (rather than exit 1, which would trigger systemd's
|
||||
# Restart=on-failure and start a flap loop against us).
|
||||
# Best-effort — proceed even if the write fails.
|
||||
try:
|
||||
from gateway.status import write_takeover_marker
|
||||
write_takeover_marker(existing_pid)
|
||||
except Exception as e:
|
||||
logger.debug("Could not write takeover marker: %s", e)
|
||||
try:
|
||||
terminate_pid(existing_pid, force=False)
|
||||
except ProcessLookupError:
|
||||
@@ -9665,6 +10225,13 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool =
|
||||
"Permission denied killing PID %d. Cannot replace.",
|
||||
existing_pid,
|
||||
)
|
||||
# Marker is scoped to a specific target; clean it up on
|
||||
# give-up so it doesn't grief an unrelated future shutdown.
|
||||
try:
|
||||
from gateway.status import clear_takeover_marker
|
||||
clear_takeover_marker()
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
# Wait up to 10 seconds for the old process to exit
|
||||
for _ in range(20):
|
||||
@@ -9685,6 +10252,13 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool =
|
||||
except (ProcessLookupError, PermissionError, OSError):
|
||||
pass
|
||||
remove_pid_file()
|
||||
# Clean up any takeover marker the old process didn't consume
|
||||
# (e.g. SIGKILL'd before its shutdown handler could read it).
|
||||
try:
|
||||
from gateway.status import clear_takeover_marker
|
||||
clear_takeover_marker()
|
||||
except Exception:
|
||||
pass
|
||||
# Also release all scoped locks left by the old process.
|
||||
# Stopped (Ctrl+Z) processes don't release locks on exit,
|
||||
# leaving stale lock files that block the new gateway from starting.
|
||||
@@ -9752,8 +10326,27 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool =
|
||||
# Set up signal handlers
|
||||
def shutdown_signal_handler():
|
||||
nonlocal _signal_initiated_shutdown
|
||||
_signal_initiated_shutdown = True
|
||||
logger.info("Received SIGTERM/SIGINT — initiating shutdown")
|
||||
# Planned --replace takeover check: when a sibling gateway is
|
||||
# taking over via --replace, it wrote a marker naming this PID
|
||||
# before sending SIGTERM. If present, treat the signal as a
|
||||
# planned shutdown and exit 0 so systemd's Restart=on-failure
|
||||
# doesn't revive us (which would flap-fight the replacer when
|
||||
# both services are enabled, e.g. hermes.service + hermes-
|
||||
# gateway.service from pre-rename installs).
|
||||
planned_takeover = False
|
||||
try:
|
||||
from gateway.status import consume_takeover_marker_for_self
|
||||
planned_takeover = consume_takeover_marker_for_self()
|
||||
except Exception as e:
|
||||
logger.debug("Takeover marker check failed: %s", e)
|
||||
|
||||
if planned_takeover:
|
||||
logger.info(
|
||||
"Received SIGTERM as a planned --replace takeover — exiting cleanly"
|
||||
)
|
||||
else:
|
||||
_signal_initiated_shutdown = True
|
||||
logger.info("Received SIGTERM/SIGINT — initiating shutdown")
|
||||
# Diagnostic: log all hermes-related processes so we can identify
|
||||
# what triggered the signal (hermes update, hermes gateway restart,
|
||||
# a stale detached subprocess, etc.).
|
||||
|
||||
@@ -82,6 +82,7 @@ class SessionSource:
|
||||
chat_topic: Optional[str] = None # Channel topic/description (Discord, Slack)
|
||||
user_id_alt: Optional[str] = None # Signal UUID (alternative to phone number)
|
||||
chat_id_alt: Optional[str] = None # Signal group internal ID
|
||||
is_bot: bool = False # True when the message author is a bot/webhook (Discord)
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
@@ -801,6 +802,57 @@ class SessionStore:
|
||||
return True
|
||||
return False
|
||||
|
||||
def prune_old_entries(self, max_age_days: int) -> int:
|
||||
"""Drop SessionEntry records older than max_age_days.
|
||||
|
||||
Pruning is based on ``updated_at`` (last activity), not ``created_at``.
|
||||
A session that's been active within the window is kept regardless of
|
||||
how old it is. Entries marked ``suspended`` are kept — the user
|
||||
explicitly paused them for later resume. Entries held by an active
|
||||
process (via has_active_processes_fn) are also kept so long-running
|
||||
background work isn't orphaned.
|
||||
|
||||
Pruning is functionally identical to a natural reset-policy expiry:
|
||||
the transcript in SQLite stays, but the session_key → session_id
|
||||
mapping is dropped and the user starts a fresh session on return.
|
||||
|
||||
``max_age_days <= 0`` disables pruning; returns 0 immediately.
|
||||
Returns the number of entries removed.
|
||||
"""
|
||||
if max_age_days is None or max_age_days <= 0:
|
||||
return 0
|
||||
from datetime import timedelta
|
||||
|
||||
cutoff = _now() - timedelta(days=max_age_days)
|
||||
removed_keys: list[str] = []
|
||||
|
||||
with self._lock:
|
||||
self._ensure_loaded_locked()
|
||||
for key, entry in list(self._entries.items()):
|
||||
if entry.suspended:
|
||||
continue
|
||||
# Never prune sessions with an active background process
|
||||
# attached — the user may still be waiting on output.
|
||||
if self._has_active_processes_fn is not None:
|
||||
try:
|
||||
if self._has_active_processes_fn(entry.session_id):
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
if entry.updated_at < cutoff:
|
||||
removed_keys.append(key)
|
||||
for key in removed_keys:
|
||||
self._entries.pop(key, None)
|
||||
if removed_keys:
|
||||
self._save()
|
||||
|
||||
if removed_keys:
|
||||
logger.info(
|
||||
"SessionStore pruned %d entries older than %d days",
|
||||
len(removed_keys), max_age_days,
|
||||
)
|
||||
return len(removed_keys)
|
||||
|
||||
def suspend_recently_active(self, max_age_seconds: int = 120) -> int:
|
||||
"""Mark recently-active sessions as suspended.
|
||||
|
||||
|
||||
+159
-11
@@ -188,8 +188,8 @@ def _write_json_file(path: Path, payload: dict[str, Any]) -> None:
|
||||
path.write_text(json.dumps(payload))
|
||||
|
||||
|
||||
def _read_pid_record() -> Optional[dict]:
|
||||
pid_path = _get_pid_path()
|
||||
def _read_pid_record(pid_path: Optional[Path] = None) -> Optional[dict]:
|
||||
pid_path = pid_path or _get_pid_path()
|
||||
if not pid_path.exists():
|
||||
return None
|
||||
|
||||
@@ -212,6 +212,18 @@ def _read_pid_record() -> Optional[dict]:
|
||||
return None
|
||||
|
||||
|
||||
def _cleanup_invalid_pid_path(pid_path: Path, *, cleanup_stale: bool) -> None:
|
||||
if not cleanup_stale:
|
||||
return
|
||||
try:
|
||||
if pid_path == _get_pid_path():
|
||||
remove_pid_file()
|
||||
else:
|
||||
pid_path.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def write_pid_file() -> None:
|
||||
"""Write the current process PID and metadata to the gateway PID file."""
|
||||
_write_json_file(_get_pid_path(), _build_pid_record())
|
||||
@@ -413,43 +425,179 @@ def release_all_scoped_locks() -> int:
|
||||
return removed
|
||||
|
||||
|
||||
def get_running_pid() -> Optional[int]:
|
||||
# ── --replace takeover marker ─────────────────────────────────────────
|
||||
#
|
||||
# When a new gateway starts with ``--replace``, it SIGTERMs the existing
|
||||
# gateway so it can take over the bot token. PR #5646 made SIGTERM exit
|
||||
# the gateway with code 1 so ``Restart=on-failure`` can revive it after
|
||||
# unexpected kills — but that also means a --replace takeover target
|
||||
# exits 1, which tricks systemd into reviving it 30 seconds later,
|
||||
# starting a flap loop against the replacer when both services are
|
||||
# enabled in the user's systemd (e.g. ``hermes.service`` + ``hermes-
|
||||
# gateway.service``).
|
||||
#
|
||||
# The takeover marker breaks the loop: the replacer writes a short-lived
|
||||
# file naming the target PID + start_time BEFORE sending SIGTERM.
|
||||
# The target's shutdown handler reads the marker and, if it names
|
||||
# this process, treats the SIGTERM as a planned takeover and exits 0.
|
||||
# The marker is unlinked after the target has consumed it, so a stale
|
||||
# marker left by a crashed replacer can grief at most one future
|
||||
# shutdown on the same PID — and only within _TAKEOVER_MARKER_TTL_S.
|
||||
|
||||
_TAKEOVER_MARKER_FILENAME = ".gateway-takeover.json"
|
||||
_TAKEOVER_MARKER_TTL_S = 60 # Marker older than this is treated as stale
|
||||
|
||||
|
||||
def _get_takeover_marker_path() -> Path:
|
||||
"""Return the path to the --replace takeover marker file."""
|
||||
home = get_hermes_home()
|
||||
return home / _TAKEOVER_MARKER_FILENAME
|
||||
|
||||
|
||||
def write_takeover_marker(target_pid: int) -> bool:
|
||||
"""Record that ``target_pid`` is being replaced by the current process.
|
||||
|
||||
Captures the target's ``start_time`` so that PID reuse after the
|
||||
target exits cannot later match the marker. Also records the
|
||||
replacer's PID and a UTC timestamp for TTL-based staleness checks.
|
||||
|
||||
Returns True on successful write, False on any failure. The caller
|
||||
should proceed with the SIGTERM even if the write fails (the marker
|
||||
is a best-effort signal, not a correctness requirement).
|
||||
"""
|
||||
try:
|
||||
target_start_time = _get_process_start_time(target_pid)
|
||||
record = {
|
||||
"target_pid": target_pid,
|
||||
"target_start_time": target_start_time,
|
||||
"replacer_pid": os.getpid(),
|
||||
"written_at": _utc_now_iso(),
|
||||
}
|
||||
_write_json_file(_get_takeover_marker_path(), record)
|
||||
return True
|
||||
except (OSError, PermissionError):
|
||||
return False
|
||||
|
||||
|
||||
def consume_takeover_marker_for_self() -> bool:
|
||||
"""Check & unlink the takeover marker if it names the current process.
|
||||
|
||||
Returns True only when a valid (non-stale) marker names this PID +
|
||||
start_time. A returning True indicates the current SIGTERM is a
|
||||
planned --replace takeover; the caller should exit 0 instead of
|
||||
signalling ``_signal_initiated_shutdown``.
|
||||
|
||||
Always unlinks the marker on match (and on detected staleness) so
|
||||
subsequent unrelated signals don't re-trigger.
|
||||
"""
|
||||
path = _get_takeover_marker_path()
|
||||
record = _read_json_file(path)
|
||||
if not record:
|
||||
return False
|
||||
|
||||
# Any malformed or stale marker → drop it and return False
|
||||
try:
|
||||
target_pid = int(record["target_pid"])
|
||||
target_start_time = record.get("target_start_time")
|
||||
written_at = record.get("written_at") or ""
|
||||
except (KeyError, TypeError, ValueError):
|
||||
try:
|
||||
path.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
return False
|
||||
|
||||
# TTL guard: a stale marker older than _TAKEOVER_MARKER_TTL_S is ignored.
|
||||
stale = False
|
||||
try:
|
||||
written_dt = datetime.fromisoformat(written_at)
|
||||
age = (datetime.now(timezone.utc) - written_dt).total_seconds()
|
||||
if age > _TAKEOVER_MARKER_TTL_S:
|
||||
stale = True
|
||||
except (TypeError, ValueError):
|
||||
stale = True # Unparseable timestamp — treat as stale
|
||||
|
||||
if stale:
|
||||
try:
|
||||
path.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
return False
|
||||
|
||||
# Does the marker name THIS process?
|
||||
our_pid = os.getpid()
|
||||
our_start_time = _get_process_start_time(our_pid)
|
||||
matches = (
|
||||
target_pid == our_pid
|
||||
and target_start_time is not None
|
||||
and our_start_time is not None
|
||||
and target_start_time == our_start_time
|
||||
)
|
||||
|
||||
# Consume the marker whether it matched or not — a marker that doesn't
|
||||
# match our identity is stale-for-us anyway.
|
||||
try:
|
||||
path.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def clear_takeover_marker() -> None:
|
||||
"""Remove the takeover marker unconditionally. Safe to call repeatedly."""
|
||||
try:
|
||||
_get_takeover_marker_path().unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def get_running_pid(
|
||||
pid_path: Optional[Path] = None,
|
||||
*,
|
||||
cleanup_stale: bool = True,
|
||||
) -> Optional[int]:
|
||||
"""Return the PID of a running gateway instance, or ``None``.
|
||||
|
||||
Checks the PID file and verifies the process is actually alive.
|
||||
Cleans up stale PID files automatically.
|
||||
"""
|
||||
record = _read_pid_record()
|
||||
resolved_pid_path = pid_path or _get_pid_path()
|
||||
record = _read_pid_record(resolved_pid_path)
|
||||
if not record:
|
||||
remove_pid_file()
|
||||
_cleanup_invalid_pid_path(resolved_pid_path, cleanup_stale=cleanup_stale)
|
||||
return None
|
||||
|
||||
try:
|
||||
pid = int(record["pid"])
|
||||
except (KeyError, TypeError, ValueError):
|
||||
remove_pid_file()
|
||||
_cleanup_invalid_pid_path(resolved_pid_path, cleanup_stale=cleanup_stale)
|
||||
return None
|
||||
|
||||
try:
|
||||
os.kill(pid, 0) # signal 0 = existence check, no actual signal sent
|
||||
except (ProcessLookupError, PermissionError):
|
||||
remove_pid_file()
|
||||
_cleanup_invalid_pid_path(resolved_pid_path, cleanup_stale=cleanup_stale)
|
||||
return None
|
||||
|
||||
recorded_start = record.get("start_time")
|
||||
current_start = _get_process_start_time(pid)
|
||||
if recorded_start is not None and current_start is not None and current_start != recorded_start:
|
||||
remove_pid_file()
|
||||
_cleanup_invalid_pid_path(resolved_pid_path, cleanup_stale=cleanup_stale)
|
||||
return None
|
||||
|
||||
if not _looks_like_gateway_process(pid):
|
||||
if not _record_looks_like_gateway(record):
|
||||
remove_pid_file()
|
||||
_cleanup_invalid_pid_path(resolved_pid_path, cleanup_stale=cleanup_stale)
|
||||
return None
|
||||
|
||||
return pid
|
||||
|
||||
|
||||
def is_gateway_running() -> bool:
|
||||
def is_gateway_running(
|
||||
pid_path: Optional[Path] = None,
|
||||
*,
|
||||
cleanup_stale: bool = True,
|
||||
) -> bool:
|
||||
"""Check if the gateway daemon is currently running."""
|
||||
return get_running_pid() is not None
|
||||
return get_running_pid(pid_path, cleanup_stale=cleanup_stale) is not None
|
||||
|
||||
@@ -100,6 +100,14 @@ class GatewayStreamConsumer:
|
||||
self._flood_strikes = 0 # Consecutive flood-control edit failures
|
||||
self._current_edit_interval = self.cfg.edit_interval # Adaptive backoff
|
||||
self._final_response_sent = False
|
||||
# Cache adapter lifecycle capability: only platforms that need an
|
||||
# explicit finalize call (e.g. DingTalk AI Cards) force us to make
|
||||
# a redundant final edit. Everyone else keeps the fast path.
|
||||
# Use ``is True`` (not ``bool(...)``) so MagicMock attribute access
|
||||
# in tests doesn't incorrectly enable this path.
|
||||
self._adapter_requires_finalize: bool = (
|
||||
getattr(adapter, "REQUIRES_EDIT_FINALIZE", False) is True
|
||||
)
|
||||
|
||||
# Think-block filter state (mirrors CLI's _stream_delta tag suppression)
|
||||
self._in_think_block = False
|
||||
@@ -361,7 +369,16 @@ class GatewayStreamConsumer:
|
||||
if not got_done and not got_segment_break and commentary_text is None:
|
||||
display_text += self.cfg.cursor
|
||||
|
||||
current_update_visible = await self._send_or_edit(display_text)
|
||||
# Segment break: finalize the current message so platforms
|
||||
# that need explicit closure (e.g. DingTalk AI Cards) don't
|
||||
# leave the previous segment stuck in a loading state when
|
||||
# the next segment (tool progress, next chunk) creates a
|
||||
# new message below it. got_done has its own finalize
|
||||
# path below so we don't finalize here for it.
|
||||
current_update_visible = await self._send_or_edit(
|
||||
display_text,
|
||||
finalize=got_segment_break,
|
||||
)
|
||||
self._last_edit_time = time.monotonic()
|
||||
|
||||
if got_done:
|
||||
@@ -372,10 +389,22 @@ class GatewayStreamConsumer:
|
||||
if self._accumulated:
|
||||
if self._fallback_final_send:
|
||||
await self._send_fallback_final(self._accumulated)
|
||||
elif current_update_visible:
|
||||
elif (
|
||||
current_update_visible
|
||||
and not self._adapter_requires_finalize
|
||||
):
|
||||
# Mid-stream edit above already delivered the
|
||||
# final accumulated content. Skip the redundant
|
||||
# final edit — but only for adapters that don't
|
||||
# need an explicit finalize signal.
|
||||
self._final_response_sent = True
|
||||
elif self._message_id:
|
||||
self._final_response_sent = await self._send_or_edit(self._accumulated)
|
||||
# Either the mid-stream edit didn't run (no
|
||||
# visible update this tick) OR the adapter needs
|
||||
# explicit finalize=True to close the stream.
|
||||
self._final_response_sent = await self._send_or_edit(
|
||||
self._accumulated, finalize=True,
|
||||
)
|
||||
elif not self._already_sent:
|
||||
self._final_response_sent = await self._send_or_edit(self._accumulated)
|
||||
return
|
||||
@@ -633,12 +662,15 @@ class GatewayStreamConsumer:
|
||||
logger.error("Commentary send error: %s", e)
|
||||
return False
|
||||
|
||||
async def _send_or_edit(self, text: str) -> bool:
|
||||
async def _send_or_edit(self, text: str, *, finalize: bool = False) -> bool:
|
||||
"""Send or edit the streaming message.
|
||||
|
||||
Returns True if the text was successfully delivered (sent or edited),
|
||||
False otherwise. Callers like the overflow split loop use this to
|
||||
decide whether to advance past the delivered chunk.
|
||||
|
||||
``finalize`` is True when this is the last edit in a streaming
|
||||
sequence.
|
||||
"""
|
||||
# Strip MEDIA: directives so they don't appear as visible text.
|
||||
# Media files are delivered as native attachments after the stream
|
||||
@@ -672,14 +704,22 @@ class GatewayStreamConsumer:
|
||||
try:
|
||||
if self._message_id is not None:
|
||||
if self._edit_supported:
|
||||
# Skip if text is identical to what we last sent
|
||||
if text == self._last_sent_text:
|
||||
# Skip if text is identical to what we last sent.
|
||||
# Exception: adapters that require an explicit finalize
|
||||
# call (REQUIRES_EDIT_FINALIZE) must still receive the
|
||||
# finalize=True edit even when content is unchanged, so
|
||||
# their streaming UI can transition out of the in-
|
||||
# progress state. Everyone else short-circuits.
|
||||
if text == self._last_sent_text and not (
|
||||
finalize and self._adapter_requires_finalize
|
||||
):
|
||||
return True
|
||||
# Edit existing message
|
||||
result = await self.adapter.edit_message(
|
||||
chat_id=self.chat_id,
|
||||
message_id=self._message_id,
|
||||
content=text,
|
||||
finalize=finalize,
|
||||
)
|
||||
if result.success:
|
||||
self._already_sent = True
|
||||
|
||||
@@ -233,6 +233,14 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
|
||||
api_key_env_vars=("XAI_API_KEY",),
|
||||
base_url_env_var="XAI_BASE_URL",
|
||||
),
|
||||
"nvidia": ProviderConfig(
|
||||
id="nvidia",
|
||||
name="NVIDIA NIM",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://integrate.api.nvidia.com/v1",
|
||||
api_key_env_vars=("NVIDIA_API_KEY",),
|
||||
base_url_env_var="NVIDIA_BASE_URL",
|
||||
),
|
||||
"ai-gateway": ProviderConfig(
|
||||
id="ai-gateway",
|
||||
name="Vercel AI Gateway",
|
||||
@@ -773,6 +781,28 @@ def is_source_suppressed(provider_id: str, source: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def unsuppress_credential_source(provider_id: str, source: str) -> bool:
|
||||
"""Clear a suppression marker so the source will be re-seeded on the next load.
|
||||
|
||||
Returns True if a marker was cleared, False if no marker existed.
|
||||
"""
|
||||
with _auth_store_lock():
|
||||
auth_store = _load_auth_store()
|
||||
suppressed = auth_store.get("suppressed_sources")
|
||||
if not isinstance(suppressed, dict):
|
||||
return False
|
||||
provider_list = suppressed.get(provider_id)
|
||||
if not isinstance(provider_list, list) or source not in provider_list:
|
||||
return False
|
||||
provider_list.remove(source)
|
||||
if not provider_list:
|
||||
suppressed.pop(provider_id, None)
|
||||
if not suppressed:
|
||||
auth_store.pop("suppressed_sources", None)
|
||||
_save_auth_store(auth_store)
|
||||
return True
|
||||
|
||||
|
||||
def get_provider_auth_state(provider_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Return persisted auth state for a provider, or None."""
|
||||
auth_store = _load_auth_store()
|
||||
@@ -2129,6 +2159,62 @@ def refresh_nous_oauth_from_state(
|
||||
)
|
||||
|
||||
|
||||
NOUS_DEVICE_CODE_SOURCE = "device_code"
|
||||
|
||||
|
||||
def persist_nous_credentials(
|
||||
creds: Dict[str, Any],
|
||||
*,
|
||||
label: Optional[str] = None,
|
||||
):
|
||||
"""Persist minted Nous OAuth credentials as the singleton provider state
|
||||
and ensure the credential pool is in sync.
|
||||
|
||||
Nous credentials are read at runtime from two independent locations:
|
||||
|
||||
- ``providers.nous``: singleton state read by
|
||||
``resolve_nous_runtime_credentials()`` during 401 recovery and by
|
||||
``_seed_from_singletons()`` during pool load.
|
||||
- ``credential_pool.nous``: used by the runtime ``pool.select()`` path.
|
||||
|
||||
Historically ``hermes auth add nous`` wrote a ``manual:device_code`` pool
|
||||
entry only, skipping ``providers.nous``. When the 24h agent_key TTL
|
||||
expired, the recovery path read the empty singleton state and raised
|
||||
``AuthError`` silently (``logger.debug`` at INFO level).
|
||||
|
||||
This helper writes ``providers.nous`` then calls ``load_pool("nous")`` so
|
||||
``_seed_from_singletons`` materialises the canonical ``device_code`` pool
|
||||
entry from the singleton. Re-running login upserts the same entry in
|
||||
place; the pool never accumulates duplicate device_code rows.
|
||||
|
||||
``label`` is an optional user-chosen display name (from
|
||||
``hermes auth add nous --label <name>``). It gets embedded in the
|
||||
singleton state so that ``_seed_from_singletons`` uses it as the pool
|
||||
entry's label on every subsequent ``load_pool("nous")`` instead of the
|
||||
auto-derived token fingerprint. When ``None``, the auto-derived label
|
||||
via ``label_from_token`` is used (unchanged default behaviour).
|
||||
|
||||
Returns the upserted :class:`PooledCredential` entry (or ``None`` if
|
||||
seeding somehow produced no match — shouldn't happen).
|
||||
"""
|
||||
from agent.credential_pool import load_pool
|
||||
|
||||
state = dict(creds)
|
||||
if label and str(label).strip():
|
||||
state["label"] = str(label).strip()
|
||||
|
||||
with _auth_store_lock():
|
||||
auth_store = _load_auth_store()
|
||||
_save_provider_state(auth_store, "nous", state)
|
||||
_save_auth_store(auth_store)
|
||||
|
||||
pool = load_pool("nous")
|
||||
return next(
|
||||
(e for e in pool.entries() if e.source == NOUS_DEVICE_CODE_SOURCE),
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def resolve_nous_runtime_credentials(
|
||||
*,
|
||||
min_key_ttl_seconds: int = DEFAULT_AGENT_KEY_MIN_TTL_SECONDS,
|
||||
|
||||
+39
-13
@@ -217,22 +217,21 @@ def auth_add_command(args) -> None:
|
||||
ca_bundle=getattr(args, "ca_bundle", None),
|
||||
min_key_ttl_seconds=max(60, int(getattr(args, "min_key_ttl_seconds", 5 * 60))),
|
||||
)
|
||||
label = (getattr(args, "label", None) or "").strip() or label_from_token(
|
||||
creds.get("access_token", ""),
|
||||
_oauth_default_label(provider, len(pool.entries()) + 1),
|
||||
# Honor `--label <name>` so nous matches other providers' UX. The
|
||||
# helper embeds this into providers.nous so that label_from_token
|
||||
# doesn't overwrite it on every subsequent load_pool("nous").
|
||||
custom_label = (getattr(args, "label", None) or "").strip() or None
|
||||
entry = auth_mod.persist_nous_credentials(creds, label=custom_label)
|
||||
shown_label = entry.label if entry is not None else label_from_token(
|
||||
creds.get("access_token", ""), _oauth_default_label(provider, 1),
|
||||
)
|
||||
entry = PooledCredential.from_dict(provider, {
|
||||
**creds,
|
||||
"label": label,
|
||||
"auth_type": AUTH_TYPE_OAUTH,
|
||||
"source": f"{SOURCE_MANUAL}:device_code",
|
||||
"base_url": creds.get("inference_base_url"),
|
||||
})
|
||||
pool.add_entry(entry)
|
||||
print(f'Added {provider} OAuth credential #{len(pool.entries())}: "{entry.label}"')
|
||||
print(f'Saved {provider} OAuth device-code credentials: "{shown_label}"')
|
||||
return
|
||||
|
||||
if provider == "openai-codex":
|
||||
# Clear any existing suppression marker so a re-link after `hermes auth
|
||||
# remove openai-codex` works without the new tokens being skipped.
|
||||
auth_mod.unsuppress_credential_source(provider, "device_code")
|
||||
creds = auth_mod._codex_device_code_login()
|
||||
label = (getattr(args, "label", None) or "").strip() or label_from_token(
|
||||
creds["tokens"]["access_token"],
|
||||
@@ -352,7 +351,34 @@ def auth_remove_command(args) -> None:
|
||||
# If this was a singleton-seeded credential (OAuth device_code, hermes_pkce),
|
||||
# clear the underlying auth store / credential file so it doesn't get
|
||||
# re-seeded on the next load_pool() call.
|
||||
elif removed.source == "device_code" and provider in ("openai-codex", "nous"):
|
||||
elif provider == "openai-codex" and (
|
||||
removed.source == "device_code" or removed.source.endswith(":device_code")
|
||||
):
|
||||
# Codex tokens live in TWO places: the Hermes auth store and
|
||||
# ~/.codex/auth.json (the Codex CLI shared file). On every refresh,
|
||||
# refresh_codex_oauth_pure() writes to both. So clearing only the
|
||||
# Hermes auth store is not enough — _seed_from_singletons() will
|
||||
# auto-import from ~/.codex/auth.json on the next load_pool() and
|
||||
# the removal is instantly undone. Mark the source as suppressed
|
||||
# so auto-import is skipped; leave ~/.codex/auth.json untouched so
|
||||
# the Codex CLI itself keeps working.
|
||||
from hermes_cli.auth import (
|
||||
_load_auth_store, _save_auth_store, _auth_store_lock,
|
||||
suppress_credential_source,
|
||||
)
|
||||
with _auth_store_lock():
|
||||
auth_store = _load_auth_store()
|
||||
providers_dict = auth_store.get("providers")
|
||||
if isinstance(providers_dict, dict) and provider in providers_dict:
|
||||
del providers_dict[provider]
|
||||
_save_auth_store(auth_store)
|
||||
print(f"Cleared {provider} OAuth tokens from auth store")
|
||||
suppress_credential_source(provider, "device_code")
|
||||
print("Suppressed openai-codex device_code source — it will not be re-seeded.")
|
||||
print("Note: Codex CLI credentials still live in ~/.codex/auth.json")
|
||||
print("Run `hermes auth add openai-codex` to re-enable if needed.")
|
||||
|
||||
elif removed.source == "device_code" and provider == "nous":
|
||||
from hermes_cli.auth import (
|
||||
_load_auth_store, _save_auth_store, _auth_store_lock,
|
||||
)
|
||||
|
||||
+119
-70
@@ -7,8 +7,8 @@ CLI tools that ship with the platform (or are commonly installed).
|
||||
|
||||
Platform support:
|
||||
macOS — osascript (always available), pngpaste (if installed)
|
||||
Windows — PowerShell via .NET System.Windows.Forms.Clipboard
|
||||
WSL2 — powershell.exe via .NET System.Windows.Forms.Clipboard
|
||||
Windows — PowerShell via WinForms, Get-Clipboard, file-drop fallback
|
||||
WSL2 — powershell.exe via WinForms, Get-Clipboard, file-drop fallback
|
||||
Linux — wl-paste (Wayland), xclip (X11)
|
||||
"""
|
||||
|
||||
@@ -46,10 +46,11 @@ def has_clipboard_image() -> bool:
|
||||
return _macos_has_image()
|
||||
if sys.platform == "win32":
|
||||
return _windows_has_image()
|
||||
if _is_wsl():
|
||||
return _wsl_has_image()
|
||||
if os.environ.get("WAYLAND_DISPLAY"):
|
||||
return _wayland_has_image()
|
||||
# Match _linux_save fallthrough order: WSL → Wayland → X11
|
||||
if _is_wsl() and _wsl_has_image():
|
||||
return True
|
||||
if os.environ.get("WAYLAND_DISPLAY") and _wayland_has_image():
|
||||
return True
|
||||
return _xclip_has_image()
|
||||
|
||||
|
||||
@@ -135,6 +136,114 @@ _PS_EXTRACT_IMAGE = (
|
||||
"[System.Convert]::ToBase64String($ms.ToArray())"
|
||||
)
|
||||
|
||||
_PS_CHECK_IMAGE_GET_CLIPBOARD = (
|
||||
"try { "
|
||||
"$img = Get-Clipboard -Format Image -ErrorAction Stop;"
|
||||
"if ($null -ne $img) { 'True' } else { 'False' }"
|
||||
"} catch { 'False' }"
|
||||
)
|
||||
|
||||
_PS_EXTRACT_IMAGE_GET_CLIPBOARD = (
|
||||
"try { "
|
||||
"Add-Type -AssemblyName System.Drawing;"
|
||||
"Add-Type -AssemblyName PresentationCore;"
|
||||
"Add-Type -AssemblyName WindowsBase;"
|
||||
"$img = Get-Clipboard -Format Image -ErrorAction Stop;"
|
||||
"if ($null -eq $img) { exit 1 }"
|
||||
"$ms = New-Object System.IO.MemoryStream;"
|
||||
"if ($img -is [System.Drawing.Image]) {"
|
||||
"$img.Save($ms, [System.Drawing.Imaging.ImageFormat]::Png)"
|
||||
"} elseif ($img -is [System.Windows.Media.Imaging.BitmapSource]) {"
|
||||
"$enc = New-Object System.Windows.Media.Imaging.PngBitmapEncoder;"
|
||||
"$enc.Frames.Add([System.Windows.Media.Imaging.BitmapFrame]::Create($img));"
|
||||
"$enc.Save($ms)"
|
||||
"} else { exit 2 }"
|
||||
"[System.Convert]::ToBase64String($ms.ToArray())"
|
||||
"} catch { exit 1 }"
|
||||
)
|
||||
|
||||
_FILEDROP_IMAGE_EXTS = "'.png','.jpg','.jpeg','.gif','.webp','.bmp','.tiff','.tif'"
|
||||
|
||||
_PS_CHECK_FILEDROP_IMAGE = (
|
||||
"try { "
|
||||
"$files = Get-Clipboard -Format FileDropList -ErrorAction Stop;"
|
||||
f"$exts = @({_FILEDROP_IMAGE_EXTS});"
|
||||
"$hit = $files | Where-Object { $exts -contains ([System.IO.Path]::GetExtension($_).ToLowerInvariant()) } | Select-Object -First 1;"
|
||||
"if ($null -ne $hit) { 'True' } else { 'False' }"
|
||||
"} catch { 'False' }"
|
||||
)
|
||||
|
||||
_PS_EXTRACT_FILEDROP_IMAGE = (
|
||||
"try { "
|
||||
"$files = Get-Clipboard -Format FileDropList -ErrorAction Stop;"
|
||||
f"$exts = @({_FILEDROP_IMAGE_EXTS});"
|
||||
"$hit = $files | Where-Object { $exts -contains ([System.IO.Path]::GetExtension($_).ToLowerInvariant()) } | Select-Object -First 1;"
|
||||
"if ($null -eq $hit) { exit 1 }"
|
||||
"[System.Convert]::ToBase64String([System.IO.File]::ReadAllBytes($hit))"
|
||||
"} catch { exit 1 }"
|
||||
)
|
||||
|
||||
_POWERSHELL_HAS_IMAGE_SCRIPTS = (
|
||||
_PS_CHECK_IMAGE,
|
||||
_PS_CHECK_IMAGE_GET_CLIPBOARD,
|
||||
_PS_CHECK_FILEDROP_IMAGE,
|
||||
)
|
||||
|
||||
_POWERSHELL_EXTRACT_IMAGE_SCRIPTS = (
|
||||
_PS_EXTRACT_IMAGE,
|
||||
_PS_EXTRACT_IMAGE_GET_CLIPBOARD,
|
||||
_PS_EXTRACT_FILEDROP_IMAGE,
|
||||
)
|
||||
|
||||
|
||||
def _run_powershell(exe: str, script: str, timeout: int) -> subprocess.CompletedProcess:
|
||||
return subprocess.run(
|
||||
[exe, "-NoProfile", "-NonInteractive", "-Command", script],
|
||||
capture_output=True, text=True, timeout=timeout,
|
||||
)
|
||||
|
||||
|
||||
def _write_base64_image(dest: Path, b64_data: str) -> bool:
|
||||
image_bytes = base64.b64decode(b64_data, validate=True)
|
||||
dest.write_bytes(image_bytes)
|
||||
return dest.exists() and dest.stat().st_size > 0
|
||||
|
||||
|
||||
def _powershell_has_image(exe: str, *, timeout: int, label: str) -> bool:
|
||||
for script in _POWERSHELL_HAS_IMAGE_SCRIPTS:
|
||||
try:
|
||||
r = _run_powershell(exe, script, timeout=timeout)
|
||||
if r.returncode == 0 and "True" in r.stdout:
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
logger.debug("%s not found — clipboard unavailable", exe)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.debug("%s clipboard image check failed: %s", label, e)
|
||||
return False
|
||||
|
||||
|
||||
def _powershell_save_image(exe: str, dest: Path, *, timeout: int, label: str) -> bool:
|
||||
for script in _POWERSHELL_EXTRACT_IMAGE_SCRIPTS:
|
||||
try:
|
||||
r = _run_powershell(exe, script, timeout=timeout)
|
||||
if r.returncode != 0:
|
||||
continue
|
||||
|
||||
b64_data = r.stdout.strip()
|
||||
if not b64_data:
|
||||
continue
|
||||
|
||||
if _write_base64_image(dest, b64_data):
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
logger.debug("%s not found — clipboard unavailable", exe)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.debug("%s clipboard image extraction failed: %s", label, e)
|
||||
dest.unlink(missing_ok=True)
|
||||
return False
|
||||
|
||||
|
||||
# ── Native Windows ────────────────────────────────────────────────────────
|
||||
|
||||
@@ -175,15 +284,7 @@ def _windows_has_image() -> bool:
|
||||
ps = _get_ps_exe()
|
||||
if ps is None:
|
||||
return False
|
||||
try:
|
||||
r = subprocess.run(
|
||||
[ps, "-NoProfile", "-NonInteractive", "-Command", _PS_CHECK_IMAGE],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
)
|
||||
return r.returncode == 0 and "True" in r.stdout
|
||||
except Exception as e:
|
||||
logger.debug("Windows clipboard image check failed: %s", e)
|
||||
return False
|
||||
return _powershell_has_image(ps, timeout=5, label="Windows")
|
||||
|
||||
|
||||
def _windows_save(dest: Path) -> bool:
|
||||
@@ -192,26 +293,7 @@ def _windows_save(dest: Path) -> bool:
|
||||
if ps is None:
|
||||
logger.debug("No PowerShell found — Windows clipboard image paste unavailable")
|
||||
return False
|
||||
try:
|
||||
r = subprocess.run(
|
||||
[ps, "-NoProfile", "-NonInteractive", "-Command", _PS_EXTRACT_IMAGE],
|
||||
capture_output=True, text=True, timeout=15,
|
||||
)
|
||||
if r.returncode != 0:
|
||||
return False
|
||||
|
||||
b64_data = r.stdout.strip()
|
||||
if not b64_data:
|
||||
return False
|
||||
|
||||
png_bytes = base64.b64decode(b64_data)
|
||||
dest.write_bytes(png_bytes)
|
||||
return dest.exists() and dest.stat().st_size > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.debug("Windows clipboard image extraction failed: %s", e)
|
||||
dest.unlink(missing_ok=True)
|
||||
return False
|
||||
return _powershell_save_image(ps, dest, timeout=15, label="Windows")
|
||||
|
||||
|
||||
# ── Linux ────────────────────────────────────────────────────────────────
|
||||
@@ -235,45 +317,12 @@ def _linux_save(dest: Path) -> bool:
|
||||
|
||||
def _wsl_has_image() -> bool:
|
||||
"""Check if Windows clipboard has an image (via powershell.exe)."""
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["powershell.exe", "-NoProfile", "-NonInteractive", "-Command",
|
||||
_PS_CHECK_IMAGE],
|
||||
capture_output=True, text=True, timeout=8,
|
||||
)
|
||||
return r.returncode == 0 and "True" in r.stdout
|
||||
except FileNotFoundError:
|
||||
logger.debug("powershell.exe not found — WSL clipboard unavailable")
|
||||
except Exception as e:
|
||||
logger.debug("WSL clipboard check failed: %s", e)
|
||||
return False
|
||||
return _powershell_has_image("powershell.exe", timeout=8, label="WSL")
|
||||
|
||||
|
||||
def _wsl_save(dest: Path) -> bool:
|
||||
"""Extract clipboard image via powershell.exe → base64 → decode to PNG."""
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["powershell.exe", "-NoProfile", "-NonInteractive", "-Command",
|
||||
_PS_EXTRACT_IMAGE],
|
||||
capture_output=True, text=True, timeout=15,
|
||||
)
|
||||
if r.returncode != 0:
|
||||
return False
|
||||
|
||||
b64_data = r.stdout.strip()
|
||||
if not b64_data:
|
||||
return False
|
||||
|
||||
png_bytes = base64.b64decode(b64_data)
|
||||
dest.write_bytes(png_bytes)
|
||||
return dest.exists() and dest.stat().st_size > 0
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.debug("powershell.exe not found — WSL clipboard unavailable")
|
||||
except Exception as e:
|
||||
logger.debug("WSL clipboard extraction failed: %s", e)
|
||||
dest.unlink(missing_ok=True)
|
||||
return False
|
||||
return _powershell_save_image("powershell.exe", dest, timeout=15, label="WSL")
|
||||
|
||||
|
||||
# ── Wayland (wl-paste) ──────────────────────────────────────────────────
|
||||
|
||||
+95
-7
@@ -87,8 +87,12 @@ COMMAND_REGISTRY: list[CommandDef] = [
|
||||
aliases=("bg",), args_hint="<prompt>"),
|
||||
CommandDef("btw", "Ephemeral side question using session context (no tools, not persisted)", "Session",
|
||||
args_hint="<question>"),
|
||||
CommandDef("agents", "Show active agents and running tasks", "Session",
|
||||
aliases=("tasks",)),
|
||||
CommandDef("queue", "Queue a prompt for the next turn (doesn't interrupt)", "Session",
|
||||
aliases=("q",), args_hint="<prompt>"),
|
||||
CommandDef("steer", "Inject a message after the next tool call without interrupting", "Session",
|
||||
args_hint="<prompt>"),
|
||||
CommandDef("status", "Show session info", "Session"),
|
||||
CommandDef("profile", "Show active profile name and home directory", "Info"),
|
||||
CommandDef("sethome", "Set this chat as the home channel", "Session",
|
||||
@@ -99,7 +103,7 @@ COMMAND_REGISTRY: list[CommandDef] = [
|
||||
# Configuration
|
||||
CommandDef("config", "Show current configuration", "Configuration",
|
||||
cli_only=True),
|
||||
CommandDef("model", "Switch model for this session", "Configuration", args_hint="[model] [--global]"),
|
||||
CommandDef("model", "Switch model for this session", "Configuration", args_hint="[model] [--provider name] [--global]"),
|
||||
CommandDef("provider", "Show available providers and current provider",
|
||||
"Configuration"),
|
||||
CommandDef("gquota", "Show Google Gemini Code Assist quota usage", "Info"),
|
||||
@@ -120,7 +124,7 @@ COMMAND_REGISTRY: list[CommandDef] = [
|
||||
args_hint="[normal|fast|status]",
|
||||
subcommands=("normal", "fast", "status", "on", "off")),
|
||||
CommandDef("skin", "Show or change the display skin/theme", "Configuration",
|
||||
cli_only=True, args_hint="[name]"),
|
||||
args_hint="[name]"),
|
||||
CommandDef("voice", "Toggle voice mode", "Configuration",
|
||||
args_hint="[on|off|tts|status]", subcommands=("on", "off", "tts", "status")),
|
||||
|
||||
@@ -155,7 +159,9 @@ COMMAND_REGISTRY: list[CommandDef] = [
|
||||
args_hint="[days]"),
|
||||
CommandDef("platforms", "Show gateway/messaging platform status", "Info",
|
||||
cli_only=True, aliases=("gateway",)),
|
||||
CommandDef("paste", "Check clipboard for an image and attach it", "Info",
|
||||
CommandDef("copy", "Copy the last assistant response to clipboard", "Info",
|
||||
cli_only=True, args_hint="[number]"),
|
||||
CommandDef("paste", "Attach clipboard image from your clipboard", "Info",
|
||||
cli_only=True),
|
||||
CommandDef("image", "Attach a local image file for your next prompt", "Info",
|
||||
cli_only=True, args_hint="<path>"),
|
||||
@@ -254,6 +260,36 @@ GATEWAY_KNOWN_COMMANDS: frozenset[str] = frozenset(
|
||||
)
|
||||
|
||||
|
||||
# Commands that must never be queued behind an active gateway session.
|
||||
# These are explicit control/info commands handled by the gateway itself;
|
||||
# if they get queued as pending text, the safety net in gateway.run will
|
||||
# discard them before they ever reach the user.
|
||||
ACTIVE_SESSION_BYPASS_COMMANDS: frozenset[str] = frozenset(
|
||||
{
|
||||
"agents",
|
||||
"approve",
|
||||
"background",
|
||||
"commands",
|
||||
"deny",
|
||||
"help",
|
||||
"new",
|
||||
"profile",
|
||||
"queue",
|
||||
"restart",
|
||||
"status",
|
||||
"steer",
|
||||
"stop",
|
||||
"update",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def should_bypass_active_session(command_name: str | None) -> bool:
|
||||
"""Return True when a slash command must bypass active-session queuing."""
|
||||
cmd = resolve_command(command_name) if command_name else None
|
||||
return bool(cmd and cmd.name in ACTIVE_SESSION_BYPASS_COMMANDS)
|
||||
|
||||
|
||||
def _resolve_config_gates() -> set[str]:
|
||||
"""Return canonical names of commands whose ``gateway_config_gate`` is truthy.
|
||||
|
||||
@@ -1044,6 +1080,51 @@ class SlashCommandCompleter(Completer):
|
||||
display_meta=f"{fp} {meta}" if meta else fp,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _skin_completions(sub_text: str, sub_lower: str):
|
||||
"""Yield completions for /skin from available skins."""
|
||||
try:
|
||||
from hermes_cli.skin_engine import list_skins
|
||||
for s in list_skins():
|
||||
name = s["name"]
|
||||
if name.startswith(sub_lower) and name != sub_lower:
|
||||
yield Completion(
|
||||
name,
|
||||
start_position=-len(sub_text),
|
||||
display=name,
|
||||
display_meta=s.get("description", "") or s.get("source", ""),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _personality_completions(sub_text: str, sub_lower: str):
|
||||
"""Yield completions for /personality from configured personalities."""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
personalities = load_config().get("agent", {}).get("personalities", {})
|
||||
if "none".startswith(sub_lower) and "none" != sub_lower:
|
||||
yield Completion(
|
||||
"none",
|
||||
start_position=-len(sub_text),
|
||||
display="none",
|
||||
display_meta="clear personality overlay",
|
||||
)
|
||||
for name, prompt in personalities.items():
|
||||
if name.startswith(sub_lower) and name != sub_lower:
|
||||
if isinstance(prompt, dict):
|
||||
meta = prompt.get("description") or prompt.get("system_prompt", "")[:50]
|
||||
else:
|
||||
meta = str(prompt)[:50]
|
||||
yield Completion(
|
||||
name,
|
||||
start_position=-len(sub_text),
|
||||
display=name,
|
||||
display_meta=meta,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _model_completions(self, sub_text: str, sub_lower: str):
|
||||
"""Yield completions for /model from config aliases + built-in aliases."""
|
||||
seen = set()
|
||||
@@ -1098,10 +1179,17 @@ class SlashCommandCompleter(Completer):
|
||||
sub_text = parts[1] if len(parts) > 1 else ""
|
||||
sub_lower = sub_text.lower()
|
||||
|
||||
# Dynamic model alias completions for /model
|
||||
if " " not in sub_text and base_cmd == "/model":
|
||||
yield from self._model_completions(sub_text, sub_lower)
|
||||
return
|
||||
# Dynamic completions for commands with runtime lists
|
||||
if " " not in sub_text:
|
||||
if base_cmd == "/model":
|
||||
yield from self._model_completions(sub_text, sub_lower)
|
||||
return
|
||||
if base_cmd == "/skin":
|
||||
yield from self._skin_completions(sub_text, sub_lower)
|
||||
return
|
||||
if base_cmd == "/personality":
|
||||
yield from self._personality_completions(sub_text, sub_lower)
|
||||
return
|
||||
|
||||
# Static subcommand completions
|
||||
if " " not in sub_text and base_cmd in SUBCOMMANDS and self._command_allowed(base_cmd):
|
||||
|
||||
+139
-8
@@ -12,6 +12,7 @@ This module provides:
|
||||
- hermes config wizard - Re-run setup wizard
|
||||
"""
|
||||
|
||||
import copy
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
@@ -26,6 +27,7 @@ from typing import Dict, Any, Optional, List, Tuple
|
||||
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
_ENV_VAR_NAME_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
||||
_LAST_EXPANDED_CONFIG_BY_PATH: Dict[str, Any] = {}
|
||||
# Env var names written to .env that aren't in OPTIONAL_ENV_VARS
|
||||
# (managed by setup/provider flows directly).
|
||||
_EXTRA_ENV_KEYS = frozenset({
|
||||
@@ -44,7 +46,8 @@ _EXTRA_ENV_KEYS = frozenset({
|
||||
"WEIXIN_HOME_CHANNEL", "WEIXIN_HOME_CHANNEL_NAME", "WEIXIN_DM_POLICY", "WEIXIN_GROUP_POLICY",
|
||||
"WEIXIN_ALLOWED_USERS", "WEIXIN_GROUP_ALLOWED_USERS", "WEIXIN_ALLOW_ALL_USERS",
|
||||
"BLUEBUBBLES_SERVER_URL", "BLUEBUBBLES_PASSWORD",
|
||||
"QQ_APP_ID", "QQ_CLIENT_SECRET", "QQ_HOME_CHANNEL", "QQ_HOME_CHANNEL_NAME",
|
||||
"QQ_APP_ID", "QQ_CLIENT_SECRET", "QQBOT_HOME_CHANNEL", "QQBOT_HOME_CHANNEL_NAME",
|
||||
"QQ_HOME_CHANNEL", "QQ_HOME_CHANNEL_NAME", # legacy aliases (pre-rename, still read for back-compat)
|
||||
"QQ_ALLOWED_USERS", "QQ_GROUP_ALLOWED_USERS", "QQ_ALLOW_ALL_USERS", "QQ_MARKDOWN_SUPPORT",
|
||||
"QQ_STT_API_KEY", "QQ_STT_BASE_URL", "QQ_STT_MODEL",
|
||||
"TERMINAL_ENV", "TERMINAL_SSH_KEY", "TERMINAL_SSH_PORT",
|
||||
@@ -417,6 +420,7 @@ DEFAULT_CONFIG = {
|
||||
"command_timeout": 30, # Timeout for browser commands in seconds (screenshot, navigate, etc.)
|
||||
"record_sessions": False, # Auto-record browser sessions as WebM videos
|
||||
"allow_private_urls": False, # Allow navigating to private/internal IPs (localhost, 192.168.x.x, etc.)
|
||||
"cdp_url": "", # Optional persistent CDP endpoint for attaching to an existing Chromium/Chrome
|
||||
"camofox": {
|
||||
# When true, Hermes sends a stable profile-scoped userId to Camofox
|
||||
# so the server maps it to a persistent Firefox profile automatically.
|
||||
@@ -537,6 +541,13 @@ DEFAULT_CONFIG = {
|
||||
"api_key": "",
|
||||
"timeout": 30,
|
||||
},
|
||||
"title_generation": {
|
||||
"provider": "auto",
|
||||
"model": "",
|
||||
"base_url": "",
|
||||
"api_key": "",
|
||||
"timeout": 30,
|
||||
},
|
||||
},
|
||||
|
||||
"display": {
|
||||
@@ -760,6 +771,20 @@ DEFAULT_CONFIG = {
|
||||
"wrap_response": True,
|
||||
},
|
||||
|
||||
# execute_code settings — controls the tool used for programmatic tool calls.
|
||||
"code_execution": {
|
||||
# Execution mode:
|
||||
# project (default) — scripts run in the session's working directory
|
||||
# with the active virtualenv/conda env's python, so project deps
|
||||
# (pandas, torch, project packages) and relative paths resolve.
|
||||
# strict — scripts run in an isolated temp directory with
|
||||
# hermes-agent's own python (sys.executable). Maximum isolation
|
||||
# and reproducibility; project deps and relative paths won't work.
|
||||
# Env scrubbing (strips *_API_KEY, *_TOKEN, *_SECRET, ...) and the
|
||||
# tool whitelist apply identically in both modes.
|
||||
"mode": "project",
|
||||
},
|
||||
|
||||
# Logging — controls file logging to ~/.hermes/logs/.
|
||||
# agent.log captures INFO+ (all agent activity); errors.log captures WARNING+.
|
||||
"logging": {
|
||||
@@ -777,7 +802,7 @@ DEFAULT_CONFIG = {
|
||||
},
|
||||
|
||||
# Config schema version - bump this when adding new required fields
|
||||
"_config_version": 18,
|
||||
"_config_version": 19,
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
@@ -861,6 +886,22 @@ OPTIONAL_ENV_VARS = {
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"NVIDIA_API_KEY": {
|
||||
"description": "NVIDIA NIM API key (build.nvidia.com or local NIM endpoint)",
|
||||
"prompt": "NVIDIA NIM API key",
|
||||
"url": "https://build.nvidia.com/",
|
||||
"password": True,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"NVIDIA_BASE_URL": {
|
||||
"description": "NVIDIA NIM base URL override (e.g. http://localhost:8000/v1 for local NIM)",
|
||||
"prompt": "NVIDIA NIM base URL (leave empty for default)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"GLM_API_KEY": {
|
||||
"description": "Z.AI / GLM API key (also recognized as ZAI_API_KEY / Z_AI_API_KEY)",
|
||||
"prompt": "Z.AI / GLM API key",
|
||||
@@ -1518,12 +1559,12 @@ OPTIONAL_ENV_VARS = {
|
||||
"prompt": "Allow All QQ Users",
|
||||
"category": "messaging",
|
||||
},
|
||||
"QQ_HOME_CHANNEL": {
|
||||
"QQBOT_HOME_CHANNEL": {
|
||||
"description": "Default QQ channel/group for cron delivery and notifications",
|
||||
"prompt": "QQ Home Channel",
|
||||
"category": "messaging",
|
||||
},
|
||||
"QQ_HOME_CHANNEL_NAME": {
|
||||
"QQBOT_HOME_CHANNEL_NAME": {
|
||||
"description": "Display name for the QQ home channel",
|
||||
"prompt": "QQ Home Channel Name",
|
||||
"category": "messaging",
|
||||
@@ -2610,6 +2651,85 @@ def _expand_env_vars(obj):
|
||||
return obj
|
||||
|
||||
|
||||
def _items_by_unique_name(items):
|
||||
"""Return a name-indexed dict only when all items have unique string names."""
|
||||
if not isinstance(items, list):
|
||||
return None
|
||||
indexed = {}
|
||||
for item in items:
|
||||
if not isinstance(item, dict) or not isinstance(item.get("name"), str):
|
||||
return None
|
||||
name = item["name"]
|
||||
if name in indexed:
|
||||
return None
|
||||
indexed[name] = item
|
||||
return indexed
|
||||
|
||||
|
||||
def _preserve_env_ref_templates(current, raw, loaded_expanded=None):
|
||||
"""Restore raw ``${VAR}`` templates when a value is otherwise unchanged.
|
||||
|
||||
``load_config()`` expands env refs for runtime use. When a caller later
|
||||
persists that config after modifying some unrelated setting, keep the
|
||||
original on-disk template instead of writing the expanded plaintext
|
||||
secret back to ``config.yaml``.
|
||||
|
||||
Prefer preserving the raw template when ``current`` still matches either
|
||||
the value previously returned by ``load_config()`` for this config path or
|
||||
the current environment expansion of ``raw``. This handles env-var
|
||||
rotation between load and save while still treating mixed literal/template
|
||||
string edits as caller-owned once their rendered value diverges.
|
||||
"""
|
||||
if isinstance(current, str) and isinstance(raw, str) and re.search(r"\${[^}]+}", raw):
|
||||
if current == raw:
|
||||
return raw
|
||||
if isinstance(loaded_expanded, str) and current == loaded_expanded:
|
||||
return raw
|
||||
if _expand_env_vars(raw) == current:
|
||||
return raw
|
||||
return current
|
||||
|
||||
if isinstance(current, dict) and isinstance(raw, dict):
|
||||
return {
|
||||
key: _preserve_env_ref_templates(
|
||||
value,
|
||||
raw.get(key),
|
||||
loaded_expanded.get(key) if isinstance(loaded_expanded, dict) else None,
|
||||
)
|
||||
for key, value in current.items()
|
||||
}
|
||||
|
||||
if isinstance(current, list) and isinstance(raw, list):
|
||||
# Prefer matching named config objects (e.g. custom_providers) by name
|
||||
# so harmless reordering doesn't drop the original template. If names
|
||||
# are duplicated, fall back to positional matching instead of silently
|
||||
# shadowing one entry.
|
||||
current_by_name = _items_by_unique_name(current)
|
||||
raw_by_name = _items_by_unique_name(raw)
|
||||
loaded_by_name = _items_by_unique_name(loaded_expanded)
|
||||
if current_by_name is not None and raw_by_name is not None:
|
||||
return [
|
||||
_preserve_env_ref_templates(
|
||||
item,
|
||||
raw_by_name.get(item.get("name")),
|
||||
loaded_by_name.get(item.get("name")) if loaded_by_name is not None else None,
|
||||
)
|
||||
for item in current
|
||||
]
|
||||
return [
|
||||
_preserve_env_ref_templates(
|
||||
item,
|
||||
raw[index] if index < len(raw) else None,
|
||||
loaded_expanded[index]
|
||||
if isinstance(loaded_expanded, list) and index < len(loaded_expanded)
|
||||
else None,
|
||||
)
|
||||
for index, item in enumerate(current)
|
||||
]
|
||||
|
||||
return current
|
||||
|
||||
|
||||
def _normalize_root_model_keys(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Move stale root-level provider/base_url into model section.
|
||||
|
||||
@@ -2677,7 +2797,6 @@ def read_raw_config() -> Dict[str, Any]:
|
||||
|
||||
def load_config() -> Dict[str, Any]:
|
||||
"""Load configuration from ~/.hermes/config.yaml."""
|
||||
import copy
|
||||
ensure_hermes_home()
|
||||
config_path = get_config_path()
|
||||
|
||||
@@ -2698,8 +2817,11 @@ def load_config() -> Dict[str, Any]:
|
||||
config = _deep_merge(config, user_config)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load config: {e}")
|
||||
|
||||
return _expand_env_vars(_normalize_root_model_keys(_normalize_max_turns_config(config)))
|
||||
|
||||
normalized = _normalize_root_model_keys(_normalize_max_turns_config(config))
|
||||
expanded = _expand_env_vars(normalized)
|
||||
_LAST_EXPANDED_CONFIG_BY_PATH[str(config_path)] = copy.deepcopy(expanded)
|
||||
return expanded
|
||||
|
||||
|
||||
_SECURITY_COMMENT = """
|
||||
@@ -2808,7 +2930,15 @@ def save_config(config: Dict[str, Any]):
|
||||
|
||||
ensure_hermes_home()
|
||||
config_path = get_config_path()
|
||||
normalized = _normalize_root_model_keys(_normalize_max_turns_config(config))
|
||||
current_normalized = _normalize_root_model_keys(_normalize_max_turns_config(config))
|
||||
normalized = current_normalized
|
||||
raw_existing = _normalize_root_model_keys(_normalize_max_turns_config(read_raw_config()))
|
||||
if raw_existing:
|
||||
normalized = _preserve_env_ref_templates(
|
||||
normalized,
|
||||
raw_existing,
|
||||
_LAST_EXPANDED_CONFIG_BY_PATH.get(str(config_path)),
|
||||
)
|
||||
|
||||
# Build optional commented-out sections for features that are off by
|
||||
# default or only relevant when explicitly configured.
|
||||
@@ -2826,6 +2956,7 @@ def save_config(config: Dict[str, Any]):
|
||||
extra_content="".join(parts) if parts else None,
|
||||
)
|
||||
_secure_file(config_path)
|
||||
_LAST_EXPANDED_CONFIG_BY_PATH[str(config_path)] = copy.deepcopy(current_normalized)
|
||||
|
||||
|
||||
def load_env() -> Dict[str, str]:
|
||||
|
||||
+137
-29
@@ -6,7 +6,10 @@ Currently supports:
|
||||
"""
|
||||
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import urllib.error
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
@@ -31,6 +34,119 @@ _MAX_LOG_BYTES = 512_000
|
||||
_AUTO_DELETE_SECONDS = 21600
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pending-deletion tracking (replaces the old fork-and-sleep subprocess).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _pending_file() -> Path:
|
||||
"""Path to ``~/.hermes/pastes/pending.json``.
|
||||
|
||||
Each entry: ``{"url": "...", "expire_at": <unix_ts>}``. Scheduled
|
||||
DELETEs used to be handled by spawning a detached Python process per
|
||||
paste that slept for 6 hours; those accumulated forever if the user
|
||||
ran ``hermes debug share`` repeatedly. We now persist the schedule
|
||||
to disk and sweep expired entries on the next debug invocation.
|
||||
"""
|
||||
return get_hermes_home() / "pastes" / "pending.json"
|
||||
|
||||
|
||||
def _load_pending() -> list[dict]:
|
||||
path = _pending_file()
|
||||
if not path.exists():
|
||||
return []
|
||||
try:
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
if isinstance(data, list):
|
||||
# Filter to well-formed entries only
|
||||
return [
|
||||
e for e in data
|
||||
if isinstance(e, dict) and "url" in e and "expire_at" in e
|
||||
]
|
||||
except (OSError, ValueError, json.JSONDecodeError):
|
||||
pass
|
||||
return []
|
||||
|
||||
|
||||
def _save_pending(entries: list[dict]) -> None:
|
||||
path = _pending_file()
|
||||
try:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = path.with_suffix(".json.tmp")
|
||||
tmp.write_text(json.dumps(entries, indent=2), encoding="utf-8")
|
||||
os.replace(tmp, path)
|
||||
except OSError:
|
||||
# Non-fatal — worst case the user has to run ``hermes debug delete``
|
||||
# manually.
|
||||
pass
|
||||
|
||||
|
||||
def _record_pending(urls: list[str], delay_seconds: int = _AUTO_DELETE_SECONDS) -> None:
|
||||
"""Record *urls* for deletion at ``now + delay_seconds``.
|
||||
|
||||
Only paste.rs URLs are recorded (dpaste.com auto-expires). Entries
|
||||
are merged into any existing pending.json.
|
||||
"""
|
||||
paste_rs_urls = [u for u in urls if _extract_paste_id(u)]
|
||||
if not paste_rs_urls:
|
||||
return
|
||||
|
||||
entries = _load_pending()
|
||||
# Dedupe by URL: keep the later expire_at if same URL appears twice
|
||||
by_url: dict[str, float] = {e["url"]: float(e["expire_at"]) for e in entries}
|
||||
expire_at = time.time() + delay_seconds
|
||||
for u in paste_rs_urls:
|
||||
by_url[u] = max(expire_at, by_url.get(u, 0.0))
|
||||
merged = [{"url": u, "expire_at": ts} for u, ts in by_url.items()]
|
||||
_save_pending(merged)
|
||||
|
||||
|
||||
def _sweep_expired_pastes(now: Optional[float] = None) -> tuple[int, int]:
|
||||
"""Synchronously DELETE any pending pastes whose ``expire_at`` has passed.
|
||||
|
||||
Returns ``(deleted, remaining)``. Best-effort: failed deletes stay in
|
||||
the pending file and will be retried on the next sweep. Silent —
|
||||
intended to be called from every ``hermes debug`` invocation with
|
||||
minimal noise.
|
||||
"""
|
||||
entries = _load_pending()
|
||||
if not entries:
|
||||
return (0, 0)
|
||||
|
||||
current = time.time() if now is None else now
|
||||
deleted = 0
|
||||
remaining: list[dict] = []
|
||||
|
||||
for entry in entries:
|
||||
try:
|
||||
expire_at = float(entry.get("expire_at", 0))
|
||||
except (TypeError, ValueError):
|
||||
continue # drop malformed entries
|
||||
if expire_at > current:
|
||||
remaining.append(entry)
|
||||
continue
|
||||
|
||||
url = entry.get("url", "")
|
||||
try:
|
||||
if delete_paste(url):
|
||||
deleted += 1
|
||||
continue
|
||||
except Exception:
|
||||
# Network hiccup, 404 (already gone), etc. — drop the entry
|
||||
# after a grace period; don't retry forever.
|
||||
pass
|
||||
|
||||
# Retain failed deletes for up to 24h past expiration, then give up.
|
||||
if expire_at + 86400 > current:
|
||||
remaining.append(entry)
|
||||
else:
|
||||
deleted += 1 # count as reaped (paste.rs will GC eventually)
|
||||
|
||||
if deleted:
|
||||
_save_pending(remaining)
|
||||
|
||||
return (deleted, len(remaining))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Privacy / delete helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -90,37 +206,19 @@ def delete_paste(url: str) -> bool:
|
||||
|
||||
|
||||
def _schedule_auto_delete(urls: list[str], delay_seconds: int = _AUTO_DELETE_SECONDS):
|
||||
"""Spawn a detached process to delete paste.rs pastes after *delay_seconds*.
|
||||
"""Record *urls* for deletion ``delay_seconds`` from now.
|
||||
|
||||
The child process is fully detached (``start_new_session=True``) so it
|
||||
survives the parent exiting (important for CLI mode). Only paste.rs
|
||||
URLs are attempted — dpaste.com pastes auto-expire on their own.
|
||||
Previously this spawned a detached Python subprocess per call that slept
|
||||
for 6 hours and then issued DELETE requests. Those subprocesses leaked —
|
||||
every ``hermes debug share`` invocation added ~20 MB of resident Python
|
||||
interpreters that never exited until the sleep completed.
|
||||
|
||||
The replacement is stateless: we append to ``~/.hermes/pastes/pending.json``
|
||||
and rely on opportunistic sweeps (``_sweep_expired_pastes``) called from
|
||||
every ``hermes debug`` invocation. If the user never runs ``hermes debug``
|
||||
again, paste.rs's own retention policy handles cleanup.
|
||||
"""
|
||||
import subprocess
|
||||
|
||||
paste_rs_urls = [u for u in urls if _extract_paste_id(u)]
|
||||
if not paste_rs_urls:
|
||||
return
|
||||
|
||||
# Build a tiny inline Python script. No imports beyond stdlib.
|
||||
url_list = ", ".join(f'"{u}"' for u in paste_rs_urls)
|
||||
script = (
|
||||
"import time, urllib.request; "
|
||||
f"time.sleep({delay_seconds}); "
|
||||
f"[urllib.request.urlopen(urllib.request.Request(u, method='DELETE', "
|
||||
f"headers={{'User-Agent': 'hermes-agent/auto-delete'}}), timeout=15) "
|
||||
f"for u in [{url_list}]]"
|
||||
)
|
||||
|
||||
try:
|
||||
subprocess.Popen(
|
||||
[sys.executable, "-c", script],
|
||||
start_new_session=True,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
except Exception:
|
||||
pass # Best-effort; manual delete still available.
|
||||
_record_pending(urls, delay_seconds=delay_seconds)
|
||||
|
||||
|
||||
def _delete_hint(url: str) -> str:
|
||||
@@ -455,6 +553,16 @@ def run_debug_delete(args):
|
||||
|
||||
def run_debug(args):
|
||||
"""Route debug subcommands."""
|
||||
# Opportunistic sweep of expired pastes on every ``hermes debug`` call.
|
||||
# Replaces the old per-paste sleeping subprocess that used to leak as
|
||||
# one orphaned Python interpreter per scheduled deletion. Silent and
|
||||
# best-effort — any failure is swallowed so ``hermes debug`` stays
|
||||
# reliable even when offline.
|
||||
try:
|
||||
_sweep_expired_pastes()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
subcmd = getattr(args, "debug_command", None)
|
||||
if subcmd == "share":
|
||||
run_debug_share(args)
|
||||
|
||||
@@ -0,0 +1,294 @@
|
||||
"""
|
||||
DingTalk Device Flow authorization.
|
||||
|
||||
Implements the same 3-step registration flow as dingtalk-openclaw-connector:
|
||||
1. POST /app/registration/init → get nonce
|
||||
2. POST /app/registration/begin → get device_code + verification_uri_complete
|
||||
3. POST /app/registration/poll → poll until SUCCESS → get client_id + client_secret
|
||||
|
||||
The verification_uri_complete is rendered as a QR code in the terminal so the
|
||||
user can scan it with DingTalk to authorize, yielding AppKey + AppSecret
|
||||
automatically.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Configuration ──────────────────────────────────────────────────────────
|
||||
|
||||
REGISTRATION_BASE_URL = os.environ.get(
|
||||
"DINGTALK_REGISTRATION_BASE_URL", "https://oapi.dingtalk.com"
|
||||
).rstrip("/")
|
||||
|
||||
REGISTRATION_SOURCE = os.environ.get("DINGTALK_REGISTRATION_SOURCE", "openClaw")
|
||||
|
||||
|
||||
# ── API helpers ────────────────────────────────────────────────────────────
|
||||
|
||||
class RegistrationError(Exception):
|
||||
"""Raised when a DingTalk registration API call fails."""
|
||||
|
||||
|
||||
def _api_post(path: str, payload: dict) -> dict:
|
||||
"""POST to the registration API and return the parsed JSON body."""
|
||||
url = f"{REGISTRATION_BASE_URL}{path}"
|
||||
try:
|
||||
resp = requests.post(url, json=payload, timeout=15)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
except requests.RequestException as exc:
|
||||
raise RegistrationError(f"Network error calling {url}: {exc}") from exc
|
||||
|
||||
errcode = data.get("errcode", -1)
|
||||
if errcode != 0:
|
||||
errmsg = data.get("errmsg", "unknown error")
|
||||
raise RegistrationError(f"API error [{path}]: {errmsg} (errcode={errcode})")
|
||||
return data
|
||||
|
||||
|
||||
# ── Core flow ──────────────────────────────────────────────────────────────
|
||||
|
||||
def begin_registration() -> dict:
|
||||
"""Start a device-flow registration.
|
||||
|
||||
Returns a dict with keys:
|
||||
device_code, verification_uri_complete, expires_in, interval
|
||||
"""
|
||||
# Step 1: init → nonce
|
||||
init_data = _api_post("/app/registration/init", {"source": REGISTRATION_SOURCE})
|
||||
nonce = str(init_data.get("nonce", "")).strip()
|
||||
if not nonce:
|
||||
raise RegistrationError("init response missing nonce")
|
||||
|
||||
# Step 2: begin → device_code, verification_uri_complete
|
||||
begin_data = _api_post("/app/registration/begin", {"nonce": nonce})
|
||||
device_code = str(begin_data.get("device_code", "")).strip()
|
||||
verification_uri_complete = str(begin_data.get("verification_uri_complete", "")).strip()
|
||||
if not device_code:
|
||||
raise RegistrationError("begin response missing device_code")
|
||||
if not verification_uri_complete:
|
||||
raise RegistrationError("begin response missing verification_uri_complete")
|
||||
|
||||
return {
|
||||
"device_code": device_code,
|
||||
"verification_uri_complete": verification_uri_complete,
|
||||
"expires_in": int(begin_data.get("expires_in", 7200)),
|
||||
"interval": max(int(begin_data.get("interval", 3)), 2),
|
||||
}
|
||||
|
||||
|
||||
def poll_registration(device_code: str) -> dict:
|
||||
"""Poll the registration status once.
|
||||
|
||||
Returns a dict with keys: status, client_id?, client_secret?, fail_reason?
|
||||
"""
|
||||
data = _api_post("/app/registration/poll", {"device_code": device_code})
|
||||
status_raw = str(data.get("status", "")).strip().upper()
|
||||
if status_raw not in ("WAITING", "SUCCESS", "FAIL", "EXPIRED"):
|
||||
status_raw = "UNKNOWN"
|
||||
return {
|
||||
"status": status_raw,
|
||||
"client_id": str(data.get("client_id", "")).strip() or None,
|
||||
"client_secret": str(data.get("client_secret", "")).strip() or None,
|
||||
"fail_reason": str(data.get("fail_reason", "")).strip() or None,
|
||||
}
|
||||
|
||||
|
||||
def wait_for_registration_success(
|
||||
device_code: str,
|
||||
interval: int = 3,
|
||||
expires_in: int = 7200,
|
||||
on_waiting: Optional[callable] = None,
|
||||
) -> Tuple[str, str]:
|
||||
"""Block until the registration succeeds or times out.
|
||||
|
||||
Returns (client_id, client_secret).
|
||||
"""
|
||||
deadline = time.monotonic() + expires_in
|
||||
retry_window = 120 # 2 minutes for transient errors
|
||||
retry_start = 0.0
|
||||
|
||||
while time.monotonic() < deadline:
|
||||
time.sleep(interval)
|
||||
try:
|
||||
result = poll_registration(device_code)
|
||||
except RegistrationError:
|
||||
if retry_start == 0:
|
||||
retry_start = time.monotonic()
|
||||
if time.monotonic() - retry_start < retry_window:
|
||||
continue
|
||||
raise
|
||||
|
||||
status = result["status"]
|
||||
if status == "WAITING":
|
||||
retry_start = 0
|
||||
if on_waiting:
|
||||
on_waiting()
|
||||
continue
|
||||
if status == "SUCCESS":
|
||||
cid = result["client_id"]
|
||||
csecret = result["client_secret"]
|
||||
if not cid or not csecret:
|
||||
raise RegistrationError("authorization succeeded but credentials are missing")
|
||||
return cid, csecret
|
||||
# FAIL / EXPIRED / UNKNOWN
|
||||
if retry_start == 0:
|
||||
retry_start = time.monotonic()
|
||||
if time.monotonic() - retry_start < retry_window:
|
||||
continue
|
||||
reason = result.get("fail_reason") or status
|
||||
raise RegistrationError(f"authorization failed: {reason}")
|
||||
|
||||
raise RegistrationError("authorization timed out, please retry")
|
||||
|
||||
|
||||
# ── QR code rendering ─────────────────────────────────────────────────────
|
||||
|
||||
def _ensure_qrcode_installed() -> bool:
|
||||
"""Try to import qrcode; if missing, auto-install it via pip/uv."""
|
||||
try:
|
||||
import qrcode # noqa: F401
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import subprocess
|
||||
|
||||
# Try uv first (Hermes convention), then pip
|
||||
for cmd in (
|
||||
[sys.executable, "-m", "uv", "pip", "install", "qrcode"],
|
||||
[sys.executable, "-m", "pip", "install", "-q", "qrcode"],
|
||||
):
|
||||
try:
|
||||
subprocess.check_call(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
import qrcode # noqa: F401,F811
|
||||
return True
|
||||
except (subprocess.CalledProcessError, ImportError, FileNotFoundError):
|
||||
continue
|
||||
return False
|
||||
|
||||
|
||||
def render_qr_to_terminal(url: str) -> bool:
|
||||
"""Render *url* as a compact QR code in the terminal.
|
||||
|
||||
Returns True if the QR code was printed, False if the library is missing.
|
||||
"""
|
||||
try:
|
||||
import qrcode
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
qr = qrcode.QRCode(
|
||||
version=1,
|
||||
error_correction=qrcode.constants.ERROR_CORRECT_L,
|
||||
box_size=1,
|
||||
border=1,
|
||||
)
|
||||
qr.add_data(url)
|
||||
qr.make(fit=True)
|
||||
|
||||
# Use half-block characters for compact rendering (2 rows per character)
|
||||
matrix = qr.get_matrix()
|
||||
rows = len(matrix)
|
||||
lines: list[str] = []
|
||||
|
||||
TOP_HALF = "\u2580" # ▀
|
||||
BOTTOM_HALF = "\u2584" # ▄
|
||||
FULL_BLOCK = "\u2588" # █
|
||||
EMPTY = " "
|
||||
|
||||
for r in range(0, rows, 2):
|
||||
line_chars: list[str] = []
|
||||
for c in range(len(matrix[r])):
|
||||
top = matrix[r][c]
|
||||
bottom = matrix[r + 1][c] if r + 1 < rows else False
|
||||
if top and bottom:
|
||||
line_chars.append(FULL_BLOCK)
|
||||
elif top:
|
||||
line_chars.append(TOP_HALF)
|
||||
elif bottom:
|
||||
line_chars.append(BOTTOM_HALF)
|
||||
else:
|
||||
line_chars.append(EMPTY)
|
||||
lines.append(" " + "".join(line_chars))
|
||||
|
||||
print("\n".join(lines))
|
||||
return True
|
||||
|
||||
|
||||
# ── High-level entry point for the setup wizard ───────────────────────────
|
||||
|
||||
def dingtalk_qr_auth() -> Optional[Tuple[str, str]]:
|
||||
"""Run the interactive QR-code device-flow authorization.
|
||||
|
||||
Returns (client_id, client_secret) on success, or None if the user
|
||||
cancelled or the flow failed.
|
||||
"""
|
||||
from hermes_cli.setup import print_info, print_success, print_warning, print_error
|
||||
|
||||
print()
|
||||
print_info(" Initializing DingTalk device authorization...")
|
||||
print_info(" Note: the scan page is branded 'OpenClaw' — DingTalk's")
|
||||
print_info(" ecosystem onboarding bridge. Safe to use.")
|
||||
|
||||
try:
|
||||
reg = begin_registration()
|
||||
except RegistrationError as exc:
|
||||
print_error(f" Authorization init failed: {exc}")
|
||||
return None
|
||||
|
||||
url = reg["verification_uri_complete"]
|
||||
|
||||
# Ensure qrcode library is available (auto-install if missing)
|
||||
if not _ensure_qrcode_installed():
|
||||
print_warning(" qrcode library install failed, will show link only.")
|
||||
|
||||
print()
|
||||
print_info(" Please scan the QR code below with DingTalk to authorize:")
|
||||
print()
|
||||
|
||||
if not render_qr_to_terminal(url):
|
||||
print_warning(f" QR code render failed, please open the link below to authorize:")
|
||||
|
||||
print()
|
||||
print_info(f" Or open this link manually: {url}")
|
||||
print()
|
||||
print_info(" Waiting for QR scan authorization... (timeout: 2 hours)")
|
||||
|
||||
dot_count = 0
|
||||
|
||||
def _on_waiting():
|
||||
nonlocal dot_count
|
||||
dot_count += 1
|
||||
if dot_count % 10 == 0:
|
||||
sys.stdout.write(".")
|
||||
sys.stdout.flush()
|
||||
|
||||
try:
|
||||
client_id, client_secret = wait_for_registration_success(
|
||||
device_code=reg["device_code"],
|
||||
interval=reg["interval"],
|
||||
expires_in=reg["expires_in"],
|
||||
on_waiting=_on_waiting,
|
||||
)
|
||||
except RegistrationError as exc:
|
||||
print()
|
||||
print_error(f" Authorization failed: {exc}")
|
||||
return None
|
||||
|
||||
print()
|
||||
print_success(" QR scan authorization successful!")
|
||||
print_success(f" Client ID: {client_id}")
|
||||
print_success(f" Client Secret: {client_secret[:8]}{'*' * (len(client_secret) - 8)}")
|
||||
|
||||
return client_id, client_secret
|
||||
@@ -825,6 +825,7 @@ def run_doctor(args):
|
||||
("Arcee AI", ("ARCEEAI_API_KEY",), "https://api.arcee.ai/api/v1/models", "ARCEE_BASE_URL", True),
|
||||
("DeepSeek", ("DEEPSEEK_API_KEY",), "https://api.deepseek.com/v1/models", "DEEPSEEK_BASE_URL", True),
|
||||
("Hugging Face", ("HF_TOKEN",), "https://router.huggingface.co/v1/models", "HF_BASE_URL", True),
|
||||
("NVIDIA NIM", ("NVIDIA_API_KEY",), "https://integrate.api.nvidia.com/v1/models", "NVIDIA_BASE_URL", True),
|
||||
("Alibaba/DashScope", ("DASHSCOPE_API_KEY",), "https://dashscope-intl.aliyuncs.com/compatible-mode/v1/models", "DASHSCOPE_BASE_URL", True),
|
||||
# MiniMax: the /anthropic endpoint doesn't support /models, but the /v1 endpoint does.
|
||||
("MiniMax", ("MINIMAX_API_KEY",), "https://api.minimax.io/v1/models", "MINIMAX_BASE_URL", True),
|
||||
@@ -894,8 +895,8 @@ def run_doctor(args):
|
||||
_model_count = len(_br_resp.get("modelSummaries", []))
|
||||
print(f"\r {color('✓', Colors.GREEN)} {_label} {color(f'({_auth_var}, {_region}, {_model_count} models)', Colors.DIM)} ")
|
||||
except ImportError:
|
||||
print(f"\r {color('⚠', Colors.YELLOW)} {_label} {color('(boto3 not installed — pip install hermes-agent[bedrock])', Colors.DIM)} ")
|
||||
issues.append("Install boto3 for Bedrock: pip install hermes-agent[bedrock]")
|
||||
print(f"\r {color('⚠', Colors.YELLOW)} {_label} {color(f'(boto3 not installed — {sys.executable} -m pip install boto3)', Colors.DIM)} ")
|
||||
issues.append(f"Install boto3 for Bedrock: {sys.executable} -m pip install boto3")
|
||||
except Exception as _e:
|
||||
_err_name = type(_e).__name__
|
||||
print(f"\r {color('⚠', Colors.YELLOW)} {_label} {color(f'({_err_name}: {_e})', Colors.DIM)} ")
|
||||
|
||||
+15
-35
@@ -43,41 +43,20 @@ def _redact(value: str) -> str:
|
||||
|
||||
def _gateway_status() -> str:
|
||||
"""Return a short gateway status string."""
|
||||
if sys.platform.startswith("linux"):
|
||||
from hermes_constants import is_container
|
||||
if is_container():
|
||||
try:
|
||||
from hermes_cli.gateway import find_gateway_pids
|
||||
pids = find_gateway_pids()
|
||||
if pids:
|
||||
return f"running (docker, pid {pids[0]})"
|
||||
return "stopped (docker)"
|
||||
except Exception:
|
||||
return "stopped (docker)"
|
||||
try:
|
||||
from hermes_cli.gateway import get_service_name
|
||||
svc = get_service_name()
|
||||
except Exception:
|
||||
svc = "hermes-gateway"
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["systemctl", "--user", "is-active", svc],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
)
|
||||
return "running (systemd)" if r.stdout.strip() == "active" else "stopped"
|
||||
except Exception:
|
||||
return "unknown"
|
||||
elif sys.platform == "darwin":
|
||||
try:
|
||||
from hermes_cli.gateway import get_launchd_label
|
||||
r = subprocess.run(
|
||||
["launchctl", "list", get_launchd_label()],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
)
|
||||
return "loaded (launchd)" if r.returncode == 0 else "not loaded"
|
||||
except Exception:
|
||||
return "unknown"
|
||||
return "N/A"
|
||||
try:
|
||||
from hermes_cli.gateway import get_gateway_runtime_snapshot
|
||||
|
||||
snapshot = get_gateway_runtime_snapshot()
|
||||
if snapshot.running:
|
||||
mode = snapshot.manager
|
||||
if snapshot.has_process_service_mismatch:
|
||||
mode = "manual"
|
||||
return f"running ({mode}, pid {snapshot.gateway_pids[0]})"
|
||||
if snapshot.service_installed and not snapshot.service_running:
|
||||
return f"stopped ({snapshot.manager})"
|
||||
return f"stopped ({snapshot.manager})"
|
||||
except Exception:
|
||||
return "unknown" if sys.platform.startswith(("linux", "darwin")) else "N/A"
|
||||
|
||||
|
||||
def _count_skills(hermes_home: Path) -> int:
|
||||
@@ -296,6 +275,7 @@ def run_dump(args):
|
||||
("DEEPSEEK_API_KEY", "deepseek"),
|
||||
("DASHSCOPE_API_KEY", "dashscope"),
|
||||
("HF_TOKEN", "huggingface"),
|
||||
("NVIDIA_API_KEY", "nvidia"),
|
||||
("AI_GATEWAY_API_KEY", "ai_gateway"),
|
||||
("OPENCODE_ZEN_API_KEY", "opencode_zen"),
|
||||
("OPENCODE_GO_API_KEY", "opencode_go"),
|
||||
|
||||
+691
-34
@@ -10,6 +10,7 @@ import shutil
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
||||
@@ -41,6 +42,23 @@ from hermes_cli.colors import Colors, color
|
||||
# Process Management (for manual gateway runs)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GatewayRuntimeSnapshot:
|
||||
manager: str
|
||||
service_installed: bool = False
|
||||
service_running: bool = False
|
||||
gateway_pids: tuple[int, ...] = ()
|
||||
service_scope: str | None = None
|
||||
|
||||
@property
|
||||
def running(self) -> bool:
|
||||
return self.service_running or bool(self.gateway_pids)
|
||||
|
||||
@property
|
||||
def has_process_service_mismatch(self) -> bool:
|
||||
return self.service_installed and self.running and not self.service_running
|
||||
|
||||
def _get_service_pids() -> set:
|
||||
"""Return PIDs currently managed by systemd or launchd gateway services.
|
||||
|
||||
@@ -157,20 +175,22 @@ def _request_gateway_self_restart(pid: int) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def find_gateway_pids(exclude_pids: set | None = None, all_profiles: bool = False) -> list:
|
||||
"""Find PIDs of running gateway processes.
|
||||
def _append_unique_pid(pids: list[int], pid: int | None, exclude_pids: set[int]) -> None:
|
||||
if pid is None or pid <= 0:
|
||||
return
|
||||
if pid == os.getpid() or pid in exclude_pids or pid in pids:
|
||||
return
|
||||
pids.append(pid)
|
||||
|
||||
Args:
|
||||
exclude_pids: PIDs to exclude from the result (e.g. service-managed
|
||||
PIDs that should not be killed during a stale-process sweep).
|
||||
all_profiles: When ``True``, return gateway PIDs across **all**
|
||||
profiles (the pre-7923 global behaviour). ``hermes update``
|
||||
needs this because a code update affects every profile.
|
||||
When ``False`` (default), only PIDs belonging to the current
|
||||
Hermes profile are returned.
|
||||
|
||||
def _scan_gateway_pids(exclude_pids: set[int], all_profiles: bool = False) -> list[int]:
|
||||
"""Best-effort process-table scan for gateway PIDs.
|
||||
|
||||
This supplements the profile-scoped PID file so status views can still spot
|
||||
a live gateway when the PID file is stale/missing, and ``--all`` sweeps can
|
||||
discover gateways outside the current profile.
|
||||
"""
|
||||
_exclude = exclude_pids or set()
|
||||
pids = [pid for pid in _get_service_pids() if pid not in _exclude]
|
||||
pids: list[int] = []
|
||||
patterns = [
|
||||
"hermes_cli.main gateway",
|
||||
"hermes_cli.main --profile",
|
||||
@@ -203,20 +223,24 @@ def find_gateway_pids(exclude_pids: set | None = None, all_profiles: bool = Fals
|
||||
if is_windows():
|
||||
result = subprocess.run(
|
||||
["wmic", "process", "get", "ProcessId,CommandLine", "/FORMAT:LIST"],
|
||||
capture_output=True, text=True, timeout=10
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
return []
|
||||
current_cmd = ""
|
||||
for line in result.stdout.split('\n'):
|
||||
for line in result.stdout.split("\n"):
|
||||
line = line.strip()
|
||||
if line.startswith("CommandLine="):
|
||||
current_cmd = line[len("CommandLine="):]
|
||||
elif line.startswith("ProcessId="):
|
||||
pid_str = line[len("ProcessId="):]
|
||||
if any(p in current_cmd for p in patterns) and (all_profiles or _matches_current_profile(current_cmd)):
|
||||
if any(p in current_cmd for p in patterns) and (
|
||||
all_profiles or _matches_current_profile(current_cmd)
|
||||
):
|
||||
try:
|
||||
pid = int(pid_str)
|
||||
if pid != os.getpid() and pid not in pids and pid not in _exclude:
|
||||
pids.append(pid)
|
||||
_append_unique_pid(pids, int(pid_str), exclude_pids)
|
||||
except ValueError:
|
||||
pass
|
||||
current_cmd = ""
|
||||
@@ -227,9 +251,11 @@ def find_gateway_pids(exclude_pids: set | None = None, all_profiles: bool = Fals
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
for line in result.stdout.split('\n'):
|
||||
if result.returncode != 0:
|
||||
return []
|
||||
for line in result.stdout.split("\n"):
|
||||
stripped = line.strip()
|
||||
if not stripped or 'grep' in stripped:
|
||||
if not stripped or "grep" in stripped:
|
||||
continue
|
||||
|
||||
pid = None
|
||||
@@ -251,16 +277,137 @@ def find_gateway_pids(exclude_pids: set | None = None, all_profiles: bool = Fals
|
||||
|
||||
if pid is None:
|
||||
continue
|
||||
if pid == os.getpid() or pid in pids or pid in _exclude:
|
||||
continue
|
||||
if any(pattern in command for pattern in patterns) and (all_profiles or _matches_current_profile(command)):
|
||||
pids.append(pid)
|
||||
if any(pattern in command for pattern in patterns) and (
|
||||
all_profiles or _matches_current_profile(command)
|
||||
):
|
||||
_append_unique_pid(pids, pid, exclude_pids)
|
||||
except (OSError, subprocess.TimeoutExpired):
|
||||
pass
|
||||
return []
|
||||
|
||||
return pids
|
||||
|
||||
|
||||
def find_gateway_pids(exclude_pids: set | None = None, all_profiles: bool = False) -> list:
|
||||
"""Find PIDs of running gateway processes.
|
||||
|
||||
Args:
|
||||
exclude_pids: PIDs to exclude from the result (e.g. service-managed
|
||||
PIDs that should not be killed during a stale-process sweep).
|
||||
all_profiles: When ``True``, return gateway PIDs across **all**
|
||||
profiles (the pre-7923 global behaviour). ``hermes update``
|
||||
needs this because a code update affects every profile.
|
||||
When ``False`` (default), only PIDs belonging to the current
|
||||
Hermes profile are returned.
|
||||
"""
|
||||
_exclude = set(exclude_pids or set())
|
||||
pids: list[int] = []
|
||||
if not all_profiles:
|
||||
try:
|
||||
from gateway.status import get_running_pid
|
||||
|
||||
_append_unique_pid(pids, get_running_pid(), _exclude)
|
||||
except Exception:
|
||||
pass
|
||||
for pid in _get_service_pids():
|
||||
_append_unique_pid(pids, pid, _exclude)
|
||||
for pid in _scan_gateway_pids(_exclude, all_profiles=all_profiles):
|
||||
_append_unique_pid(pids, pid, _exclude)
|
||||
return pids
|
||||
|
||||
|
||||
def _probe_systemd_service_running(system: bool = False) -> tuple[bool, bool]:
|
||||
selected_system = _select_systemd_scope(system)
|
||||
unit_exists = get_systemd_unit_path(system=selected_system).exists()
|
||||
if not unit_exists:
|
||||
return selected_system, False
|
||||
try:
|
||||
result = _run_systemctl(
|
||||
["is-active", get_service_name()],
|
||||
system=selected_system,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
except (RuntimeError, subprocess.TimeoutExpired):
|
||||
return selected_system, False
|
||||
return selected_system, result.stdout.strip() == "active"
|
||||
|
||||
|
||||
def _probe_launchd_service_running() -> bool:
|
||||
if not get_launchd_plist_path().exists():
|
||||
return False
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["launchctl", "list", get_launchd_label()],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
return False
|
||||
return result.returncode == 0
|
||||
|
||||
|
||||
def get_gateway_runtime_snapshot(system: bool = False) -> GatewayRuntimeSnapshot:
|
||||
"""Return a unified view of gateway liveness for the current profile."""
|
||||
gateway_pids = tuple(find_gateway_pids())
|
||||
if is_termux():
|
||||
return GatewayRuntimeSnapshot(
|
||||
manager="Termux / manual process",
|
||||
gateway_pids=gateway_pids,
|
||||
)
|
||||
|
||||
from hermes_constants import is_container
|
||||
|
||||
if is_linux() and is_container():
|
||||
return GatewayRuntimeSnapshot(
|
||||
manager="docker (foreground)",
|
||||
gateway_pids=gateway_pids,
|
||||
)
|
||||
|
||||
if supports_systemd_services():
|
||||
selected_system, service_running = _probe_systemd_service_running(system=system)
|
||||
scope_label = _service_scope_label(selected_system)
|
||||
return GatewayRuntimeSnapshot(
|
||||
manager=f"systemd ({scope_label})",
|
||||
service_installed=get_systemd_unit_path(system=selected_system).exists(),
|
||||
service_running=service_running,
|
||||
gateway_pids=gateway_pids,
|
||||
service_scope=scope_label,
|
||||
)
|
||||
|
||||
if is_macos():
|
||||
return GatewayRuntimeSnapshot(
|
||||
manager="launchd",
|
||||
service_installed=get_launchd_plist_path().exists(),
|
||||
service_running=_probe_launchd_service_running(),
|
||||
gateway_pids=gateway_pids,
|
||||
service_scope="launchd",
|
||||
)
|
||||
|
||||
return GatewayRuntimeSnapshot(
|
||||
manager="manual process",
|
||||
gateway_pids=gateway_pids,
|
||||
)
|
||||
|
||||
|
||||
def _format_gateway_pids(pids: tuple[int, ...] | list[int], *, limit: int | None = 3) -> str:
|
||||
rendered = [str(pid) for pid in pids[:limit] if pid > 0] if limit is not None else [str(pid) for pid in pids if pid > 0]
|
||||
if limit is not None and len(pids) > limit:
|
||||
rendered.append("...")
|
||||
return ", ".join(rendered)
|
||||
|
||||
|
||||
def _print_gateway_process_mismatch(snapshot: GatewayRuntimeSnapshot) -> None:
|
||||
if not snapshot.has_process_service_mismatch:
|
||||
return
|
||||
print()
|
||||
print("⚠ Gateway process is running for this profile, but the service is not active")
|
||||
print(f" PID(s): {_format_gateway_pids(snapshot.gateway_pids, limit=None)}")
|
||||
print(" This is usually a manual foreground/tmux/nohup run, so `hermes gateway`")
|
||||
print(" can refuse to start another copy until this process stops.")
|
||||
|
||||
|
||||
def kill_gateway_processes(force: bool = False, exclude_pids: set | None = None,
|
||||
all_profiles: bool = False) -> int:
|
||||
"""Kill any running gateway processes. Returns count killed.
|
||||
@@ -340,25 +487,44 @@ def _wsl_systemd_operational() -> bool:
|
||||
WSL2 with ``systemd=true`` in wsl.conf has working systemd.
|
||||
WSL2 without it (or WSL1) does not — systemctl commands fail.
|
||||
"""
|
||||
return _systemd_operational(system=True)
|
||||
|
||||
|
||||
def _systemd_operational(system: bool = False) -> bool:
|
||||
"""Return True when the requested systemd scope is usable."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["systemctl", "is-system-running"],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
result = _run_systemctl(
|
||||
["is-system-running"],
|
||||
system=system,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
# "running", "degraded", "starting" all mean systemd is PID 1
|
||||
status = result.stdout.strip().lower()
|
||||
return status in ("running", "degraded", "starting", "initializing")
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired, OSError):
|
||||
except (RuntimeError, subprocess.TimeoutExpired, OSError):
|
||||
return False
|
||||
|
||||
|
||||
def _container_systemd_operational() -> bool:
|
||||
"""Return True when a container exposes working user or system systemd."""
|
||||
if _systemd_operational(system=False):
|
||||
return True
|
||||
if _systemd_operational(system=True):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def supports_systemd_services() -> bool:
|
||||
if not is_linux() or is_termux() or is_container():
|
||||
if not is_linux() or is_termux():
|
||||
return False
|
||||
if shutil.which("systemctl") is None:
|
||||
return False
|
||||
if is_wsl():
|
||||
return _wsl_systemd_operational()
|
||||
if is_container():
|
||||
return _container_systemd_operational()
|
||||
return True
|
||||
|
||||
|
||||
@@ -521,6 +687,195 @@ def has_conflicting_systemd_units() -> bool:
|
||||
return len(get_installed_systemd_scopes()) > 1
|
||||
|
||||
|
||||
# Legacy service names from older Hermes installs that predate the
|
||||
# hermes-gateway rename. Kept as an explicit allowlist (NOT a glob) so
|
||||
# profile units (hermes-gateway-*.service) and unrelated third-party
|
||||
# "hermes" units are never matched.
|
||||
_LEGACY_SERVICE_NAMES: tuple[str, ...] = ("hermes.service",)
|
||||
|
||||
# ExecStart content markers that identify a unit as running our gateway.
|
||||
# A legacy unit is only flagged when its file contains one of these.
|
||||
_LEGACY_UNIT_EXECSTART_MARKERS: tuple[str, ...] = (
|
||||
"hermes_cli.main gateway",
|
||||
"hermes_cli/main.py gateway",
|
||||
"gateway/run.py",
|
||||
" hermes gateway ",
|
||||
"/hermes gateway ",
|
||||
)
|
||||
|
||||
|
||||
def _legacy_unit_search_paths() -> list[tuple[bool, Path]]:
|
||||
"""Return ``[(is_system, base_dir), ...]`` — directories to scan for legacy units.
|
||||
|
||||
Factored out so tests can monkeypatch the search roots without touching
|
||||
real filesystem paths.
|
||||
"""
|
||||
return [
|
||||
(False, Path.home() / ".config" / "systemd" / "user"),
|
||||
(True, Path("/etc/systemd/system")),
|
||||
]
|
||||
|
||||
|
||||
def _find_legacy_hermes_units() -> list[tuple[str, Path, bool]]:
|
||||
"""Return ``[(unit_name, unit_path, is_system)]`` for legacy Hermes gateway units.
|
||||
|
||||
Detects unit files installed by older Hermes versions that used a
|
||||
different service name (e.g. ``hermes.service`` before the rename to
|
||||
``hermes-gateway.service``). When both a legacy unit and the current
|
||||
``hermes-gateway.service`` are active, they fight over the same bot
|
||||
token — the PR #5646 signal-recovery change turns this into a 30-second
|
||||
SIGTERM flap loop.
|
||||
|
||||
Safety guards:
|
||||
|
||||
* Explicit allowlist of legacy names (no globbing). Profile units such
|
||||
as ``hermes-gateway-coder.service`` and unrelated third-party
|
||||
``hermes-*`` services are never matched.
|
||||
* ExecStart content check — only flag units that invoke our gateway
|
||||
entrypoint. A user-created ``hermes.service`` running an unrelated
|
||||
binary is left untouched.
|
||||
* Results are returned purely for caller inspection; this function
|
||||
never mutates or removes anything.
|
||||
"""
|
||||
results: list[tuple[str, Path, bool]] = []
|
||||
for is_system, base in _legacy_unit_search_paths():
|
||||
for name in _LEGACY_SERVICE_NAMES:
|
||||
unit_path = base / name
|
||||
try:
|
||||
if not unit_path.exists():
|
||||
continue
|
||||
text = unit_path.read_text(encoding="utf-8", errors="ignore")
|
||||
except (OSError, PermissionError):
|
||||
continue
|
||||
if not any(marker in text for marker in _LEGACY_UNIT_EXECSTART_MARKERS):
|
||||
# Not our gateway — leave alone
|
||||
continue
|
||||
results.append((name, unit_path, is_system))
|
||||
return results
|
||||
|
||||
|
||||
def has_legacy_hermes_units() -> bool:
|
||||
"""Return True when any legacy Hermes gateway unit files exist."""
|
||||
return bool(_find_legacy_hermes_units())
|
||||
|
||||
|
||||
def print_legacy_unit_warning() -> None:
|
||||
"""Warn about legacy Hermes gateway unit files if any are installed.
|
||||
|
||||
Idempotent: prints nothing when no legacy units are detected. Safe to
|
||||
call from any status/install/setup path.
|
||||
"""
|
||||
legacy = _find_legacy_hermes_units()
|
||||
if not legacy:
|
||||
return
|
||||
print_warning("Legacy Hermes gateway unit(s) detected from an older install:")
|
||||
for name, path, is_system in legacy:
|
||||
scope = "system" if is_system else "user"
|
||||
print_info(f" {path} ({scope} scope)")
|
||||
print_info(" These run alongside the current hermes-gateway service and")
|
||||
print_info(" cause SIGTERM flap loops — both try to use the same bot token.")
|
||||
print_info(" Remove them with:")
|
||||
print_info(" hermes gateway migrate-legacy")
|
||||
|
||||
|
||||
def remove_legacy_hermes_units(
|
||||
interactive: bool = True,
|
||||
dry_run: bool = False,
|
||||
) -> tuple[int, list[Path]]:
|
||||
"""Stop, disable, and remove legacy Hermes gateway unit files.
|
||||
|
||||
Iterates over whatever ``_find_legacy_hermes_units()`` returns — which is
|
||||
an explicit allowlist of legacy names (not a glob). Profile units and
|
||||
unrelated third-party services are never touched.
|
||||
|
||||
Args:
|
||||
interactive: When True, prompt before removing. When False, remove
|
||||
without asking (used when another prompt has already confirmed,
|
||||
e.g. from the install flow).
|
||||
dry_run: When True, list what would be removed and return.
|
||||
|
||||
Returns:
|
||||
``(removed_count, remaining_paths)`` — remaining includes units we
|
||||
couldn't remove (typically system-scope when not running as root).
|
||||
"""
|
||||
legacy = _find_legacy_hermes_units()
|
||||
if not legacy:
|
||||
print("No legacy Hermes gateway units found.")
|
||||
return 0, []
|
||||
|
||||
user_units = [(n, p) for n, p, is_sys in legacy if not is_sys]
|
||||
system_units = [(n, p) for n, p, is_sys in legacy if is_sys]
|
||||
|
||||
print()
|
||||
print("Legacy Hermes gateway unit(s) found:")
|
||||
for name, path, is_system in legacy:
|
||||
scope = "system" if is_system else "user"
|
||||
print(f" {path} ({scope} scope)")
|
||||
print()
|
||||
|
||||
if dry_run:
|
||||
print("(dry-run — nothing removed)")
|
||||
return 0, [p for _, p, _ in legacy]
|
||||
|
||||
if interactive and not prompt_yes_no("Remove these legacy units?", True):
|
||||
print("Skipped. Run again with: hermes gateway migrate-legacy")
|
||||
return 0, [p for _, p, _ in legacy]
|
||||
|
||||
removed = 0
|
||||
remaining: list[Path] = []
|
||||
|
||||
# User-scope removal
|
||||
for name, path in user_units:
|
||||
try:
|
||||
_run_systemctl(["stop", name], system=False, check=False, timeout=90)
|
||||
_run_systemctl(["disable", name], system=False, check=False, timeout=30)
|
||||
path.unlink(missing_ok=True)
|
||||
print(f" ✓ Removed {path}")
|
||||
removed += 1
|
||||
except (OSError, RuntimeError) as e:
|
||||
print(f" ⚠ Could not remove {path}: {e}")
|
||||
remaining.append(path)
|
||||
|
||||
if user_units:
|
||||
try:
|
||||
_run_systemctl(["daemon-reload"], system=False, check=False, timeout=30)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
# System-scope removal (needs root)
|
||||
if system_units:
|
||||
if os.geteuid() != 0:
|
||||
print()
|
||||
print_warning("System-scope legacy units require root to remove.")
|
||||
print_info(" Re-run with: sudo hermes gateway migrate-legacy")
|
||||
for _, path in system_units:
|
||||
remaining.append(path)
|
||||
else:
|
||||
for name, path in system_units:
|
||||
try:
|
||||
_run_systemctl(["stop", name], system=True, check=False, timeout=90)
|
||||
_run_systemctl(["disable", name], system=True, check=False, timeout=30)
|
||||
path.unlink(missing_ok=True)
|
||||
print(f" ✓ Removed {path}")
|
||||
removed += 1
|
||||
except (OSError, RuntimeError) as e:
|
||||
print(f" ⚠ Could not remove {path}: {e}")
|
||||
remaining.append(path)
|
||||
|
||||
try:
|
||||
_run_systemctl(["daemon-reload"], system=True, check=False, timeout=30)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
print()
|
||||
if remaining:
|
||||
print_warning(f"{len(remaining)} legacy unit(s) still present — see messages above.")
|
||||
else:
|
||||
print_success(f"Removed {removed} legacy unit(s).")
|
||||
|
||||
return removed, remaining
|
||||
|
||||
|
||||
def print_systemd_scope_conflict_warning() -> None:
|
||||
scopes = get_installed_systemd_scopes()
|
||||
if len(scopes) < 2:
|
||||
@@ -1054,6 +1409,19 @@ def systemd_install(force: bool = False, system: bool = False, run_as_user: str
|
||||
if system:
|
||||
_require_root_for_system_service("install")
|
||||
|
||||
# Offer to remove legacy units (hermes.service from pre-rename installs)
|
||||
# before installing the new hermes-gateway.service. If both remain, they
|
||||
# flap-fight for the Telegram bot token on every gateway startup.
|
||||
# Only removes units matching _LEGACY_SERVICE_NAMES + our ExecStart
|
||||
# signature — profile units are never touched.
|
||||
if has_legacy_hermes_units():
|
||||
print()
|
||||
print_legacy_unit_warning()
|
||||
print()
|
||||
if prompt_yes_no("Remove the legacy unit(s) before installing?", True):
|
||||
remove_legacy_hermes_units(interactive=False)
|
||||
print()
|
||||
|
||||
unit_path = get_systemd_unit_path(system=system)
|
||||
scope_flag = " --system" if system else ""
|
||||
|
||||
@@ -1092,6 +1460,7 @@ def systemd_install(force: bool = False, system: bool = False, run_as_user: str
|
||||
_ensure_linger_enabled()
|
||||
|
||||
print_systemd_scope_conflict_warning()
|
||||
print_legacy_unit_warning()
|
||||
|
||||
|
||||
def systemd_uninstall(system: bool = False):
|
||||
@@ -1215,6 +1584,10 @@ def systemd_status(deep: bool = False, system: bool = False):
|
||||
print_systemd_scope_conflict_warning()
|
||||
print()
|
||||
|
||||
if has_legacy_hermes_units():
|
||||
print_legacy_unit_warning()
|
||||
print()
|
||||
|
||||
if not systemd_unit_is_current(system=system):
|
||||
print("⚠ Installed gateway service definition is outdated")
|
||||
print(f" Run: {'sudo ' if system else ''}hermes gateway restart{scope_flag} # auto-refreshes the unit")
|
||||
@@ -1998,7 +2371,7 @@ _PLATFORMS = [
|
||||
{"name": "QQ_ALLOWED_USERS", "prompt": "Allowed user OpenIDs (comma-separated, leave empty for open access)", "password": False,
|
||||
"is_allowlist": True,
|
||||
"help": "Optional — restrict DM access to specific user OpenIDs."},
|
||||
{"name": "QQ_HOME_CHANNEL", "prompt": "Home channel (user/group OpenID for cron delivery, or empty)", "password": False,
|
||||
{"name": "QQBOT_HOME_CHANNEL", "prompt": "Home channel (user/group OpenID for cron delivery, or empty)", "password": False,
|
||||
"help": "OpenID to deliver cron results and notifications to."},
|
||||
],
|
||||
},
|
||||
@@ -2211,9 +2584,62 @@ def _setup_sms():
|
||||
|
||||
|
||||
def _setup_dingtalk():
|
||||
"""Configure DingTalk via the standard platform setup."""
|
||||
"""Configure DingTalk — QR scan (recommended) or manual credential entry."""
|
||||
from hermes_cli.setup import (
|
||||
prompt_choice, prompt_yes_no, print_info, print_success, print_warning,
|
||||
)
|
||||
|
||||
dingtalk_platform = next(p for p in _PLATFORMS if p["key"] == "dingtalk")
|
||||
_setup_standard_platform(dingtalk_platform)
|
||||
emoji = dingtalk_platform["emoji"]
|
||||
label = dingtalk_platform["label"]
|
||||
|
||||
print()
|
||||
print(color(f" ─── {emoji} {label} Setup ───", Colors.CYAN))
|
||||
|
||||
existing = get_env_value("DINGTALK_CLIENT_ID")
|
||||
if existing:
|
||||
print()
|
||||
print_success(f"{label} is already configured (Client ID: {existing}).")
|
||||
if not prompt_yes_no(f" Reconfigure {label}?", False):
|
||||
return
|
||||
|
||||
print()
|
||||
method = prompt_choice(
|
||||
" Choose setup method",
|
||||
[
|
||||
"QR Code Scan (Recommended, auto-obtain Client ID and Client Secret)",
|
||||
"Manual Input (Client ID and Client Secret)",
|
||||
],
|
||||
default=0,
|
||||
)
|
||||
|
||||
if method == 0:
|
||||
# ── QR-code device-flow authorization ──
|
||||
try:
|
||||
from hermes_cli.dingtalk_auth import dingtalk_qr_auth
|
||||
except ImportError as exc:
|
||||
print_warning(f" QR auth module failed to load ({exc}), falling back to manual input.")
|
||||
_setup_standard_platform(dingtalk_platform)
|
||||
return
|
||||
|
||||
result = dingtalk_qr_auth()
|
||||
if result is None:
|
||||
print_warning(" QR auth incomplete, falling back to manual input.")
|
||||
_setup_standard_platform(dingtalk_platform)
|
||||
return
|
||||
|
||||
client_id, client_secret = result
|
||||
save_env_value("DINGTALK_CLIENT_ID", client_id)
|
||||
save_env_value("DINGTALK_CLIENT_SECRET", client_secret)
|
||||
save_env_value("DINGTALK_ALLOW_ALL_USERS", "true")
|
||||
print()
|
||||
print_success(f"{emoji} {label} configured via QR scan!")
|
||||
else:
|
||||
# ── Manual entry ──
|
||||
_setup_standard_platform(dingtalk_platform)
|
||||
# Also enable allow-all by default for convenience
|
||||
if get_env_value("DINGTALK_CLIENT_ID"):
|
||||
save_env_value("DINGTALK_ALLOW_ALL_USERS", "true")
|
||||
|
||||
|
||||
def _setup_wecom():
|
||||
@@ -2572,6 +2998,215 @@ def _setup_feishu():
|
||||
print_info(f" Bot: {bot_name}")
|
||||
|
||||
|
||||
def _setup_qqbot():
|
||||
"""Interactive setup for QQ Bot — scan-to-configure or manual credentials."""
|
||||
print()
|
||||
print(color(" ─── 🐧 QQ Bot Setup ───", Colors.CYAN))
|
||||
|
||||
existing_app_id = get_env_value("QQ_APP_ID")
|
||||
existing_secret = get_env_value("QQ_CLIENT_SECRET")
|
||||
if existing_app_id and existing_secret:
|
||||
print()
|
||||
print_success("QQ Bot is already configured.")
|
||||
if not prompt_yes_no(" Reconfigure QQ Bot?", False):
|
||||
return
|
||||
|
||||
# ── Choose setup method ──
|
||||
print()
|
||||
method_choices = [
|
||||
"Scan QR code to add bot automatically (recommended)",
|
||||
"Enter existing App ID and App Secret manually",
|
||||
]
|
||||
method_idx = prompt_choice(" How would you like to set up QQ Bot?", method_choices, 0)
|
||||
|
||||
credentials = None
|
||||
used_qr = False
|
||||
|
||||
if method_idx == 0:
|
||||
# ── QR scan-to-configure ──
|
||||
try:
|
||||
credentials = _qqbot_qr_flow()
|
||||
except KeyboardInterrupt:
|
||||
print()
|
||||
print_warning(" QQ Bot setup cancelled.")
|
||||
return
|
||||
if credentials:
|
||||
used_qr = True
|
||||
if not credentials:
|
||||
print_info(" QR setup did not complete. Continuing with manual input.")
|
||||
|
||||
# ── Manual credential input ──
|
||||
if not credentials:
|
||||
print()
|
||||
print_info(" Go to https://q.qq.com to register a QQ Bot application.")
|
||||
print_info(" Note your App ID and App Secret from the application page.")
|
||||
print()
|
||||
app_id = prompt(" App ID", password=False)
|
||||
if not app_id:
|
||||
print_warning(" Skipped — QQ Bot won't work without an App ID.")
|
||||
return
|
||||
app_secret = prompt(" App Secret", password=True)
|
||||
if not app_secret:
|
||||
print_warning(" Skipped — QQ Bot won't work without an App Secret.")
|
||||
return
|
||||
credentials = {"app_id": app_id.strip(), "client_secret": app_secret.strip(), "user_openid": ""}
|
||||
|
||||
# ── Save core credentials ──
|
||||
save_env_value("QQ_APP_ID", credentials["app_id"])
|
||||
save_env_value("QQ_CLIENT_SECRET", credentials["client_secret"])
|
||||
|
||||
user_openid = credentials.get("user_openid", "")
|
||||
|
||||
# ── DM security policy ──
|
||||
print()
|
||||
access_choices = [
|
||||
"Use DM pairing approval (recommended)",
|
||||
"Allow all direct messages",
|
||||
"Only allow listed user OpenIDs",
|
||||
]
|
||||
access_idx = prompt_choice(" How should direct messages be authorized?", access_choices, 0)
|
||||
if access_idx == 0:
|
||||
save_env_value("QQ_ALLOW_ALL_USERS", "false")
|
||||
if user_openid:
|
||||
print()
|
||||
if prompt_yes_no(f" Add yourself ({user_openid}) to the allow list?", True):
|
||||
save_env_value("QQ_ALLOWED_USERS", user_openid)
|
||||
print_success(f" Allow list set to {user_openid}")
|
||||
else:
|
||||
save_env_value("QQ_ALLOWED_USERS", "")
|
||||
else:
|
||||
save_env_value("QQ_ALLOWED_USERS", "")
|
||||
print_success(" DM pairing enabled.")
|
||||
print_info(" Unknown users can request access; approve with `hermes pairing approve`.")
|
||||
elif access_idx == 1:
|
||||
save_env_value("QQ_ALLOW_ALL_USERS", "true")
|
||||
save_env_value("QQ_ALLOWED_USERS", "")
|
||||
print_warning(" Open DM access enabled for QQ Bot.")
|
||||
else:
|
||||
default_allow = user_openid or ""
|
||||
allowlist = prompt(" Allowed user OpenIDs (comma-separated)", default_allow, password=False).replace(" ", "")
|
||||
save_env_value("QQ_ALLOW_ALL_USERS", "false")
|
||||
save_env_value("QQ_ALLOWED_USERS", allowlist)
|
||||
print_success(" Allowlist saved.")
|
||||
|
||||
# ── Home channel ──
|
||||
if user_openid:
|
||||
print()
|
||||
if prompt_yes_no(f" Use your QQ user ID ({user_openid}) as the home channel?", True):
|
||||
save_env_value("QQBOT_HOME_CHANNEL", user_openid)
|
||||
print_success(f" Home channel set to {user_openid}")
|
||||
else:
|
||||
print()
|
||||
home_channel = prompt(" Home channel OpenID (for cron/notifications, or empty)", password=False)
|
||||
if home_channel:
|
||||
save_env_value("QQBOT_HOME_CHANNEL", home_channel.strip())
|
||||
print_success(f" Home channel set to {home_channel.strip()}")
|
||||
|
||||
print()
|
||||
print_success("🐧 QQ Bot configured!")
|
||||
print_info(f" App ID: {credentials['app_id']}")
|
||||
|
||||
|
||||
def _qqbot_render_qr(url: str) -> bool:
|
||||
"""Try to render a QR code in the terminal. Returns True if successful."""
|
||||
try:
|
||||
import qrcode as _qr
|
||||
qr = _qr.QRCode(border=1,error_correction=_qr.constants.ERROR_CORRECT_L)
|
||||
qr.add_data(url)
|
||||
qr.make(fit=True)
|
||||
qr.print_ascii(invert=True)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _qqbot_qr_flow():
|
||||
"""Run the QR-code scan-to-configure flow.
|
||||
|
||||
Returns a dict with app_id, client_secret, user_openid on success,
|
||||
or None on failure/cancel.
|
||||
"""
|
||||
try:
|
||||
from gateway.platforms.qqbot import (
|
||||
create_bind_task, poll_bind_result, build_connect_url,
|
||||
decrypt_secret, BindStatus,
|
||||
)
|
||||
from gateway.platforms.qqbot.constants import ONBOARD_POLL_INTERVAL
|
||||
except Exception as exc:
|
||||
print_error(f" QQBot onboard import failed: {exc}")
|
||||
return None
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
MAX_REFRESHES = 3
|
||||
refresh_count = 0
|
||||
|
||||
while refresh_count <= MAX_REFRESHES:
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
# ── Create bind task ──
|
||||
try:
|
||||
task_id, aes_key = loop.run_until_complete(create_bind_task())
|
||||
except Exception as e:
|
||||
print_warning(f" Failed to create bind task: {e}")
|
||||
loop.close()
|
||||
return None
|
||||
|
||||
url = build_connect_url(task_id)
|
||||
|
||||
# ── Display QR code + URL ──
|
||||
print()
|
||||
if _qqbot_render_qr(url):
|
||||
print(f" Scan the QR code above, or open this URL directly:\n {url}")
|
||||
else:
|
||||
print(f" Open this URL in QQ on your phone:\n {url}")
|
||||
print_info(" Tip: pip install qrcode to show a scannable QR code here")
|
||||
|
||||
# ── Poll loop (silent — keep QR visible at bottom) ──
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
status, app_id, encrypted_secret, user_openid = loop.run_until_complete(
|
||||
poll_bind_result(task_id)
|
||||
)
|
||||
except Exception:
|
||||
time.sleep(ONBOARD_POLL_INTERVAL)
|
||||
continue
|
||||
|
||||
if status == BindStatus.COMPLETED:
|
||||
client_secret = decrypt_secret(encrypted_secret, aes_key)
|
||||
print()
|
||||
print_success(f" QR scan complete! (App ID: {app_id})")
|
||||
if user_openid:
|
||||
print_info(f" Scanner's OpenID: {user_openid}")
|
||||
return {
|
||||
"app_id": app_id,
|
||||
"client_secret": client_secret,
|
||||
"user_openid": user_openid,
|
||||
}
|
||||
|
||||
if status == BindStatus.EXPIRED:
|
||||
refresh_count += 1
|
||||
if refresh_count > MAX_REFRESHES:
|
||||
print()
|
||||
print_warning(f" QR code expired {MAX_REFRESHES} times — giving up.")
|
||||
return None
|
||||
print()
|
||||
print_warning(f" QR code expired, refreshing... ({refresh_count}/{MAX_REFRESHES})")
|
||||
loop.close()
|
||||
break # outer while creates a new task
|
||||
|
||||
time.sleep(ONBOARD_POLL_INTERVAL)
|
||||
except KeyboardInterrupt:
|
||||
loop.close()
|
||||
raise
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _setup_signal():
|
||||
"""Interactive setup for Signal messenger."""
|
||||
import shutil
|
||||
@@ -2709,6 +3344,10 @@ def gateway_setup():
|
||||
print_systemd_scope_conflict_warning()
|
||||
print()
|
||||
|
||||
if supports_systemd_services() and has_legacy_hermes_units():
|
||||
print_legacy_unit_warning()
|
||||
print()
|
||||
|
||||
if service_installed and service_running:
|
||||
print_success("Gateway service is installed and running.")
|
||||
elif service_installed:
|
||||
@@ -2749,8 +3388,12 @@ def gateway_setup():
|
||||
_setup_signal()
|
||||
elif platform["key"] == "weixin":
|
||||
_setup_weixin()
|
||||
elif platform["key"] == "dingtalk":
|
||||
_setup_dingtalk()
|
||||
elif platform["key"] == "feishu":
|
||||
_setup_feishu()
|
||||
elif platform["key"] == "qqbot":
|
||||
_setup_qqbot()
|
||||
else:
|
||||
_setup_standard_platform(platform)
|
||||
|
||||
@@ -3110,15 +3753,18 @@ def gateway_command(args):
|
||||
elif subcmd == "status":
|
||||
deep = getattr(args, 'deep', False)
|
||||
system = getattr(args, 'system', False)
|
||||
snapshot = get_gateway_runtime_snapshot(system=system)
|
||||
|
||||
# Check for service first
|
||||
if supports_systemd_services() and (get_systemd_unit_path(system=False).exists() or get_systemd_unit_path(system=True).exists()):
|
||||
systemd_status(deep, system=system)
|
||||
_print_gateway_process_mismatch(snapshot)
|
||||
elif is_macos() and get_launchd_plist_path().exists():
|
||||
launchd_status(deep)
|
||||
_print_gateway_process_mismatch(snapshot)
|
||||
else:
|
||||
# Check for manually running processes
|
||||
pids = find_gateway_pids()
|
||||
pids = list(snapshot.gateway_pids)
|
||||
if pids:
|
||||
print(f"✓ Gateway is running (PID: {', '.join(map(str, pids))})")
|
||||
print(" (Running manually, not as a system service)")
|
||||
@@ -3159,3 +3805,14 @@ def gateway_command(args):
|
||||
else:
|
||||
print(" hermes gateway install # Install as user service")
|
||||
print(" sudo hermes gateway install --system # Install as boot-time system service")
|
||||
|
||||
elif subcmd == "migrate-legacy":
|
||||
# Stop, disable, and remove legacy Hermes gateway unit files from
|
||||
# pre-rename installs (e.g. hermes.service). Profile units and
|
||||
# unrelated third-party services are never touched.
|
||||
dry_run = getattr(args, 'dry_run', False)
|
||||
yes = getattr(args, 'yes', False)
|
||||
if not supports_systemd_services() and not is_macos():
|
||||
print("Legacy unit migration only applies to systemd-based Linux hosts.")
|
||||
return
|
||||
remove_legacy_hermes_units(interactive=not yes, dry_run=dry_run)
|
||||
|
||||
+2585
-687
File diff suppressed because it is too large
Load Diff
@@ -374,7 +374,26 @@ def normalize_model_for_provider(model_input: str, target_provider: str) -> str:
|
||||
return bare
|
||||
return _dots_to_hyphens(bare)
|
||||
|
||||
# --- Copilot: strip matching provider prefix, keep dots ---
|
||||
# --- Copilot / Copilot ACP: delegate to the Copilot-specific
|
||||
# normalizer. It knows about the alias table (vendor-prefix
|
||||
# stripping for Anthropic/OpenAI, dash-to-dot repair for Claude)
|
||||
# and live-catalog lookups. Without this, vendor-prefixed or
|
||||
# dash-notation Claude IDs survive to the Copilot API and hit
|
||||
# HTTP 400 "model_not_supported". See issue #6879.
|
||||
if provider in {"copilot", "copilot-acp"}:
|
||||
try:
|
||||
from hermes_cli.models import normalize_copilot_model_id
|
||||
|
||||
normalized = normalize_copilot_model_id(name)
|
||||
if normalized:
|
||||
return normalized
|
||||
except Exception:
|
||||
# Fall through to the generic strip-vendor behaviour below
|
||||
# if the Copilot-specific path is unavailable for any reason.
|
||||
pass
|
||||
|
||||
# --- Copilot / Copilot ACP / openai-codex fallback:
|
||||
# strip matching provider prefix, keep dots ---
|
||||
if provider in _STRIP_VENDOR_ONLY_PROVIDERS:
|
||||
stripped = _strip_matching_provider_prefix(name, provider)
|
||||
if stripped == name and name.startswith("openai/"):
|
||||
|
||||
@@ -692,12 +692,12 @@ def switch_model(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
validation = {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"accepted": False,
|
||||
"persist": False,
|
||||
"recognized": False,
|
||||
"message": None,
|
||||
"message": f"Could not validate `{new_model}`: {e}",
|
||||
}
|
||||
|
||||
if not validation.get("accepted"):
|
||||
|
||||
+58
-29
@@ -26,7 +26,8 @@ COPILOT_REASONING_EFFORTS_O_SERIES = ["low", "medium", "high"]
|
||||
# Fallback OpenRouter snapshot used when the live catalog is unavailable.
|
||||
# (model_id, display description shown in menus)
|
||||
OPENROUTER_MODELS: list[tuple[str, str]] = [
|
||||
("anthropic/claude-opus-4.7", "recommended"),
|
||||
("moonshotai/kimi-k2.5", "recommended"),
|
||||
("anthropic/claude-opus-4.7", ""),
|
||||
("anthropic/claude-opus-4.6", ""),
|
||||
("anthropic/claude-sonnet-4.6", ""),
|
||||
("qwen/qwen3.6-plus", ""),
|
||||
@@ -49,7 +50,6 @@ OPENROUTER_MODELS: list[tuple[str, str]] = [
|
||||
("z-ai/glm-5.1", ""),
|
||||
("z-ai/glm-5v-turbo", ""),
|
||||
("z-ai/glm-5-turbo", ""),
|
||||
("moonshotai/kimi-k2.5", ""),
|
||||
("x-ai/grok-4.20", ""),
|
||||
("nvidia/nemotron-3-super-120b-a12b", ""),
|
||||
("nvidia/nemotron-3-super-120b-a12b:free", "free"),
|
||||
@@ -75,6 +75,7 @@ def _codex_curated_models() -> list[str]:
|
||||
|
||||
_PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"nous": [
|
||||
"moonshotai/kimi-k2.5",
|
||||
"xiaomi/mimo-v2-pro",
|
||||
"anthropic/claude-opus-4.7",
|
||||
"anthropic/claude-opus-4.6",
|
||||
@@ -96,7 +97,6 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"z-ai/glm-5.1",
|
||||
"z-ai/glm-5v-turbo",
|
||||
"z-ai/glm-5-turbo",
|
||||
"moonshotai/kimi-k2.5",
|
||||
"x-ai/grok-4.20-beta",
|
||||
"nvidia/nemotron-3-super-120b-a12b",
|
||||
"nvidia/nemotron-3-super-120b-a12b:free",
|
||||
@@ -135,7 +135,6 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"gemini-2.5-flash-lite",
|
||||
# Gemma open models (also served via AI Studio)
|
||||
"gemma-4-31b-it",
|
||||
"gemma-4-26b-it",
|
||||
],
|
||||
"google-gemini-cli": [
|
||||
"gemini-2.5-pro",
|
||||
@@ -155,9 +154,23 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"grok-4.20-reasoning",
|
||||
"grok-4-1-fast-reasoning",
|
||||
],
|
||||
"nvidia": [
|
||||
# NVIDIA flagship reasoning models
|
||||
"nvidia/nemotron-3-super-120b-a12b",
|
||||
"nvidia/nemotron-3-nano-30b-a3b",
|
||||
"nvidia/llama-3.3-nemotron-super-49b-v1.5",
|
||||
# Third-party agentic models hosted on build.nvidia.com
|
||||
# (map to OpenRouter defaults — users get familiar picks on NIM)
|
||||
"qwen/qwen3.5-397b-a17b",
|
||||
"deepseek-ai/deepseek-v3.2",
|
||||
"moonshotai/kimi-k2.5",
|
||||
"minimaxai/minimax-m2.5",
|
||||
"z-ai/glm5",
|
||||
"openai/gpt-oss-120b",
|
||||
],
|
||||
"kimi-coding": [
|
||||
"kimi-for-coding",
|
||||
"kimi-k2.5",
|
||||
"kimi-for-coding",
|
||||
"kimi-k2-thinking",
|
||||
"kimi-k2-thinking-turbo",
|
||||
"kimi-k2-turbo-preview",
|
||||
@@ -212,6 +225,7 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"trinity-mini",
|
||||
],
|
||||
"opencode-zen": [
|
||||
"kimi-k2.5",
|
||||
"gpt-5.4-pro",
|
||||
"gpt-5.4",
|
||||
"gpt-5.3-codex",
|
||||
@@ -243,16 +257,15 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"glm-5",
|
||||
"glm-4.7",
|
||||
"glm-4.6",
|
||||
"kimi-k2.5",
|
||||
"kimi-k2-thinking",
|
||||
"kimi-k2",
|
||||
"qwen3-coder",
|
||||
"big-pickle",
|
||||
],
|
||||
"opencode-go": [
|
||||
"kimi-k2.5",
|
||||
"glm-5.1",
|
||||
"glm-5",
|
||||
"kimi-k2.5",
|
||||
"mimo-v2-pro",
|
||||
"mimo-v2-omni",
|
||||
"minimax-m2.7",
|
||||
@@ -285,21 +298,21 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
# to https://dashscope-intl.aliyuncs.com/compatible-mode/v1 (OpenAI-compat)
|
||||
# or https://dashscope-intl.aliyuncs.com/apps/anthropic (Anthropic-compat).
|
||||
"alibaba": [
|
||||
"kimi-k2.5",
|
||||
"qwen3.5-plus",
|
||||
"qwen3-coder-plus",
|
||||
"qwen3-coder-next",
|
||||
# Third-party models available on coding-intl
|
||||
"glm-5",
|
||||
"glm-4.7",
|
||||
"kimi-k2.5",
|
||||
"MiniMax-M2.5",
|
||||
],
|
||||
# Curated HF model list — only agentic models that map to OpenRouter defaults.
|
||||
"huggingface": [
|
||||
"moonshotai/Kimi-K2.5",
|
||||
"Qwen/Qwen3.5-397B-A17B",
|
||||
"Qwen/Qwen3.5-35B-A3B",
|
||||
"deepseek-ai/DeepSeek-V3.2",
|
||||
"moonshotai/Kimi-K2.5",
|
||||
"MiniMaxAI/MiniMax-M2.5",
|
||||
"zai-org/GLM-5",
|
||||
"XiaomiMiMo/MiMo-V2-Flash",
|
||||
@@ -536,6 +549,7 @@ CANONICAL_PROVIDERS: list[ProviderEntry] = [
|
||||
ProviderEntry("anthropic", "Anthropic", "Anthropic (Claude models — API key or Claude Code)"),
|
||||
ProviderEntry("openai-codex", "OpenAI Codex", "OpenAI Codex"),
|
||||
ProviderEntry("xiaomi", "Xiaomi MiMo", "Xiaomi MiMo (MiMo-V2 models — pro, omni, flash)"),
|
||||
ProviderEntry("nvidia", "NVIDIA NIM", "NVIDIA NIM (Nemotron models — build.nvidia.com or local NIM)"),
|
||||
ProviderEntry("qwen-oauth", "Qwen OAuth (Portal)", "Qwen OAuth (reuses local Qwen CLI login)"),
|
||||
ProviderEntry("copilot", "GitHub Copilot", "GitHub Copilot (uses GITHUB_TOKEN or gh auth token)"),
|
||||
ProviderEntry("copilot-acp", "GitHub Copilot ACP", "GitHub Copilot ACP (spawns `copilot --acp --stdio`)"),
|
||||
@@ -618,6 +632,10 @@ _PROVIDER_ALIASES = {
|
||||
"grok": "xai",
|
||||
"x-ai": "xai",
|
||||
"x.ai": "xai",
|
||||
"nim": "nvidia",
|
||||
"nvidia-nim": "nvidia",
|
||||
"build-nvidia": "nvidia",
|
||||
"nemotron": "nvidia",
|
||||
"ollama": "custom", # bare "ollama" = local; use "ollama-cloud" for cloud
|
||||
"ollama_cloud": "ollama-cloud",
|
||||
}
|
||||
@@ -1488,6 +1506,19 @@ _COPILOT_MODEL_ALIASES = {
|
||||
"anthropic/claude-sonnet-4.6": "claude-sonnet-4.6",
|
||||
"anthropic/claude-sonnet-4.5": "claude-sonnet-4.5",
|
||||
"anthropic/claude-haiku-4.5": "claude-haiku-4.5",
|
||||
# Dash-notation fallbacks: Hermes' default Claude IDs elsewhere use
|
||||
# hyphens (anthropic native format), but Copilot's API only accepts
|
||||
# dot-notation. Accept both so users who configure copilot + a
|
||||
# default hyphenated Claude model don't hit HTTP 400
|
||||
# "model_not_supported". See issue #6879.
|
||||
"claude-opus-4-6": "claude-opus-4.6",
|
||||
"claude-sonnet-4-6": "claude-sonnet-4.6",
|
||||
"claude-sonnet-4-5": "claude-sonnet-4.5",
|
||||
"claude-haiku-4-5": "claude-haiku-4.5",
|
||||
"anthropic/claude-opus-4-6": "claude-opus-4.6",
|
||||
"anthropic/claude-sonnet-4-6": "claude-sonnet-4.6",
|
||||
"anthropic/claude-sonnet-4-5": "claude-sonnet-4.5",
|
||||
"anthropic/claude-haiku-4-5": "claude-haiku-4.5",
|
||||
}
|
||||
|
||||
|
||||
@@ -2019,8 +2050,8 @@ def validate_requested_model(
|
||||
)
|
||||
|
||||
return {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"accepted": False,
|
||||
"persist": False,
|
||||
"recognized": False,
|
||||
"message": message,
|
||||
}
|
||||
@@ -2033,8 +2064,8 @@ def validate_requested_model(
|
||||
message += f"\n If this server expects `/v1`, try base URL: `{probe.get('suggested_base_url')}`"
|
||||
|
||||
return {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"accepted": False,
|
||||
"persist": False,
|
||||
"recognized": False,
|
||||
"message": message,
|
||||
}
|
||||
@@ -2068,12 +2099,11 @@ def validate_requested_model(
|
||||
if suggestions:
|
||||
suggestion_text = "\n Similar models: " + ", ".join(f"`{s}`" for s in suggestions)
|
||||
return {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"accepted": False,
|
||||
"persist": False,
|
||||
"recognized": False,
|
||||
"message": (
|
||||
f"Note: `{requested}` was not found in the OpenAI Codex model listing. "
|
||||
f"It may still work if your account has access to it."
|
||||
f"Model `{requested}` was not found in the OpenAI Codex model listing."
|
||||
f"{suggestion_text}"
|
||||
),
|
||||
}
|
||||
@@ -2112,16 +2142,15 @@ def validate_requested_model(
|
||||
if suggestions:
|
||||
suggestion_text = "\n Similar models: " + ", ".join(f"`{s}`" for s in suggestions)
|
||||
|
||||
return {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"recognized": False,
|
||||
"message": (
|
||||
f"Note: `{requested}` was not found in this provider's model listing. "
|
||||
f"It may still work if your plan supports it."
|
||||
f"{suggestion_text}"
|
||||
),
|
||||
}
|
||||
return {
|
||||
"accepted": False,
|
||||
"persist": False,
|
||||
"recognized": False,
|
||||
"message": (
|
||||
f"Model `{requested}` was not found in this provider's model listing."
|
||||
f"{suggestion_text}"
|
||||
),
|
||||
}
|
||||
|
||||
# api_models is None — couldn't reach API. Accept and persist,
|
||||
# but warn so typos don't silently break things.
|
||||
@@ -2163,8 +2192,8 @@ def validate_requested_model(
|
||||
|
||||
provider_label = _PROVIDER_LABELS.get(normalized, normalized)
|
||||
return {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"accepted": False,
|
||||
"persist": False,
|
||||
"recognized": False,
|
||||
"message": (
|
||||
f"Could not reach the {provider_label} API to validate `{requested}`. "
|
||||
|
||||
+3
-12
@@ -300,19 +300,10 @@ def _read_config_model(profile_dir: Path) -> tuple:
|
||||
|
||||
def _check_gateway_running(profile_dir: Path) -> bool:
|
||||
"""Check if a gateway is running for a given profile directory."""
|
||||
pid_file = profile_dir / "gateway.pid"
|
||||
if not pid_file.exists():
|
||||
return False
|
||||
try:
|
||||
raw = pid_file.read_text().strip()
|
||||
if not raw:
|
||||
return False
|
||||
data = json.loads(raw) if raw.startswith("{") else {"pid": int(raw)}
|
||||
pid = int(data["pid"])
|
||||
os.kill(pid, 0) # existence check
|
||||
return True
|
||||
except (json.JSONDecodeError, KeyError, ValueError, TypeError,
|
||||
ProcessLookupError, PermissionError, OSError):
|
||||
from gateway.status import get_running_pid
|
||||
return get_running_pid(profile_dir / "gateway.pid", cleanup_stale=False) is not None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@@ -137,6 +137,11 @@ HERMES_OVERLAYS: Dict[str, HermesOverlay] = {
|
||||
base_url_override="https://api.x.ai/v1",
|
||||
base_url_env_var="XAI_BASE_URL",
|
||||
),
|
||||
"nvidia": HermesOverlay(
|
||||
transport="openai_chat",
|
||||
base_url_override="https://integrate.api.nvidia.com/v1",
|
||||
base_url_env_var="NVIDIA_BASE_URL",
|
||||
),
|
||||
"xiaomi": HermesOverlay(
|
||||
transport="openai_chat",
|
||||
base_url_env_var="XIAOMI_BASE_URL",
|
||||
@@ -191,6 +196,12 @@ ALIASES: Dict[str, str] = {
|
||||
"x.ai": "xai",
|
||||
"grok": "xai",
|
||||
|
||||
# nvidia
|
||||
"nim": "nvidia",
|
||||
"nvidia-nim": "nvidia",
|
||||
"build-nvidia": "nvidia",
|
||||
"nemotron": "nvidia",
|
||||
|
||||
# kimi-for-coding (models.dev ID)
|
||||
"kimi": "kimi-for-coding",
|
||||
"kimi-coding": "kimi-for-coding",
|
||||
|
||||
+13
-54
@@ -91,7 +91,7 @@ _DEFAULT_PROVIDER_MODELS = {
|
||||
"gemini": [
|
||||
"gemini-3.1-pro-preview", "gemini-3-flash-preview", "gemini-3.1-flash-lite-preview",
|
||||
"gemini-2.5-pro", "gemini-2.5-flash", "gemini-2.5-flash-lite",
|
||||
"gemma-4-31b-it", "gemma-4-26b-it",
|
||||
"gemma-4-31b-it",
|
||||
],
|
||||
"zai": ["glm-5.1", "glm-5", "glm-4.7", "glm-4.5", "glm-4.5-flash"],
|
||||
"kimi-coding": ["kimi-k2.5", "kimi-k2-thinking", "kimi-k2-turbo-preview"],
|
||||
@@ -2005,52 +2005,6 @@ def _setup_wecom_callback():
|
||||
_gw_setup()
|
||||
|
||||
|
||||
def _setup_qqbot():
|
||||
"""Configure QQ Bot gateway."""
|
||||
print_header("QQ Bot")
|
||||
existing = get_env_value("QQ_APP_ID")
|
||||
if existing:
|
||||
print_info("QQ Bot: already configured")
|
||||
if not prompt_yes_no("Reconfigure QQ Bot?", False):
|
||||
return
|
||||
|
||||
print_info("Connects Hermes to QQ via the Official QQ Bot API (v2).")
|
||||
print_info(" Requires a QQ Bot application at q.qq.com")
|
||||
print_info(" Reference: https://bot.q.qq.com/wiki/develop/api-v2/")
|
||||
print()
|
||||
|
||||
app_id = prompt("QQ Bot App ID")
|
||||
if not app_id:
|
||||
print_warning("App ID is required — skipping QQ Bot setup")
|
||||
return
|
||||
save_env_value("QQ_APP_ID", app_id.strip())
|
||||
|
||||
client_secret = prompt("QQ Bot App Secret", password=True)
|
||||
if not client_secret:
|
||||
print_warning("App Secret is required — skipping QQ Bot setup")
|
||||
return
|
||||
save_env_value("QQ_CLIENT_SECRET", client_secret)
|
||||
print_success("QQ Bot credentials saved")
|
||||
|
||||
print()
|
||||
print_info("🔒 Security: Restrict who can DM your bot")
|
||||
print_info(" Use QQ user OpenIDs (found in event payloads)")
|
||||
print()
|
||||
allowed_users = prompt("Allowed user OpenIDs (comma-separated, leave empty for open access)")
|
||||
if allowed_users:
|
||||
save_env_value("QQ_ALLOWED_USERS", allowed_users.replace(" ", ""))
|
||||
print_success("QQ Bot allowlist configured")
|
||||
else:
|
||||
print_info("⚠️ No allowlist set — anyone can DM the bot!")
|
||||
|
||||
print()
|
||||
print_info("📬 Home Channel: OpenID for cron job delivery and notifications.")
|
||||
home_channel = prompt("Home channel OpenID (leave empty to set later)")
|
||||
if home_channel:
|
||||
save_env_value("QQ_HOME_CHANNEL", home_channel)
|
||||
|
||||
print()
|
||||
print_success("QQ Bot configured!")
|
||||
|
||||
|
||||
def _setup_bluebubbles():
|
||||
@@ -2119,12 +2073,9 @@ def _setup_bluebubbles():
|
||||
|
||||
|
||||
def _setup_qqbot():
|
||||
"""Configure QQ Bot (Official API v2) via standard platform setup."""
|
||||
from hermes_cli.gateway import _PLATFORMS
|
||||
qq_platform = next((p for p in _PLATFORMS if p["key"] == "qqbot"), None)
|
||||
if qq_platform:
|
||||
from hermes_cli.gateway import _setup_standard_platform
|
||||
_setup_standard_platform(qq_platform)
|
||||
"""Configure QQ Bot (Official API v2) via gateway setup."""
|
||||
from hermes_cli.gateway import _setup_qqbot as _gateway_setup_qqbot
|
||||
_gateway_setup_qqbot()
|
||||
|
||||
|
||||
def _setup_webhooks():
|
||||
@@ -2264,7 +2215,9 @@ def setup_gateway(config: dict):
|
||||
missing_home.append("Slack")
|
||||
if get_env_value("BLUEBUBBLES_SERVER_URL") and not get_env_value("BLUEBUBBLES_HOME_CHANNEL"):
|
||||
missing_home.append("BlueBubbles")
|
||||
if get_env_value("QQ_APP_ID") and not get_env_value("QQ_HOME_CHANNEL"):
|
||||
if get_env_value("QQ_APP_ID") and not (
|
||||
get_env_value("QQBOT_HOME_CHANNEL") or get_env_value("QQ_HOME_CHANNEL")
|
||||
):
|
||||
missing_home.append("QQBot")
|
||||
|
||||
if missing_home:
|
||||
@@ -2289,8 +2242,10 @@ def setup_gateway(config: dict):
|
||||
_is_service_running,
|
||||
supports_systemd_services,
|
||||
has_conflicting_systemd_units,
|
||||
has_legacy_hermes_units,
|
||||
install_linux_gateway_from_setup,
|
||||
print_systemd_scope_conflict_warning,
|
||||
print_legacy_unit_warning,
|
||||
systemd_start,
|
||||
systemd_restart,
|
||||
launchd_install,
|
||||
@@ -2308,6 +2263,10 @@ def setup_gateway(config: dict):
|
||||
print_systemd_scope_conflict_warning()
|
||||
print()
|
||||
|
||||
if supports_systemd and has_legacy_hermes_units():
|
||||
print_legacy_unit_warning()
|
||||
print()
|
||||
|
||||
if service_running:
|
||||
if prompt_yes_no(" Restart the gateway to pick up changes?", True):
|
||||
try:
|
||||
|
||||
@@ -515,6 +515,90 @@ def do_inspect(identifier: str, console: Optional[Console] = None) -> None:
|
||||
c.print()
|
||||
|
||||
|
||||
def browse_skills(page: int = 1, page_size: int = 20, source: str = "all") -> dict:
|
||||
"""Paginated hub browse for programmatic callers (e.g. TUI gateway).
|
||||
|
||||
Returns ``{"items": [...], "page": int, "total_pages": int, "total": int}``.
|
||||
"""
|
||||
from tools.skills_hub import GitHubAuth, create_source_router
|
||||
|
||||
page_size = max(1, min(page_size, 100))
|
||||
_TRUST_RANK = {"builtin": 3, "trusted": 2, "community": 1}
|
||||
_PER_SOURCE_LIMIT = {"official": 100, "skills-sh": 100, "well-known": 25, "github": 100, "clawhub": 50,
|
||||
"claude-marketplace": 50, "lobehub": 50}
|
||||
auth = GitHubAuth()
|
||||
sources = create_source_router(auth)
|
||||
all_results: list = []
|
||||
for src in sources:
|
||||
sid = src.source_id()
|
||||
if source != "all" and sid != source and sid != "official":
|
||||
continue
|
||||
try:
|
||||
limit = _PER_SOURCE_LIMIT.get(sid, 50)
|
||||
all_results.extend(src.search("", limit=limit))
|
||||
except Exception:
|
||||
continue
|
||||
if not all_results:
|
||||
return {"items": [], "page": 1, "total_pages": 1, "total": 0}
|
||||
seen: dict = {}
|
||||
for r in all_results:
|
||||
rank = _TRUST_RANK.get(r.trust_level, 0)
|
||||
if r.name not in seen or rank > _TRUST_RANK.get(seen[r.name].trust_level, 0):
|
||||
seen[r.name] = r
|
||||
deduped = list(seen.values())
|
||||
deduped.sort(key=lambda r: (-_TRUST_RANK.get(r.trust_level, 0), r.source != "official", r.name.lower()))
|
||||
total = len(deduped)
|
||||
total_pages = max(1, (total + page_size - 1) // page_size)
|
||||
page = max(1, min(page, total_pages))
|
||||
start = (page - 1) * page_size
|
||||
page_items = deduped[start : min(start + page_size, total)]
|
||||
return {
|
||||
"items": [{"name": r.name, "description": r.description, "source": r.source,
|
||||
"trust": r.trust_level} for r in page_items],
|
||||
"page": page,
|
||||
"total_pages": total_pages,
|
||||
"total": total,
|
||||
}
|
||||
|
||||
|
||||
def inspect_skill(identifier: str) -> Optional[dict]:
|
||||
"""Skill metadata (+ SKILL.md preview) for programmatic callers."""
|
||||
from tools.skills_hub import GitHubAuth, create_source_router
|
||||
|
||||
class _Q:
|
||||
def print(self, *a, **k):
|
||||
pass
|
||||
|
||||
c = _Q()
|
||||
auth = GitHubAuth()
|
||||
sources = create_source_router(auth)
|
||||
ident = identifier
|
||||
if "/" not in ident:
|
||||
ident = _resolve_short_name(ident, sources, c)
|
||||
if not ident:
|
||||
return None
|
||||
meta, bundle, _ = _resolve_source_meta_and_bundle(ident, sources)
|
||||
if not meta:
|
||||
return None
|
||||
out: dict = {
|
||||
"name": meta.name,
|
||||
"description": meta.description,
|
||||
"source": meta.source,
|
||||
"identifier": meta.identifier,
|
||||
"tags": list(meta.tags) if meta.tags else [],
|
||||
}
|
||||
if bundle and "SKILL.md" in bundle.files:
|
||||
content = bundle.files["SKILL.md"]
|
||||
if isinstance(content, bytes):
|
||||
content = content.decode("utf-8", errors="replace")
|
||||
lines = content.split("\n")
|
||||
preview = "\n".join(lines[:50])
|
||||
if len(lines) > 50:
|
||||
preview += f"\n\n... ({len(lines) - 50} more lines)"
|
||||
out["skill_md_preview"] = preview
|
||||
return out
|
||||
|
||||
|
||||
def do_list(source_filter: str = "all", console: Optional[Console] = None) -> None:
|
||||
"""List installed skills, distinguishing hub, builtin, and local skills."""
|
||||
from tools.skills_hub import HubLockFile, ensure_hub_dirs
|
||||
|
||||
@@ -23,7 +23,7 @@ All fields are optional. Missing values inherit from the ``default`` skin.
|
||||
banner_dim: "#B8860B" # Dim/muted text (separators, labels)
|
||||
banner_text: "#FFF8DC" # Body text (tool names, skill names)
|
||||
ui_accent: "#FFBF00" # General UI accent
|
||||
ui_label: "#4dd0e1" # UI labels
|
||||
ui_label: "#DAA520" # UI labels (warm gold; teal clashed w/ default banner gold)
|
||||
ui_ok: "#4caf50" # Success indicators
|
||||
ui_error: "#ef5350" # Error indicators
|
||||
ui_warn: "#ffa726" # Warning indicators
|
||||
@@ -163,7 +163,7 @@ _BUILTIN_SKINS: Dict[str, Dict[str, Any]] = {
|
||||
"banner_dim": "#B8860B",
|
||||
"banner_text": "#FFF8DC",
|
||||
"ui_accent": "#FFBF00",
|
||||
"ui_label": "#4dd0e1",
|
||||
"ui_label": "#DAA520",
|
||||
"ui_ok": "#4caf50",
|
||||
"ui_error": "#ef5350",
|
||||
"ui_warn": "#ffa726",
|
||||
|
||||
+30
-64
@@ -317,7 +317,7 @@ def show_status(args):
|
||||
"WeCom Callback": ("WECOM_CALLBACK_CORP_ID", None),
|
||||
"Weixin": ("WEIXIN_ACCOUNT_ID", "WEIXIN_HOME_CHANNEL"),
|
||||
"BlueBubbles": ("BLUEBUBBLES_SERVER_URL", "BLUEBUBBLES_HOME_CHANNEL"),
|
||||
"QQBot": ("QQ_APP_ID", "QQ_HOME_CHANNEL"),
|
||||
"QQBot": ("QQ_APP_ID", "QQBOT_HOME_CHANNEL"),
|
||||
}
|
||||
|
||||
for name, (token_var, home_var) in platforms.items():
|
||||
@@ -327,6 +327,9 @@ def show_status(args):
|
||||
home_channel = ""
|
||||
if home_var:
|
||||
home_channel = os.getenv(home_var, "")
|
||||
# Back-compat: QQBot home channel was renamed from QQ_HOME_CHANNEL to QQBOT_HOME_CHANNEL
|
||||
if not home_channel and home_var == "QQBOT_HOME_CHANNEL":
|
||||
home_channel = os.getenv("QQ_HOME_CHANNEL", "")
|
||||
|
||||
status = "configured" if has_token else "not configured"
|
||||
if home_channel:
|
||||
@@ -339,73 +342,36 @@ def show_status(args):
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Gateway Service", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
if _is_termux():
|
||||
try:
|
||||
from hermes_cli.gateway import find_gateway_pids
|
||||
gateway_pids = find_gateway_pids()
|
||||
except Exception:
|
||||
gateway_pids = []
|
||||
is_running = bool(gateway_pids)
|
||||
|
||||
try:
|
||||
from hermes_cli.gateway import get_gateway_runtime_snapshot, _format_gateway_pids
|
||||
|
||||
snapshot = get_gateway_runtime_snapshot()
|
||||
is_running = snapshot.running
|
||||
print(f" Status: {check_mark(is_running)} {'running' if is_running else 'stopped'}")
|
||||
print(" Manager: Termux / manual process")
|
||||
if gateway_pids:
|
||||
rendered = ", ".join(str(pid) for pid in gateway_pids[:3])
|
||||
if len(gateway_pids) > 3:
|
||||
rendered += ", ..."
|
||||
print(f" PID(s): {rendered}")
|
||||
else:
|
||||
print(f" Manager: {snapshot.manager}")
|
||||
if snapshot.gateway_pids:
|
||||
print(f" PID(s): {_format_gateway_pids(snapshot.gateway_pids)}")
|
||||
if snapshot.has_process_service_mismatch:
|
||||
print(" Service: installed but not managing the current running gateway")
|
||||
elif _is_termux() and not snapshot.gateway_pids:
|
||||
print(" Start with: hermes gateway")
|
||||
print(" Note: Android may stop background jobs when Termux is suspended")
|
||||
|
||||
elif sys.platform.startswith('linux'):
|
||||
from hermes_constants import is_container
|
||||
if is_container():
|
||||
# Docker/Podman: no systemd — check for running gateway processes
|
||||
try:
|
||||
from hermes_cli.gateway import find_gateway_pids
|
||||
gateway_pids = find_gateway_pids()
|
||||
is_active = len(gateway_pids) > 0
|
||||
except Exception:
|
||||
is_active = False
|
||||
print(f" Status: {check_mark(is_active)} {'running' if is_active else 'stopped'}")
|
||||
print(" Manager: docker (foreground)")
|
||||
elif snapshot.service_installed and not snapshot.service_running:
|
||||
print(" Service: installed but stopped")
|
||||
except Exception:
|
||||
if _is_termux():
|
||||
print(f" Status: {color('unknown', Colors.DIM)}")
|
||||
print(" Manager: Termux / manual process")
|
||||
elif sys.platform.startswith('linux'):
|
||||
print(f" Status: {color('unknown', Colors.DIM)}")
|
||||
print(" Manager: systemd/manual")
|
||||
elif sys.platform == 'darwin':
|
||||
print(f" Status: {color('unknown', Colors.DIM)}")
|
||||
print(" Manager: launchd")
|
||||
else:
|
||||
try:
|
||||
from hermes_cli.gateway import get_service_name
|
||||
_gw_svc = get_service_name()
|
||||
except Exception:
|
||||
_gw_svc = "hermes-gateway"
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["systemctl", "--user", "is-active", _gw_svc],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
is_active = result.stdout.strip() == "active"
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||
is_active = False
|
||||
print(f" Status: {check_mark(is_active)} {'running' if is_active else 'stopped'}")
|
||||
print(" Manager: systemd (user)")
|
||||
|
||||
elif sys.platform == 'darwin':
|
||||
from hermes_cli.gateway import get_launchd_label
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["launchctl", "list", get_launchd_label()],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
is_loaded = result.returncode == 0
|
||||
except subprocess.TimeoutExpired:
|
||||
is_loaded = False
|
||||
print(f" Status: {check_mark(is_loaded)} {'loaded' if is_loaded else 'not loaded'}")
|
||||
print(" Manager: launchd")
|
||||
else:
|
||||
print(f" Status: {color('N/A', Colors.DIM)}")
|
||||
print(" Manager: (not supported on this platform)")
|
||||
print(f" Status: {color('N/A', Colors.DIM)}")
|
||||
print(" Manager: (not supported on this platform)")
|
||||
|
||||
# =========================================================================
|
||||
# Cron Jobs
|
||||
|
||||
@@ -512,7 +512,7 @@ def _get_platform_tools(
|
||||
"""Resolve which individual toolset names are enabled for a platform."""
|
||||
from toolsets import resolve_toolset
|
||||
|
||||
platform_toolsets = config.get("platform_toolsets", {})
|
||||
platform_toolsets = config.get("platform_toolsets") or {}
|
||||
toolset_names = platform_toolsets.get(platform)
|
||||
|
||||
if toolset_names is None or not isinstance(toolset_names, list):
|
||||
|
||||
@@ -56,10 +56,10 @@ try:
|
||||
except ImportError:
|
||||
raise SystemExit(
|
||||
"Web UI requires fastapi and uvicorn.\n"
|
||||
"Run 'hermes web' to auto-install, or: pip install hermes-agent[web]"
|
||||
f"Install with: {sys.executable} -m pip install 'fastapi' 'uvicorn[standard]'"
|
||||
)
|
||||
|
||||
WEB_DIST = Path(__file__).parent / "web_dist"
|
||||
WEB_DIST = Path(os.environ["HERMES_WEB_DIST"]) if "HERMES_WEB_DIST" in os.environ else Path(__file__).parent / "web_dist"
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI(title="Hermes Agent", version=__version__)
|
||||
@@ -1444,38 +1444,8 @@ def _nous_poller(session_id: str) -> None:
|
||||
auth_state, min_key_ttl_seconds=300, timeout_seconds=15.0,
|
||||
force_refresh=False, force_mint=True,
|
||||
)
|
||||
# Save into credential pool same as auth_commands.py does
|
||||
from agent.credential_pool import (
|
||||
PooledCredential,
|
||||
load_pool,
|
||||
AUTH_TYPE_OAUTH,
|
||||
SOURCE_MANUAL,
|
||||
)
|
||||
pool = load_pool("nous")
|
||||
entry = PooledCredential.from_dict("nous", {
|
||||
**full_state,
|
||||
"label": "dashboard device_code",
|
||||
"auth_type": AUTH_TYPE_OAUTH,
|
||||
"source": f"{SOURCE_MANUAL}:dashboard_device_code",
|
||||
"base_url": full_state.get("inference_base_url"),
|
||||
})
|
||||
pool.add_entry(entry)
|
||||
# Also persist to auth store so get_nous_auth_status() sees it
|
||||
# (matches what _login_nous in auth.py does for the CLI flow).
|
||||
try:
|
||||
from hermes_cli.auth import (
|
||||
_load_auth_store, _save_provider_state, _save_auth_store,
|
||||
_auth_store_lock,
|
||||
)
|
||||
with _auth_store_lock():
|
||||
auth_store = _load_auth_store()
|
||||
_save_provider_state(auth_store, "nous", full_state)
|
||||
_save_auth_store(auth_store)
|
||||
except Exception as store_exc:
|
||||
_log.warning(
|
||||
"oauth/device: credential pool saved but auth store write failed "
|
||||
"(session=%s): %s", session_id, store_exc,
|
||||
)
|
||||
from hermes_cli.auth import persist_nous_credentials
|
||||
persist_nous_credentials(full_state)
|
||||
with _oauth_sessions_lock:
|
||||
sess["status"] = "approved"
|
||||
_log.info("oauth/device: nous login completed (session=%s)", session_id)
|
||||
|
||||
+2
-1
@@ -14,7 +14,8 @@ def get_hermes_home() -> Path:
|
||||
Reads HERMES_HOME env var, falls back to ~/.hermes.
|
||||
This is the single source of truth — all other copies should import this.
|
||||
"""
|
||||
return Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
val = os.environ.get("HERMES_HOME", "").strip()
|
||||
return Path(val) if val else Path.home() / ".hermes"
|
||||
|
||||
|
||||
def get_default_hermes_root() -> Path:
|
||||
|
||||
+57
-2
@@ -987,6 +987,22 @@ class SessionDB:
|
||||
|
||||
return sanitized.strip()
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _contains_cjk(text: str) -> bool:
|
||||
"""Check if text contains CJK (Chinese, Japanese, Korean) characters."""
|
||||
for ch in text:
|
||||
cp = ord(ch)
|
||||
if (0x4E00 <= cp <= 0x9FFF or # CJK Unified Ideographs
|
||||
0x3400 <= cp <= 0x4DBF or # CJK Extension A
|
||||
0x20000 <= cp <= 0x2A6DF or # CJK Extension B
|
||||
0x3000 <= cp <= 0x303F or # CJK Symbols
|
||||
0x3040 <= cp <= 0x309F or # Hiragana
|
||||
0x30A0 <= cp <= 0x30FF or # Katakana
|
||||
0xAC00 <= cp <= 0xD7AF): # Hangul Syllables
|
||||
return True
|
||||
return False
|
||||
|
||||
def search_messages(
|
||||
self,
|
||||
query: str,
|
||||
@@ -1062,8 +1078,47 @@ class SessionDB:
|
||||
cursor = self._conn.execute(sql, params)
|
||||
except sqlite3.OperationalError:
|
||||
# FTS5 query syntax error despite sanitization — return empty
|
||||
return []
|
||||
matches = [dict(row) for row in cursor.fetchall()]
|
||||
# unless query contains CJK (fall back to LIKE below)
|
||||
if not self._contains_cjk(query):
|
||||
return []
|
||||
matches = []
|
||||
else:
|
||||
matches = [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
# LIKE fallback for CJK queries: FTS5 default tokenizer splits CJK
|
||||
# characters individually, causing multi-character queries to fail.
|
||||
if not matches and self._contains_cjk(query):
|
||||
raw_query = query.strip('"').strip()
|
||||
like_where = ["m.content LIKE ?"]
|
||||
like_params: list = [f"%{raw_query}%"]
|
||||
if source_filter is not None:
|
||||
like_where.append(f"s.source IN ({','.join('?' for _ in source_filter)})")
|
||||
like_params.extend(source_filter)
|
||||
if exclude_sources is not None:
|
||||
like_where.append(f"s.source NOT IN ({','.join('?' for _ in exclude_sources)})")
|
||||
like_params.extend(exclude_sources)
|
||||
if role_filter:
|
||||
like_where.append(f"m.role IN ({','.join('?' for _ in role_filter)})")
|
||||
like_params.extend(role_filter)
|
||||
like_sql = f"""
|
||||
SELECT m.id, m.session_id, m.role,
|
||||
substr(m.content,
|
||||
max(1, instr(m.content, ?) - 40),
|
||||
120) AS snippet,
|
||||
m.content, m.timestamp, m.tool_name,
|
||||
s.source, s.model, s.started_at AS session_started
|
||||
FROM messages m
|
||||
JOIN sessions s ON s.id = m.session_id
|
||||
WHERE {' AND '.join(like_where)}
|
||||
ORDER BY m.timestamp DESC
|
||||
LIMIT ? OFFSET ?
|
||||
"""
|
||||
like_params.extend([limit, offset])
|
||||
# instr() parameter goes first in the bound list
|
||||
like_params = [raw_query] + like_params
|
||||
with self._lock:
|
||||
like_cursor = self._conn.execute(like_sql, like_params)
|
||||
matches = [dict(row) for row in like_cursor.fetchall()]
|
||||
|
||||
# Add surrounding context (1 message before + after each match).
|
||||
# Done outside the lock so we don't hold it across N sequential queries.
|
||||
|
||||
+2
-2
@@ -433,7 +433,7 @@ def create_mcp_server(event_bridge: Optional[EventBridge] = None) -> "FastMCP":
|
||||
if not _MCP_SERVER_AVAILABLE:
|
||||
raise ImportError(
|
||||
"MCP server requires the 'mcp' package. "
|
||||
"Install with: pip install 'hermes-agent[mcp]'"
|
||||
f"Install with: {sys.executable} -m pip install 'mcp'"
|
||||
)
|
||||
|
||||
mcp = FastMCP(
|
||||
@@ -838,7 +838,7 @@ def run_mcp_server(verbose: bool = False) -> None:
|
||||
if not _MCP_SERVER_AVAILABLE:
|
||||
print(
|
||||
"Error: MCP server requires the 'mcp' package.\n"
|
||||
"Install with: pip install 'hermes-agent[mcp]'",
|
||||
f"Install with: {sys.executable} -m pip install 'mcp'",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
+20
-6
@@ -43,6 +43,15 @@ from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def _effective_temperature_for_model(model: str) -> Optional[float]:
|
||||
"""Return a fixed temperature for models with strict sampling contracts."""
|
||||
try:
|
||||
from agent.auxiliary_client import _fixed_temperature_for_model
|
||||
except Exception:
|
||||
return None
|
||||
return _fixed_temperature_for_model(model)
|
||||
|
||||
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@@ -442,12 +451,17 @@ Complete the user's task step by step."""
|
||||
|
||||
# Make API call
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=api_messages,
|
||||
tools=self.tools,
|
||||
timeout=300.0
|
||||
)
|
||||
api_kwargs = {
|
||||
"model": self.model,
|
||||
"messages": api_messages,
|
||||
"tools": self.tools,
|
||||
"timeout": 300.0,
|
||||
}
|
||||
fixed_temperature = _effective_temperature_for_model(self.model)
|
||||
if fixed_temperature is not None:
|
||||
api_kwargs["temperature"] = fixed_temperature
|
||||
|
||||
response = self.client.chat.completions.create(**api_kwargs)
|
||||
except Exception as e:
|
||||
self.logger.error(f"API call failed: {e}")
|
||||
break
|
||||
|
||||
+2
-2
@@ -274,9 +274,9 @@ def get_tool_definitions(
|
||||
# execute_code" even when the API key isn't configured or the toolset is
|
||||
# disabled (#560-discord).
|
||||
if "execute_code" in available_tool_names:
|
||||
from tools.code_execution_tool import SANDBOX_ALLOWED_TOOLS, build_execute_code_schema
|
||||
from tools.code_execution_tool import SANDBOX_ALLOWED_TOOLS, build_execute_code_schema, _get_execution_mode
|
||||
sandbox_enabled = SANDBOX_ALLOWED_TOOLS & available_tool_names
|
||||
dynamic_schema = build_execute_code_schema(sandbox_enabled)
|
||||
dynamic_schema = build_execute_code_schema(sandbox_enabled, mode=_get_execution_mode())
|
||||
for i, td in enumerate(filtered_tools):
|
||||
if td.get("function", {}).get("name") == "execute_code":
|
||||
filtered_tools[i] = {"type": "function", "function": dynamic_schema}
|
||||
|
||||
@@ -103,6 +103,51 @@ json.dump(sorted(leaf_paths(DEFAULT_CONFIG)), sys.stdout, indent=2)
|
||||
echo "ok" > $out/result
|
||||
'';
|
||||
|
||||
# Verify bundled TUI is present and compiled
|
||||
bundled-tui = pkgs.runCommand "hermes-bundled-tui" { } ''
|
||||
set -e
|
||||
echo "=== Checking bundled TUI ==="
|
||||
test -d ${hermes-agent}/ui-tui || (echo "FAIL: ui-tui directory missing"; exit 1)
|
||||
echo "PASS: ui-tui directory exists"
|
||||
|
||||
test -f ${hermes-agent}/ui-tui/dist/entry.js || (echo "FAIL: compiled entry.js missing"; exit 1)
|
||||
echo "PASS: compiled entry.js present"
|
||||
|
||||
test -d ${hermes-agent}/ui-tui/node_modules || (echo "FAIL: node_modules missing"; exit 1)
|
||||
echo "PASS: node_modules present"
|
||||
|
||||
grep -q "HERMES_TUI_DIR" ${hermes-agent}/bin/hermes || \
|
||||
(echo "FAIL: HERMES_TUI_DIR not in wrapper"; exit 1)
|
||||
echo "PASS: HERMES_TUI_DIR set in wrapper"
|
||||
|
||||
echo "=== All bundled TUI checks passed ==="
|
||||
mkdir -p $out
|
||||
echo "ok" > $out/result
|
||||
'';
|
||||
|
||||
# Verify HERMES_NODE is set in wrapper and points to Node 20+
|
||||
# (string-width uses the /v regex flag which requires Node 20+)
|
||||
hermes-node = pkgs.runCommand "hermes-node-version" { } ''
|
||||
set -e
|
||||
echo "=== Checking HERMES_NODE in wrapper ==="
|
||||
grep -q "HERMES_NODE" ${hermes-agent}/bin/hermes || \
|
||||
(echo "FAIL: HERMES_NODE not set in wrapper"; exit 1)
|
||||
echo "PASS: HERMES_NODE present in wrapper"
|
||||
|
||||
HERMES_NODE=$(sed -n "s/^export HERMES_NODE='\(.*\)'/\1/p" ${hermes-agent}/bin/hermes)
|
||||
test -x "$HERMES_NODE" || (echo "FAIL: HERMES_NODE=$HERMES_NODE not executable"; exit 1)
|
||||
echo "PASS: HERMES_NODE executable at $HERMES_NODE"
|
||||
|
||||
NODE_MAJOR=$("$HERMES_NODE" --version | sed 's/^v//' | cut -d. -f1)
|
||||
test "$NODE_MAJOR" -ge 20 || \
|
||||
(echo "FAIL: Node v$NODE_MAJOR < 20, TUI needs /v regex flag support"; exit 1)
|
||||
echo "PASS: Node v$NODE_MAJOR >= 20"
|
||||
|
||||
echo "=== All HERMES_NODE checks passed ==="
|
||||
mkdir -p $out
|
||||
echo "ok" > $out/result
|
||||
'';
|
||||
|
||||
# Verify HERMES_MANAGED guard works on all mutation commands
|
||||
managed-guard = pkgs.runCommand "hermes-managed-guard" { } ''
|
||||
set -e
|
||||
|
||||
+15
-38
@@ -1,49 +1,26 @@
|
||||
# nix/devShell.nix — Fast dev shell with stamp-file optimization
|
||||
# nix/devShell.nix — Dev shell that delegates setup to each package
|
||||
#
|
||||
# Each package in inputsFrom exposes passthru.devShellHook — a bash snippet
|
||||
# with stamp-checked setup logic. This file collects and runs them all.
|
||||
{ inputs, ... }: {
|
||||
perSystem = { pkgs, ... }:
|
||||
perSystem = { pkgs, system, ... }:
|
||||
let
|
||||
python = pkgs.python311;
|
||||
hermes-agent = inputs.self.packages.${system}.default;
|
||||
hermes-tui = inputs.self.packages.${system}.tui;
|
||||
packages = [ hermes-agent hermes-tui ];
|
||||
in {
|
||||
devShells.default = pkgs.mkShell {
|
||||
inputsFrom = packages;
|
||||
packages = with pkgs; [
|
||||
python uv nodejs_20 ripgrep git openssh ffmpeg
|
||||
python311 uv nodejs_22 ripgrep git openssh ffmpeg
|
||||
];
|
||||
|
||||
shellHook = ''
|
||||
shellHook = let
|
||||
hooks = map (p: p.passthru.devShellHook or "") packages;
|
||||
combined = pkgs.lib.concatStringsSep "\n" (builtins.filter (h: h != "") hooks);
|
||||
in ''
|
||||
echo "Hermes Agent dev shell"
|
||||
|
||||
# Composite stamp: changes when nix python or uv change
|
||||
STAMP_VALUE="${python}:${pkgs.uv}"
|
||||
STAMP_FILE=".venv/.nix-stamp"
|
||||
|
||||
# Create venv if missing
|
||||
if [ ! -d .venv ]; then
|
||||
echo "Creating Python 3.11 venv..."
|
||||
uv venv .venv --python ${python}/bin/python3
|
||||
fi
|
||||
|
||||
source .venv/bin/activate
|
||||
|
||||
# Only install if stamp is stale or missing
|
||||
if [ ! -f "$STAMP_FILE" ] || [ "$(cat "$STAMP_FILE")" != "$STAMP_VALUE" ]; then
|
||||
echo "Installing Python dependencies..."
|
||||
uv pip install -e ".[all]"
|
||||
if [ -d mini-swe-agent ]; then
|
||||
uv pip install -e ./mini-swe-agent 2>/dev/null || true
|
||||
fi
|
||||
if [ -d tinker-atropos ]; then
|
||||
uv pip install -e ./tinker-atropos 2>/dev/null || true
|
||||
fi
|
||||
|
||||
# Install npm deps
|
||||
if [ -f package.json ] && [ ! -d node_modules ]; then
|
||||
echo "Installing npm dependencies..."
|
||||
npm install
|
||||
fi
|
||||
|
||||
echo "$STAMP_VALUE" > "$STAMP_FILE"
|
||||
fi
|
||||
|
||||
${combined}
|
||||
echo "Ready. Run 'hermes' to start."
|
||||
'';
|
||||
};
|
||||
|
||||
+11
-3
@@ -121,11 +121,19 @@
|
||||
# ── Provision apt packages (first boot only, cached in writable layer) ──
|
||||
# sudo: agent self-modification
|
||||
# nodejs/npm: writable node so npm i -g works (nix store copies are read-only)
|
||||
# curl: needed for uv installer
|
||||
# Node 22 via NodeSource — Ubuntu 24.04 ships Node 18 which is EOL.
|
||||
# curl: needed for uv installer + NodeSource setup
|
||||
if [ ! -f /var/lib/hermes-tools-provisioned ] && command -v apt-get >/dev/null 2>&1; then
|
||||
echo "First boot: provisioning agent tools..."
|
||||
apt-get update -qq
|
||||
apt-get install -y -qq sudo nodejs npm curl
|
||||
apt-get install -y -qq sudo curl ca-certificates gnupg
|
||||
mkdir -p /etc/apt/keyrings
|
||||
curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key \
|
||||
| gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg
|
||||
echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_22.x nodistro main" \
|
||||
> /etc/apt/sources.list.d/nodesource.list
|
||||
apt-get update -qq
|
||||
apt-get install -y -qq nodejs
|
||||
touch /var/lib/hermes-tools-provisioned
|
||||
fi
|
||||
|
||||
@@ -171,7 +179,7 @@
|
||||
# Package and entrypoint use stable symlinks (current-package, current-entrypoint)
|
||||
# so they can update without recreation. Env vars go through $HERMES_HOME/.env.
|
||||
containerIdentity = builtins.hashString "sha256" (builtins.toJSON {
|
||||
schema = 3; # bump when identity inputs change
|
||||
schema = 4; # bump when identity inputs change (4: Node 18→22 via NodeSource)
|
||||
image = cfg.container.image;
|
||||
extraVolumes = cfg.container.extraVolumes;
|
||||
extraOptions = cfg.container.extraOptions;
|
||||
|
||||
+91
-29
@@ -1,54 +1,116 @@
|
||||
# nix/packages.nix — Hermes Agent package built with uv2nix
|
||||
{ inputs, ... }: {
|
||||
perSystem = { pkgs, system, ... }:
|
||||
{ inputs, ... }:
|
||||
{
|
||||
perSystem =
|
||||
{ pkgs, inputs', ... }:
|
||||
let
|
||||
hermesVenv = pkgs.callPackage ./python.nix {
|
||||
inherit (inputs) uv2nix pyproject-nix pyproject-build-systems;
|
||||
};
|
||||
|
||||
hermesTui = pkgs.callPackage ./tui.nix {
|
||||
npm-lockfile-fix = inputs'.npm-lockfile-fix.packages.default;
|
||||
};
|
||||
|
||||
# Import bundled skills, excluding runtime caches
|
||||
bundledSkills = pkgs.lib.cleanSourceWith {
|
||||
src = ../skills;
|
||||
filter = path: _type:
|
||||
!(pkgs.lib.hasInfix "/index-cache/" path);
|
||||
filter = path: _type: !(pkgs.lib.hasInfix "/index-cache/" path);
|
||||
};
|
||||
|
||||
hermesWeb = pkgs.callPackage ./web.nix {
|
||||
npm-lockfile-fix = inputs'.npm-lockfile-fix.packages.default;
|
||||
};
|
||||
|
||||
runtimeDeps = with pkgs; [
|
||||
nodejs_20 ripgrep git openssh ffmpeg tirith
|
||||
nodejs_22
|
||||
ripgrep
|
||||
git
|
||||
openssh
|
||||
ffmpeg
|
||||
tirith
|
||||
];
|
||||
|
||||
runtimePath = pkgs.lib.makeBinPath runtimeDeps;
|
||||
in {
|
||||
packages.default = pkgs.stdenv.mkDerivation {
|
||||
pname = "hermes-agent";
|
||||
version = (builtins.fromTOML (builtins.readFile ../pyproject.toml)).project.version;
|
||||
|
||||
dontUnpack = true;
|
||||
dontBuild = true;
|
||||
nativeBuildInputs = [ pkgs.makeWrapper ];
|
||||
# Lockfile hashes for dev shell stamps
|
||||
pyprojectHash = builtins.hashString "sha256" (builtins.readFile ../pyproject.toml);
|
||||
uvLockHash =
|
||||
if builtins.pathExists ../uv.lock then
|
||||
builtins.hashString "sha256" (builtins.readFile ../uv.lock)
|
||||
else
|
||||
"none";
|
||||
in
|
||||
{
|
||||
packages = {
|
||||
default = pkgs.stdenv.mkDerivation {
|
||||
pname = "hermes-agent";
|
||||
version = (fromTOML (builtins.readFile ../pyproject.toml)).project.version;
|
||||
|
||||
installPhase = ''
|
||||
runHook preInstall
|
||||
dontUnpack = true;
|
||||
dontBuild = true;
|
||||
nativeBuildInputs = [ pkgs.makeWrapper ];
|
||||
|
||||
mkdir -p $out/share/hermes-agent $out/bin
|
||||
cp -r ${bundledSkills} $out/share/hermes-agent/skills
|
||||
installPhase = ''
|
||||
runHook preInstall
|
||||
|
||||
${pkgs.lib.concatMapStringsSep "\n" (name: ''
|
||||
makeWrapper ${hermesVenv}/bin/${name} $out/bin/${name} \
|
||||
--suffix PATH : "${runtimePath}" \
|
||||
--set HERMES_BUNDLED_SKILLS $out/share/hermes-agent/skills
|
||||
'') [ "hermes" "hermes-agent" "hermes-acp" ]}
|
||||
mkdir -p $out/share/hermes-agent $out/bin
|
||||
cp -r ${bundledSkills} $out/share/hermes-agent/skills
|
||||
cp -r ${hermesWeb} $out/share/hermes-agent/web_dist
|
||||
|
||||
runHook postInstall
|
||||
'';
|
||||
# copy pre-built TUI (same layout as dev: ui-tui/dist/ + node_modules/)
|
||||
mkdir -p $out/ui-tui
|
||||
cp -r ${hermesTui}/lib/hermes-tui/* $out/ui-tui/
|
||||
|
||||
meta = with pkgs.lib; {
|
||||
description = "AI agent with advanced tool-calling capabilities";
|
||||
homepage = "https://github.com/NousResearch/hermes-agent";
|
||||
mainProgram = "hermes";
|
||||
license = licenses.mit;
|
||||
platforms = platforms.unix;
|
||||
${pkgs.lib.concatMapStringsSep "\n"
|
||||
(name: ''
|
||||
makeWrapper ${hermesVenv}/bin/${name} $out/bin/${name} \
|
||||
--suffix PATH : "${runtimePath}" \
|
||||
--set HERMES_BUNDLED_SKILLS $out/share/hermes-agent/skills \
|
||||
--set HERMES_WEB_DIST $out/share/hermes-agent/web_dist \
|
||||
--set HERMES_TUI_DIR $out/ui-tui \
|
||||
--set HERMES_PYTHON ${hermesVenv}/bin/python3 \
|
||||
--set HERMES_NODE ${pkgs.nodejs_22}/bin/node
|
||||
'')
|
||||
[
|
||||
"hermes"
|
||||
"hermes-agent"
|
||||
"hermes-acp"
|
||||
]
|
||||
}
|
||||
|
||||
runHook postInstall
|
||||
'';
|
||||
|
||||
passthru.devShellHook = ''
|
||||
STAMP=".nix-stamps/hermes-agent"
|
||||
STAMP_VALUE="${pyprojectHash}:${uvLockHash}"
|
||||
if [ ! -f "$STAMP" ] || [ "$(cat "$STAMP")" != "$STAMP_VALUE" ]; then
|
||||
echo "hermes-agent: installing Python dependencies..."
|
||||
uv venv .venv --python ${pkgs.python311}/bin/python3 2>/dev/null || true
|
||||
source .venv/bin/activate
|
||||
uv pip install -e ".[all]"
|
||||
[ -d mini-swe-agent ] && uv pip install -e ./mini-swe-agent 2>/dev/null || true
|
||||
[ -d tinker-atropos ] && uv pip install -e ./tinker-atropos 2>/dev/null || true
|
||||
mkdir -p .nix-stamps
|
||||
echo "$STAMP_VALUE" > "$STAMP"
|
||||
else
|
||||
source .venv/bin/activate
|
||||
export HERMES_PYTHON=${hermesVenv}/bin/python3
|
||||
fi
|
||||
'';
|
||||
|
||||
meta = with pkgs.lib; {
|
||||
description = "AI agent with advanced tool-calling capabilities";
|
||||
homepage = "https://github.com/NousResearch/hermes-agent";
|
||||
mainProgram = "hermes";
|
||||
license = licenses.mit;
|
||||
platforms = platforms.unix;
|
||||
};
|
||||
};
|
||||
|
||||
tui = hermesTui;
|
||||
web = hermesWeb;
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
@@ -35,6 +35,20 @@ let
|
||||
};
|
||||
};
|
||||
|
||||
# Legacy alibabacloud packages ship only sdists with setup.py/setup.cfg
|
||||
# and no pyproject.toml, so setuptools isn't declared as a build dep.
|
||||
buildSystemOverrides = final: prev: builtins.mapAttrs
|
||||
(name: _: prev.${name}.overrideAttrs (old: {
|
||||
nativeBuildInputs = (old.nativeBuildInputs or [ ]) ++ [ final.setuptools ];
|
||||
}))
|
||||
(lib.genAttrs [
|
||||
"alibabacloud-credentials-api"
|
||||
"alibabacloud-endpoint-util"
|
||||
"alibabacloud-gateway-dingtalk"
|
||||
"alibabacloud-gateway-spi"
|
||||
"alibabacloud-tea"
|
||||
] (_: null));
|
||||
|
||||
pythonPackageOverrides = final: _prev:
|
||||
if isAarch64Darwin then {
|
||||
numpy = mkPrebuiltOverride final python311.pkgs.numpy { };
|
||||
@@ -75,6 +89,7 @@ let
|
||||
(lib.composeManyExtensions [
|
||||
pyproject-build-systems.overlays.default
|
||||
overlay
|
||||
buildSystemOverrides
|
||||
pythonPackageOverrides
|
||||
]);
|
||||
in
|
||||
|
||||
+77
@@ -0,0 +1,77 @@
|
||||
# nix/tui.nix — Hermes TUI (Ink/React) compiled with tsc and bundled
|
||||
{ pkgs, npm-lockfile-fix, ... }:
|
||||
let
|
||||
src = ../ui-tui;
|
||||
npmDeps = pkgs.fetchNpmDeps {
|
||||
inherit src;
|
||||
hash = "sha256-mG3vpgGi4ljt4X3XIf3I/5mIcm+rVTUAmx2DQ6YVA90=";
|
||||
};
|
||||
|
||||
packageJson = builtins.fromJSON (builtins.readFile (src + "/package.json"));
|
||||
version = packageJson.version;
|
||||
|
||||
npmLockHash = builtins.hashString "sha256" (builtins.readFile ../ui-tui/package-lock.json);
|
||||
in
|
||||
pkgs.buildNpmPackage {
|
||||
pname = "hermes-tui";
|
||||
inherit src npmDeps version;
|
||||
|
||||
doCheck = false;
|
||||
|
||||
installPhase = ''
|
||||
runHook preInstall
|
||||
|
||||
mkdir -p $out/lib/hermes-tui
|
||||
|
||||
cp -r dist $out/lib/hermes-tui/dist
|
||||
|
||||
# runtime node_modules
|
||||
cp -r node_modules $out/lib/hermes-tui/node_modules
|
||||
|
||||
# @hermes/ink is a file: dependency, we need to copy it in fr
|
||||
rm -f $out/lib/hermes-tui/node_modules/@hermes/ink
|
||||
cp -r packages/hermes-ink $out/lib/hermes-tui/node_modules/@hermes/ink
|
||||
|
||||
# package.json needed for "type": "module" resolution
|
||||
cp package.json $out/lib/hermes-tui/
|
||||
|
||||
runHook postInstall
|
||||
'';
|
||||
|
||||
nativeBuildInputs = [
|
||||
(pkgs.writeShellScriptBin "update_tui_lockfile" ''
|
||||
set -euox pipefail
|
||||
|
||||
# get root of repo
|
||||
REPO_ROOT=$(git rev-parse --show-toplevel)
|
||||
|
||||
# cd into ui-tui and reinstall
|
||||
cd "$REPO_ROOT/ui-tui"
|
||||
rm -rf node_modules/
|
||||
npm cache clean --force
|
||||
CI=true npm install # ci env var to suppress annoying unicode install banner lag
|
||||
${pkgs.lib.getExe npm-lockfile-fix} ./package-lock.json
|
||||
|
||||
NIX_FILE="$REPO_ROOT/nix/tui.nix"
|
||||
# compute the new hash
|
||||
sed -i "s/hash = \"[^\"]*\";/hash = \"\";/" $NIX_FILE
|
||||
NIX_OUTPUT=$(nix build .#tui 2>&1 || true)
|
||||
NEW_HASH=$(echo "$NIX_OUTPUT" | grep 'got:' | awk '{print $2}')
|
||||
echo got new hash $NEW_HASH
|
||||
sed -i "s|hash = \"[^\"]*\";|hash = \"$NEW_HASH\";|" $NIX_FILE
|
||||
nix build .#tui
|
||||
echo "Updated npm hash in $NIX_FILE to $NEW_HASH"
|
||||
'')
|
||||
];
|
||||
|
||||
passthru.devShellHook = ''
|
||||
STAMP=".nix-stamps/hermes-tui"
|
||||
STAMP_VALUE="${npmLockHash}"
|
||||
if [ ! -f "$STAMP" ] || [ "$(cat "$STAMP")" != "$STAMP_VALUE" ]; then
|
||||
echo "hermes-tui: installing npm dependencies..."
|
||||
cd ui-tui && CI=true npm install --silent --no-fund --no-audit 2>/dev/null && cd ..
|
||||
mkdir -p .nix-stamps
|
||||
echo "$STAMP_VALUE" > "$STAMP"
|
||||
fi
|
||||
'';
|
||||
}
|
||||
+63
@@ -0,0 +1,63 @@
|
||||
# nix/web.nix — Hermes Web Dashboard (Vite/React) frontend build
|
||||
{ pkgs, npm-lockfile-fix, ... }:
|
||||
let
|
||||
src = ../web;
|
||||
npmDeps = pkgs.fetchNpmDeps {
|
||||
inherit src;
|
||||
hash = "sha256-Y0pOzdFG8BLjfvCLmsvqYpjxFjAQabXp1i7X9W/cCU4=";
|
||||
};
|
||||
|
||||
npmLockHash = builtins.hashString "sha256" (builtins.readFile ../web/package-lock.json);
|
||||
in
|
||||
pkgs.buildNpmPackage {
|
||||
pname = "hermes-web";
|
||||
version = "0.0.0";
|
||||
inherit src npmDeps;
|
||||
|
||||
doCheck = false;
|
||||
|
||||
buildPhase = ''
|
||||
npx tsc -b
|
||||
npx vite build --outDir dist
|
||||
'';
|
||||
|
||||
installPhase = ''
|
||||
runHook preInstall
|
||||
cp -r dist $out
|
||||
runHook postInstall
|
||||
'';
|
||||
|
||||
nativeBuildInputs = [
|
||||
(pkgs.writeShellScriptBin "update_web_lockfile" ''
|
||||
set -euox pipefail
|
||||
|
||||
REPO_ROOT=$(git rev-parse --show-toplevel)
|
||||
|
||||
cd "$REPO_ROOT/web"
|
||||
rm -rf node_modules/
|
||||
npm cache clean --force
|
||||
CI=true npm install
|
||||
${pkgs.lib.getExe npm-lockfile-fix} ./package-lock.json
|
||||
|
||||
NIX_FILE="$REPO_ROOT/nix/web.nix"
|
||||
sed -i "s/hash = \"[^\"]*\";/hash = \"\";/" $NIX_FILE
|
||||
NIX_OUTPUT=$(nix build .#web 2>&1 || true)
|
||||
NEW_HASH=$(echo "$NIX_OUTPUT" | grep 'got:' | awk '{print $2}')
|
||||
echo got new hash $NEW_HASH
|
||||
sed -i "s|hash = \"[^\"]*\";|hash = \"$NEW_HASH\";|" $NIX_FILE
|
||||
nix build .#web
|
||||
echo "Updated npm hash in $NIX_FILE to $NEW_HASH"
|
||||
'')
|
||||
];
|
||||
|
||||
passthru.devShellHook = ''
|
||||
STAMP=".nix-stamps/hermes-web"
|
||||
STAMP_VALUE="${npmLockHash}"
|
||||
if [ ! -f "$STAMP" ] || [ "$(cat "$STAMP")" != "$STAMP_VALUE" ]; then
|
||||
echo "hermes-web: installing npm dependencies..."
|
||||
cd web && CI=true npm install --silent --no-fund --no-audit 2>/dev/null && cd ..
|
||||
mkdir -p .nix-stamps
|
||||
echo "$STAMP_VALUE" > "$STAMP"
|
||||
fi
|
||||
'';
|
||||
}
|
||||
@@ -7,7 +7,7 @@ license: MIT
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [atropos, rl, environments, training, reinforcement-learning, reward-functions]
|
||||
related_skills: [axolotl, grpo-rl-training, trl-fine-tuning, lm-evaluation-harness]
|
||||
related_skills: [axolotl, fine-tuning-with-trl, lm-evaluation-harness]
|
||||
---
|
||||
|
||||
# Hermes Agent Atropos Environments
|
||||
|
||||
+5
-5
@@ -39,8 +39,8 @@ dependencies = [
|
||||
[project.optional-dependencies]
|
||||
modal = ["modal>=1.0.0,<2"]
|
||||
daytona = ["daytona>=0.148.0,<1"]
|
||||
dev = ["debugpy>=1.8.0,<2", "pytest>=9.0.2,<10", "pytest-asyncio>=1.3.0,<2", "pytest-xdist>=3.0,<4", "pytest-split>=0.9,<1", "mcp>=1.2.0,<2"]
|
||||
messaging = ["python-telegram-bot[webhooks]>=22.6,<23", "discord.py[voice]>=2.7.1,<3", "aiohttp>=3.13.3,<4", "slack-bolt>=1.18.0,<2", "slack-sdk>=3.27.0,<4"]
|
||||
dev = ["debugpy>=1.8.0,<2", "pytest>=9.0.2,<10", "pytest-asyncio>=1.3.0,<2", "pytest-xdist>=3.0,<4", "mcp>=1.2.0,<2"]
|
||||
messaging = ["python-telegram-bot[webhooks]>=22.6,<23", "discord.py[voice]>=2.7.1,<3", "aiohttp>=3.13.3,<4", "slack-bolt>=1.18.0,<2", "slack-sdk>=3.27.0,<4", "qrcode>=7.0,<8"]
|
||||
cron = ["croniter>=6.0.0,<7"]
|
||||
slack = ["slack-bolt>=1.18.0,<2", "slack-sdk>=3.27.0,<4"]
|
||||
matrix = ["mautrix[encryption]>=0.20,<1", "Markdown>=3.6,<4", "aiosqlite>=0.20", "asyncpg>=0.29"]
|
||||
@@ -76,8 +76,8 @@ termux = [
|
||||
"hermes-agent[honcho]",
|
||||
"hermes-agent[acp]",
|
||||
]
|
||||
dingtalk = ["dingtalk-stream>=0.1.0,<1"]
|
||||
feishu = ["lark-oapi>=1.5.3,<2"]
|
||||
dingtalk = ["dingtalk-stream>=0.20,<1", "alibabacloud-dingtalk>=2.0.0", "qrcode>=7.0,<8"]
|
||||
feishu = ["lark-oapi>=1.5.3,<2", "qrcode>=7.0,<8"]
|
||||
web = ["fastapi>=0.104.0,<1", "uvicorn[standard]>=0.24.0,<1"]
|
||||
rl = [
|
||||
"atroposlib @ git+https://github.com/NousResearch/atropos.git@c20c85256e5a45ad31edf8b7276e9c5ee1995a30",
|
||||
@@ -126,7 +126,7 @@ py-modules = ["run_agent", "model_tools", "toolsets", "batch_runner", "trajector
|
||||
hermes_cli = ["web_dist/**/*"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["agent", "tools", "tools.*", "hermes_cli", "gateway", "gateway.*", "cron", "acp_adapter", "plugins", "plugins.*"]
|
||||
include = ["agent", "tools", "tools.*", "hermes_cli", "gateway", "gateway.*", "tui_gateway", "tui_gateway.*", "cron", "acp_adapter", "plugins", "plugins.*"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
|
||||
+484
-25
@@ -353,12 +353,50 @@ def _sanitize_surrogates(text: str) -> str:
|
||||
return text
|
||||
|
||||
|
||||
def _sanitize_structure_surrogates(payload: Any) -> bool:
|
||||
"""Replace surrogate code points in nested dict/list payloads in-place.
|
||||
|
||||
Mirror of ``_sanitize_structure_non_ascii`` but for surrogate recovery.
|
||||
Used to scrub nested structured fields (e.g. ``reasoning_details`` — an
|
||||
array of dicts with ``summary``/``text`` strings) that flat per-field
|
||||
checks don't reach. Returns True if any surrogates were replaced.
|
||||
"""
|
||||
found = False
|
||||
|
||||
def _walk(node):
|
||||
nonlocal found
|
||||
if isinstance(node, dict):
|
||||
for key, value in node.items():
|
||||
if isinstance(value, str):
|
||||
if _SURROGATE_RE.search(value):
|
||||
node[key] = _SURROGATE_RE.sub('\ufffd', value)
|
||||
found = True
|
||||
elif isinstance(value, (dict, list)):
|
||||
_walk(value)
|
||||
elif isinstance(node, list):
|
||||
for idx, value in enumerate(node):
|
||||
if isinstance(value, str):
|
||||
if _SURROGATE_RE.search(value):
|
||||
node[idx] = _SURROGATE_RE.sub('\ufffd', value)
|
||||
found = True
|
||||
elif isinstance(value, (dict, list)):
|
||||
_walk(value)
|
||||
|
||||
_walk(payload)
|
||||
return found
|
||||
|
||||
|
||||
def _sanitize_messages_surrogates(messages: list) -> bool:
|
||||
"""Sanitize surrogate characters from all string content in a messages list.
|
||||
|
||||
Walks message dicts in-place. Returns True if any surrogates were found
|
||||
and replaced, False otherwise. Covers content/text, name, and tool call
|
||||
metadata/arguments so retries don't fail on a non-content field.
|
||||
and replaced, False otherwise. Covers content/text, name, tool call
|
||||
metadata/arguments, AND any additional string or nested structured fields
|
||||
(``reasoning``, ``reasoning_content``, ``reasoning_details``, etc.) so
|
||||
retries don't fail on a non-content field. Byte-level reasoning models
|
||||
(xiaomi/mimo, kimi, glm) can emit lone surrogates in reasoning output
|
||||
that flow through to ``api_messages["reasoning_content"]`` on the next
|
||||
turn and crash json.dumps inside the OpenAI SDK.
|
||||
"""
|
||||
found = False
|
||||
for msg in messages:
|
||||
@@ -398,6 +436,21 @@ def _sanitize_messages_surrogates(messages: list) -> bool:
|
||||
if isinstance(fn_args, str) and _SURROGATE_RE.search(fn_args):
|
||||
fn["arguments"] = _SURROGATE_RE.sub('\ufffd', fn_args)
|
||||
found = True
|
||||
# Walk any additional string / nested fields (reasoning,
|
||||
# reasoning_content, reasoning_details, etc.) — surrogates from
|
||||
# byte-level reasoning models (xiaomi/mimo, kimi, glm) can lurk
|
||||
# in these fields and aren't covered by the per-field checks above.
|
||||
# Matches _sanitize_messages_non_ascii's coverage (PR #10537).
|
||||
for key, value in msg.items():
|
||||
if key in {"content", "name", "tool_calls", "role"}:
|
||||
continue
|
||||
if isinstance(value, str):
|
||||
if _SURROGATE_RE.search(value):
|
||||
msg[key] = _SURROGATE_RE.sub('\ufffd', value)
|
||||
found = True
|
||||
elif isinstance(value, (dict, list)):
|
||||
if _sanitize_structure_surrogates(value):
|
||||
found = True
|
||||
return found
|
||||
|
||||
|
||||
@@ -778,6 +831,26 @@ class AIAgent:
|
||||
self._execution_thread_id: int | None = None # Set at run_conversation() start
|
||||
self._interrupt_thread_signal_pending = False
|
||||
self._client_lock = threading.RLock()
|
||||
|
||||
# /steer mechanism — inject a user note into the next tool result
|
||||
# without interrupting the agent. Unlike interrupt(), steer() does
|
||||
# NOT set _interrupt_requested; it waits for the current tool batch
|
||||
# to finish naturally, then the drain hook appends the text to the
|
||||
# last tool result's content so the model sees it on its next
|
||||
# iteration. Message-role alternation is preserved (we modify an
|
||||
# existing tool message rather than inserting a new user turn).
|
||||
self._pending_steer: Optional[str] = None
|
||||
self._pending_steer_lock = threading.Lock()
|
||||
|
||||
# Concurrent-tool worker thread tracking. `_execute_tool_calls_concurrent`
|
||||
# runs each tool on its own ThreadPoolExecutor worker — those worker
|
||||
# threads have tids distinct from `_execution_thread_id`, so
|
||||
# `_set_interrupt(True, _execution_thread_id)` alone does NOT cause
|
||||
# `is_interrupted()` inside the worker to return True. Track the
|
||||
# workers here so `interrupt()` / `clear_interrupt()` can fan out to
|
||||
# their tids explicitly.
|
||||
self._tool_worker_threads: set[int] = set()
|
||||
self._tool_worker_threads_lock = threading.Lock()
|
||||
|
||||
# Subagent delegation state
|
||||
self._delegate_depth = 0 # 0 = top-level agent, incremented for children
|
||||
@@ -981,6 +1054,16 @@ class AIAgent:
|
||||
}
|
||||
elif "portal.qwen.ai" in effective_base.lower():
|
||||
client_kwargs["default_headers"] = _qwen_portal_headers()
|
||||
elif "generativelanguage.googleapis.com" in effective_base.lower():
|
||||
# Google's OpenAI-compatible endpoint only accepts x-goog-api-key.
|
||||
# The OpenAI SDK auto-injects Authorization: Bearer when api_key= is
|
||||
# set to a real value, causing HTTP 400 "Multiple authentication
|
||||
# credentials received". Pass a placeholder so the SDK does not
|
||||
# emit Bearer, and carry the real key via x-goog-api-key instead.
|
||||
# Fixes: https://github.com/NousResearch/hermes-agent/issues/7893
|
||||
real_key = client_kwargs["api_key"]
|
||||
client_kwargs["api_key"] = "not-used"
|
||||
client_kwargs["default_headers"] = {"x-goog-api-key": real_key}
|
||||
else:
|
||||
# No explicit creds — use the centralized provider router
|
||||
from agent.auxiliary_client import resolve_provider_client
|
||||
@@ -3138,6 +3221,25 @@ class AIAgent:
|
||||
# interrupt signal until startup completes instead of targeting
|
||||
# the caller thread by mistake.
|
||||
self._interrupt_thread_signal_pending = True
|
||||
# Fan out to concurrent-tool worker threads. Those workers run tools
|
||||
# on their own tids (ThreadPoolExecutor workers), so `is_interrupted()`
|
||||
# inside a tool only sees an interrupt when their specific tid is in
|
||||
# the `_interrupted_threads` set. Without this propagation, an
|
||||
# already-running concurrent tool (e.g. a terminal command hung on
|
||||
# network I/O) never notices the interrupt and has to run to its own
|
||||
# timeout. See `_run_tool` for the matching entry/exit bookkeeping.
|
||||
# `getattr` fallback covers test stubs that build AIAgent via
|
||||
# object.__new__ and skip __init__.
|
||||
_tracker = getattr(self, "_tool_worker_threads", None)
|
||||
_tracker_lock = getattr(self, "_tool_worker_threads_lock", None)
|
||||
if _tracker is not None and _tracker_lock is not None:
|
||||
with _tracker_lock:
|
||||
_worker_tids = list(_tracker)
|
||||
for _wtid in _worker_tids:
|
||||
try:
|
||||
_set_interrupt(True, _wtid)
|
||||
except Exception:
|
||||
pass
|
||||
# Propagate interrupt to any running child agents (subagent delegation)
|
||||
with self._active_children_lock:
|
||||
children_copy = list(self._active_children)
|
||||
@@ -3156,6 +3258,146 @@ class AIAgent:
|
||||
self._interrupt_thread_signal_pending = False
|
||||
if self._execution_thread_id is not None:
|
||||
_set_interrupt(False, self._execution_thread_id)
|
||||
# Also clear any concurrent-tool worker thread bits. Tracked
|
||||
# workers normally clear their own bit on exit, but an explicit
|
||||
# clear here guarantees no stale interrupt can survive a turn
|
||||
# boundary and fire on a subsequent, unrelated tool call that
|
||||
# happens to get scheduled onto the same recycled worker tid.
|
||||
# `getattr` fallback covers test stubs that build AIAgent via
|
||||
# object.__new__ and skip __init__.
|
||||
_tracker = getattr(self, "_tool_worker_threads", None)
|
||||
_tracker_lock = getattr(self, "_tool_worker_threads_lock", None)
|
||||
if _tracker is not None and _tracker_lock is not None:
|
||||
with _tracker_lock:
|
||||
_worker_tids = list(_tracker)
|
||||
for _wtid in _worker_tids:
|
||||
try:
|
||||
_set_interrupt(False, _wtid)
|
||||
except Exception:
|
||||
pass
|
||||
# A hard interrupt supersedes any pending /steer — the steer was
|
||||
# meant for the agent's next tool-call iteration, which will no
|
||||
# longer happen. Drop it instead of surprising the user with a
|
||||
# late injection on the post-interrupt turn.
|
||||
_steer_lock = getattr(self, "_pending_steer_lock", None)
|
||||
if _steer_lock is not None:
|
||||
with _steer_lock:
|
||||
self._pending_steer = None
|
||||
|
||||
def steer(self, text: str) -> bool:
|
||||
"""
|
||||
Inject a user message into the next tool result without interrupting.
|
||||
|
||||
Unlike interrupt(), this does NOT stop the current tool call. The
|
||||
text is stashed and the agent loop appends it to the LAST tool
|
||||
result's content once the current tool batch finishes. The model
|
||||
sees the steer as part of the tool output on its next iteration.
|
||||
|
||||
Thread-safe: callable from gateway/CLI/TUI threads. Multiple calls
|
||||
before the drain point concatenate with newlines.
|
||||
|
||||
Args:
|
||||
text: The user text to inject. Empty strings are ignored.
|
||||
|
||||
Returns:
|
||||
True if the steer was accepted, False if the text was empty.
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return False
|
||||
cleaned = text.strip()
|
||||
_lock = getattr(self, "_pending_steer_lock", None)
|
||||
if _lock is None:
|
||||
# Test stubs that built AIAgent via object.__new__ skip __init__.
|
||||
# Fall back to direct attribute set; no concurrent callers expected
|
||||
# in those stubs.
|
||||
existing = getattr(self, "_pending_steer", None)
|
||||
self._pending_steer = (existing + "\n" + cleaned) if existing else cleaned
|
||||
return True
|
||||
with _lock:
|
||||
if self._pending_steer:
|
||||
self._pending_steer = self._pending_steer + "\n" + cleaned
|
||||
else:
|
||||
self._pending_steer = cleaned
|
||||
return True
|
||||
|
||||
def _drain_pending_steer(self) -> Optional[str]:
|
||||
"""Return the pending steer text (if any) and clear the slot.
|
||||
|
||||
Safe to call from the agent execution thread after appending tool
|
||||
results. Returns None when no steer is pending.
|
||||
"""
|
||||
_lock = getattr(self, "_pending_steer_lock", None)
|
||||
if _lock is None:
|
||||
text = getattr(self, "_pending_steer", None)
|
||||
self._pending_steer = None
|
||||
return text
|
||||
with _lock:
|
||||
text = self._pending_steer
|
||||
self._pending_steer = None
|
||||
return text
|
||||
|
||||
def _apply_pending_steer_to_tool_results(self, messages: list, num_tool_msgs: int) -> None:
|
||||
"""Append any pending /steer text to the last tool result in this turn.
|
||||
|
||||
Called at the end of a tool-call batch, before the next API call.
|
||||
The steer is appended to the last ``role:"tool"`` message's content
|
||||
with a clear marker so the model understands it came from the user
|
||||
and NOT from the tool itself. Role alternation is preserved —
|
||||
nothing new is inserted, we only modify existing content.
|
||||
|
||||
Args:
|
||||
messages: The running messages list.
|
||||
num_tool_msgs: Number of tool results appended in this batch;
|
||||
used to locate the tail slice safely.
|
||||
"""
|
||||
if num_tool_msgs <= 0 or not messages:
|
||||
return
|
||||
steer_text = self._drain_pending_steer()
|
||||
if not steer_text:
|
||||
return
|
||||
# Find the last tool-role message in the recent tail. Skipping
|
||||
# non-tool messages defends against future code appending
|
||||
# something else at the boundary.
|
||||
target_idx = None
|
||||
for j in range(len(messages) - 1, max(len(messages) - num_tool_msgs - 1, -1), -1):
|
||||
msg = messages[j]
|
||||
if isinstance(msg, dict) and msg.get("role") == "tool":
|
||||
target_idx = j
|
||||
break
|
||||
if target_idx is None:
|
||||
# No tool result in this batch (e.g. all skipped by interrupt);
|
||||
# put the steer back so the caller's fallback path can deliver
|
||||
# it as a normal next-turn user message.
|
||||
_lock = getattr(self, "_pending_steer_lock", None)
|
||||
if _lock is not None:
|
||||
with _lock:
|
||||
if self._pending_steer:
|
||||
self._pending_steer = self._pending_steer + "\n" + steer_text
|
||||
else:
|
||||
self._pending_steer = steer_text
|
||||
else:
|
||||
existing = getattr(self, "_pending_steer", None)
|
||||
self._pending_steer = (existing + "\n" + steer_text) if existing else steer_text
|
||||
return
|
||||
marker = f"\n\n[USER STEER (injected mid-run, not tool output): {steer_text}]"
|
||||
existing_content = messages[target_idx].get("content", "")
|
||||
if not isinstance(existing_content, str):
|
||||
# Anthropic multimodal content blocks — preserve them and append
|
||||
# a text block at the end.
|
||||
try:
|
||||
blocks = list(existing_content) if existing_content else []
|
||||
blocks.append({"type": "text", "text": marker.lstrip()})
|
||||
messages[target_idx]["content"] = blocks
|
||||
except Exception:
|
||||
# Fall back to string replacement if content shape is unexpected.
|
||||
messages[target_idx]["content"] = f"{existing_content}{marker}"
|
||||
else:
|
||||
messages[target_idx]["content"] = existing_content + marker
|
||||
logger.info(
|
||||
"Delivered /steer to agent after tool batch (%d chars): %s",
|
||||
len(steer_text),
|
||||
steer_text[:120] + ("..." if len(steer_text) > 120 else ""),
|
||||
)
|
||||
|
||||
def _touch_activity(self, desc: str) -> None:
|
||||
"""Update the last-activity timestamp and description (thread-safe)."""
|
||||
@@ -3242,6 +3484,53 @@ class AIAgent:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def release_clients(self) -> None:
|
||||
"""Release LLM client resources WITHOUT tearing down session tool state.
|
||||
|
||||
Used by the gateway when evicting this agent from _agent_cache for
|
||||
memory-management reasons (LRU cap or idle TTL) — the session may
|
||||
resume at any time with a freshly-built AIAgent that reuses the
|
||||
same task_id / session_id, so we must NOT kill:
|
||||
- process_registry entries for task_id (user's bg shells)
|
||||
- terminal sandbox for task_id (cwd, env, shell state)
|
||||
- browser daemon for task_id (open tabs, cookies)
|
||||
- memory provider (has its own lifecycle; keeps running)
|
||||
|
||||
We DO close:
|
||||
- OpenAI/httpx client pool (big chunk of held memory + sockets;
|
||||
the rebuilt agent gets a fresh client anyway)
|
||||
- Active child subagents (per-turn artefacts; safe to drop)
|
||||
|
||||
Safe to call multiple times. Distinct from close() — which is the
|
||||
hard teardown for actual session boundaries (/new, /reset, session
|
||||
expiry).
|
||||
"""
|
||||
# Close active child agents (per-turn; no cross-turn persistence).
|
||||
try:
|
||||
with self._active_children_lock:
|
||||
children = list(self._active_children)
|
||||
self._active_children.clear()
|
||||
for child in children:
|
||||
try:
|
||||
child.release_clients()
|
||||
except Exception:
|
||||
# Fall back to full close on children; they're per-turn.
|
||||
try:
|
||||
child.close()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Close the OpenAI/httpx client to release sockets immediately.
|
||||
try:
|
||||
client = getattr(self, "client", None)
|
||||
if client is not None:
|
||||
self._close_openai_client(client, reason="cache_evict", shared=True)
|
||||
self.client = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
"""Release all resources held by this agent instance.
|
||||
|
||||
@@ -4956,6 +5245,17 @@ class AIAgent:
|
||||
self._client_kwargs["default_headers"] = {"User-Agent": "KimiCLI/1.30.0"}
|
||||
elif "portal.qwen.ai" in normalized:
|
||||
self._client_kwargs["default_headers"] = _qwen_portal_headers()
|
||||
elif "generativelanguage.googleapis.com" in normalized:
|
||||
# Google's endpoint rejects Bearer tokens; use x-goog-api-key instead.
|
||||
# Swap the real key out of api_key and into the header so the OpenAI
|
||||
# SDK does not emit Authorization: Bearer.
|
||||
# Fixes: https://github.com/NousResearch/hermes-agent/issues/7893
|
||||
real_key = self._client_kwargs.get("api_key", "")
|
||||
if real_key and real_key != "not-used":
|
||||
self._client_kwargs["api_key"] = "not-used"
|
||||
self._client_kwargs["default_headers"] = {
|
||||
"x-goog-api-key": real_key or self._client_kwargs.get("api_key", ""),
|
||||
}
|
||||
else:
|
||||
self._client_kwargs.pop("default_headers", None)
|
||||
|
||||
@@ -5412,7 +5712,7 @@ class AIAgent:
|
||||
raise result["error"]
|
||||
return result["response"]
|
||||
|
||||
result = {"response": None, "error": None}
|
||||
result = {"response": None, "error": None, "partial_tool_names": []}
|
||||
request_client_holder = {"client": None}
|
||||
first_delta_fired = {"done": False}
|
||||
deltas_were_sent = {"yes": False} # Track if any deltas were fired (for fallback)
|
||||
@@ -5584,6 +5884,14 @@ class AIAgent:
|
||||
tool_gen_notified.add(idx)
|
||||
_fire_first_delta()
|
||||
self._fire_tool_gen_started(name)
|
||||
# Record the partial tool-call name so the outer
|
||||
# stub-builder can surface a user-visible warning
|
||||
# if streaming dies before this tool's arguments
|
||||
# are fully delivered. Without this, a stall
|
||||
# during tool-call JSON generation lets the stub
|
||||
# at line ~6107 return `tool_calls=None`, silently
|
||||
# discarding the attempted action.
|
||||
result["partial_tool_names"].append(name)
|
||||
|
||||
if chunk.choices[0].finish_reason:
|
||||
finish_reason = chunk.choices[0].finish_reason
|
||||
@@ -5794,6 +6102,7 @@ class AIAgent:
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
self._emit_status("🔄 Reconnected — resuming…")
|
||||
continue
|
||||
self._emit_status(
|
||||
"❌ Connection to provider failed after "
|
||||
@@ -5949,13 +6258,44 @@ class AIAgent:
|
||||
_partial_text = (
|
||||
getattr(self, "_current_streamed_assistant_text", "") or ""
|
||||
).strip() or None
|
||||
logger.warning(
|
||||
"Partial stream delivered before error; returning stub "
|
||||
"response with %s chars of recovered content to prevent "
|
||||
"duplicate messages: %s",
|
||||
len(_partial_text or ""),
|
||||
result["error"],
|
||||
)
|
||||
|
||||
# If the stream died while the model was emitting a tool call,
|
||||
# the stub below will silently set `tool_calls=None` and the
|
||||
# agent loop will treat the turn as complete — the attempted
|
||||
# action is lost with no user-facing signal. Append a
|
||||
# human-visible warning to the stub content so (a) the user
|
||||
# knows something failed, and (b) the next turn's model sees
|
||||
# in conversation history what was attempted and can retry.
|
||||
_partial_names = list(result.get("partial_tool_names") or [])
|
||||
if _partial_names:
|
||||
_name_str = ", ".join(_partial_names[:3])
|
||||
if len(_partial_names) > 3:
|
||||
_name_str += f", +{len(_partial_names) - 3} more"
|
||||
_warn = (
|
||||
f"\n\n⚠ Stream stalled mid tool-call "
|
||||
f"({_name_str}); the action was not executed. "
|
||||
f"Ask me to retry if you want to continue."
|
||||
)
|
||||
_partial_text = (_partial_text or "") + _warn
|
||||
# Also fire as a streaming delta so the user sees it now
|
||||
# instead of only in the persisted transcript.
|
||||
try:
|
||||
self._fire_stream_delta(_warn)
|
||||
except Exception:
|
||||
pass
|
||||
logger.warning(
|
||||
"Partial stream dropped tool call(s) %s after %s chars "
|
||||
"of text; surfaced warning to user: %s",
|
||||
_partial_names, len(_partial_text or ""), result["error"],
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Partial stream delivered before error; returning stub "
|
||||
"response with %s chars of recovered content to prevent "
|
||||
"duplicate messages: %s",
|
||||
len(_partial_text or ""),
|
||||
result["error"],
|
||||
)
|
||||
_stub_msg = SimpleNamespace(
|
||||
role="assistant", content=_partial_text, tool_calls=None,
|
||||
reasoning_content=None,
|
||||
@@ -6697,6 +7037,14 @@ class AIAgent:
|
||||
"messages": sanitized_messages,
|
||||
"timeout": float(os.getenv("HERMES_API_TIMEOUT", 1800.0)),
|
||||
}
|
||||
try:
|
||||
from agent.auxiliary_client import _fixed_temperature_for_model
|
||||
except Exception:
|
||||
_fixed_temperature_for_model = None
|
||||
if _fixed_temperature_for_model is not None:
|
||||
fixed_temperature = _fixed_temperature_for_model(self.model)
|
||||
if fixed_temperature is not None:
|
||||
api_kwargs["temperature"] = fixed_temperature
|
||||
if self._is_qwen_portal():
|
||||
api_kwargs["metadata"] = {
|
||||
"sessionId": self.session_id or "hermes",
|
||||
@@ -6902,7 +7250,7 @@ class AIAgent:
|
||||
# (gateway, batch, quiet) still get reasoning.
|
||||
# Any reasoning that wasn't shown during streaming is caught by the
|
||||
# CLI post-response display fallback (cli.py _reasoning_shown_this_turn).
|
||||
if not self.stream_delta_callback:
|
||||
if not self.stream_delta_callback and not self._stream_callback:
|
||||
try:
|
||||
self.reasoning_callback(reasoning_text)
|
||||
except Exception:
|
||||
@@ -7107,14 +7455,22 @@ class AIAgent:
|
||||
|
||||
# Use auxiliary client for the flush call when available --
|
||||
# it's cheaper and avoids Codex Responses API incompatibility.
|
||||
from agent.auxiliary_client import call_llm as _call_llm
|
||||
from agent.auxiliary_client import (
|
||||
call_llm as _call_llm,
|
||||
_fixed_temperature_for_model,
|
||||
)
|
||||
_aux_available = True
|
||||
# Use the fixed-temperature override (e.g. kimi-for-coding → 0.6) if
|
||||
# the model has a strict contract; otherwise the historical 0.3 default.
|
||||
_flush_temperature = _fixed_temperature_for_model(self.model)
|
||||
if _flush_temperature is None:
|
||||
_flush_temperature = 0.3
|
||||
try:
|
||||
response = _call_llm(
|
||||
task="flush_memories",
|
||||
messages=api_messages,
|
||||
tools=[memory_tool_def],
|
||||
temperature=0.3,
|
||||
temperature=_flush_temperature,
|
||||
max_tokens=5120,
|
||||
# timeout resolved from auxiliary.flush_memories.timeout config
|
||||
)
|
||||
@@ -7126,7 +7482,7 @@ class AIAgent:
|
||||
# No auxiliary client -- use the Codex Responses path directly
|
||||
codex_kwargs = self._build_api_kwargs(api_messages)
|
||||
codex_kwargs["tools"] = self._responses_tools([memory_tool_def])
|
||||
codex_kwargs["temperature"] = 0.3
|
||||
codex_kwargs["temperature"] = _flush_temperature
|
||||
if "max_output_tokens" in codex_kwargs:
|
||||
codex_kwargs["max_output_tokens"] = 5120
|
||||
response = self._run_codex_stream(codex_kwargs)
|
||||
@@ -7145,7 +7501,7 @@ class AIAgent:
|
||||
"model": self.model,
|
||||
"messages": api_messages,
|
||||
"tools": [memory_tool_def],
|
||||
"temperature": 0.3,
|
||||
"temperature": _flush_temperature,
|
||||
**self._max_tokens_param(5120),
|
||||
}
|
||||
from agent.auxiliary_client import _get_task_timeout
|
||||
@@ -7536,6 +7892,22 @@ class AIAgent:
|
||||
|
||||
def _run_tool(index, tool_call, function_name, function_args):
|
||||
"""Worker function executed in a thread."""
|
||||
# Register this worker tid so the agent can fan out an interrupt
|
||||
# to it — see AIAgent.interrupt(). Must happen first thing, and
|
||||
# must be paired with discard + clear in the finally block.
|
||||
_worker_tid = threading.current_thread().ident
|
||||
with self._tool_worker_threads_lock:
|
||||
self._tool_worker_threads.add(_worker_tid)
|
||||
# Race: if the agent was interrupted between fan-out (which
|
||||
# snapshotted an empty/earlier set) and our registration, apply
|
||||
# the interrupt to our own tid now so is_interrupted() inside
|
||||
# the tool returns True on the next poll.
|
||||
if self._interrupt_requested:
|
||||
try:
|
||||
from tools.interrupt import set_interrupt as _sif
|
||||
_sif(True, _worker_tid)
|
||||
except Exception:
|
||||
pass
|
||||
# Set the activity callback on THIS worker thread so
|
||||
# _wait_for_process (terminal commands) can fire heartbeats.
|
||||
# The callback is thread-local; the main thread's callback
|
||||
@@ -7558,6 +7930,16 @@ class AIAgent:
|
||||
else:
|
||||
logger.info("tool %s completed (%.2fs, %d chars)", function_name, duration, len(result))
|
||||
results[index] = (function_name, function_args, result, duration, is_error)
|
||||
# Tear down worker-tid tracking. Clear any interrupt bit we may
|
||||
# have set so the next task scheduled onto this recycled tid
|
||||
# starts with a clean slate.
|
||||
with self._tool_worker_threads_lock:
|
||||
self._tool_worker_threads.discard(_worker_tid)
|
||||
try:
|
||||
from tools.interrupt import set_interrupt as _sif
|
||||
_sif(False, _worker_tid)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Start spinner for CLI mode (skip when TUI handles tool progress)
|
||||
spinner = None
|
||||
@@ -7702,6 +8084,13 @@ class AIAgent:
|
||||
turn_tool_msgs = messages[-num_tools:]
|
||||
enforce_turn_budget(turn_tool_msgs, env=get_active_env(effective_task_id))
|
||||
|
||||
# ── /steer injection ──────────────────────────────────────────────
|
||||
# Append any pending user steer text to the last tool result so the
|
||||
# agent sees it on its next iteration. Runs AFTER budget enforcement
|
||||
# so the steer marker is never truncated. See steer() for details.
|
||||
if num_tools > 0:
|
||||
self._apply_pending_steer_to_tool_results(messages, num_tools)
|
||||
|
||||
def _execute_tool_calls_sequential(self, assistant_message, messages: list, effective_task_id: str, api_call_count: int = 0) -> None:
|
||||
"""Execute tool calls sequentially (original behavior). Used for single calls or interactive tools."""
|
||||
for i, tool_call in enumerate(assistant_message.tool_calls, 1):
|
||||
@@ -8081,6 +8470,12 @@ class AIAgent:
|
||||
if num_tools_seq > 0:
|
||||
enforce_turn_budget(messages[-num_tools_seq:], env=get_active_env(effective_task_id))
|
||||
|
||||
# ── /steer injection ──────────────────────────────────────────────
|
||||
# See _execute_tool_calls_parallel for the rationale. Same hook,
|
||||
# applied to sequential execution as well.
|
||||
if num_tools_seq > 0:
|
||||
self._apply_pending_steer_to_tool_results(messages, num_tools_seq)
|
||||
|
||||
|
||||
|
||||
def _handle_max_iterations(self, messages: list, api_call_count: int) -> str:
|
||||
@@ -8118,6 +8513,15 @@ class AIAgent:
|
||||
api_messages.insert(sys_offset + idx, pfm.copy())
|
||||
|
||||
summary_extra_body = {}
|
||||
try:
|
||||
from agent.auxiliary_client import _fixed_temperature_for_model
|
||||
except Exception:
|
||||
_fixed_temperature_for_model = None
|
||||
_summary_temperature = (
|
||||
_fixed_temperature_for_model(self.model)
|
||||
if _fixed_temperature_for_model is not None
|
||||
else None
|
||||
)
|
||||
_is_nous = "nousresearch" in self._base_url_lower
|
||||
if self._supports_reasoning_extra_body():
|
||||
if self.reasoning_config is not None:
|
||||
@@ -8141,6 +8545,8 @@ class AIAgent:
|
||||
"model": self.model,
|
||||
"messages": api_messages,
|
||||
}
|
||||
if _summary_temperature is not None:
|
||||
summary_kwargs["temperature"] = _summary_temperature
|
||||
if self.max_tokens is not None:
|
||||
summary_kwargs.update(self._max_tokens_param(self.max_tokens))
|
||||
|
||||
@@ -8206,6 +8612,8 @@ class AIAgent:
|
||||
"model": self.model,
|
||||
"messages": api_messages,
|
||||
}
|
||||
if _summary_temperature is not None:
|
||||
summary_kwargs["temperature"] = _summary_temperature
|
||||
if self.max_tokens is not None:
|
||||
summary_kwargs.update(self._max_tokens_param(self.max_tokens))
|
||||
if summary_extra_body:
|
||||
@@ -8641,6 +9049,7 @@ class AIAgent:
|
||||
{
|
||||
"name": tc["function"]["name"],
|
||||
"result": _results_by_id.get(tc.get("id")),
|
||||
"arguments": tc["function"].get("arguments"),
|
||||
}
|
||||
for tc in _m["tool_calls"]
|
||||
if isinstance(tc, dict)
|
||||
@@ -9255,8 +9664,7 @@ class AIAgent:
|
||||
"and had none left for the actual response.\n\n"
|
||||
"To fix this:\n"
|
||||
"→ Lower reasoning effort: `/thinkon low` or `/thinkon minimal`\n"
|
||||
"→ Increase the output token limit: "
|
||||
"set `model.max_tokens` in config.yaml"
|
||||
"→ Or switch to a larger/non-reasoning model with `/model`"
|
||||
)
|
||||
self._cleanup_task_resources(effective_task_id)
|
||||
self._persist_session(messages, conversation_history)
|
||||
@@ -9523,13 +9931,51 @@ class AIAgent:
|
||||
if isinstance(api_error, UnicodeEncodeError) and getattr(self, '_unicode_sanitization_passes', 0) < 2:
|
||||
_err_str = str(api_error).lower()
|
||||
_is_ascii_codec = "'ascii'" in _err_str or "ascii" in _err_str
|
||||
# Detect surrogate errors — utf-8 codec refusing to
|
||||
# encode U+D800..U+DFFF. The error text is:
|
||||
# "'utf-8' codec can't encode characters in position
|
||||
# N-M: surrogates not allowed"
|
||||
_is_surrogate_error = (
|
||||
"surrogate" in _err_str
|
||||
or ("'utf-8'" in _err_str and not _is_ascii_codec)
|
||||
)
|
||||
# Sanitize surrogates from both the canonical `messages`
|
||||
# list AND `api_messages` (the API-copy, which may carry
|
||||
# `reasoning_content`/`reasoning_details` transformed
|
||||
# from `reasoning` — fields the canonical list doesn't
|
||||
# have directly). Also clean `api_kwargs` if built and
|
||||
# `prefill_messages` if present. Mirrors the ASCII
|
||||
# codec recovery below.
|
||||
_surrogates_found = _sanitize_messages_surrogates(messages)
|
||||
if _surrogates_found:
|
||||
if isinstance(api_messages, list):
|
||||
if _sanitize_messages_surrogates(api_messages):
|
||||
_surrogates_found = True
|
||||
if isinstance(api_kwargs, dict):
|
||||
if _sanitize_structure_surrogates(api_kwargs):
|
||||
_surrogates_found = True
|
||||
if isinstance(getattr(self, "prefill_messages", None), list):
|
||||
if _sanitize_messages_surrogates(self.prefill_messages):
|
||||
_surrogates_found = True
|
||||
# Gate the retry on the error type, not on whether we
|
||||
# found anything — _force_ascii_payload / the extended
|
||||
# surrogate walker above cover all known paths, but a
|
||||
# new transformed field could still slip through. If
|
||||
# the error was a surrogate encode failure, always let
|
||||
# the retry run; the proactive sanitizer at line ~8781
|
||||
# runs again on the next iteration. Bounded by
|
||||
# _unicode_sanitization_passes < 2 (outer guard).
|
||||
if _surrogates_found or _is_surrogate_error:
|
||||
self._unicode_sanitization_passes += 1
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ Stripped invalid surrogate characters from messages. Retrying...",
|
||||
force=True,
|
||||
)
|
||||
if _surrogates_found:
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ Stripped invalid surrogate characters from messages. Retrying...",
|
||||
force=True,
|
||||
)
|
||||
else:
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ Surrogate encoding error — retrying after full-payload sanitization...",
|
||||
force=True,
|
||||
)
|
||||
continue
|
||||
if _is_ascii_codec:
|
||||
self._force_ascii_payload = True
|
||||
@@ -10297,9 +10743,9 @@ class AIAgent:
|
||||
pass
|
||||
wait_time = _retry_after if _retry_after else jittered_backoff(retry_count, base_delay=2.0, max_delay=60.0)
|
||||
if is_rate_limited:
|
||||
self._emit_status(f"⏱️ Rate limit reached. Waiting {wait_time}s before retry (attempt {retry_count + 1}/{max_retries})...")
|
||||
self._emit_status(f"⏱️ Rate limited. Waiting {wait_time:.1f}s (attempt {retry_count + 1}/{max_retries})...")
|
||||
else:
|
||||
self._emit_status(f"⏳ Retrying in {wait_time}s (attempt {retry_count}/{max_retries})...")
|
||||
self._emit_status(f"⏳ Retrying in {wait_time:.1f}s (attempt {retry_count}/{max_retries})...")
|
||||
logger.warning(
|
||||
"Retrying API call in %ss (attempt %s/%s) %s error=%s",
|
||||
wait_time,
|
||||
@@ -10715,7 +11161,14 @@ class AIAgent:
|
||||
elif self.quiet_mode:
|
||||
clean = self._strip_think_blocks(turn_content).strip()
|
||||
if clean:
|
||||
self._vprint(f" ┊ 💬 {clean}")
|
||||
relayed = False
|
||||
if (
|
||||
self.tool_progress_callback
|
||||
and getattr(self, "platform", "") == "tui"
|
||||
):
|
||||
relayed = True
|
||||
if not relayed:
|
||||
self._vprint(f" ┊ 💬 {clean}")
|
||||
|
||||
# Pop thinking-only prefill message(s) before appending
|
||||
# (tool-call path — same rationale as the final-response path).
|
||||
@@ -11303,6 +11756,12 @@ class AIAgent:
|
||||
"cost_status": self.session_cost_status,
|
||||
"cost_source": self.session_cost_source,
|
||||
}
|
||||
# If a /steer landed after the final assistant turn (no more tool
|
||||
# batches to drain into), hand it back to the caller so it can be
|
||||
# delivered as the next user turn instead of being silently lost.
|
||||
_leftover_steer = self._drain_pending_steer()
|
||||
if _leftover_steer:
|
||||
result["pending_steer"] = _leftover_steer
|
||||
self._response_was_previewed = False
|
||||
|
||||
# Include interrupt message if one triggered the interrupt
|
||||
|
||||
@@ -721,6 +721,20 @@ function Install-NodeDeps {
|
||||
}
|
||||
}
|
||||
|
||||
# Install TUI dependencies
|
||||
$tuiDir = "$InstallDir\ui-tui"
|
||||
if (Test-Path "$tuiDir\package.json") {
|
||||
Write-Info "Installing TUI dependencies..."
|
||||
Push-Location $tuiDir
|
||||
try {
|
||||
npm install --silent 2>&1 | Out-Null
|
||||
Write-Success "TUI dependencies installed"
|
||||
} catch {
|
||||
Write-Warn "TUI npm install failed (hermes --tui may not work)"
|
||||
}
|
||||
Pop-Location
|
||||
}
|
||||
|
||||
# Install WhatsApp bridge dependencies
|
||||
$bridgeDir = "$InstallDir\scripts\whatsapp-bridge"
|
||||
if (Test-Path "$bridgeDir\package.json") {
|
||||
|
||||
@@ -1194,6 +1194,16 @@ install_node_deps() {
|
||||
log_success "Browser engine setup complete"
|
||||
fi
|
||||
|
||||
# Install TUI dependencies
|
||||
if [ -f "$INSTALL_DIR/ui-tui/package.json" ]; then
|
||||
log_info "Installing TUI dependencies..."
|
||||
cd "$INSTALL_DIR/ui-tui"
|
||||
npm install --silent 2>/dev/null || {
|
||||
log_warn "TUI npm install failed (hermes --tui may not work)"
|
||||
}
|
||||
log_success "TUI dependencies installed"
|
||||
fi
|
||||
|
||||
# Install WhatsApp bridge dependencies
|
||||
if [ -f "$INSTALL_DIR/scripts/whatsapp-bridge/package.json" ]; then
|
||||
log_info "Installing WhatsApp bridge dependencies..."
|
||||
|
||||
@@ -0,0 +1,238 @@
|
||||
#!/usr/bin/env bash
|
||||
# ============================================================================
|
||||
# scripts/lib/node-bootstrap.sh
|
||||
# ----------------------------------------------------------------------------
|
||||
# Sourceable helper: ensure Node.js >= MIN_VERSION is available for the TUI
|
||||
# (React + Ink), browser tools, and the WhatsApp bridge.
|
||||
#
|
||||
# Strategy (first hit wins — respects the user's existing tooling):
|
||||
# 1. modern `node` already on PATH
|
||||
# 2. ~/.hermes/node/ from a prior Hermes-managed install
|
||||
# 3. fnm, proto, nvm (in that order) if the user already uses a version manager
|
||||
# 4. Termux `pkg`, macOS Homebrew
|
||||
# 5. pinned nodejs.org tarball into ~/.hermes/node/ (always works, zero shell rc edits)
|
||||
#
|
||||
# Usage:
|
||||
# source scripts/lib/node-bootstrap.sh
|
||||
# ensure_node # returns 0 on success, non-zero on failure
|
||||
# if [ "$HERMES_NODE_AVAILABLE" = true ]; then ...; fi
|
||||
#
|
||||
# Env inputs (set before sourcing to override defaults):
|
||||
# HERMES_NODE_MIN_VERSION (default: 20) — accepted on PATH
|
||||
# HERMES_NODE_TARGET_MAJOR (default: 22) — installed when we install
|
||||
# HERMES_HOME (default: $HOME/.hermes)
|
||||
# ============================================================================
|
||||
|
||||
HERMES_NODE_MIN_VERSION="${HERMES_NODE_MIN_VERSION:-20}"
|
||||
HERMES_NODE_TARGET_MAJOR="${HERMES_NODE_TARGET_MAJOR:-22}"
|
||||
HERMES_HOME="${HERMES_HOME:-$HOME/.hermes}"
|
||||
HERMES_NODE_AVAILABLE=false
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Logging — prefer the host script's log_* helpers when present
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_nb_log() { declare -F log_info >/dev/null 2>&1 && log_info "$*" || printf '→ %s\n' "$*" >&2; }
|
||||
_nb_ok() { declare -F log_success >/dev/null 2>&1 && log_success "$*" || printf '✓ %s\n' "$*" >&2; }
|
||||
_nb_warn() { declare -F log_warn >/dev/null 2>&1 && log_warn "$*" || printf '⚠ %s\n' "$*" >&2; }
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Platform + version helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_nb_is_termux() {
|
||||
[ -n "${TERMUX_VERSION:-}" ] || [[ "${PREFIX:-}" == *"com.termux/files/usr"* ]]
|
||||
}
|
||||
|
||||
_nb_node_major() {
|
||||
local v
|
||||
v=$(node --version 2>/dev/null | sed 's/^v//' | cut -d. -f1)
|
||||
[[ "$v" =~ ^[0-9]+$ ]] && echo "$v" || echo 0
|
||||
}
|
||||
|
||||
_nb_have_modern_node() {
|
||||
command -v node >/dev/null 2>&1 || return 1
|
||||
[ "$(_nb_node_major)" -ge "$HERMES_NODE_MIN_VERSION" ]
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Version-manager paths — respect what the user already uses
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_nb_try_fnm() {
|
||||
command -v fnm >/dev/null 2>&1 || return 1
|
||||
_nb_log "fnm detected — installing Node $HERMES_NODE_TARGET_MAJOR..."
|
||||
eval "$(fnm env 2>/dev/null)" || true
|
||||
fnm install "$HERMES_NODE_TARGET_MAJOR" >/dev/null 2>&1 || return 1
|
||||
fnm use "$HERMES_NODE_TARGET_MAJOR" >/dev/null 2>&1 || return 1
|
||||
_nb_have_modern_node || return 1
|
||||
_nb_ok "Node $(node --version) activated via fnm"
|
||||
return 0
|
||||
}
|
||||
|
||||
_nb_try_proto() {
|
||||
command -v proto >/dev/null 2>&1 || return 1
|
||||
_nb_log "proto detected — installing Node $HERMES_NODE_TARGET_MAJOR..."
|
||||
proto install node "$HERMES_NODE_TARGET_MAJOR" >/dev/null 2>&1 || return 1
|
||||
_nb_have_modern_node || return 1
|
||||
_nb_ok "Node $(node --version) activated via proto"
|
||||
return 0
|
||||
}
|
||||
|
||||
_nb_try_nvm() {
|
||||
local nvm_sh="${NVM_DIR:-$HOME/.nvm}/nvm.sh"
|
||||
[ -s "$nvm_sh" ] || return 1
|
||||
# shellcheck source=/dev/null
|
||||
\. "$nvm_sh" >/dev/null 2>&1 || return 1
|
||||
_nb_log "nvm detected — installing Node $HERMES_NODE_TARGET_MAJOR..."
|
||||
nvm install "$HERMES_NODE_TARGET_MAJOR" >/dev/null 2>&1 || return 1
|
||||
nvm use "$HERMES_NODE_TARGET_MAJOR" >/dev/null 2>&1 || return 1
|
||||
_nb_have_modern_node || return 1
|
||||
_nb_ok "Node $(node --version) activated via nvm"
|
||||
return 0
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Platform package managers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_nb_try_termux_pkg() {
|
||||
_nb_is_termux || return 1
|
||||
_nb_log "Installing Node.js via pkg..."
|
||||
pkg install -y nodejs >/dev/null 2>&1 || return 1
|
||||
_nb_have_modern_node || return 1
|
||||
_nb_ok "Node $(node --version) installed via pkg"
|
||||
return 0
|
||||
}
|
||||
|
||||
_nb_try_brew() {
|
||||
[ "$(uname -s)" = "Darwin" ] || return 1
|
||||
command -v brew >/dev/null 2>&1 || return 1
|
||||
_nb_log "Installing Node via Homebrew..."
|
||||
brew install "node@${HERMES_NODE_TARGET_MAJOR}" >/dev/null 2>&1 \
|
||||
|| brew install node >/dev/null 2>&1 \
|
||||
|| return 1
|
||||
brew link --overwrite --force "node@${HERMES_NODE_TARGET_MAJOR}" >/dev/null 2>&1 || true
|
||||
_nb_have_modern_node || return 1
|
||||
_nb_ok "Node $(node --version) installed via Homebrew"
|
||||
return 0
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bundled binary fallback — always works, no shell rc edits
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_nb_install_bundled_node() {
|
||||
local arch node_arch os_name node_os
|
||||
arch=$(uname -m)
|
||||
case "$arch" in
|
||||
x86_64) node_arch="x64" ;;
|
||||
aarch64|arm64) node_arch="arm64" ;;
|
||||
armv7l) node_arch="armv7l" ;;
|
||||
*)
|
||||
_nb_warn "Unsupported arch ($arch) — install Node.js manually: https://nodejs.org/"
|
||||
return 1
|
||||
;;
|
||||
esac
|
||||
|
||||
os_name=$(uname -s)
|
||||
case "$os_name" in
|
||||
Linux*) node_os="linux" ;;
|
||||
Darwin*) node_os="darwin" ;;
|
||||
*)
|
||||
_nb_warn "Unsupported OS ($os_name) — install Node.js manually: https://nodejs.org/"
|
||||
return 1
|
||||
;;
|
||||
esac
|
||||
|
||||
local index_url="https://nodejs.org/dist/latest-v${HERMES_NODE_TARGET_MAJOR}.x/"
|
||||
local tarball
|
||||
tarball=$(curl -fsSL "$index_url" \
|
||||
| grep -oE "node-v${HERMES_NODE_TARGET_MAJOR}\.[0-9]+\.[0-9]+-${node_os}-${node_arch}\.tar\.xz" \
|
||||
| head -1)
|
||||
if [ -z "$tarball" ]; then
|
||||
tarball=$(curl -fsSL "$index_url" \
|
||||
| grep -oE "node-v${HERMES_NODE_TARGET_MAJOR}\.[0-9]+\.[0-9]+-${node_os}-${node_arch}\.tar\.gz" \
|
||||
| head -1)
|
||||
fi
|
||||
if [ -z "$tarball" ]; then
|
||||
_nb_warn "Could not resolve Node $HERMES_NODE_TARGET_MAJOR binary for $node_os-$node_arch"
|
||||
return 1
|
||||
fi
|
||||
|
||||
local tmp
|
||||
tmp=$(mktemp -d)
|
||||
_nb_log "Downloading $tarball..."
|
||||
curl -fsSL "${index_url}${tarball}" -o "$tmp/$tarball" || {
|
||||
_nb_warn "Download failed"; rm -rf "$tmp"; return 1
|
||||
}
|
||||
|
||||
_nb_log "Extracting to $HERMES_HOME/node/..."
|
||||
if [[ "$tarball" == *.tar.xz ]]; then
|
||||
tar xf "$tmp/$tarball" -C "$tmp" || { rm -rf "$tmp"; return 1; }
|
||||
else
|
||||
tar xzf "$tmp/$tarball" -C "$tmp" || { rm -rf "$tmp"; return 1; }
|
||||
fi
|
||||
|
||||
local extracted
|
||||
extracted=$(find "$tmp" -maxdepth 1 -type d -name 'node-v*' 2>/dev/null | head -1)
|
||||
if [ ! -d "$extracted" ]; then
|
||||
_nb_warn "Extraction produced no node-v* directory"
|
||||
rm -rf "$tmp"
|
||||
return 1
|
||||
fi
|
||||
|
||||
mkdir -p "$HERMES_HOME"
|
||||
rm -rf "$HERMES_HOME/node"
|
||||
mv "$extracted" "$HERMES_HOME/node"
|
||||
rm -rf "$tmp"
|
||||
|
||||
mkdir -p "$HOME/.local/bin"
|
||||
ln -sf "$HERMES_HOME/node/bin/node" "$HOME/.local/bin/node"
|
||||
ln -sf "$HERMES_HOME/node/bin/npm" "$HOME/.local/bin/npm"
|
||||
ln -sf "$HERMES_HOME/node/bin/npx" "$HOME/.local/bin/npx"
|
||||
export PATH="$HERMES_HOME/node/bin:$PATH"
|
||||
|
||||
_nb_have_modern_node || return 1
|
||||
_nb_ok "Node $(node --version) installed to $HERMES_HOME/node/"
|
||||
return 0
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
ensure_node() {
|
||||
HERMES_NODE_AVAILABLE=false
|
||||
|
||||
if _nb_have_modern_node; then
|
||||
_nb_ok "Node $(node --version) found"
|
||||
HERMES_NODE_AVAILABLE=true
|
||||
return 0
|
||||
fi
|
||||
|
||||
if [ -x "$HERMES_HOME/node/bin/node" ]; then
|
||||
export PATH="$HERMES_HOME/node/bin:$PATH"
|
||||
if _nb_have_modern_node; then
|
||||
_nb_ok "Node $(node --version) found (Hermes-managed)"
|
||||
HERMES_NODE_AVAILABLE=true
|
||||
return 0
|
||||
fi
|
||||
fi
|
||||
|
||||
# Version managers first — respect the user's existing setup.
|
||||
_nb_try_fnm && { HERMES_NODE_AVAILABLE=true; return 0; }
|
||||
_nb_try_proto && { HERMES_NODE_AVAILABLE=true; return 0; }
|
||||
_nb_try_nvm && { HERMES_NODE_AVAILABLE=true; return 0; }
|
||||
|
||||
# Platform package managers.
|
||||
_nb_try_termux_pkg && { HERMES_NODE_AVAILABLE=true; return 0; }
|
||||
_nb_try_brew && { HERMES_NODE_AVAILABLE=true; return 0; }
|
||||
|
||||
# Last resort: pinned nodejs.org tarball.
|
||||
_nb_install_bundled_node && { HERMES_NODE_AVAILABLE=true; return 0; }
|
||||
|
||||
_nb_warn "Node.js install failed — TUI and browser tools will be unavailable."
|
||||
_nb_warn "Install manually: https://nodejs.org/en/download/ (or: \`brew install node\`, \`fnm install $HERMES_NODE_TARGET_MAJOR\`, etc.)"
|
||||
return 1
|
||||
}
|
||||
@@ -44,12 +44,14 @@ AUTHOR_MAP = {
|
||||
"teknium@nousresearch.com": "teknium1",
|
||||
"127238744+teknium1@users.noreply.github.com": "teknium1",
|
||||
# contributors (from noreply pattern)
|
||||
"snreynolds2506@gmail.com": "snreynolds",
|
||||
"35742124+0xbyt4@users.noreply.github.com": "0xbyt4",
|
||||
"82637225+kshitijk4poor@users.noreply.github.com": "kshitijk4poor",
|
||||
"kshitijk4poor@users.noreply.github.com": "kshitijk4poor",
|
||||
"16443023+stablegenius49@users.noreply.github.com": "stablegenius49",
|
||||
"185121704+stablegenius49@users.noreply.github.com": "stablegenius49",
|
||||
"101283333+batuhankocyigit@users.noreply.github.com": "batuhankocyigit",
|
||||
"valdi.jorge@gmail.com": "jvcl",
|
||||
"126368201+vilkasdev@users.noreply.github.com": "vilkasdev",
|
||||
"137614867+cutepawss@users.noreply.github.com": "cutepawss",
|
||||
"96793918+memosr@users.noreply.github.com": "memosr",
|
||||
@@ -70,8 +72,11 @@ AUTHOR_MAP = {
|
||||
"27917469+nosleepcassette@users.noreply.github.com": "nosleepcassette",
|
||||
"241404605+MestreY0d4-Uninter@users.noreply.github.com": "MestreY0d4-Uninter",
|
||||
"109555139+davetist@users.noreply.github.com": "davetist",
|
||||
"39405770+yyq4193@users.noreply.github.com": "yyq4193",
|
||||
"Asunfly@users.noreply.github.com": "Asunfly",
|
||||
# contributors (manual mapping from git names)
|
||||
"ahmedsherif95@gmail.com": "asheriif",
|
||||
"liujinkun@bytedance.com": "liujinkun2025",
|
||||
"dmayhem93@gmail.com": "dmahan93",
|
||||
"samherring99@gmail.com": "samherring99",
|
||||
"desaiaum08@gmail.com": "Aum08Desai",
|
||||
@@ -83,6 +88,8 @@ AUTHOR_MAP = {
|
||||
"abdullahfarukozden@gmail.com": "Farukest",
|
||||
"lovre.pesut@gmail.com": "rovle",
|
||||
"kevinskysunny@gmail.com": "kevinskysunny",
|
||||
"xiewenxuan462@gmail.com": "yule975",
|
||||
"yiweimeng.dlut@hotmail.com": "meng93",
|
||||
"hakanerten02@hotmail.com": "teyrebaz33",
|
||||
"ruzzgarcn@gmail.com": "Ruzzgar",
|
||||
"alireza78.crypto@gmail.com": "alireza78a",
|
||||
@@ -90,13 +97,16 @@ AUTHOR_MAP = {
|
||||
"4317663+helix4u@users.noreply.github.com": "helix4u",
|
||||
"331214+counterposition@users.noreply.github.com": "counterposition",
|
||||
"blspear@gmail.com": "BrennerSpear",
|
||||
"akhater@gmail.com": "akhater",
|
||||
"239876380+handsdiff@users.noreply.github.com": "handsdiff",
|
||||
"gpickett00@gmail.com": "gpickett00",
|
||||
"mcosma@gmail.com": "wakamex",
|
||||
"clawdia.nash@proton.me": "clawdia-nash",
|
||||
"pickett.austin@gmail.com": "austinpickett",
|
||||
"dangtc94@gmail.com": "dieutx",
|
||||
"jaisehgal11299@gmail.com": "jaisup",
|
||||
"percydikec@gmail.com": "PercyDikec",
|
||||
"noonou7@gmail.com": "HenkDz",
|
||||
"dean.kerr@gmail.com": "deankerr",
|
||||
"socrates1024@gmail.com": "socrates1024",
|
||||
"satelerd@gmail.com": "satelerd",
|
||||
@@ -109,6 +119,7 @@ AUTHOR_MAP = {
|
||||
"vincentcharlebois@gmail.com": "vincentcharlebois",
|
||||
"aryan@synvoid.com": "aryansingh",
|
||||
"johnsonblake1@gmail.com": "blakejohnson",
|
||||
"hcn518@gmail.com": "pedh",
|
||||
"greer.guthrie@gmail.com": "g-guthrie",
|
||||
"kennyx102@gmail.com": "bobashopcashier",
|
||||
"shokatalishaikh95@gmail.com": "areu01or00",
|
||||
@@ -177,6 +188,7 @@ AUTHOR_MAP = {
|
||||
"juan.ovalle@mistral.ai": "jjovalle99",
|
||||
"julien.talbot@ergonomia.re": "Julientalbot",
|
||||
"kagura.chen28@gmail.com": "kagura-agent",
|
||||
"1342088860@qq.com": "youngDoo",
|
||||
"kamil@gwozdz.me": "kamil-gwozdz",
|
||||
"karamusti912@gmail.com": "MustafaKara7",
|
||||
"kira@ariaki.me": "kira-ariaki",
|
||||
@@ -195,6 +207,7 @@ AUTHOR_MAP = {
|
||||
"cola-runner@users.noreply.github.com": "cola-runner",
|
||||
"ygd58@users.noreply.github.com": "ygd58",
|
||||
"vominh1919@users.noreply.github.com": "vominh1919",
|
||||
"iamagenius00@users.noreply.github.com": "iamagenius00",
|
||||
"trevmanthony@gmail.com": "trevthefoolish",
|
||||
"ziliangpeng@users.noreply.github.com": "ziliangpeng",
|
||||
"centripetal-star@users.noreply.github.com": "centripetal-star",
|
||||
@@ -231,7 +244,26 @@ AUTHOR_MAP = {
|
||||
"zaynjarvis@gmail.com": "ZaynJarvis",
|
||||
"zhiheng.liu@bytedance.com": "ZaynJarvis",
|
||||
"mbelleau@Michels-MacBook-Pro.local": "malaiwah",
|
||||
"michel.belleau@malaiwah.com": "malaiwah",
|
||||
"gnanasekaran.sekareee@gmail.com": "gnanam1990",
|
||||
"jz.pentest@gmail.com": "0xyg3n",
|
||||
"hypnosis.mda@gmail.com": "Hypn0sis",
|
||||
"ywt000818@gmail.com": "OwenYWT",
|
||||
"dhandhalyabhavik@gmail.com": "v1k22",
|
||||
"rucchizhao@zhaochenfeideMacBook-Pro.local": "RucchiZ",
|
||||
"lehaolin98@outlook.com": "LehaoLin",
|
||||
"yuewang1@microsoft.com": "imink",
|
||||
"1736355688@qq.com": "hedgeho9X",
|
||||
"bernylinville@devopsthink.org": "bernylinville",
|
||||
"brian@bde.io": "briandevans",
|
||||
"hubin_ll@qq.com": "LLQWQ",
|
||||
"memosr_email@gmail.com": "memosr",
|
||||
"anthhub@163.com": "anthhub",
|
||||
"shenuu@gmail.com": "shenuu",
|
||||
"xiayh17@gmail.com": "xiayh0107",
|
||||
"asurla@nvidia.com": "anniesurla",
|
||||
"limkuan24@gmail.com": "WideLee",
|
||||
"aviralarora002@gmail.com": "AviArora02-commits",
|
||||
}
|
||||
|
||||
|
||||
|
||||
Executable
+104
@@ -0,0 +1,104 @@
|
||||
#!/usr/bin/env bash
|
||||
# Canonical test runner for hermes-agent. Run this instead of calling
|
||||
# `pytest` directly to guarantee your local run matches CI behavior.
|
||||
#
|
||||
# What this script enforces:
|
||||
# * -n 4 xdist workers (CI has 4 cores; -n auto diverges locally)
|
||||
# * TZ=UTC, LANG=C.UTF-8, PYTHONHASHSEED=0 (deterministic)
|
||||
# * Credential env vars blanked (conftest.py also does this, but this
|
||||
# is belt-and-suspenders for anyone running `pytest` outside of
|
||||
# our conftest path — e.g. calling pytest on a single file)
|
||||
# * Proper venv activation
|
||||
#
|
||||
# Usage:
|
||||
# scripts/run_tests.sh # full suite
|
||||
# scripts/run_tests.sh tests/agent/ # one directory
|
||||
# scripts/run_tests.sh tests/agent/test_foo.py::TestClass::test_method
|
||||
# scripts/run_tests.sh --tb=long -v # pass-through pytest args
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# ── Locate repo root ────────────────────────────────────────────────────────
|
||||
# Works whether this is the main checkout or a worktree.
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
|
||||
|
||||
# ── Activate venv ───────────────────────────────────────────────────────────
|
||||
# Prefer a .venv in the current tree, fall back to the main checkout's venv
|
||||
# (useful for worktrees where we don't always duplicate the venv).
|
||||
VENV=""
|
||||
for candidate in "$REPO_ROOT/.venv" "$REPO_ROOT/venv" "$HOME/.hermes/hermes-agent/venv"; do
|
||||
if [ -f "$candidate/bin/activate" ]; then
|
||||
VENV="$candidate"
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
if [ -z "$VENV" ]; then
|
||||
echo "error: no virtualenv found in $REPO_ROOT/.venv or $REPO_ROOT/venv" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PYTHON="$VENV/bin/python"
|
||||
|
||||
# ── Ensure pytest-split is installed (required for shard-equivalent runs) ──
|
||||
if ! "$PYTHON" -c "import pytest_split" 2>/dev/null; then
|
||||
echo "→ installing pytest-split into $VENV"
|
||||
"$PYTHON" -m pip install --quiet "pytest-split>=0.9,<1"
|
||||
fi
|
||||
|
||||
# ── Hermetic environment ────────────────────────────────────────────────────
|
||||
# Mirror what CI does in .github/workflows/tests.yml + what conftest.py does.
|
||||
# Unset every credential-shaped var currently in the environment.
|
||||
while IFS='=' read -r name _; do
|
||||
case "$name" in
|
||||
*_API_KEY|*_TOKEN|*_SECRET|*_PASSWORD|*_CREDENTIALS|*_ACCESS_KEY| \
|
||||
*_SECRET_ACCESS_KEY|*_PRIVATE_KEY|*_OAUTH_TOKEN|*_WEBHOOK_SECRET| \
|
||||
*_ENCRYPT_KEY|*_APP_SECRET|*_CLIENT_SECRET|*_CORP_SECRET|*_AES_KEY| \
|
||||
AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_SESSION_TOKEN|FAL_KEY| \
|
||||
GH_TOKEN|GITHUB_TOKEN)
|
||||
unset "$name"
|
||||
;;
|
||||
esac
|
||||
done < <(env)
|
||||
|
||||
# Unset HERMES_* behavioral vars too.
|
||||
unset HERMES_YOLO_MODE HERMES_INTERACTIVE HERMES_QUIET HERMES_TOOL_PROGRESS \
|
||||
HERMES_TOOL_PROGRESS_MODE HERMES_MAX_ITERATIONS HERMES_SESSION_PLATFORM \
|
||||
HERMES_SESSION_CHAT_ID HERMES_SESSION_CHAT_NAME HERMES_SESSION_THREAD_ID \
|
||||
HERMES_SESSION_SOURCE HERMES_SESSION_KEY HERMES_GATEWAY_SESSION \
|
||||
HERMES_PLATFORM HERMES_INFERENCE_PROVIDER HERMES_MANAGED HERMES_DEV \
|
||||
HERMES_CONTAINER HERMES_EPHEMERAL_SYSTEM_PROMPT HERMES_TIMEZONE \
|
||||
HERMES_REDACT_SECRETS HERMES_BACKGROUND_NOTIFICATIONS HERMES_EXEC_ASK \
|
||||
HERMES_HOME_MODE 2>/dev/null || true
|
||||
|
||||
# Pin deterministic runtime.
|
||||
export TZ=UTC
|
||||
export LANG=C.UTF-8
|
||||
export LC_ALL=C.UTF-8
|
||||
export PYTHONHASHSEED=0
|
||||
|
||||
# ── Worker count ────────────────────────────────────────────────────────────
|
||||
# CI uses `-n auto` on ubuntu-latest which gives 4 workers. A 20-core
|
||||
# workstation with `-n auto` gets 20 workers and exposes test-ordering
|
||||
# flakes that CI will never see. Pin to 4 so local matches CI.
|
||||
WORKERS="${HERMES_TEST_WORKERS:-4}"
|
||||
|
||||
# ── Run pytest ──────────────────────────────────────────────────────────────
|
||||
cd "$REPO_ROOT"
|
||||
|
||||
# If the first argument starts with `-` treat all args as pytest flags;
|
||||
# otherwise treat them as test paths.
|
||||
ARGS=("$@")
|
||||
|
||||
echo "▶ running pytest with $WORKERS workers, hermetic env, in $REPO_ROOT"
|
||||
echo " (TZ=UTC LANG=C.UTF-8 PYTHONHASHSEED=0; all credential env vars unset)"
|
||||
|
||||
# -o "addopts=" clears pyproject.toml's `-n auto` so our -n wins.
|
||||
exec "$PYTHON" -m pytest \
|
||||
-o "addopts=" \
|
||||
-n "$WORKERS" \
|
||||
--ignore=tests/integration \
|
||||
--ignore=tests/e2e \
|
||||
-m "not integration" \
|
||||
"${ARGS[@]}"
|
||||
@@ -1,430 +0,0 @@
|
||||
---
|
||||
name: gguf-quantization
|
||||
description: GGUF format and llama.cpp quantization for efficient CPU/GPU inference. Use when deploying models on consumer hardware, Apple Silicon, or when needing flexible quantization from 2-8 bit without GPU requirements.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [llama-cpp-python>=0.2.0]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [GGUF, Quantization, llama.cpp, CPU Inference, Apple Silicon, Model Compression, Optimization]
|
||||
|
||||
---
|
||||
|
||||
# GGUF - Quantization Format for llama.cpp
|
||||
|
||||
The GGUF (GPT-Generated Unified Format) is the standard file format for llama.cpp, enabling efficient inference on CPUs, Apple Silicon, and GPUs with flexible quantization options.
|
||||
|
||||
## When to use GGUF
|
||||
|
||||
**Use GGUF when:**
|
||||
- Deploying on consumer hardware (laptops, desktops)
|
||||
- Running on Apple Silicon (M1/M2/M3) with Metal acceleration
|
||||
- Need CPU inference without GPU requirements
|
||||
- Want flexible quantization (Q2_K to Q8_0)
|
||||
- Using local AI tools (LM Studio, Ollama, text-generation-webui)
|
||||
|
||||
**Key advantages:**
|
||||
- **Universal hardware**: CPU, Apple Silicon, NVIDIA, AMD support
|
||||
- **No Python runtime**: Pure C/C++ inference
|
||||
- **Flexible quantization**: 2-8 bit with various methods (K-quants)
|
||||
- **Ecosystem support**: LM Studio, Ollama, koboldcpp, and more
|
||||
- **imatrix**: Importance matrix for better low-bit quality
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **AWQ/GPTQ**: Maximum accuracy with calibration on NVIDIA GPUs
|
||||
- **HQQ**: Fast calibration-free quantization for HuggingFace
|
||||
- **bitsandbytes**: Simple integration with transformers library
|
||||
- **TensorRT-LLM**: Production NVIDIA deployment with maximum speed
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Clone llama.cpp
|
||||
git clone https://github.com/ggml-org/llama.cpp
|
||||
cd llama.cpp
|
||||
|
||||
# Build (CPU)
|
||||
make
|
||||
|
||||
# Build with CUDA (NVIDIA)
|
||||
make GGML_CUDA=1
|
||||
|
||||
# Build with Metal (Apple Silicon)
|
||||
make GGML_METAL=1
|
||||
|
||||
# Install Python bindings (optional)
|
||||
pip install llama-cpp-python
|
||||
```
|
||||
|
||||
### Convert model to GGUF
|
||||
|
||||
```bash
|
||||
# Install requirements
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Convert HuggingFace model to GGUF (FP16)
|
||||
python convert_hf_to_gguf.py ./path/to/model --outfile model-f16.gguf
|
||||
|
||||
# Or specify output type
|
||||
python convert_hf_to_gguf.py ./path/to/model \
|
||||
--outfile model-f16.gguf \
|
||||
--outtype f16
|
||||
```
|
||||
|
||||
### Quantize model
|
||||
|
||||
```bash
|
||||
# Basic quantization to Q4_K_M
|
||||
./llama-quantize model-f16.gguf model-q4_k_m.gguf Q4_K_M
|
||||
|
||||
# Quantize with importance matrix (better quality)
|
||||
./llama-imatrix -m model-f16.gguf -f calibration.txt -o model.imatrix
|
||||
./llama-quantize --imatrix model.imatrix model-f16.gguf model-q4_k_m.gguf Q4_K_M
|
||||
```
|
||||
|
||||
### Run inference
|
||||
|
||||
```bash
|
||||
# CLI inference
|
||||
./llama-cli -m model-q4_k_m.gguf -p "Hello, how are you?"
|
||||
|
||||
# Interactive mode
|
||||
./llama-cli -m model-q4_k_m.gguf --interactive
|
||||
|
||||
# With GPU offload
|
||||
./llama-cli -m model-q4_k_m.gguf -ngl 35 -p "Hello!"
|
||||
```
|
||||
|
||||
## Quantization types
|
||||
|
||||
### K-quant methods (recommended)
|
||||
|
||||
| Type | Bits | Size (7B) | Quality | Use Case |
|
||||
|------|------|-----------|---------|----------|
|
||||
| Q2_K | 2.5 | ~2.8 GB | Low | Extreme compression |
|
||||
| Q3_K_S | 3.0 | ~3.0 GB | Low-Med | Memory constrained |
|
||||
| Q3_K_M | 3.3 | ~3.3 GB | Medium | Balance |
|
||||
| Q4_K_S | 4.0 | ~3.8 GB | Med-High | Good balance |
|
||||
| Q4_K_M | 4.5 | ~4.1 GB | High | **Recommended default** |
|
||||
| Q5_K_S | 5.0 | ~4.6 GB | High | Quality focused |
|
||||
| Q5_K_M | 5.5 | ~4.8 GB | Very High | High quality |
|
||||
| Q6_K | 6.0 | ~5.5 GB | Excellent | Near-original |
|
||||
| Q8_0 | 8.0 | ~7.2 GB | Best | Maximum quality |
|
||||
|
||||
### Legacy methods
|
||||
|
||||
| Type | Description |
|
||||
|------|-------------|
|
||||
| Q4_0 | 4-bit, basic |
|
||||
| Q4_1 | 4-bit with delta |
|
||||
| Q5_0 | 5-bit, basic |
|
||||
| Q5_1 | 5-bit with delta |
|
||||
|
||||
**Recommendation**: Use K-quant methods (Q4_K_M, Q5_K_M) for best quality/size ratio.
|
||||
|
||||
## Conversion workflows
|
||||
|
||||
### Workflow 1: HuggingFace to GGUF
|
||||
|
||||
```bash
|
||||
# 1. Download model
|
||||
huggingface-cli download meta-llama/Llama-3.1-8B --local-dir ./llama-3.1-8b
|
||||
|
||||
# 2. Convert to GGUF (FP16)
|
||||
python convert_hf_to_gguf.py ./llama-3.1-8b \
|
||||
--outfile llama-3.1-8b-f16.gguf \
|
||||
--outtype f16
|
||||
|
||||
# 3. Quantize
|
||||
./llama-quantize llama-3.1-8b-f16.gguf llama-3.1-8b-q4_k_m.gguf Q4_K_M
|
||||
|
||||
# 4. Test
|
||||
./llama-cli -m llama-3.1-8b-q4_k_m.gguf -p "Hello!" -n 50
|
||||
```
|
||||
|
||||
### Workflow 2: With importance matrix (better quality)
|
||||
|
||||
```bash
|
||||
# 1. Convert to GGUF
|
||||
python convert_hf_to_gguf.py ./model --outfile model-f16.gguf
|
||||
|
||||
# 2. Create calibration text (diverse samples)
|
||||
cat > calibration.txt << 'EOF'
|
||||
The quick brown fox jumps over the lazy dog.
|
||||
Machine learning is a subset of artificial intelligence.
|
||||
Python is a popular programming language.
|
||||
# Add more diverse text samples...
|
||||
EOF
|
||||
|
||||
# 3. Generate importance matrix
|
||||
./llama-imatrix -m model-f16.gguf \
|
||||
-f calibration.txt \
|
||||
--chunk 512 \
|
||||
-o model.imatrix \
|
||||
-ngl 35 # GPU layers if available
|
||||
|
||||
# 4. Quantize with imatrix
|
||||
./llama-quantize --imatrix model.imatrix \
|
||||
model-f16.gguf \
|
||||
model-q4_k_m.gguf \
|
||||
Q4_K_M
|
||||
```
|
||||
|
||||
### Workflow 3: Multiple quantizations
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
MODEL="llama-3.1-8b-f16.gguf"
|
||||
IMATRIX="llama-3.1-8b.imatrix"
|
||||
|
||||
# Generate imatrix once
|
||||
./llama-imatrix -m $MODEL -f wiki.txt -o $IMATRIX -ngl 35
|
||||
|
||||
# Create multiple quantizations
|
||||
for QUANT in Q4_K_M Q5_K_M Q6_K Q8_0; do
|
||||
OUTPUT="llama-3.1-8b-${QUANT,,}.gguf"
|
||||
./llama-quantize --imatrix $IMATRIX $MODEL $OUTPUT $QUANT
|
||||
echo "Created: $OUTPUT ($(du -h $OUTPUT | cut -f1))"
|
||||
done
|
||||
```
|
||||
|
||||
## Python usage
|
||||
|
||||
### llama-cpp-python
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
# Load model
|
||||
llm = Llama(
|
||||
model_path="./model-q4_k_m.gguf",
|
||||
n_ctx=4096, # Context window
|
||||
n_gpu_layers=35, # GPU offload (0 for CPU only)
|
||||
n_threads=8 # CPU threads
|
||||
)
|
||||
|
||||
# Generate
|
||||
output = llm(
|
||||
"What is machine learning?",
|
||||
max_tokens=256,
|
||||
temperature=0.7,
|
||||
stop=["</s>", "\n\n"]
|
||||
)
|
||||
print(output["choices"][0]["text"])
|
||||
```
|
||||
|
||||
### Chat completion
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(
|
||||
model_path="./model-q4_k_m.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=35,
|
||||
chat_format="llama-3" # Or "chatml", "mistral", etc.
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is Python?"}
|
||||
]
|
||||
|
||||
response = llm.create_chat_completion(
|
||||
messages=messages,
|
||||
max_tokens=256,
|
||||
temperature=0.7
|
||||
)
|
||||
print(response["choices"][0]["message"]["content"])
|
||||
```
|
||||
|
||||
### Streaming
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(model_path="./model-q4_k_m.gguf", n_gpu_layers=35)
|
||||
|
||||
# Stream tokens
|
||||
for chunk in llm(
|
||||
"Explain quantum computing:",
|
||||
max_tokens=256,
|
||||
stream=True
|
||||
):
|
||||
print(chunk["choices"][0]["text"], end="", flush=True)
|
||||
```
|
||||
|
||||
## Server mode
|
||||
|
||||
### Start OpenAI-compatible server
|
||||
|
||||
```bash
|
||||
# Start server
|
||||
./llama-server -m model-q4_k_m.gguf \
|
||||
--host 0.0.0.0 \
|
||||
--port 8080 \
|
||||
-ngl 35 \
|
||||
-c 4096
|
||||
|
||||
# Or with Python bindings
|
||||
python -m llama_cpp.server \
|
||||
--model model-q4_k_m.gguf \
|
||||
--n_gpu_layers 35 \
|
||||
--host 0.0.0.0 \
|
||||
--port 8080
|
||||
```
|
||||
|
||||
### Use with OpenAI client
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8080/v1",
|
||||
api_key="not-needed"
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="local-model",
|
||||
messages=[{"role": "user", "content": "Hello!"}],
|
||||
max_tokens=256
|
||||
)
|
||||
print(response.choices[0].message.content)
|
||||
```
|
||||
|
||||
## Hardware optimization
|
||||
|
||||
### Apple Silicon (Metal)
|
||||
|
||||
```bash
|
||||
# Build with Metal
|
||||
make clean && make GGML_METAL=1
|
||||
|
||||
# Run with Metal acceleration
|
||||
./llama-cli -m model.gguf -ngl 99 -p "Hello"
|
||||
|
||||
# Python with Metal
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
n_gpu_layers=99, # Offload all layers
|
||||
n_threads=1 # Metal handles parallelism
|
||||
)
|
||||
```
|
||||
|
||||
### NVIDIA CUDA
|
||||
|
||||
```bash
|
||||
# Build with CUDA
|
||||
make clean && make GGML_CUDA=1
|
||||
|
||||
# Run with CUDA
|
||||
./llama-cli -m model.gguf -ngl 35 -p "Hello"
|
||||
|
||||
# Specify GPU
|
||||
CUDA_VISIBLE_DEVICES=0 ./llama-cli -m model.gguf -ngl 35
|
||||
```
|
||||
|
||||
### CPU optimization
|
||||
|
||||
```bash
|
||||
# Build with AVX2/AVX512
|
||||
make clean && make
|
||||
|
||||
# Run with optimal threads
|
||||
./llama-cli -m model.gguf -t 8 -p "Hello"
|
||||
|
||||
# Python CPU config
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
n_gpu_layers=0, # CPU only
|
||||
n_threads=8, # Match physical cores
|
||||
n_batch=512 # Batch size for prompt processing
|
||||
)
|
||||
```
|
||||
|
||||
## Integration with tools
|
||||
|
||||
### Ollama
|
||||
|
||||
```bash
|
||||
# Create Modelfile
|
||||
cat > Modelfile << 'EOF'
|
||||
FROM ./model-q4_k_m.gguf
|
||||
TEMPLATE """{{ .System }}
|
||||
{{ .Prompt }}"""
|
||||
PARAMETER temperature 0.7
|
||||
PARAMETER num_ctx 4096
|
||||
EOF
|
||||
|
||||
# Create Ollama model
|
||||
ollama create mymodel -f Modelfile
|
||||
|
||||
# Run
|
||||
ollama run mymodel "Hello!"
|
||||
```
|
||||
|
||||
### LM Studio
|
||||
|
||||
1. Place GGUF file in `~/.cache/lm-studio/models/`
|
||||
2. Open LM Studio and select the model
|
||||
3. Configure context length and GPU offload
|
||||
4. Start inference
|
||||
|
||||
### text-generation-webui
|
||||
|
||||
```bash
|
||||
# Place in models folder
|
||||
cp model-q4_k_m.gguf text-generation-webui/models/
|
||||
|
||||
# Start with llama.cpp loader
|
||||
python server.py --model model-q4_k_m.gguf --loader llama.cpp --n-gpu-layers 35
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Use K-quants**: Q4_K_M offers best quality/size balance
|
||||
2. **Use imatrix**: Always use importance matrix for Q4 and below
|
||||
3. **GPU offload**: Offload as many layers as VRAM allows
|
||||
4. **Context length**: Start with 4096, increase if needed
|
||||
5. **Thread count**: Match physical CPU cores, not logical
|
||||
6. **Batch size**: Increase n_batch for faster prompt processing
|
||||
|
||||
## Common issues
|
||||
|
||||
**Model loads slowly:**
|
||||
```bash
|
||||
# Use mmap for faster loading
|
||||
./llama-cli -m model.gguf --mmap
|
||||
```
|
||||
|
||||
**Out of memory:**
|
||||
```bash
|
||||
# Reduce GPU layers
|
||||
./llama-cli -m model.gguf -ngl 20 # Reduce from 35
|
||||
|
||||
# Or use smaller quantization
|
||||
./llama-quantize model-f16.gguf model-q3_k_m.gguf Q3_K_M
|
||||
```
|
||||
|
||||
**Poor quality at low bits:**
|
||||
```bash
|
||||
# Always use imatrix for Q4 and below
|
||||
./llama-imatrix -m model-f16.gguf -f calibration.txt -o model.imatrix
|
||||
./llama-quantize --imatrix model.imatrix model-f16.gguf model-q4_k_m.gguf Q4_K_M
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- **[Advanced Usage](references/advanced-usage.md)** - Batching, speculative decoding, custom builds
|
||||
- **[Troubleshooting](references/troubleshooting.md)** - Common issues, debugging, benchmarks
|
||||
|
||||
## Resources
|
||||
|
||||
- **Repository**: https://github.com/ggml-org/llama.cpp
|
||||
- **Python Bindings**: https://github.com/abetlen/llama-cpp-python
|
||||
- **Pre-quantized Models**: https://huggingface.co/TheBloke
|
||||
- **GGUF Converter**: https://huggingface.co/spaces/ggml-org/gguf-my-repo
|
||||
- **License**: MIT
|
||||
@@ -1,138 +1,271 @@
|
||||
---
|
||||
name: llama-cpp
|
||||
description: Runs LLM inference on CPU, Apple Silicon, and consumer GPUs without NVIDIA hardware. Use for edge deployment, M1/M2/M3 Macs, AMD/Intel GPUs, or when CUDA is unavailable. Supports GGUF quantization (1.5-8 bit) for reduced memory and 4-10× speedup vs PyTorch on CPU.
|
||||
version: 1.0.0
|
||||
description: Run LLM inference with llama.cpp on CPU, Apple Silicon, AMD/Intel GPUs, or NVIDIA — plus GGUF model conversion and quantization (2–8 bit with K-quants and imatrix). Covers CLI, Python bindings, OpenAI-compatible server, and Ollama/LM Studio integration. Use for edge deployment, M1/M2/M3/M4 Macs, CUDA-less environments, or flexible local quantization.
|
||||
version: 2.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [llama-cpp-python]
|
||||
dependencies: [llama-cpp-python>=0.2.0]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Inference Serving, Llama.cpp, CPU Inference, Apple Silicon, Edge Deployment, GGUF, Quantization, Non-NVIDIA, AMD GPUs, Intel GPUs, Embedded]
|
||||
|
||||
tags: [llama.cpp, GGUF, Quantization, CPU Inference, Apple Silicon, Edge Deployment, Non-NVIDIA, AMD GPUs, Intel GPUs, Embedded, Model Compression]
|
||||
---
|
||||
|
||||
# llama.cpp
|
||||
# llama.cpp + GGUF
|
||||
|
||||
Pure C/C++ LLM inference with minimal dependencies, optimized for CPUs and non-NVIDIA hardware.
|
||||
Pure C/C++ LLM inference with minimal dependencies, plus the GGUF (GPT-Generated Unified Format) standard used for quantized weights. One toolchain covers conversion, quantization, and serving.
|
||||
|
||||
## When to use llama.cpp
|
||||
## When to use
|
||||
|
||||
**Use llama.cpp when:**
|
||||
- Running on CPU-only machines
|
||||
- Deploying on Apple Silicon (M1/M2/M3/M4)
|
||||
- Using AMD or Intel GPUs (no CUDA)
|
||||
- Edge deployment (Raspberry Pi, embedded systems)
|
||||
- Need simple deployment without Docker/Python
|
||||
**Use llama.cpp + GGUF when:**
|
||||
- Running on CPU-only machines or Apple Silicon (M1/M2/M3/M4) with Metal acceleration
|
||||
- Using AMD (ROCm) or Intel GPUs where CUDA isn't available
|
||||
- Edge deployment (Raspberry Pi, embedded systems, consumer laptops)
|
||||
- Need flexible quantization (2–8 bit with K-quants)
|
||||
- Want local AI tools (LM Studio, Ollama, text-generation-webui, koboldcpp)
|
||||
- Want a single binary deploy without Docker/Python
|
||||
|
||||
**Use TensorRT-LLM instead when:**
|
||||
- Have NVIDIA GPUs (A100/H100)
|
||||
- Need maximum throughput (100K+ tok/s)
|
||||
- Running in datacenter with CUDA
|
||||
**Key advantages:**
|
||||
- Universal hardware: CPU, Apple Silicon, NVIDIA, AMD, Intel
|
||||
- No Python runtime required (pure C/C++)
|
||||
- K-quants + imatrix for better low-bit quality
|
||||
- OpenAI-compatible server built in
|
||||
- Rich ecosystem (Ollama, LM Studio, llama-cpp-python)
|
||||
|
||||
**Use vLLM instead when:**
|
||||
- Have NVIDIA GPUs
|
||||
- Need Python-first API
|
||||
- Want PagedAttention
|
||||
**Use alternatives instead:**
|
||||
- **vLLM** — NVIDIA GPUs, PagedAttention, Python-first, max throughput
|
||||
- **TensorRT-LLM** — Production NVIDIA (A100/H100), maximum speed
|
||||
- **AWQ/GPTQ** — Calibrated quantization for NVIDIA-only deployments
|
||||
- **bitsandbytes** — Simple HuggingFace transformers integration
|
||||
- **HQQ** — Fast calibration-free quantization
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
### Install
|
||||
|
||||
```bash
|
||||
# macOS/Linux
|
||||
# macOS / Linux (simplest)
|
||||
brew install llama.cpp
|
||||
|
||||
# Or build from source
|
||||
git clone https://github.com/ggerganov/llama.cpp
|
||||
git clone https://github.com/ggml-org/llama.cpp
|
||||
cd llama.cpp
|
||||
make
|
||||
make # CPU
|
||||
make GGML_METAL=1 # Apple Silicon
|
||||
make GGML_CUDA=1 # NVIDIA CUDA
|
||||
make LLAMA_HIP=1 # AMD ROCm
|
||||
|
||||
# With Metal (Apple Silicon)
|
||||
make LLAMA_METAL=1
|
||||
|
||||
# With CUDA (NVIDIA)
|
||||
make LLAMA_CUDA=1
|
||||
|
||||
# With ROCm (AMD)
|
||||
make LLAMA_HIP=1
|
||||
# Python bindings (optional)
|
||||
pip install llama-cpp-python
|
||||
# With CUDA: CMAKE_ARGS="-DGGML_CUDA=on" pip install llama-cpp-python --force-reinstall --no-cache-dir
|
||||
# With Metal: CMAKE_ARGS="-DGGML_METAL=on" pip install llama-cpp-python --force-reinstall --no-cache-dir
|
||||
```
|
||||
|
||||
### Download model
|
||||
### Download a pre-quantized GGUF
|
||||
|
||||
```bash
|
||||
# Download from HuggingFace (GGUF format)
|
||||
# TheBloke hosts most popular models pre-quantized
|
||||
huggingface-cli download \
|
||||
TheBloke/Llama-2-7B-Chat-GGUF \
|
||||
llama-2-7b-chat.Q4_K_M.gguf \
|
||||
--local-dir models/
|
||||
```
|
||||
|
||||
# Or convert from HuggingFace
|
||||
python convert_hf_to_gguf.py models/llama-2-7b-chat/
|
||||
### Or convert a HuggingFace model to GGUF
|
||||
|
||||
```bash
|
||||
# 1. Download HF model
|
||||
huggingface-cli download meta-llama/Llama-3.1-8B --local-dir ./llama-3.1-8b
|
||||
|
||||
# 2. Convert to FP16 GGUF
|
||||
python convert_hf_to_gguf.py ./llama-3.1-8b \
|
||||
--outfile llama-3.1-8b-f16.gguf \
|
||||
--outtype f16
|
||||
|
||||
# 3. Quantize to Q4_K_M
|
||||
./llama-quantize llama-3.1-8b-f16.gguf llama-3.1-8b-q4_k_m.gguf Q4_K_M
|
||||
```
|
||||
|
||||
### Run inference
|
||||
|
||||
```bash
|
||||
# Simple chat
|
||||
./llama-cli \
|
||||
-m models/llama-2-7b-chat.Q4_K_M.gguf \
|
||||
-p "Explain quantum computing" \
|
||||
-n 256 # Max tokens
|
||||
# One-shot prompt
|
||||
./llama-cli -m model.Q4_K_M.gguf -p "Explain quantum computing" -n 256
|
||||
|
||||
# Interactive chat
|
||||
./llama-cli \
|
||||
-m models/llama-2-7b-chat.Q4_K_M.gguf \
|
||||
--interactive
|
||||
./llama-cli -m model.Q4_K_M.gguf --interactive
|
||||
|
||||
# With GPU offload
|
||||
./llama-cli -m model.Q4_K_M.gguf -ngl 35 -p "Hello!"
|
||||
```
|
||||
|
||||
### Server mode
|
||||
### Serve an OpenAI-compatible API
|
||||
|
||||
```bash
|
||||
# Start OpenAI-compatible server
|
||||
./llama-server \
|
||||
-m models/llama-2-7b-chat.Q4_K_M.gguf \
|
||||
-m model.Q4_K_M.gguf \
|
||||
--host 0.0.0.0 \
|
||||
--port 8080 \
|
||||
-ngl 32 # Offload 32 layers to GPU
|
||||
-ngl 35 \
|
||||
-c 4096 \
|
||||
--parallel 4 \
|
||||
--cont-batching
|
||||
```
|
||||
|
||||
# Client request
|
||||
```bash
|
||||
curl http://localhost:8080/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "llama-2-7b-chat",
|
||||
"model": "local",
|
||||
"messages": [{"role": "user", "content": "Hello!"}],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 100
|
||||
}'
|
||||
```
|
||||
|
||||
## Quantization formats
|
||||
## Quantization formats (GGUF)
|
||||
|
||||
### GGUF format overview
|
||||
### K-quant methods (recommended)
|
||||
|
||||
| Format | Bits | Size (7B) | Speed | Quality | Use Case |
|
||||
|--------|------|-----------|-------|---------|----------|
|
||||
| **Q4_K_M** | 4.5 | 4.1 GB | Fast | Good | **Recommended default** |
|
||||
| Q4_K_S | 4.3 | 3.9 GB | Faster | Lower | Speed critical |
|
||||
| Q5_K_M | 5.5 | 4.8 GB | Medium | Better | Quality critical |
|
||||
| Q6_K | 6.5 | 5.5 GB | Slower | Best | Maximum quality |
|
||||
| Q8_0 | 8.0 | 7.0 GB | Slow | Excellent | Minimal degradation |
|
||||
| Q2_K | 2.5 | 2.7 GB | Fastest | Poor | Testing only |
|
||||
| Type | Bits | Size (7B) | Quality | Use Case |
|
||||
|------|------|-----------|---------|----------|
|
||||
| Q2_K | 2.5 | ~2.8 GB | Low | Extreme compression (testing only) |
|
||||
| Q3_K_S | 3.0 | ~3.0 GB | Low-Med | Memory constrained |
|
||||
| Q3_K_M | 3.3 | ~3.3 GB | Medium | Fits small devices |
|
||||
| Q4_K_S | 4.0 | ~3.8 GB | Med-High | Speed critical |
|
||||
| **Q4_K_M** | 4.5 | ~4.1 GB | High | **Recommended default** |
|
||||
| Q5_K_S | 5.0 | ~4.6 GB | High | Quality focused |
|
||||
| Q5_K_M | 5.5 | ~4.8 GB | Very High | High quality |
|
||||
| Q6_K | 6.0 | ~5.5 GB | Excellent | Near-original |
|
||||
| Q8_0 | 8.0 | ~7.2 GB | Best | Maximum quality, minimal degradation |
|
||||
|
||||
### Choosing quantization
|
||||
**Variant suffixes** — `_S` (Small, faster, lower quality), `_M` (Medium, balanced), `_L` (Large, better quality).
|
||||
|
||||
**Legacy (Q4_0/Q4_1/Q5_0/Q5_1) exist** but always prefer K-quants for better quality/size ratio.
|
||||
|
||||
**IQ quantization** — ultra-low-bit with importance-aware methods: IQ2_XXS, IQ2_XS, IQ2_S, IQ3_XXS, IQ3_XS, IQ3_S, IQ4_XS. Require `--imatrix`.
|
||||
|
||||
**Task-specific defaults:**
|
||||
- General chat / assistants: Q4_K_M, or Q5_K_M if RAM allows
|
||||
- Code generation: Q5_K_M or Q6_K (higher precision helps)
|
||||
- Technical / medical: Q6_K or Q8_0
|
||||
- Very large (70B, 405B) on consumer hardware: Q3_K_M or Q4_K_S
|
||||
- Raspberry Pi / edge: Q2_K or Q3_K_S
|
||||
|
||||
## Conversion workflows
|
||||
|
||||
### Basic: HF → GGUF → quantized
|
||||
|
||||
```bash
|
||||
# General use (balanced)
|
||||
Q4_K_M # 4-bit, medium quality
|
||||
python convert_hf_to_gguf.py ./model --outfile model-f16.gguf --outtype f16
|
||||
./llama-quantize model-f16.gguf model-q4_k_m.gguf Q4_K_M
|
||||
./llama-cli -m model-q4_k_m.gguf -p "Hello!" -n 50
|
||||
```
|
||||
|
||||
# Maximum speed (more degradation)
|
||||
Q2_K or Q3_K_M
|
||||
### With importance matrix (imatrix) — better low-bit quality
|
||||
|
||||
# Maximum quality (slower)
|
||||
Q6_K or Q8_0
|
||||
`imatrix` gives 10–20% perplexity improvement at Q4, essential at Q3 and below.
|
||||
|
||||
# Very large models (70B, 405B)
|
||||
Q3_K_M or Q4_K_S # Lower bits to fit in memory
|
||||
```bash
|
||||
# 1. Convert to FP16 GGUF
|
||||
python convert_hf_to_gguf.py ./model --outfile model-f16.gguf
|
||||
|
||||
# 2. Prepare calibration data (diverse text, ~100MB is ideal)
|
||||
cat > calibration.txt << 'EOF'
|
||||
The quick brown fox jumps over the lazy dog.
|
||||
Machine learning is a subset of artificial intelligence.
|
||||
# Add more diverse text samples...
|
||||
EOF
|
||||
|
||||
# 3. Generate importance matrix
|
||||
./llama-imatrix -m model-f16.gguf \
|
||||
-f calibration.txt \
|
||||
--chunk 512 \
|
||||
-o model.imatrix \
|
||||
-ngl 35
|
||||
|
||||
# 4. Quantize with imatrix
|
||||
./llama-quantize --imatrix model.imatrix \
|
||||
model-f16.gguf model-q4_k_m.gguf Q4_K_M
|
||||
```
|
||||
|
||||
### Multi-quant batch
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
MODEL="llama-3.1-8b-f16.gguf"
|
||||
IMATRIX="llama-3.1-8b.imatrix"
|
||||
|
||||
./llama-imatrix -m $MODEL -f wiki.txt -o $IMATRIX -ngl 35
|
||||
|
||||
for QUANT in Q4_K_M Q5_K_M Q6_K Q8_0; do
|
||||
OUTPUT="llama-3.1-8b-${QUANT,,}.gguf"
|
||||
./llama-quantize --imatrix $IMATRIX $MODEL $OUTPUT $QUANT
|
||||
echo "Created: $OUTPUT ($(du -h $OUTPUT | cut -f1))"
|
||||
done
|
||||
```
|
||||
|
||||
### Quality testing (perplexity)
|
||||
|
||||
```bash
|
||||
./llama-perplexity -m model.gguf -f wikitext-2-raw/wiki.test.raw -c 512
|
||||
# Baseline FP16: ~5.96 | Q4_K_M: ~6.06 (+1.7%) | Q2_K: ~6.87 (+15.3%)
|
||||
```
|
||||
|
||||
## Python bindings (llama-cpp-python)
|
||||
|
||||
### Basic generation
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(
|
||||
model_path="./model-q4_k_m.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=35, # 0 for CPU only, 99 to offload everything
|
||||
n_threads=8,
|
||||
)
|
||||
|
||||
output = llm(
|
||||
"What is machine learning?",
|
||||
max_tokens=256,
|
||||
temperature=0.7,
|
||||
stop=["</s>", "\n\n"],
|
||||
)
|
||||
print(output["choices"][0]["text"])
|
||||
```
|
||||
|
||||
### Chat completion + streaming
|
||||
|
||||
```python
|
||||
llm = Llama(
|
||||
model_path="./model-q4_k_m.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=35,
|
||||
chat_format="llama-3", # Or "chatml", "mistral", etc.
|
||||
)
|
||||
|
||||
# Non-streaming
|
||||
response = llm.create_chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is Python?"},
|
||||
],
|
||||
max_tokens=256,
|
||||
temperature=0.7,
|
||||
)
|
||||
print(response["choices"][0]["message"]["content"])
|
||||
|
||||
# Streaming
|
||||
for chunk in llm("Explain quantum computing:", max_tokens=256, stream=True):
|
||||
print(chunk["choices"][0]["text"], end="", flush=True)
|
||||
```
|
||||
|
||||
### Embeddings
|
||||
|
||||
```python
|
||||
llm = Llama(model_path="./model-q4_k_m.gguf", embedding=True, n_gpu_layers=35)
|
||||
vec = llm.embed("This is a test sentence.")
|
||||
print(f"Embedding dimension: {len(vec)}")
|
||||
```
|
||||
|
||||
## Hardware acceleration
|
||||
@@ -140,122 +273,166 @@ Q3_K_M or Q4_K_S # Lower bits to fit in memory
|
||||
### Apple Silicon (Metal)
|
||||
|
||||
```bash
|
||||
# Build with Metal
|
||||
make LLAMA_METAL=1
|
||||
|
||||
# Run with GPU acceleration (automatic)
|
||||
./llama-cli -m model.gguf -ngl 999 # Offload all layers
|
||||
|
||||
# Performance: M3 Max 40-60 tokens/sec (Llama 2-7B Q4_K_M)
|
||||
make clean && make GGML_METAL=1
|
||||
./llama-cli -m model.gguf -ngl 99 -p "Hello" # offload all layers
|
||||
```
|
||||
|
||||
### NVIDIA GPUs (CUDA)
|
||||
|
||||
```bash
|
||||
# Build with CUDA
|
||||
make LLAMA_CUDA=1
|
||||
|
||||
# Offload layers to GPU
|
||||
./llama-cli -m model.gguf -ngl 35 # Offload 35/40 layers
|
||||
|
||||
# Hybrid CPU+GPU for large models
|
||||
./llama-cli -m llama-70b.Q4_K_M.gguf -ngl 20 # GPU: 20 layers, CPU: rest
|
||||
```python
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
n_gpu_layers=99, # Offload everything
|
||||
n_threads=1, # Metal handles parallelism
|
||||
)
|
||||
```
|
||||
|
||||
### AMD GPUs (ROCm)
|
||||
Performance: M3 Max ~40–60 tok/s on Llama 2-7B Q4_K_M.
|
||||
|
||||
### NVIDIA (CUDA)
|
||||
|
||||
```bash
|
||||
make clean && make GGML_CUDA=1
|
||||
./llama-cli -m model.gguf -ngl 35 -p "Hello"
|
||||
|
||||
# Hybrid for large models
|
||||
./llama-cli -m llama-70b.Q4_K_M.gguf -ngl 20 # GPU: 20 layers, CPU: rest
|
||||
|
||||
# Multi-GPU split
|
||||
./llama-cli -m large-model.gguf --tensor-split 0.5,0.5 -ngl 60
|
||||
```
|
||||
|
||||
### AMD (ROCm)
|
||||
|
||||
```bash
|
||||
# Build with ROCm
|
||||
make LLAMA_HIP=1
|
||||
|
||||
# Run with AMD GPU
|
||||
./llama-cli -m model.gguf -ngl 999
|
||||
```
|
||||
|
||||
## Common patterns
|
||||
|
||||
### Batch processing
|
||||
### CPU
|
||||
|
||||
```bash
|
||||
# Process multiple prompts from file
|
||||
cat prompts.txt | ./llama-cli \
|
||||
-m model.gguf \
|
||||
--batch-size 512 \
|
||||
-n 100
|
||||
# Match PHYSICAL cores, not logical
|
||||
./llama-cli -m model.gguf -t 8 -p "Hello"
|
||||
|
||||
# BLAS acceleration (2–3× speedup)
|
||||
make LLAMA_OPENBLAS=1
|
||||
```
|
||||
|
||||
### Constrained generation
|
||||
|
||||
```bash
|
||||
# JSON output with grammar
|
||||
./llama-cli \
|
||||
-m model.gguf \
|
||||
-p "Generate a person: " \
|
||||
--grammar-file grammars/json.gbnf
|
||||
|
||||
# Outputs valid JSON only
|
||||
```
|
||||
|
||||
### Context size
|
||||
|
||||
```bash
|
||||
# Increase context (default 512)
|
||||
./llama-cli \
|
||||
-m model.gguf \
|
||||
-c 4096 # 4K context window
|
||||
|
||||
# Very long context (if model supports)
|
||||
./llama-cli -m model.gguf -c 32768 # 32K context
|
||||
```python
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
n_gpu_layers=0,
|
||||
n_threads=8,
|
||||
n_batch=512, # Larger batch = faster prompt processing
|
||||
)
|
||||
```
|
||||
|
||||
## Performance benchmarks
|
||||
|
||||
### CPU performance (Llama 2-7B Q4_K_M)
|
||||
### CPU (Llama 2-7B Q4_K_M)
|
||||
|
||||
| CPU | Threads | Speed | Cost |
|
||||
|-----|---------|-------|------|
|
||||
| Apple M3 Max | 16 | 50 tok/s | $0 (local) |
|
||||
| AMD Ryzen 9 7950X | 32 | 35 tok/s | $0.50/hour |
|
||||
| Intel i9-13900K | 32 | 30 tok/s | $0.40/hour |
|
||||
| AWS c7i.16xlarge | 64 | 40 tok/s | $2.88/hour |
|
||||
| CPU | Threads | Speed |
|
||||
|-----|---------|-------|
|
||||
| Apple M3 Max (Metal) | 16 | 50 tok/s |
|
||||
| AMD Ryzen 9 7950X | 32 | 35 tok/s |
|
||||
| Intel i9-13900K | 32 | 30 tok/s |
|
||||
|
||||
### GPU acceleration (Llama 2-7B Q4_K_M)
|
||||
### GPU offloading on RTX 4090
|
||||
|
||||
| GPU | Speed | vs CPU | Cost |
|
||||
|-----|-------|--------|------|
|
||||
| NVIDIA RTX 4090 | 120 tok/s | 3-4× | $0 (local) |
|
||||
| NVIDIA A10 | 80 tok/s | 2-3× | $1.00/hour |
|
||||
| AMD MI250 | 70 tok/s | 2× | $2.00/hour |
|
||||
| Apple M3 Max (Metal) | 50 tok/s | ~Same | $0 (local) |
|
||||
| Layers GPU | Speed | VRAM |
|
||||
|------------|-------|------|
|
||||
| 0 (CPU only) | 30 tok/s | 0 GB |
|
||||
| 20 (hybrid) | 80 tok/s | 8 GB |
|
||||
| 35 (all) | 120 tok/s | 12 GB |
|
||||
|
||||
## Supported models
|
||||
|
||||
**LLaMA family**:
|
||||
- Llama 2 (7B, 13B, 70B)
|
||||
- Llama 3 (8B, 70B, 405B)
|
||||
- Code Llama
|
||||
- **LLaMA family**: Llama 2 (7B/13B/70B), Llama 3 (8B/70B/405B), Code Llama
|
||||
- **Mistral family**: Mistral 7B, Mixtral 8x7B/8x22B
|
||||
- **Other**: Falcon, BLOOM, GPT-J, Phi-3, Gemma, Qwen, LLaVA (vision), Whisper (audio)
|
||||
|
||||
**Mistral family**:
|
||||
- Mistral 7B
|
||||
- Mixtral 8x7B, 8x22B
|
||||
Find GGUF models: https://huggingface.co/models?library=gguf
|
||||
|
||||
**Other**:
|
||||
- Falcon, BLOOM, GPT-J
|
||||
- Phi-3, Gemma, Qwen
|
||||
- LLaVA (vision), Whisper (audio)
|
||||
## Ecosystem integrations
|
||||
|
||||
**Find models**: https://huggingface.co/models?library=gguf
|
||||
### Ollama
|
||||
|
||||
```bash
|
||||
cat > Modelfile << 'EOF'
|
||||
FROM ./model-q4_k_m.gguf
|
||||
TEMPLATE """{{ .System }}
|
||||
{{ .Prompt }}"""
|
||||
PARAMETER temperature 0.7
|
||||
PARAMETER num_ctx 4096
|
||||
EOF
|
||||
|
||||
ollama create mymodel -f Modelfile
|
||||
ollama run mymodel "Hello!"
|
||||
```
|
||||
|
||||
### LM Studio
|
||||
|
||||
1. Place GGUF file in `~/.cache/lm-studio/models/`
|
||||
2. Open LM Studio and select the model
|
||||
3. Configure context length and GPU offload, start inference
|
||||
|
||||
### text-generation-webui
|
||||
|
||||
```bash
|
||||
cp model-q4_k_m.gguf text-generation-webui/models/
|
||||
python server.py --model model-q4_k_m.gguf --loader llama.cpp --n-gpu-layers 35
|
||||
```
|
||||
|
||||
### OpenAI client → llama-server
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(base_url="http://localhost:8080/v1", api_key="not-needed")
|
||||
response = client.chat.completions.create(
|
||||
model="local-model",
|
||||
messages=[{"role": "user", "content": "Hello!"}],
|
||||
max_tokens=256,
|
||||
)
|
||||
print(response.choices[0].message.content)
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Use K-quants** — Q4_K_M is the recommended default
|
||||
2. **Use imatrix** for Q4 and below (calibration improves quality substantially)
|
||||
3. **Offload as many layers as VRAM allows** — start high, reduce by 5 on OOM
|
||||
4. **Thread count** — match physical cores, not logical
|
||||
5. **Batch size** — increase `n_batch` (e.g. 512) for faster prompt processing
|
||||
6. **Context** — start at 4096, grow only as needed (memory scales with ctx)
|
||||
7. **Flash Attention** — add `--flash-attn` if your build supports it
|
||||
|
||||
## Common issues (quick fixes)
|
||||
|
||||
**Model loads slowly** — use `--mmap` for memory-mapped loading.
|
||||
|
||||
**Out of memory (GPU)** — reduce `-ngl`, use a smaller quant (Q4_K_S / Q3_K_M), or quantize the KV cache:
|
||||
```python
|
||||
Llama(model_path="...", type_k=2, type_v=2, n_gpu_layers=35) # Q4_0 KV cache
|
||||
```
|
||||
|
||||
**Garbage output** — wrong `chat_format`, temperature too high, or model file corrupted. Test with `temperature=0.1` and verify FP16 baseline works.
|
||||
|
||||
**Connection refused (server)** — bind to `--host 0.0.0.0`, check `lsof -i :8080`.
|
||||
|
||||
See `references/troubleshooting.md` for the full playbook.
|
||||
|
||||
## References
|
||||
|
||||
- **[Quantization Guide](references/quantization.md)** - GGUF formats, conversion, quality comparison
|
||||
- **[Server Deployment](references/server.md)** - API endpoints, Docker, monitoring
|
||||
- **[Optimization](references/optimization.md)** - Performance tuning, hybrid CPU+GPU
|
||||
- **[advanced-usage.md](references/advanced-usage.md)** — speculative decoding, batched inference, grammar-constrained generation, LoRA, multi-GPU, custom builds, benchmark scripts
|
||||
- **[quantization.md](references/quantization.md)** — perplexity tables, use-case guide, model size scaling (7B/13B/70B RAM needs), imatrix deep dive
|
||||
- **[server.md](references/server.md)** — OpenAI API endpoints, Docker deployment, NGINX load balancing, monitoring
|
||||
- **[optimization.md](references/optimization.md)** — CPU threading, BLAS, GPU offload heuristics, batch tuning, benchmarks
|
||||
- **[troubleshooting.md](references/troubleshooting.md)** — install/convert/quantize/inference/server issues, Apple Silicon, debugging
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/ggerganov/llama.cpp
|
||||
- **Models**: https://huggingface.co/models?library=gguf
|
||||
- **Discord**: https://discord.gg/llama-cpp
|
||||
|
||||
|
||||
- **GitHub**: https://github.com/ggml-org/llama.cpp
|
||||
- **Python bindings**: https://github.com/abetlen/llama-cpp-python
|
||||
- **Pre-quantized models**: https://huggingface.co/TheBloke
|
||||
- **GGUF converter Space**: https://huggingface.co/spaces/ggml-org/gguf-my-repo
|
||||
- **License**: MIT
|
||||
|
||||
@@ -1,97 +0,0 @@
|
||||
# GRPO/RL Training Skill
|
||||
|
||||
**Expert-level guidance for Group Relative Policy Optimization with TRL**
|
||||
|
||||
## 📁 Skill Structure
|
||||
|
||||
```
|
||||
grpo-rl-training/
|
||||
├── SKILL.md # Main skill documentation (READ THIS FIRST)
|
||||
├── README.md # This file
|
||||
├── templates/
|
||||
│ └── basic_grpo_training.py # Production-ready training template
|
||||
└── examples/
|
||||
└── reward_functions_library.py # 20+ reward function examples
|
||||
```
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
1. **Read SKILL.md** - Comprehensive guide with all concepts and patterns
|
||||
2. **Copy `templates/basic_grpo_training.py`** - Start with working code
|
||||
3. **Browse `examples/reward_functions_library.py`** - Pick reward functions for your task
|
||||
4. **Modify for your use case** - Adapt dataset, rewards, and config
|
||||
|
||||
## 💡 What's Inside
|
||||
|
||||
### SKILL.md (Main Documentation)
|
||||
- Core GRPO concepts and algorithm fundamentals
|
||||
- Complete implementation workflow (dataset → rewards → training → deployment)
|
||||
- 10+ reward function examples with code
|
||||
- Hyperparameter tuning guide
|
||||
- Training insights (loss behavior, metrics, debugging)
|
||||
- Troubleshooting guide
|
||||
- Production best practices
|
||||
|
||||
### Templates
|
||||
- **basic_grpo_training.py**: Minimal, production-ready training script
|
||||
- Uses Qwen 2.5 1.5B Instruct
|
||||
- 3 reward functions (format + correctness)
|
||||
- LoRA for efficient training
|
||||
- Fully documented and ready to run
|
||||
|
||||
### Examples
|
||||
- **reward_functions_library.py**: 20+ battle-tested reward functions
|
||||
- Correctness rewards (exact match, fuzzy match, numeric, code execution)
|
||||
- Format rewards (XML, JSON, strict/soft)
|
||||
- Length rewards (ideal length, min/max)
|
||||
- Style rewards (reasoning quality, citations, repetition penalty)
|
||||
- Combined rewards (multi-objective optimization)
|
||||
- Preset collections for common tasks
|
||||
|
||||
## 📖 Usage for Agents
|
||||
|
||||
When this skill is loaded in your agent's context:
|
||||
|
||||
1. **Always read SKILL.md first** before implementing
|
||||
2. **Start simple** - Use length-based reward to validate setup
|
||||
3. **Build incrementally** - Add one reward function at a time
|
||||
4. **Reference examples** - Copy patterns from reward_functions_library.py
|
||||
5. **Monitor training** - Watch reward metrics (not loss!)
|
||||
|
||||
## 🎯 Common Use Cases
|
||||
|
||||
| Task Type | Recommended Rewards | Template |
|
||||
|-----------|---------------------|----------|
|
||||
| Math reasoning | `MATH_REASONING_REWARDS` preset | basic_grpo_training.py |
|
||||
| Code generation | `CODE_GENERATION_REWARDS` preset | Modify dataset in template |
|
||||
| Summarization | `SUMMARIZATION_REWARDS` preset | Adjust prompts + rewards |
|
||||
| Q&A | `QA_REWARDS` preset | Use fuzzy match + citations |
|
||||
|
||||
## ⚠️ Critical Reminders
|
||||
|
||||
- **Loss goes UP during training** - This is normal (it's KL divergence)
|
||||
- **Use 3-5 reward functions** - Single rewards often fail
|
||||
- **Test rewards before training** - Debug each function independently
|
||||
- **Monitor reward_std** - Should stay > 0.1 (avoid mode collapse)
|
||||
- **Start with num_generations=4-8** - Scale up if GPU allows
|
||||
|
||||
## 🔗 External Resources
|
||||
|
||||
- [TRL Documentation](https://huggingface.co/docs/trl)
|
||||
- [DeepSeek R1 Paper](https://arxiv.org/abs/2501.12948)
|
||||
- [Open R1 Implementation](https://github.com/huggingface/open-r1)
|
||||
- [Unsloth (2-3x faster)](https://docs.unsloth.ai/)
|
||||
|
||||
## 📝 Version
|
||||
|
||||
**v1.0.0** - Initial release (January 2025)
|
||||
|
||||
## 👨💻 Maintained By
|
||||
|
||||
Orchestra Research
|
||||
For questions or improvements, see https://orchestra.com
|
||||
|
||||
---
|
||||
|
||||
**License:** MIT
|
||||
**Last Updated:** January 2025
|
||||
@@ -252,6 +252,8 @@ trl dpo \
|
||||
|
||||
Train with reinforcement learning using minimal memory.
|
||||
|
||||
For in-depth GRPO guidance — reward function design, critical training insights (loss behavior, mode collapse, tuning), and advanced multi-stage patterns — see **[references/grpo-training.md](references/grpo-training.md)**. A production-ready training script is in **[templates/basic_grpo_training.py](templates/basic_grpo_training.py)**.
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
@@ -428,6 +430,8 @@ config = PPOConfig(
|
||||
|
||||
**Online RL methods**: See [references/online-rl.md](references/online-rl.md) for PPO, GRPO, RLOO, and OnlineDPO with detailed configurations.
|
||||
|
||||
**GRPO deep dive**: See [references/grpo-training.md](references/grpo-training.md) for expert-level GRPO patterns — reward function design philosophy, training insights (why loss increases, mode collapse detection), hyperparameter tuning, multi-stage training, and troubleshooting. Production-ready template in [templates/basic_grpo_training.py](templates/basic_grpo_training.py).
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **GPU**: NVIDIA (CUDA required)
|
||||
|
||||
+129
-200
@@ -1,51 +1,36 @@
|
||||
---
|
||||
name: grpo-rl-training
|
||||
description: Expert guidance for GRPO/RL fine-tuning with TRL for reasoning and task-specific model training
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [transformers>=4.47.0, trl>=0.14.0, datasets>=3.2.0, peft>=0.14.0, torch]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Post-Training, Reinforcement Learning, GRPO, TRL, RLHF, Reward Modeling, Reasoning, DPO, PPO, Structured Output]
|
||||
# GRPO (Group Relative Policy Optimization) — Deep Guide
|
||||
|
||||
---
|
||||
Expert-level patterns, critical insights, and production-ready workflows for fine-tuning language models with custom reward functions using TRL's `GRPOTrainer`. This is the deep reference for the GRPO workflow summarized in the main skill.
|
||||
|
||||
# GRPO/RL Training with TRL
|
||||
## When to use GRPO
|
||||
|
||||
Expert-level guidance for implementing Group Relative Policy Optimization (GRPO) using the Transformer Reinforcement Learning (TRL) library. This skill provides battle-tested patterns, critical insights, and production-ready workflows for fine-tuning language models with custom reward functions.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use GRPO training when you need to:
|
||||
- **Enforce specific output formats** (e.g., XML tags, JSON, structured reasoning)
|
||||
Use GRPO when you need to:
|
||||
- **Enforce specific output formats** (XML tags, JSON, structured reasoning)
|
||||
- **Teach verifiable tasks** with objective correctness metrics (math, coding, fact-checking)
|
||||
- **Improve reasoning capabilities** by rewarding chain-of-thought patterns
|
||||
- **Align models to domain-specific behaviors** without labeled preference data
|
||||
- **Optimize for multiple objectives** simultaneously (format + correctness + style)
|
||||
|
||||
**Do NOT use GRPO for:**
|
||||
- Simple supervised fine-tuning tasks (use SFT instead)
|
||||
- Simple supervised fine-tuning tasks → use SFT
|
||||
- Tasks without clear reward signals
|
||||
- When you already have high-quality preference pairs (use DPO/PPO instead)
|
||||
- When you already have high-quality preference pairs → use DPO/PPO
|
||||
|
||||
---
|
||||
## Core concepts
|
||||
|
||||
## Core Concepts
|
||||
### 1. GRPO algorithm fundamentals
|
||||
|
||||
### 1. GRPO Algorithm Fundamentals
|
||||
|
||||
**Key Mechanism:**
|
||||
- Generates **multiple completions** for each prompt (group size: 4-16)
|
||||
**Key mechanism:**
|
||||
- Generates **multiple completions** per prompt (group size: 4–16)
|
||||
- Compares completions within each group using reward functions
|
||||
- Updates policy to favor higher-rewarded responses relative to the group
|
||||
|
||||
**Critical Difference from PPO:**
|
||||
**Critical differences from PPO:**
|
||||
- No separate reward model needed
|
||||
- More sample-efficient (learns from within-group comparisons)
|
||||
- Simpler to implement and debug
|
||||
|
||||
**Mathematical Intuition:**
|
||||
**Mathematical intuition:**
|
||||
```
|
||||
For each prompt p:
|
||||
1. Generate N completions: {c₁, c₂, ..., cₙ}
|
||||
@@ -54,35 +39,32 @@ For each prompt p:
|
||||
relative to low-reward ones in the same group
|
||||
```
|
||||
|
||||
### 2. Reward Function Design Philosophy
|
||||
### 2. Reward function design philosophy
|
||||
|
||||
**Golden Rules:**
|
||||
1. **Compose multiple reward functions** - Each handles one aspect (format, correctness, style)
|
||||
2. **Scale rewards appropriately** - Higher weight = stronger signal
|
||||
3. **Use incremental rewards** - Partial credit for partial compliance
|
||||
4. **Test rewards independently** - Debug each reward function in isolation
|
||||
**Golden rules:**
|
||||
1. **Compose multiple reward functions** — each handles one aspect (format, correctness, style)
|
||||
2. **Scale rewards appropriately** — higher weight = stronger signal
|
||||
3. **Use incremental rewards** — partial credit for partial compliance
|
||||
4. **Test rewards independently** — debug each reward function in isolation
|
||||
|
||||
**Reward Function Types:**
|
||||
**Reward function types:**
|
||||
|
||||
| Type | Use Case | Example Weight |
|
||||
|------|----------|----------------|
|
||||
| **Correctness** | Verifiable tasks (math, code) | 2.0 (highest) |
|
||||
| **Format** | Strict structure enforcement | 0.5-1.0 |
|
||||
| **Length** | Encourage verbosity/conciseness | 0.1-0.5 |
|
||||
| **Style** | Penalize unwanted patterns | -0.5 to 0.5 |
|
||||
| **Format** | Strict structure enforcement | 0.5–1.0 |
|
||||
| **Length** | Encourage verbosity/conciseness | 0.1–0.5 |
|
||||
| **Style** | Penalize unwanted patterns | −0.5 to 0.5 |
|
||||
|
||||
---
|
||||
## Implementation workflow
|
||||
|
||||
## Implementation Workflow
|
||||
### Step 1: Dataset preparation
|
||||
|
||||
### Step 1: Dataset Preparation
|
||||
|
||||
**Critical Requirements:**
|
||||
- Prompts in chat format (list of dicts with 'role' and 'content')
|
||||
**Critical requirements:**
|
||||
- Prompts in chat format (list of dicts with `role` and `content`)
|
||||
- Include system prompts to set expectations
|
||||
- For verifiable tasks, include ground truth answers as additional columns
|
||||
|
||||
**Example Structure:**
|
||||
```python
|
||||
from datasets import load_dataset, Dataset
|
||||
|
||||
@@ -97,8 +79,7 @@ Respond in the following format:
|
||||
"""
|
||||
|
||||
def prepare_dataset(raw_data):
|
||||
"""
|
||||
Transform raw data into GRPO-compatible format.
|
||||
"""Transform raw data into GRPO-compatible format.
|
||||
|
||||
Returns: Dataset with columns:
|
||||
- 'prompt': List[Dict] with role/content (system + user messages)
|
||||
@@ -113,14 +94,14 @@ def prepare_dataset(raw_data):
|
||||
})
|
||||
```
|
||||
|
||||
**Pro Tips:**
|
||||
- Use one-shot or few-shot examples in system prompt for complex formats
|
||||
- Keep prompts concise (max_prompt_length: 256-512 tokens)
|
||||
**Pro tips:**
|
||||
- Use one-shot or few-shot examples in the system prompt for complex formats
|
||||
- Keep prompts concise (max_prompt_length: 256–512 tokens)
|
||||
- Validate data quality before training (garbage in = garbage out)
|
||||
|
||||
### Step 2: Reward Function Implementation
|
||||
### Step 2: Reward function implementation
|
||||
|
||||
**Template Structure:**
|
||||
**Template structure:**
|
||||
```python
|
||||
def reward_function_name(
|
||||
prompts, # List[List[Dict]]: Original prompts
|
||||
@@ -128,24 +109,16 @@ def reward_function_name(
|
||||
answer=None, # Optional: Ground truth from dataset
|
||||
**kwargs # Additional dataset columns
|
||||
) -> list[float]:
|
||||
"""
|
||||
Evaluate completions and return rewards.
|
||||
|
||||
Returns: List of floats (one per completion)
|
||||
"""
|
||||
# Extract completion text
|
||||
"""Evaluate completions and return rewards (one per completion)."""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
|
||||
# Compute rewards
|
||||
rewards = []
|
||||
for response in responses:
|
||||
score = compute_score(response)
|
||||
rewards.append(score)
|
||||
|
||||
return rewards
|
||||
```
|
||||
|
||||
**Example 1: Correctness Reward (Math/Coding)**
|
||||
**Example 1: correctness reward (math/coding)**
|
||||
```python
|
||||
def correctness_reward(prompts, completions, answer, **kwargs):
|
||||
"""Reward correct answers with high score."""
|
||||
@@ -155,7 +128,7 @@ def correctness_reward(prompts, completions, answer, **kwargs):
|
||||
for ans, gt in zip(extracted, answer)]
|
||||
```
|
||||
|
||||
**Example 2: Format Reward (Structured Output)**
|
||||
**Example 2: format reward (structured output)**
|
||||
```python
|
||||
import re
|
||||
|
||||
@@ -167,7 +140,7 @@ def format_reward(completions, **kwargs):
|
||||
for r in responses]
|
||||
```
|
||||
|
||||
**Example 3: Incremental Format Reward (Partial Credit)**
|
||||
**Example 3: incremental format reward (partial credit)**
|
||||
```python
|
||||
def incremental_format_reward(completions, **kwargs):
|
||||
"""Award partial credit for format compliance."""
|
||||
@@ -176,14 +149,10 @@ def incremental_format_reward(completions, **kwargs):
|
||||
|
||||
for r in responses:
|
||||
score = 0.0
|
||||
if '<reasoning>' in r:
|
||||
score += 0.25
|
||||
if '</reasoning>' in r:
|
||||
score += 0.25
|
||||
if '<answer>' in r:
|
||||
score += 0.25
|
||||
if '</answer>' in r:
|
||||
score += 0.25
|
||||
if '<reasoning>' in r: score += 0.25
|
||||
if '</reasoning>' in r: score += 0.25
|
||||
if '<answer>' in r: score += 0.25
|
||||
if '</answer>' in r: score += 0.25
|
||||
# Penalize extra text after closing tag
|
||||
if r.count('</answer>') == 1:
|
||||
extra_text = r.split('</answer>')[-1].strip()
|
||||
@@ -193,12 +162,11 @@ def incremental_format_reward(completions, **kwargs):
|
||||
return rewards
|
||||
```
|
||||
|
||||
**Critical Insight:**
|
||||
Combine 3-5 reward functions for robust training. Order matters less than diversity of signals.
|
||||
**Critical insight:** Combine 3–5 reward functions for robust training. Order matters less than diversity of signals.
|
||||
|
||||
### Step 3: Training Configuration
|
||||
### Step 3: Training configuration
|
||||
|
||||
**Memory-Optimized Config (Small GPU)**
|
||||
**Memory-optimized config (small GPU)**
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
@@ -218,13 +186,13 @@ training_args = GRPOConfig(
|
||||
gradient_accumulation_steps=4, # Effective batch = 4
|
||||
|
||||
# GRPO-specific
|
||||
num_generations=8, # Group size: 8-16 recommended
|
||||
num_generations=8, # Group size: 8–16 recommended
|
||||
max_prompt_length=256,
|
||||
max_completion_length=512,
|
||||
|
||||
# Training duration
|
||||
num_train_epochs=1,
|
||||
max_steps=None, # Or set fixed steps (e.g., 500)
|
||||
max_steps=None,
|
||||
|
||||
# Optimization
|
||||
bf16=True, # Faster on A100/H100
|
||||
@@ -234,11 +202,11 @@ training_args = GRPOConfig(
|
||||
# Logging
|
||||
logging_steps=1,
|
||||
save_steps=100,
|
||||
report_to="wandb", # Or "none" for no logging
|
||||
report_to="wandb",
|
||||
)
|
||||
```
|
||||
|
||||
**High-Performance Config (Large GPU)**
|
||||
**High-performance config (large GPU)**
|
||||
```python
|
||||
training_args = GRPOConfig(
|
||||
output_dir="outputs/grpo-model",
|
||||
@@ -255,31 +223,30 @@ training_args = GRPOConfig(
|
||||
)
|
||||
```
|
||||
|
||||
**Critical Hyperparameters:**
|
||||
**Critical hyperparameters:**
|
||||
|
||||
| Parameter | Impact | Tuning Advice |
|
||||
|-----------|--------|---------------|
|
||||
| `num_generations` | Group size for comparison | Start with 8, increase to 16 if GPU allows |
|
||||
| `num_generations` | Group size for comparison | Start 8, increase to 16 if GPU allows |
|
||||
| `learning_rate` | Convergence speed/stability | 5e-6 (safe), 1e-5 (faster, riskier) |
|
||||
| `max_completion_length` | Output verbosity | Match your task (512 for reasoning, 256 for short answers) |
|
||||
| `max_completion_length` | Output verbosity | Match your task (512 reasoning, 256 short answers) |
|
||||
| `gradient_accumulation_steps` | Effective batch size | Increase if GPU memory limited |
|
||||
|
||||
### Step 4: Model Setup and Training
|
||||
### Step 4: Model setup and training
|
||||
|
||||
**Standard Setup (Transformers)**
|
||||
**Standard setup (Transformers + TRL)**
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import LoraConfig
|
||||
from trl import GRPOTrainer
|
||||
|
||||
# Load model
|
||||
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2", # 2-3x faster
|
||||
device_map="auto"
|
||||
attn_implementation="flash_attention_2", # 2–3× faster
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
@@ -287,17 +254,16 @@ tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# Optional: LoRA for parameter-efficient training
|
||||
peft_config = LoraConfig(
|
||||
r=16, # Rank (higher = more capacity)
|
||||
lora_alpha=32, # Scaling factor (typically 2*r)
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=[
|
||||
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj"
|
||||
"gate_proj", "up_proj", "down_proj",
|
||||
],
|
||||
task_type="CAUSAL_LM",
|
||||
lora_dropout=0.05,
|
||||
)
|
||||
|
||||
# Initialize trainer
|
||||
trainer = GRPOTrainer(
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
@@ -308,17 +274,14 @@ trainer = GRPOTrainer(
|
||||
],
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
peft_config=peft_config, # Remove for full fine-tuning
|
||||
peft_config=peft_config, # Remove for full fine-tuning
|
||||
)
|
||||
|
||||
# Train
|
||||
trainer.train()
|
||||
|
||||
# Save
|
||||
trainer.save_model("final_model")
|
||||
```
|
||||
|
||||
**Unsloth Setup (2-3x Faster)**
|
||||
**Unsloth setup (2–3× faster)**
|
||||
```python
|
||||
from unsloth import FastLanguageModel
|
||||
|
||||
@@ -339,28 +302,26 @@ model = FastLanguageModel.get_peft_model(
|
||||
use_gradient_checkpointing="unsloth",
|
||||
)
|
||||
|
||||
# Rest is identical to standard setup
|
||||
# Rest is identical to the standard setup
|
||||
trainer = GRPOTrainer(model=model, ...)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
---
|
||||
## Critical training insights
|
||||
|
||||
## Critical Training Insights
|
||||
### 1. Loss behavior (EXPECTED pattern)
|
||||
- **Loss starts near 0 and INCREASES during training** — this is CORRECT
|
||||
- Loss measures KL divergence from initial policy; the model is learning (diverging from original behavior to optimize rewards)
|
||||
- **Monitor reward metrics, not loss, for progress**
|
||||
|
||||
### 1. Loss Behavior (EXPECTED PATTERN)
|
||||
- **Loss starts near 0 and INCREASES during training**
|
||||
- This is CORRECT - loss measures KL divergence from initial policy
|
||||
- Model is learning (diverging from original behavior to optimize rewards)
|
||||
- Monitor reward metrics instead of loss for progress
|
||||
### 2. Reward tracking
|
||||
|
||||
### 2. Reward Tracking
|
||||
Key metrics to watch:
|
||||
- `reward`: Average across all completions
|
||||
- `reward_std`: Diversity within groups (should remain > 0)
|
||||
- `kl`: KL divergence from reference (should grow moderately)
|
||||
- `reward` — average across all completions
|
||||
- `reward_std` — diversity within groups (should remain > 0)
|
||||
- `kl` — KL divergence from reference (should grow moderately)
|
||||
|
||||
**Healthy Training Pattern:**
|
||||
**Healthy pattern:**
|
||||
```
|
||||
Step Reward Reward_Std KL
|
||||
100 0.5 0.3 0.02
|
||||
@@ -369,12 +330,12 @@ Step Reward Reward_Std KL
|
||||
400 1.5 0.15 0.12
|
||||
```
|
||||
|
||||
**Warning Signs:**
|
||||
- Reward std → 0 (model collapsing to single response)
|
||||
- KL exploding (> 0.5) (diverging too much, reduce LR)
|
||||
- Reward stuck (reward functions too harsh or model capacity issue)
|
||||
**Warning signs:**
|
||||
- `reward_std` → 0 (model collapsing to a single response)
|
||||
- `kl` exploding (> 0.5) — diverging too much, reduce LR
|
||||
- Reward stuck — reward functions too harsh or model capacity issue
|
||||
|
||||
### 3. Common Pitfalls and Solutions
|
||||
### 3. Common pitfalls and solutions
|
||||
|
||||
| Problem | Symptom | Solution |
|
||||
|---------|---------|----------|
|
||||
@@ -384,15 +345,14 @@ Step Reward Reward_Std KL
|
||||
| **Slow training** | < 1 it/s | Enable `use_vllm=True`, use Unsloth, reduce seq length |
|
||||
| **Format ignored** | Model doesn't follow structure | Increase format reward weight, add incremental rewards |
|
||||
|
||||
---
|
||||
## Advanced patterns
|
||||
|
||||
## Advanced Patterns
|
||||
### 1. Multi-stage training
|
||||
|
||||
### 1. Multi-Stage Training
|
||||
For complex tasks, train in stages:
|
||||
|
||||
```python
|
||||
# Stage 1: Format compliance (epochs=1)
|
||||
# Stage 1: Format compliance
|
||||
trainer_stage1 = GRPOTrainer(
|
||||
model=model,
|
||||
reward_funcs=[incremental_format_reward, format_reward],
|
||||
@@ -400,7 +360,7 @@ trainer_stage1 = GRPOTrainer(
|
||||
)
|
||||
trainer_stage1.train()
|
||||
|
||||
# Stage 2: Correctness (epochs=1)
|
||||
# Stage 2: Correctness
|
||||
trainer_stage2 = GRPOTrainer(
|
||||
model=model,
|
||||
reward_funcs=[format_reward, correctness_reward],
|
||||
@@ -409,7 +369,8 @@ trainer_stage2 = GRPOTrainer(
|
||||
trainer_stage2.train()
|
||||
```
|
||||
|
||||
### 2. Adaptive Reward Scaling
|
||||
### 2. Adaptive reward scaling
|
||||
|
||||
```python
|
||||
class AdaptiveReward:
|
||||
def __init__(self, base_reward_func, initial_weight=1.0):
|
||||
@@ -428,148 +389,116 @@ class AdaptiveReward:
|
||||
self.weight *= 0.9
|
||||
```
|
||||
|
||||
### 3. Custom Dataset Integration
|
||||
### 3. Custom dataset integration
|
||||
|
||||
```python
|
||||
def load_custom_knowledge_base(csv_path):
|
||||
"""Example: School communication platform docs."""
|
||||
import pandas as pd
|
||||
df = pd.read_csv(csv_path)
|
||||
|
||||
dataset = Dataset.from_pandas(df).map(lambda x: {
|
||||
return Dataset.from_pandas(df).map(lambda x: {
|
||||
'prompt': [
|
||||
{'role': 'system', 'content': CUSTOM_SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': x['question']}
|
||||
],
|
||||
'answer': x['expert_answer']
|
||||
})
|
||||
return dataset
|
||||
```
|
||||
|
||||
---
|
||||
## Deployment and inference
|
||||
|
||||
## Deployment and Inference
|
||||
|
||||
### Save and Merge LoRA
|
||||
### Save and merge LoRA
|
||||
```python
|
||||
# Merge LoRA adapters into base model
|
||||
if hasattr(trainer.model, 'merge_and_unload'):
|
||||
merged_model = trainer.model.merge_and_unload()
|
||||
merged_model.save_pretrained("production_model")
|
||||
tokenizer.save_pretrained("production_model")
|
||||
```
|
||||
|
||||
### Inference Example
|
||||
### Inference
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
generator = pipeline(
|
||||
"text-generation",
|
||||
model="production_model",
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
generator = pipeline("text-generation", model="production_model", tokenizer=tokenizer)
|
||||
|
||||
result = generator(
|
||||
[
|
||||
{'role': 'system', 'content': SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': "What is 15 + 27?"}
|
||||
{'role': 'user', 'content': "What is 15 + 27?"},
|
||||
],
|
||||
max_new_tokens=256,
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
top_p=0.9
|
||||
top_p=0.9,
|
||||
)
|
||||
print(result[0]['generated_text'])
|
||||
```
|
||||
|
||||
---
|
||||
## Best practices checklist
|
||||
|
||||
## Best Practices Checklist
|
||||
|
||||
**Before Training:**
|
||||
**Before training:**
|
||||
- [ ] Validate dataset format (prompts as List[Dict])
|
||||
- [ ] Test reward functions on sample data
|
||||
- [ ] Calculate expected max_prompt_length from data
|
||||
- [ ] Choose appropriate num_generations based on GPU memory
|
||||
- [ ] Calculate expected `max_prompt_length` from data
|
||||
- [ ] Choose `num_generations` based on GPU memory
|
||||
- [ ] Set up logging (wandb recommended)
|
||||
|
||||
**During Training:**
|
||||
**During training:**
|
||||
- [ ] Monitor reward progression (should increase)
|
||||
- [ ] Check reward_std (should stay > 0.1)
|
||||
- [ ] Check `reward_std` (should stay > 0.1)
|
||||
- [ ] Watch for OOM errors (reduce batch size if needed)
|
||||
- [ ] Sample generations every 50-100 steps
|
||||
- [ ] Sample generations every 50–100 steps
|
||||
- [ ] Validate format compliance on holdout set
|
||||
|
||||
**After Training:**
|
||||
**After training:**
|
||||
- [ ] Merge LoRA weights if using PEFT
|
||||
- [ ] Test on diverse prompts
|
||||
- [ ] Compare to baseline model
|
||||
- [ ] Document reward weights and hyperparameters
|
||||
- [ ] Save reproducibility config
|
||||
|
||||
---
|
||||
## Troubleshooting
|
||||
|
||||
## Troubleshooting Guide
|
||||
### Debugging workflow
|
||||
1. **Isolate reward functions** — test each independently
|
||||
2. **Check data distribution** — ensure diversity in prompts
|
||||
3. **Reduce complexity** — start with single reward, add gradually
|
||||
4. **Monitor generations** — print samples every N steps
|
||||
5. **Validate extraction logic** — ensure answer parsing works
|
||||
|
||||
### Debugging Workflow
|
||||
1. **Isolate reward functions** - Test each independently
|
||||
2. **Check data distribution** - Ensure diversity in prompts
|
||||
3. **Reduce complexity** - Start with single reward, add gradually
|
||||
4. **Monitor generations** - Print samples every N steps
|
||||
5. **Validate extraction logic** - Ensure answer parsing works
|
||||
|
||||
### Quick Fixes
|
||||
### Quick debug reward
|
||||
```python
|
||||
# Debug reward function
|
||||
def debug_reward(completions, **kwargs):
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
for i, r in enumerate(responses[:2]): # Print first 2
|
||||
for i, r in enumerate(responses[:2]):
|
||||
print(f"Response {i}: {r[:200]}...")
|
||||
return [1.0] * len(responses) # Dummy rewards
|
||||
return [1.0] * len(responses)
|
||||
|
||||
# Test without training
|
||||
trainer = GRPOTrainer(..., reward_funcs=[debug_reward])
|
||||
trainer.generate_completions(dataset[:1]) # Generate without updating
|
||||
trainer.generate_completions(dataset[:1])
|
||||
```
|
||||
|
||||
---
|
||||
## Template
|
||||
|
||||
## References and Resources
|
||||
A production-ready training script lives at **`../templates/basic_grpo_training.py`**. It uses Qwen 2.5-1.5B-Instruct with LoRA and three reward functions (incremental format, strict format, correctness) on GSM8K. Copy and adapt:
|
||||
1. `get_dataset()` — swap in your data loader
|
||||
2. Reward functions — tune to your task
|
||||
3. `SYSTEM_PROMPT` — match your output format
|
||||
4. `GRPOConfig` — adjust hyperparameters for your GPU
|
||||
|
||||
## References and resources
|
||||
|
||||
**Official Documentation:**
|
||||
- TRL GRPO Trainer: https://huggingface.co/docs/trl/grpo_trainer
|
||||
- DeepSeek R1 Paper: https://arxiv.org/abs/2501.12948
|
||||
- Unsloth Docs: https://docs.unsloth.ai/
|
||||
|
||||
**Example Repositories:**
|
||||
- Open R1 Implementation: https://github.com/huggingface/open-r1
|
||||
- TRL Examples: https://github.com/huggingface/trl/tree/main/examples
|
||||
|
||||
**Recommended Reading:**
|
||||
- Progressive Disclosure Pattern for agent instructions
|
||||
- Reward shaping in RL (Ng et al.)
|
||||
- LoRA paper (Hu et al., 2021)
|
||||
|
||||
---
|
||||
|
||||
## Usage Instructions for Agents
|
||||
|
||||
When this skill is loaded:
|
||||
|
||||
1. **Read this entire file** before implementing GRPO training
|
||||
2. **Start with the simplest reward function** (e.g., length-based) to validate setup
|
||||
3. **Use the templates** in `templates/` directory as starting points
|
||||
4. **Reference examples** in `examples/` for task-specific implementations
|
||||
5. **Follow the workflow** sequentially (don't skip steps)
|
||||
6. **Debug incrementally** - add one reward function at a time
|
||||
|
||||
**Critical Reminders:**
|
||||
- Always use multiple reward functions (3-5 is optimal)
|
||||
- Monitor reward metrics, not loss
|
||||
- Test reward functions before training
|
||||
- Start small (num_generations=4), scale up gradually
|
||||
- Save checkpoints frequently (every 100 steps)
|
||||
|
||||
This skill is designed for **expert-level implementation**. Beginners should start with supervised fine-tuning before attempting GRPO.
|
||||
|
||||
- GRPO paper (DeepSeek): https://arxiv.org/abs/2402.03300
|
||||
- DeepSeek R1 paper: https://arxiv.org/abs/2501.12948
|
||||
- Open R1 implementation: https://github.com/huggingface/open-r1
|
||||
- TRL examples: https://github.com/huggingface/trl/tree/main/examples
|
||||
- Unsloth (faster training): https://docs.unsloth.ai/
|
||||
|
||||
## Critical reminders
|
||||
|
||||
- **Loss goes UP during training** — this is normal (it's KL divergence)
|
||||
- **Use 3–5 reward functions** — single rewards often fail
|
||||
- **Test rewards before training** — debug each function independently
|
||||
- **Monitor `reward_std`** — should stay > 0.1 (avoid mode collapse)
|
||||
- **Start with `num_generations=4–8`** — scale up if GPU allows
|
||||
+59
-12
@@ -42,9 +42,10 @@ class TestToolProgressCallback:
|
||||
def test_emits_tool_call_start(self, mock_conn, event_loop_fixture):
|
||||
"""Tool progress should emit a ToolCallStart update."""
|
||||
tool_call_ids = {}
|
||||
tool_call_meta = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
|
||||
# Run callback in the event loop context
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
@@ -66,9 +67,10 @@ class TestToolProgressCallback:
|
||||
def test_handles_string_args(self, mock_conn, event_loop_fixture):
|
||||
"""If args is a JSON string, it should be parsed."""
|
||||
tool_call_ids = {}
|
||||
tool_call_meta = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
@@ -82,9 +84,10 @@ class TestToolProgressCallback:
|
||||
def test_handles_non_dict_args(self, mock_conn, event_loop_fixture):
|
||||
"""If args is not a dict, it should be wrapped."""
|
||||
tool_call_ids = {}
|
||||
tool_call_meta = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
@@ -98,10 +101,11 @@ class TestToolProgressCallback:
|
||||
def test_duplicate_same_name_tool_calls_use_fifo_ids(self, mock_conn, event_loop_fixture):
|
||||
"""Multiple same-name tool calls should be tracked independently in order."""
|
||||
tool_call_ids = {}
|
||||
tool_call_meta = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
progress_cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
step_cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
progress_cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
step_cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
@@ -163,7 +167,7 @@ class TestStepCallback:
|
||||
tool_call_ids = {"terminal": "tc-abc123"}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
@@ -181,7 +185,7 @@ class TestStepCallback:
|
||||
tool_call_ids = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
cb(1, [{"name": "unknown_tool", "result": "ok"}])
|
||||
@@ -193,7 +197,7 @@ class TestStepCallback:
|
||||
tool_call_ids = {"read_file": "tc-def456"}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
@@ -212,7 +216,7 @@ class TestStepCallback:
|
||||
tool_call_ids = {"terminal": deque(["tc-xyz789"])}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \
|
||||
patch("acp_adapter.events.build_tool_complete") as mock_btc:
|
||||
@@ -224,7 +228,7 @@ class TestStepCallback:
|
||||
cb(1, [{"name": "terminal", "result": '{"output": "hello"}'}])
|
||||
|
||||
mock_btc.assert_called_once_with(
|
||||
"tc-xyz789", "terminal", result='{"output": "hello"}'
|
||||
"tc-xyz789", "terminal", result='{"output": "hello"}', function_args=None, snapshot=None
|
||||
)
|
||||
|
||||
def test_none_result_passed_through(self, mock_conn, event_loop_fixture):
|
||||
@@ -234,7 +238,7 @@ class TestStepCallback:
|
||||
tool_call_ids = {"web_search": deque(["tc-aaa"])}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \
|
||||
patch("acp_adapter.events.build_tool_complete") as mock_btc:
|
||||
@@ -244,7 +248,50 @@ class TestStepCallback:
|
||||
|
||||
cb(1, [{"name": "web_search", "result": None}])
|
||||
|
||||
mock_btc.assert_called_once_with("tc-aaa", "web_search", result=None)
|
||||
mock_btc.assert_called_once_with("tc-aaa", "web_search", result=None, function_args=None, snapshot=None)
|
||||
|
||||
def test_step_callback_passes_arguments_and_snapshot(self, mock_conn, event_loop_fixture):
|
||||
from collections import deque
|
||||
|
||||
tool_call_ids = {"write_file": deque(["tc-write"])}
|
||||
tool_call_meta = {"tc-write": {"args": {"path": "fallback.txt"}, "snapshot": "snap"}}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \
|
||||
patch("acp_adapter.events.build_tool_complete") as mock_btc:
|
||||
future = MagicMock(spec=Future)
|
||||
future.result.return_value = None
|
||||
mock_rcts.return_value = future
|
||||
|
||||
cb(1, [{"name": "write_file", "result": '{"bytes_written": 23}', "arguments": {"path": "diff-test.txt"}}])
|
||||
|
||||
mock_btc.assert_called_once_with(
|
||||
"tc-write",
|
||||
"write_file",
|
||||
result='{"bytes_written": 23}',
|
||||
function_args={"path": "diff-test.txt"},
|
||||
snapshot="snap",
|
||||
)
|
||||
|
||||
def test_tool_progress_captures_snapshot_metadata(self, mock_conn, event_loop_fixture):
|
||||
tool_call_ids = {}
|
||||
tool_call_meta = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
with patch("acp_adapter.events.make_tool_call_id", return_value="tc-meta"), \
|
||||
patch("acp_adapter.events._send_update") as mock_send, \
|
||||
patch("agent.display.capture_local_edit_snapshot", return_value="snapshot"):
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
cb("tool.started", "write_file", None, {"path": "diff-test.txt", "content": "hello"})
|
||||
|
||||
assert list(tool_call_ids["write_file"]) == ["tc-meta"]
|
||||
assert tool_call_meta["tc-meta"] == {
|
||||
"args": {"path": "diff-test.txt", "content": "hello"},
|
||||
"snapshot": "snapshot",
|
||||
}
|
||||
mock_send.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -29,6 +29,7 @@ from acp.schema import (
|
||||
|
||||
from acp_adapter.server import HermesACPAgent
|
||||
from acp_adapter.session import SessionManager
|
||||
from acp_adapter.tools import build_tool_start
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -181,6 +182,25 @@ class TestMcpRegistrationE2E:
|
||||
assert complete_event.raw_output is not None
|
||||
assert "hello" in str(complete_event.raw_output)
|
||||
|
||||
def test_patch_mode_tool_start_emits_diff_blocks_for_v4a_patch(self):
|
||||
update = build_tool_start(
|
||||
"tc-1",
|
||||
"patch",
|
||||
{
|
||||
"mode": "patch",
|
||||
"patch": "*** Begin Patch\n*** Update File: src/app.py\n@@\n-old line\n+new line\n*** Add File: src/new.py\n+hello\n*** End Patch",
|
||||
},
|
||||
)
|
||||
|
||||
assert len(update.content) == 2
|
||||
assert update.content[0].type == "diff"
|
||||
assert update.content[0].path == "src/app.py"
|
||||
assert update.content[0].old_text == "old line"
|
||||
assert update.content[0].new_text == "new line"
|
||||
assert update.content[1].type == "diff"
|
||||
assert update.content[1].path == "src/new.py"
|
||||
assert update.content[1].new_text == "hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_tool_results_paired_by_call_id(self, acp_agent, mock_manager):
|
||||
"""The ToolCallUpdate's toolCallId must match the ToolCallStart's."""
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user