Compare commits
360 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b4d4fee6fe | |||
| a1f9961f51 | |||
| 3741ee08d2 | |||
| 9cd3050a08 | |||
| 4670f66a33 | |||
| c9479c6c6f | |||
| 5a5d7ec2a2 | |||
| 87995cd9c5 | |||
| 8fd8def544 | |||
| 1d6a92103a | |||
| a692859ddb | |||
| 5127567d5d | |||
| cc4514076b | |||
| 8ecd7aed2c | |||
| e0dbbdb2c9 | |||
| eb2127c1dc | |||
| 5a1e2a307a | |||
| 41d9d08078 | |||
| b7bcae49c6 | |||
| 915df02bbf | |||
| 75fcbc44ce | |||
| be416cdfa9 | |||
| b8b1f24fd7 | |||
| a2847ea7f0 | |||
| 58ca875e19 | |||
| 3f95e741a7 | |||
| 03396627a6 | |||
| 22cfad157b | |||
| 867eefdd9f | |||
| a8df7f9964 | |||
| 1519c4d477 | |||
| 005786c55d | |||
| ad764d3513 | |||
| f008ee1019 | |||
| 60fdb58ce4 | |||
| 18d28c63a7 | |||
| 3c57eaf744 | |||
| 2d232c9991 | |||
| 0375b2a0d7 | |||
| 08fa326bb0 | |||
| bde45f5a2a | |||
| 716e616d28 | |||
| bdccdd67a1 | |||
| 148f46620f | |||
| 6610c377ba | |||
| e5d14445ef | |||
| 72250b5f62 | |||
| 243ee67529 | |||
| 3a86328847 | |||
| db241ae6ce | |||
| 41ee207a5e | |||
| e9e7fb0683 | |||
| 76ed15dd4d | |||
| a8e02c7d49 | |||
| b81d49dc45 | |||
| 3a7907b278 | |||
| b7b3294c4a | |||
| 62f8aa9b03 | |||
| 2c719f0701 | |||
| c6fe75e99b | |||
| 36af1f3baf | |||
| 43af094ae3 | |||
| 9989e579da | |||
| 4a56e2cd88 | |||
| 26bfdc22b4 | |||
| 0426bb745f | |||
| c511e087e0 | |||
| c07c17f5f2 | |||
| cbf195e806 | |||
| 08d3be0412 | |||
| 156b50358b | |||
| 59575d6a91 | |||
| f46542b6c6 | |||
| 5b29ff50f8 | |||
| 7258311710 | |||
| 910ec7eb38 | |||
| 4b45f65858 | |||
| b374f52063 | |||
| bd43a43f07 | |||
| 432ba3b709 | |||
| 712cebc40f | |||
| 45f57c2012 | |||
| 41081d718c | |||
| 281100e2df | |||
| 0d7f739675 | |||
| 9783c9d5c1 | |||
| 0cfc1f88a3 | |||
| 3bc953a666 | |||
| bd6b138e85 | |||
| 9792bde31a | |||
| 9d1e13019e | |||
| 37cabc47d3 | |||
| f7f30aaab9 | |||
| d218cf9118 | |||
| 841401f588 | |||
| 77bcaba2d7 | |||
| e0cfc089da | |||
| 7126524e8d | |||
| f83c27e26f | |||
| ab548a9b5e | |||
| 73e66eb3c0 | |||
| 14cf2d85ca | |||
| 8bb1d15da4 | |||
| 861624d4e9 | |||
| e4033b2baf | |||
| 94e3d9adbf | |||
| 0dcd6ab2f2 | |||
| b6461903ff | |||
| 8f6ef042c1 | |||
| 099dfca6db | |||
| 68ab37e891 | |||
| 65dace1b1a | |||
| 650b400c98 | |||
| 61949f0af7 | |||
| 52c5e491f5 | |||
| f665351740 | |||
| fba73a60e3 | |||
| 114e636b7d | |||
| 20cc1731f4 | |||
| b2a6b012fe | |||
| 42fec19151 | |||
| 5dbe2d9d73 | |||
| c6f4515f73 | |||
| fd292e676b | |||
| e5691eed38 | |||
| ab4ba8163a | |||
| 80cc27eb9d | |||
| 1b24a226ea | |||
| 9b32f846a8 | |||
| 7ca22ea11b | |||
| ef47531617 | |||
| b36fe9282a | |||
| 1e9ff53a74 | |||
| 27c023e071 | |||
| 9231a335d4 | |||
| 7efaa5968d | |||
| 8ee4f32819 | |||
| 689344430c | |||
| 618f15dda9 | |||
| 481915587e | |||
| 0b993c1e07 | |||
| 9718334962 | |||
| ebcb81b649 | |||
| ac5b8a478a | |||
| 624e4a8e7a | |||
| 177e43259f | |||
| c9b76057d4 | |||
| 745859babb | |||
| ad1bf16f28 | |||
| e2c81c6e2f | |||
| 677b11d84c | |||
| ee3f3e756d | |||
| 02b38b93cb | |||
| 2233f764af | |||
| 98b5570961 | |||
| 773d3bb4df | |||
| a312ee7b4c | |||
| 2e524272b1 | |||
| ce39f9cc44 | |||
| 18cbd18fa9 | |||
| b641ee88f4 | |||
| 2f1c4fb01f | |||
| 4313b8aff6 | |||
| 87e2626cf6 | |||
| 1345e93393 | |||
| 6e97a3b338 | |||
| 8416bc2142 | |||
| 48b5bc6038 | |||
| 4ff73fb32c | |||
| 73a88a02fe | |||
| f9c2565ab4 | |||
| ad5f973a8d | |||
| 0791efe2c3 | |||
| 934fbe3c06 | |||
| 6302e56e7c | |||
| 868b3c07e3 | |||
| 9d6148316c | |||
| 7da0822456 | |||
| d35df0db71 | |||
| 93dc5dee6f | |||
| 2d8fad8230 | |||
| ca2958ff98 | |||
| f60ebc7bf2 | |||
| b072737193 | |||
| 3b509da571 | |||
| 5ddb6a191f | |||
| 1b5fb36c9d | |||
| 942f6eac94 | |||
| 2b3c1d81f0 | |||
| 1f21ef7488 | |||
| b799bca7a3 | |||
| b2b4a9ee7d | |||
| ed805f57ff | |||
| e93b539a8f | |||
| fa6f069577 | |||
| cd2280d1a3 | |||
| 5e5ad634a1 | |||
| 55a27a3fb8 | |||
| 8587cddd6c | |||
| 2bd8e5cb23 | |||
| bfe4baa6ed | |||
| 72a6d7dffe | |||
| afe2f0abe1 | |||
| 09fd007c6e | |||
| 24cf2a7954 | |||
| be3eb62047 | |||
| 9c32fed184 | |||
| 6435d69a6d | |||
| a2276177a3 | |||
| ebd0291ef2 | |||
| 0510ee056d | |||
| 44b572a9e0 | |||
| f9c2ad48c2 | |||
| c275aa4732 | |||
| ff071fc74c | |||
| 8d528e0045 | |||
| fd32e3d6e8 | |||
| 34be3f8be6 | |||
| 3037450c77 | |||
| b7091f93b1 | |||
| ab3cbfc99d | |||
| 26030266d2 | |||
| edda0e324b | |||
| 5407d12bc6 | |||
| 2de42ba690 | |||
| f3301a31d5 | |||
| e6a708aa04 | |||
| e80489135b | |||
| a53db44d40 | |||
| 0698ddb496 | |||
| 0962cbb2e5 | |||
| f69c47d9ae | |||
| 027fc1a85a | |||
| f84230527c | |||
| 0e64a48743 | |||
| ffa8b562e9 | |||
| 56b0104154 | |||
| c0c13e4ed4 | |||
| 89befcaf33 | |||
| 0f1c970179 | |||
| 57d3ac0c0b | |||
| a9f9c60efd | |||
| e109a8b502 | |||
| b81926def6 | |||
| 8cb7864110 | |||
| 7cd9f9ed48 | |||
| 2c2334d4db | |||
| 21ffadc2a6 | |||
| 241f966b1a | |||
| 7d0e4510b8 | |||
| 306e67f32d | |||
| 5c8d7d5d6f | |||
| 0b370f2dd9 | |||
| 887e8a8d84 | |||
| 189214a69d | |||
| cd6d24f111 | |||
| c01cfe4f9a | |||
| fbbe9e6030 | |||
| 43bca6d107 | |||
| 669c60a6bb | |||
| dd39003a9b | |||
| 4bded44b6a | |||
| ec22635b47 | |||
| 29d0541ac9 | |||
| a0f411c87d | |||
| 862d5224dd | |||
| e664bc7632 | |||
| f9052d7ecf | |||
| 7dff34ba4e | |||
| dbc25a386e | |||
| 0ea7d0ec80 | |||
| 1d28b4699b | |||
| e0ca46cd73 | |||
| 5454a55269 | |||
| 40c9a13476 | |||
| bd49bce278 | |||
| 52dd479214 | |||
| c57d5cbdde | |||
| 525caadd8c | |||
| f9fa7421cb | |||
| 342096b4bd | |||
| 55510cbad2 | |||
| 3ab50376b0 | |||
| f8fb61d4ad | |||
| 0d68446323 | |||
| 81dbf4309a | |||
| febfe1c268 | |||
| 2a5f86ed6d | |||
| d3659c8ca0 | |||
| f7f75de7c3 | |||
| f58902818d | |||
| 8da410ed95 | |||
| da44c196b6 | |||
| 36079c6646 | |||
| 135448f513 | |||
| 2e143fd15c | |||
| 0b9526b476 | |||
| f304bc63b8 | |||
| decc7851f2 | |||
| 97108db038 | |||
| 1f1fa71d0c | |||
| 2988334fe5 | |||
| 292d12bed4 | |||
| 509cff6e5c | |||
| 29520df44f | |||
| 9be42e49f9 | |||
| 42cef9c282 | |||
| 3a71099dac | |||
| 356122e990 | |||
| aefcdd6f7f | |||
| 3835a8d5df | |||
| e8188a56c7 | |||
| c42a18e9e5 | |||
| b73d221324 | |||
| cc51ffdb57 | |||
| c8971db435 | |||
| c4e787d47b | |||
| fb48b8f0c5 | |||
| 67600d0a0b | |||
| 5a9ab09bc3 | |||
| 2c06ec5f51 | |||
| d70e07fc45 | |||
| fff7203049 | |||
| 5663980015 | |||
| 8304a7716d | |||
| 523d8c38f9 | |||
| e6299960cc | |||
| fb6d41237c | |||
| e183744cb5 | |||
| 07112e4e98 | |||
| bc15f6cca3 | |||
| 3921fb973c | |||
| 6408b4ad53 | |||
| 326b146d68 | |||
| 1830db0476 | |||
| 3ba6043c62 | |||
| f4a74d3ac7 | |||
| e75f58420c | |||
| 28bb0e770f | |||
| 06f4df52f1 | |||
| a03cbcd5f9 | |||
| df67ae730b | |||
| 9305164bf3 | |||
| 453f4c5175 | |||
| 37a9979459 | |||
| 713f2f73da | |||
| 237499d102 | |||
| 3f811f52fd | |||
| 2ea8054304 | |||
| 488a30e879 | |||
| bc3f425212 | |||
| fd1d6c03cb | |||
| 58b52dfb2f | |||
| 651e92fbbf | |||
| 779619f742 | |||
| 96a5e9fc11 | |||
| eb537b5db4 | |||
| 2da79b13df | |||
| 870ebb8850 | |||
| 10d719ac1b |
@@ -0,0 +1,40 @@
|
||||
name: Nix
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
paths:
|
||||
- 'flake.nix'
|
||||
- 'flake.lock'
|
||||
- 'nix/**'
|
||||
- 'pyproject.toml'
|
||||
- 'uv.lock'
|
||||
- 'hermes_cli/**'
|
||||
- 'run_agent.py'
|
||||
- 'acp_adapter/**'
|
||||
|
||||
concurrency:
|
||||
group: nix-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
nix:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest]
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: DeterminateSystems/nix-installer-action@main
|
||||
- uses: DeterminateSystems/magic-nix-cache-action@main
|
||||
- name: Check flake
|
||||
if: runner.os == 'Linux'
|
||||
run: nix flake check --print-build-logs
|
||||
- name: Build package
|
||||
if: runner.os == 'Linux'
|
||||
run: nix build --print-build-logs
|
||||
- name: Evaluate flake (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
run: nix flake show --json > /dev/null
|
||||
@@ -0,0 +1,192 @@
|
||||
name: Supply Chain Audit
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
|
||||
permissions:
|
||||
pull-requests: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
scan:
|
||||
name: Scan PR for supply chain risks
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Scan diff for suspicious patterns
|
||||
id: scan
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
BASE="${{ github.event.pull_request.base.sha }}"
|
||||
HEAD="${{ github.event.pull_request.head.sha }}"
|
||||
|
||||
# Get the full diff (added lines only)
|
||||
DIFF=$(git diff "$BASE".."$HEAD" -- . ':!uv.lock' ':!*.lock' ':!package-lock.json' ':!yarn.lock' || true)
|
||||
|
||||
FINDINGS=""
|
||||
CRITICAL=false
|
||||
|
||||
# --- .pth files (auto-execute on Python startup) ---
|
||||
PTH_FILES=$(git diff --name-only "$BASE".."$HEAD" | grep '\.pth$' || true)
|
||||
if [ -n "$PTH_FILES" ]; then
|
||||
CRITICAL=true
|
||||
FINDINGS="${FINDINGS}
|
||||
### 🚨 CRITICAL: .pth file added or modified
|
||||
Python \`.pth\` files in \`site-packages/\` execute automatically when the interpreter starts — no import required. This is the exact mechanism used in the [litellm supply chain attack](https://github.com/BerriAI/litellm/issues/24512).
|
||||
|
||||
**Files:**
|
||||
\`\`\`
|
||||
${PTH_FILES}
|
||||
\`\`\`
|
||||
"
|
||||
fi
|
||||
|
||||
# --- base64 + exec/eval combo (the litellm attack pattern) ---
|
||||
B64_EXEC_HITS=$(echo "$DIFF" | grep -n '^\+' | grep -iE 'base64\.(b64decode|decodebytes|urlsafe_b64decode)' | grep -iE 'exec\(|eval\(' | head -10 || true)
|
||||
if [ -n "$B64_EXEC_HITS" ]; then
|
||||
CRITICAL=true
|
||||
FINDINGS="${FINDINGS}
|
||||
### 🚨 CRITICAL: base64 decode + exec/eval combo
|
||||
This is the exact pattern used in the [litellm supply chain attack](https://github.com/BerriAI/litellm/issues/24512) — base64-decoded strings passed to exec/eval to hide credential-stealing payloads.
|
||||
|
||||
**Matches:**
|
||||
\`\`\`
|
||||
${B64_EXEC_HITS}
|
||||
\`\`\`
|
||||
"
|
||||
fi
|
||||
|
||||
# --- base64 decode/encode (alone — legitimate uses exist) ---
|
||||
B64_HITS=$(echo "$DIFF" | grep -n '^\+' | grep -iE 'base64\.(b64decode|b64encode|decodebytes|encodebytes|urlsafe_b64decode)|atob\(|btoa\(|Buffer\.from\(.*base64' | head -20 || true)
|
||||
if [ -n "$B64_HITS" ]; then
|
||||
FINDINGS="${FINDINGS}
|
||||
### ⚠️ WARNING: base64 encoding/decoding detected
|
||||
Base64 has legitimate uses (images, JWT, etc.) but is also commonly used to obfuscate malicious payloads. Verify the usage is appropriate.
|
||||
|
||||
**Matches (first 20):**
|
||||
\`\`\`
|
||||
${B64_HITS}
|
||||
\`\`\`
|
||||
"
|
||||
fi
|
||||
|
||||
# --- exec/eval with string arguments ---
|
||||
EXEC_HITS=$(echo "$DIFF" | grep -n '^\+' | grep -E '(exec|eval)\s*\(' | grep -v '^\+\s*#' | grep -v 'test_\|mock\|assert\|# ' | head -20 || true)
|
||||
if [ -n "$EXEC_HITS" ]; then
|
||||
FINDINGS="${FINDINGS}
|
||||
### ⚠️ WARNING: exec() or eval() usage
|
||||
Dynamic code execution can hide malicious behavior, especially when combined with base64 or network fetches.
|
||||
|
||||
**Matches (first 20):**
|
||||
\`\`\`
|
||||
${EXEC_HITS}
|
||||
\`\`\`
|
||||
"
|
||||
fi
|
||||
|
||||
# --- subprocess with encoded/obfuscated commands ---
|
||||
PROC_HITS=$(echo "$DIFF" | grep -n '^\+' | grep -E 'subprocess\.(Popen|call|run)\s*\(' | grep -iE 'base64|decode|encode|\\x|chr\(' | head -10 || true)
|
||||
if [ -n "$PROC_HITS" ]; then
|
||||
CRITICAL=true
|
||||
FINDINGS="${FINDINGS}
|
||||
### 🚨 CRITICAL: subprocess with encoded/obfuscated command
|
||||
Subprocess calls with encoded arguments are a strong indicator of payload execution.
|
||||
|
||||
**Matches:**
|
||||
\`\`\`
|
||||
${PROC_HITS}
|
||||
\`\`\`
|
||||
"
|
||||
fi
|
||||
|
||||
# --- Network calls to non-standard domains ---
|
||||
EXFIL_HITS=$(echo "$DIFF" | grep -n '^\+' | grep -iE 'requests\.(post|put)\(|httpx\.(post|put)\(|urllib\.request\.urlopen' | grep -v '^\+\s*#' | grep -v 'test_\|mock\|assert' | head -10 || true)
|
||||
if [ -n "$EXFIL_HITS" ]; then
|
||||
FINDINGS="${FINDINGS}
|
||||
### ⚠️ WARNING: Outbound network calls (POST/PUT)
|
||||
Outbound POST/PUT requests in new code could be data exfiltration. Verify the destination URLs are legitimate.
|
||||
|
||||
**Matches (first 10):**
|
||||
\`\`\`
|
||||
${EXFIL_HITS}
|
||||
\`\`\`
|
||||
"
|
||||
fi
|
||||
|
||||
# --- setup.py / setup.cfg install hooks ---
|
||||
SETUP_HITS=$(git diff --name-only "$BASE".."$HEAD" | grep -E '(setup\.py|setup\.cfg|__init__\.pth|sitecustomize\.py|usercustomize\.py)$' || true)
|
||||
if [ -n "$SETUP_HITS" ]; then
|
||||
FINDINGS="${FINDINGS}
|
||||
### ⚠️ WARNING: Install hook files modified
|
||||
These files can execute code during package installation or interpreter startup.
|
||||
|
||||
**Files:**
|
||||
\`\`\`
|
||||
${SETUP_HITS}
|
||||
\`\`\`
|
||||
"
|
||||
fi
|
||||
|
||||
# --- Compile/marshal/pickle (code object injection) ---
|
||||
MARSHAL_HITS=$(echo "$DIFF" | grep -n '^\+' | grep -iE 'marshal\.loads|pickle\.loads|compile\(' | grep -v '^\+\s*#' | grep -v 'test_\|re\.compile\|ast\.compile' | head -10 || true)
|
||||
if [ -n "$MARSHAL_HITS" ]; then
|
||||
FINDINGS="${FINDINGS}
|
||||
### ⚠️ WARNING: marshal/pickle/compile usage
|
||||
These can deserialize or construct executable code objects.
|
||||
|
||||
**Matches:**
|
||||
\`\`\`
|
||||
${MARSHAL_HITS}
|
||||
\`\`\`
|
||||
"
|
||||
fi
|
||||
|
||||
# --- Output results ---
|
||||
if [ -n "$FINDINGS" ]; then
|
||||
echo "found=true" >> "$GITHUB_OUTPUT"
|
||||
if [ "$CRITICAL" = true ]; then
|
||||
echo "critical=true" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "critical=false" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
# Write findings to a file (multiline env vars are fragile)
|
||||
echo "$FINDINGS" > /tmp/findings.md
|
||||
else
|
||||
echo "found=false" >> "$GITHUB_OUTPUT"
|
||||
echo "critical=false" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: Post warning comment
|
||||
if: steps.scan.outputs.found == 'true'
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
SEVERITY="⚠️ Supply Chain Risk Detected"
|
||||
if [ "${{ steps.scan.outputs.critical }}" = "true" ]; then
|
||||
SEVERITY="🚨 CRITICAL Supply Chain Risk Detected"
|
||||
fi
|
||||
|
||||
BODY="## ${SEVERITY}
|
||||
|
||||
This PR contains patterns commonly associated with supply chain attacks. This does **not** mean the PR is malicious — but these patterns require careful human review before merging.
|
||||
|
||||
$(cat /tmp/findings.md)
|
||||
|
||||
---
|
||||
*Automated scan triggered by [supply-chain-audit](/.github/workflows/supply-chain-audit.yml). If this is a false positive, a maintainer can approve after manual review.*"
|
||||
|
||||
gh pr comment "${{ github.event.pull_request.number }}" --body "$BODY"
|
||||
|
||||
- name: Fail on critical findings
|
||||
if: steps.scan.outputs.critical == 'true'
|
||||
run: |
|
||||
echo "::error::CRITICAL supply chain risk patterns detected in this PR. See the PR comment for details."
|
||||
exit 1
|
||||
@@ -53,3 +53,8 @@ environments/benchmarks/evals/
|
||||
|
||||
# Release script temp files
|
||||
.release_notes.md
|
||||
mini-swe-agent/
|
||||
|
||||
# Nix
|
||||
.direnv/
|
||||
result
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
[submodule "mini-swe-agent"]
|
||||
path = mini-swe-agent
|
||||
url = https://github.com/SWE-agent/mini-swe-agent
|
||||
[submodule "tinker-atropos"]
|
||||
path = tinker-atropos
|
||||
url = https://github.com/nousresearch/tinker-atropos
|
||||
|
||||
@@ -38,6 +38,7 @@ hermes-agent/
|
||||
│ ├── tools_config.py # `hermes tools` — enable/disable tools per platform
|
||||
│ ├── skills_hub.py # `/skills` slash command (search, browse, install)
|
||||
│ ├── models.py # Model catalog, provider model lists
|
||||
│ ├── model_switch.py # Shared /model switch pipeline (CLI + gateway)
|
||||
│ └── auth.py # Provider credential resolution
|
||||
├── tools/ # Tool implementations (one file per tool)
|
||||
│ ├── registry.py # Central tool registry (schemas, handlers, dispatch)
|
||||
@@ -172,6 +173,7 @@ if canonical == "mycommand":
|
||||
- `args_hint` — argument placeholder shown in help (e.g. `"<prompt>"`, `"[name]"`)
|
||||
- `cli_only` — only available in the interactive CLI
|
||||
- `gateway_only` — only available in messaging platforms
|
||||
- `gateway_config_gate` — config dotpath (e.g. `"display.tool_progress_command"`); when set on a `cli_only` command, the command becomes available in the gateway if the config value is truthy. `GATEWAY_KNOWN_COMMANDS` always includes config-gated commands so the gateway can dispatch them; help/menus only show them when the gate is open.
|
||||
|
||||
**Adding an alias** requires only adding it to the `aliases` tuple on the existing `CommandDef`. No other file changes needed — dispatch, help text, Telegram menu, Slack mapping, and autocomplete all update automatically.
|
||||
|
||||
|
||||
+3
-2
@@ -72,8 +72,9 @@ export VIRTUAL_ENV="$(pwd)/venv"
|
||||
|
||||
# Install with all extras (messaging, cron, CLI menus, dev tools)
|
||||
uv pip install -e ".[all,dev]"
|
||||
uv pip install -e "./mini-swe-agent"
|
||||
uv pip install -e "./tinker-atropos"
|
||||
|
||||
# Optional: RL training submodule
|
||||
# git submodule update --init tinker-atropos && uv pip install -e "./tinker-atropos"
|
||||
|
||||
# Optional: browser tools
|
||||
npm install
|
||||
|
||||
@@ -144,16 +144,14 @@ Quick start for contributors:
|
||||
```bash
|
||||
git clone https://github.com/NousResearch/hermes-agent.git
|
||||
cd hermes-agent
|
||||
git submodule update --init mini-swe-agent # required terminal backend
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
uv venv venv --python 3.11
|
||||
source venv/bin/activate
|
||||
uv pip install -e ".[all,dev]"
|
||||
uv pip install -e "./mini-swe-agent"
|
||||
python -m pytest tests/ -q
|
||||
```
|
||||
|
||||
> **RL Training (optional):** To work on the RL/Tinker-Atropos integration, also run:
|
||||
> **RL Training (optional):** To work on the RL/Tinker-Atropos integration:
|
||||
> ```bash
|
||||
> git submodule update --init tinker-atropos
|
||||
> uv pip install -e "./tinker-atropos"
|
||||
|
||||
@@ -0,0 +1,400 @@
|
||||
# Hermes Agent v0.4.0 (v2026.3.23)
|
||||
|
||||
**Release Date:** March 23, 2026
|
||||
|
||||
> The platform expansion release — OpenAI-compatible API server, 6 new messaging adapters, 4 new inference providers, MCP server management with OAuth 2.1, @ context references, gateway prompt caching, streaming enabled by default, and a sweeping reliability pass with 200+ bug fixes.
|
||||
|
||||
---
|
||||
|
||||
## ✨ Highlights
|
||||
|
||||
- **OpenAI-compatible API server** — Expose Hermes as an `/v1/chat/completions` endpoint with a new `/api/jobs` REST API for cron job management, hardened with input limits, field whitelists, SQLite-backed response persistence, and CORS origin protection ([#1756](https://github.com/NousResearch/hermes-agent/pull/1756), [#2450](https://github.com/NousResearch/hermes-agent/pull/2450), [#2456](https://github.com/NousResearch/hermes-agent/pull/2456), [#2451](https://github.com/NousResearch/hermes-agent/pull/2451), [#2472](https://github.com/NousResearch/hermes-agent/pull/2472))
|
||||
|
||||
- **6 new messaging platform adapters** — Signal, DingTalk, SMS (Twilio), Mattermost, Matrix, and Webhook adapters join Telegram, Discord, and WhatsApp. Gateway auto-reconnects failed platforms with exponential backoff ([#2206](https://github.com/NousResearch/hermes-agent/pull/2206), [#1685](https://github.com/NousResearch/hermes-agent/pull/1685), [#1688](https://github.com/NousResearch/hermes-agent/pull/1688), [#1683](https://github.com/NousResearch/hermes-agent/pull/1683), [#2166](https://github.com/NousResearch/hermes-agent/pull/2166), [#2584](https://github.com/NousResearch/hermes-agent/pull/2584))
|
||||
|
||||
- **@ context references** — Claude Code-style `@file` and `@url` context injection with tab completions in the CLI ([#2343](https://github.com/NousResearch/hermes-agent/pull/2343), [#2482](https://github.com/NousResearch/hermes-agent/pull/2482))
|
||||
|
||||
- **4 new inference providers** — GitHub Copilot (OAuth + token validation), Alibaba Cloud / DashScope, Kilo Code, and OpenCode Zen/Go ([#1924](https://github.com/NousResearch/hermes-agent/pull/1924), [#1879](https://github.com/NousResearch/hermes-agent/pull/1879) by @mchzimm, [#1673](https://github.com/NousResearch/hermes-agent/pull/1673), [#1666](https://github.com/NousResearch/hermes-agent/pull/1666), [#1650](https://github.com/NousResearch/hermes-agent/pull/1650))
|
||||
|
||||
- **MCP server management CLI** — `hermes mcp` commands for installing, configuring, and authenticating MCP servers with full OAuth 2.1 PKCE flow ([#2465](https://github.com/NousResearch/hermes-agent/pull/2465))
|
||||
|
||||
- **Gateway prompt caching** — Cache AIAgent instances per session, preserving Anthropic prompt cache across turns for dramatic cost reduction on long conversations ([#2282](https://github.com/NousResearch/hermes-agent/pull/2282), [#2284](https://github.com/NousResearch/hermes-agent/pull/2284), [#2361](https://github.com/NousResearch/hermes-agent/pull/2361))
|
||||
|
||||
- **Context compression overhaul** — Structured summaries with iterative updates, token-budget tail protection, configurable summary endpoint, and fallback model support ([#2323](https://github.com/NousResearch/hermes-agent/pull/2323), [#1727](https://github.com/NousResearch/hermes-agent/pull/1727), [#2224](https://github.com/NousResearch/hermes-agent/pull/2224))
|
||||
|
||||
- **Streaming enabled by default** — CLI streaming on by default with proper spinner/tool progress display during streaming mode, plus extensive linebreak and concatenation fixes ([#2340](https://github.com/NousResearch/hermes-agent/pull/2340), [#2161](https://github.com/NousResearch/hermes-agent/pull/2161), [#2258](https://github.com/NousResearch/hermes-agent/pull/2258))
|
||||
|
||||
---
|
||||
|
||||
## 🖥️ CLI & User Experience
|
||||
|
||||
### New Commands & Interactions
|
||||
- **@ context completions** — Tab-completable `@file`/`@url` references that inject file content or web pages into the conversation ([#2482](https://github.com/NousResearch/hermes-agent/pull/2482), [#2343](https://github.com/NousResearch/hermes-agent/pull/2343))
|
||||
- **`/statusbar`** — Toggle a persistent config bar showing model + provider info in the prompt ([#2240](https://github.com/NousResearch/hermes-agent/pull/2240), [#1917](https://github.com/NousResearch/hermes-agent/pull/1917))
|
||||
- **`/queue`** — Queue prompts for the agent without interrupting the current run ([#2191](https://github.com/NousResearch/hermes-agent/pull/2191), [#2469](https://github.com/NousResearch/hermes-agent/pull/2469))
|
||||
- **`/permission`** — Switch approval mode dynamically during a session ([#2207](https://github.com/NousResearch/hermes-agent/pull/2207))
|
||||
- **`/browser`** — Interactive browser sessions from the CLI ([#2273](https://github.com/NousResearch/hermes-agent/pull/2273), [#1814](https://github.com/NousResearch/hermes-agent/pull/1814))
|
||||
- **`/cost`** — Live pricing and usage tracking in gateway mode ([#2180](https://github.com/NousResearch/hermes-agent/pull/2180))
|
||||
- **`/approve` and `/deny`** — Replaced bare text approval in gateway with explicit commands ([#2002](https://github.com/NousResearch/hermes-agent/pull/2002))
|
||||
|
||||
### Streaming & Display
|
||||
- Streaming enabled by default in CLI ([#2340](https://github.com/NousResearch/hermes-agent/pull/2340))
|
||||
- Show spinners and tool progress during streaming mode ([#2161](https://github.com/NousResearch/hermes-agent/pull/2161))
|
||||
- Show reasoning/thinking blocks when `show_reasoning` enabled ([#2118](https://github.com/NousResearch/hermes-agent/pull/2118))
|
||||
- Context pressure warnings for CLI and gateway ([#2159](https://github.com/NousResearch/hermes-agent/pull/2159))
|
||||
- Fix: streaming chunks concatenated without whitespace ([#2258](https://github.com/NousResearch/hermes-agent/pull/2258))
|
||||
- Fix: iteration boundary linebreak prevents stream concatenation ([#2413](https://github.com/NousResearch/hermes-agent/pull/2413))
|
||||
- Fix: defer streaming linebreak to prevent blank line stacking ([#2473](https://github.com/NousResearch/hermes-agent/pull/2473))
|
||||
- Fix: suppress spinner animation in non-TTY environments ([#2216](https://github.com/NousResearch/hermes-agent/pull/2216))
|
||||
- Fix: display provider and endpoint in API error messages ([#2266](https://github.com/NousResearch/hermes-agent/pull/2266))
|
||||
- Fix: resolve garbled ANSI escape codes in status printouts ([#2448](https://github.com/NousResearch/hermes-agent/pull/2448))
|
||||
- Fix: update gold ANSI color to true-color format ([#2246](https://github.com/NousResearch/hermes-agent/pull/2246))
|
||||
- Fix: normalize toolset labels and use skin colors in banner ([#1912](https://github.com/NousResearch/hermes-agent/pull/1912))
|
||||
|
||||
### CLI Polish
|
||||
- Fix: prevent 'Press ENTER to continue...' on exit ([#2555](https://github.com/NousResearch/hermes-agent/pull/2555))
|
||||
- Fix: flush stdout during agent loop to prevent macOS display freeze ([#1654](https://github.com/NousResearch/hermes-agent/pull/1654))
|
||||
- Fix: show human-readable error when `hermes setup` hits permissions error ([#2196](https://github.com/NousResearch/hermes-agent/pull/2196))
|
||||
- Fix: `/stop` command crash + UnboundLocalError in streaming media delivery ([#2463](https://github.com/NousResearch/hermes-agent/pull/2463))
|
||||
- Fix: allow custom/local endpoints without API key ([#2556](https://github.com/NousResearch/hermes-agent/pull/2556))
|
||||
- Fix: Kitty keyboard protocol Shift+Enter for Ghostty/WezTerm (attempted + reverted due to prompt_toolkit crash) ([#2345](https://github.com/NousResearch/hermes-agent/pull/2345), [#2349](https://github.com/NousResearch/hermes-agent/pull/2349))
|
||||
|
||||
### Configuration
|
||||
- **`${ENV_VAR}` substitution** in config.yaml ([#2684](https://github.com/NousResearch/hermes-agent/pull/2684))
|
||||
- **Real-time config reload** — config.yaml changes apply without restart ([#2210](https://github.com/NousResearch/hermes-agent/pull/2210))
|
||||
- **`custom_models.yaml`** for user-managed model additions ([#2214](https://github.com/NousResearch/hermes-agent/pull/2214))
|
||||
- **Priority-based context file selection** + CLAUDE.md support ([#2301](https://github.com/NousResearch/hermes-agent/pull/2301))
|
||||
- **Merge nested YAML sections** instead of replacing on config update ([#2213](https://github.com/NousResearch/hermes-agent/pull/2213))
|
||||
- Fix: config.yaml provider key overrides env var silently ([#2272](https://github.com/NousResearch/hermes-agent/pull/2272))
|
||||
- Fix: log warning instead of silently swallowing config.yaml errors ([#2683](https://github.com/NousResearch/hermes-agent/pull/2683))
|
||||
- Fix: disabled toolsets re-enable themselves after `hermes tools` ([#2268](https://github.com/NousResearch/hermes-agent/pull/2268))
|
||||
- Fix: platform default toolsets silently override tool deselection ([#2624](https://github.com/NousResearch/hermes-agent/pull/2624))
|
||||
- Fix: honor bare YAML `approvals.mode: off` ([#2620](https://github.com/NousResearch/hermes-agent/pull/2620))
|
||||
- Fix: `hermes update` use `.[all]` extras with fallback ([#1728](https://github.com/NousResearch/hermes-agent/pull/1728))
|
||||
- Fix: `hermes update` prompt before resetting working tree on stash conflicts ([#2390](https://github.com/NousResearch/hermes-agent/pull/2390))
|
||||
- Fix: use git pull --rebase in update/install to avoid divergent branch error ([#2274](https://github.com/NousResearch/hermes-agent/pull/2274))
|
||||
- Fix: add zprofile fallback and create zshrc on fresh macOS installs ([#2320](https://github.com/NousResearch/hermes-agent/pull/2320))
|
||||
- Fix: remove `ANTHROPIC_BASE_URL` env var to avoid collisions ([#1675](https://github.com/NousResearch/hermes-agent/pull/1675))
|
||||
- Fix: don't ask IMAP password if already in keyring or env ([#2212](https://github.com/NousResearch/hermes-agent/pull/2212))
|
||||
- Fix: OpenCode Zen/Go show OpenRouter models instead of their own ([#2277](https://github.com/NousResearch/hermes-agent/pull/2277))
|
||||
|
||||
---
|
||||
|
||||
## 🏗️ Core Agent & Architecture
|
||||
|
||||
### New Providers
|
||||
- **GitHub Copilot** — Full OAuth auth, API routing, token validation, and 400k context. ([#1924](https://github.com/NousResearch/hermes-agent/pull/1924), [#1896](https://github.com/NousResearch/hermes-agent/pull/1896), [#1879](https://github.com/NousResearch/hermes-agent/pull/1879) by @mchzimm, [#2507](https://github.com/NousResearch/hermes-agent/pull/2507))
|
||||
- **Alibaba Cloud / DashScope** — Full integration with DashScope v1 runtime, model dot preservation, and 401 auth fixes ([#1673](https://github.com/NousResearch/hermes-agent/pull/1673), [#2332](https://github.com/NousResearch/hermes-agent/pull/2332), [#2459](https://github.com/NousResearch/hermes-agent/pull/2459))
|
||||
- **Kilo Code** — First-class inference provider ([#1666](https://github.com/NousResearch/hermes-agent/pull/1666))
|
||||
- **OpenCode Zen and OpenCode Go** — New provider backends ([#1650](https://github.com/NousResearch/hermes-agent/pull/1650), [#2393](https://github.com/NousResearch/hermes-agent/pull/2393) by @0xbyt4)
|
||||
- **NeuTTS** — Local TTS provider backend with built-in setup flow, replacing the old optional skill ([#1657](https://github.com/NousResearch/hermes-agent/pull/1657), [#1664](https://github.com/NousResearch/hermes-agent/pull/1664))
|
||||
|
||||
### Provider Improvements
|
||||
- **Eager fallback** to backup model on rate-limit errors ([#1730](https://github.com/NousResearch/hermes-agent/pull/1730))
|
||||
- **Endpoint metadata** for custom model context and pricing; query local servers for actual context window size ([#1906](https://github.com/NousResearch/hermes-agent/pull/1906), [#2091](https://github.com/NousResearch/hermes-agent/pull/2091) by @dusterbloom)
|
||||
- **Context length detection overhaul** — models.dev integration, provider-aware resolution, fuzzy matching for custom endpoints, `/v1/props` for llama.cpp ([#2158](https://github.com/NousResearch/hermes-agent/pull/2158), [#2051](https://github.com/NousResearch/hermes-agent/pull/2051), [#2403](https://github.com/NousResearch/hermes-agent/pull/2403))
|
||||
- **Model catalog updates** — gpt-5.4-mini, gpt-5.4-nano, healer-alpha, haiku-4.5, minimax-m2.7, claude 4.6 at 1M context ([#1913](https://github.com/NousResearch/hermes-agent/pull/1913), [#1915](https://github.com/NousResearch/hermes-agent/pull/1915), [#1900](https://github.com/NousResearch/hermes-agent/pull/1900), [#2155](https://github.com/NousResearch/hermes-agent/pull/2155), [#2474](https://github.com/NousResearch/hermes-agent/pull/2474))
|
||||
- **Custom endpoint improvements** — `model.base_url` in config.yaml, `api_mode` override for responses API, allow endpoints without API key, fail fast on missing keys ([#2330](https://github.com/NousResearch/hermes-agent/pull/2330), [#1651](https://github.com/NousResearch/hermes-agent/pull/1651), [#2556](https://github.com/NousResearch/hermes-agent/pull/2556), [#2445](https://github.com/NousResearch/hermes-agent/pull/2445), [#1994](https://github.com/NousResearch/hermes-agent/pull/1994), [#1998](https://github.com/NousResearch/hermes-agent/pull/1998))
|
||||
- Inject model and provider into system prompt ([#1929](https://github.com/NousResearch/hermes-agent/pull/1929))
|
||||
- Tie `api_mode` to provider config instead of env var ([#1656](https://github.com/NousResearch/hermes-agent/pull/1656))
|
||||
- Fix: prevent Anthropic token leaking to third-party `anthropic_messages` providers ([#2389](https://github.com/NousResearch/hermes-agent/pull/2389))
|
||||
- Fix: prevent Anthropic fallback from inheriting non-Anthropic `base_url` ([#2388](https://github.com/NousResearch/hermes-agent/pull/2388))
|
||||
- Fix: `auxiliary_is_nous` flag never resets — leaked Nous tags to other providers ([#1713](https://github.com/NousResearch/hermes-agent/pull/1713))
|
||||
- Fix: Anthropic `tool_choice 'none'` still allowed tool calls ([#1714](https://github.com/NousResearch/hermes-agent/pull/1714))
|
||||
- Fix: Mistral parser nested JSON fallback extraction ([#2335](https://github.com/NousResearch/hermes-agent/pull/2335))
|
||||
- Fix: MiniMax 401 auth resolved by defaulting to `anthropic_messages` ([#2103](https://github.com/NousResearch/hermes-agent/pull/2103))
|
||||
- Fix: case-insensitive model family matching ([#2350](https://github.com/NousResearch/hermes-agent/pull/2350))
|
||||
- Fix: ignore placeholder provider keys in activation checks ([#2358](https://github.com/NousResearch/hermes-agent/pull/2358))
|
||||
- Fix: Preserve Ollama model:tag colons in context length detection ([#2149](https://github.com/NousResearch/hermes-agent/pull/2149))
|
||||
- Fix: recognize Claude Code OAuth credentials in startup gate ([#1663](https://github.com/NousResearch/hermes-agent/pull/1663))
|
||||
- Fix: detect Claude Code version dynamically for OAuth user-agent ([#1670](https://github.com/NousResearch/hermes-agent/pull/1670))
|
||||
- Fix: OAuth flag stale after refresh/fallback ([#1890](https://github.com/NousResearch/hermes-agent/pull/1890))
|
||||
- Fix: auxiliary client skips expired Codex JWT ([#2397](https://github.com/NousResearch/hermes-agent/pull/2397))
|
||||
|
||||
### Agent Loop
|
||||
- **Gateway prompt caching** — Cache AIAgent per session, keep assistant turns, fix session restore ([#2282](https://github.com/NousResearch/hermes-agent/pull/2282), [#2284](https://github.com/NousResearch/hermes-agent/pull/2284), [#2361](https://github.com/NousResearch/hermes-agent/pull/2361))
|
||||
- **Context compression overhaul** — Structured summaries, iterative updates, token-budget tail protection, configurable `summary_base_url` ([#2323](https://github.com/NousResearch/hermes-agent/pull/2323), [#1727](https://github.com/NousResearch/hermes-agent/pull/1727), [#2224](https://github.com/NousResearch/hermes-agent/pull/2224))
|
||||
- **Pre-call sanitization and post-call tool guardrails** ([#1732](https://github.com/NousResearch/hermes-agent/pull/1732))
|
||||
- **Auto-recover** from provider-rejected `tool_choice` by retrying without ([#2174](https://github.com/NousResearch/hermes-agent/pull/2174))
|
||||
- **Background memory/skill review** replaces inline nudges ([#2235](https://github.com/NousResearch/hermes-agent/pull/2235))
|
||||
- **SOUL.md as primary agent identity** instead of hardcoded default ([#1922](https://github.com/NousResearch/hermes-agent/pull/1922))
|
||||
- Fix: prevent silent tool result loss during context compression ([#1993](https://github.com/NousResearch/hermes-agent/pull/1993))
|
||||
- Fix: handle empty/null function arguments in tool call recovery ([#2163](https://github.com/NousResearch/hermes-agent/pull/2163))
|
||||
- Fix: handle API refusal responses gracefully instead of crashing ([#2156](https://github.com/NousResearch/hermes-agent/pull/2156))
|
||||
- Fix: prevent stuck agent loop on malformed tool calls ([#2114](https://github.com/NousResearch/hermes-agent/pull/2114))
|
||||
- Fix: return JSON parse error to model instead of dispatching with empty args ([#2342](https://github.com/NousResearch/hermes-agent/pull/2342))
|
||||
- Fix: consecutive assistant message merge drops content on mixed types ([#1703](https://github.com/NousResearch/hermes-agent/pull/1703))
|
||||
- Fix: message role alternation violations in JSON recovery and error handler ([#1722](https://github.com/NousResearch/hermes-agent/pull/1722))
|
||||
- Fix: `compression_attempts` resets each iteration — allowed unlimited compressions ([#1723](https://github.com/NousResearch/hermes-agent/pull/1723))
|
||||
- Fix: `length_continue_retries` never resets — later truncations got fewer retries ([#1717](https://github.com/NousResearch/hermes-agent/pull/1717))
|
||||
- Fix: compressor summary role violated consecutive-role constraint ([#1720](https://github.com/NousResearch/hermes-agent/pull/1720), [#1743](https://github.com/NousResearch/hermes-agent/pull/1743))
|
||||
- Fix: remove hardcoded `gemini-3-flash-preview` as default summary model ([#2464](https://github.com/NousResearch/hermes-agent/pull/2464))
|
||||
- Fix: correctly handle empty tool results ([#2201](https://github.com/NousResearch/hermes-agent/pull/2201))
|
||||
- Fix: crash on None entry in `tool_calls` list ([#2209](https://github.com/NousResearch/hermes-agent/pull/2209) by @0xbyt4, [#2316](https://github.com/NousResearch/hermes-agent/pull/2316))
|
||||
- Fix: per-thread persistent event loops in worker threads ([#2214](https://github.com/NousResearch/hermes-agent/pull/2214) by @jquesnelle)
|
||||
- Fix: prevent 'event loop already running' when async tools run in parallel ([#2207](https://github.com/NousResearch/hermes-agent/pull/2207))
|
||||
- Fix: strip ANSI at the source — clean terminal output before it reaches the model ([#2115](https://github.com/NousResearch/hermes-agent/pull/2115))
|
||||
- Fix: skip top-level `cache_control` on role:tool for OpenRouter ([#2391](https://github.com/NousResearch/hermes-agent/pull/2391))
|
||||
- Fix: delegate tool — save parent tool names before child construction mutates global ([#2083](https://github.com/NousResearch/hermes-agent/pull/2083) by @ygd58, [#1894](https://github.com/NousResearch/hermes-agent/pull/1894))
|
||||
- Fix: only strip last assistant message if empty string ([#2326](https://github.com/NousResearch/hermes-agent/pull/2326))
|
||||
|
||||
### Session & Memory
|
||||
- **Session search** and management slash commands ([#2198](https://github.com/NousResearch/hermes-agent/pull/2198))
|
||||
- **Auto session titles** and `.hermes.md` project config ([#1712](https://github.com/NousResearch/hermes-agent/pull/1712))
|
||||
- Fix: concurrent memory writes silently drop entries — added file locking ([#1726](https://github.com/NousResearch/hermes-agent/pull/1726))
|
||||
- Fix: search all sources by default in `session_search` ([#1892](https://github.com/NousResearch/hermes-agent/pull/1892))
|
||||
- Fix: handle hyphenated FTS5 queries and preserve quoted literals ([#1776](https://github.com/NousResearch/hermes-agent/pull/1776))
|
||||
- Fix: skip corrupt lines in `load_transcript` instead of crashing ([#1744](https://github.com/NousResearch/hermes-agent/pull/1744))
|
||||
- Fix: normalize session keys to prevent case-sensitive duplicates ([#2157](https://github.com/NousResearch/hermes-agent/pull/2157))
|
||||
- Fix: prevent `session_search` crash when no sessions exist ([#2194](https://github.com/NousResearch/hermes-agent/pull/2194))
|
||||
- Fix: reset token counters on new session for accurate usage display ([#2101](https://github.com/NousResearch/hermes-agent/pull/2101) by @InB4DevOps)
|
||||
- Fix: prevent stale memory overwrites by flush agent ([#2687](https://github.com/NousResearch/hermes-agent/pull/2687))
|
||||
- Fix: remove synthetic error message injection, fix session resume after repeated failures ([#2303](https://github.com/NousResearch/hermes-agent/pull/2303))
|
||||
- Fix: quiet mode with `--resume` now passes conversation_history ([#2357](https://github.com/NousResearch/hermes-agent/pull/2357))
|
||||
- Fix: unify resume logic in batch mode ([#2331](https://github.com/NousResearch/hermes-agent/pull/2331))
|
||||
|
||||
### Honcho Memory
|
||||
- Honcho config fixes and @ context reference integration ([#2343](https://github.com/NousResearch/hermes-agent/pull/2343))
|
||||
- Self-hosted / Docker configuration documentation ([#2475](https://github.com/NousResearch/hermes-agent/pull/2475))
|
||||
|
||||
---
|
||||
|
||||
## 📱 Messaging Platforms (Gateway)
|
||||
|
||||
### New Platform Adapters
|
||||
- **Signal Messenger** — Full adapter with attachment handling, group message filtering, and Note to Self echo-back protection ([#2206](https://github.com/NousResearch/hermes-agent/pull/2206), [#2400](https://github.com/NousResearch/hermes-agent/pull/2400), [#2297](https://github.com/NousResearch/hermes-agent/pull/2297), [#2156](https://github.com/NousResearch/hermes-agent/pull/2156))
|
||||
- **DingTalk** — Adapter with gateway wiring and setup docs ([#1685](https://github.com/NousResearch/hermes-agent/pull/1685), [#1690](https://github.com/NousResearch/hermes-agent/pull/1690), [#1692](https://github.com/NousResearch/hermes-agent/pull/1692))
|
||||
- **SMS (Twilio)** ([#1688](https://github.com/NousResearch/hermes-agent/pull/1688))
|
||||
- **Mattermost** — With @-mention-only channel filter ([#1683](https://github.com/NousResearch/hermes-agent/pull/1683), [#2443](https://github.com/NousResearch/hermes-agent/pull/2443))
|
||||
- **Matrix** — With vision support and image caching ([#1683](https://github.com/NousResearch/hermes-agent/pull/1683), [#2520](https://github.com/NousResearch/hermes-agent/pull/2520))
|
||||
- **Webhook** — Platform adapter for external event triggers ([#2166](https://github.com/NousResearch/hermes-agent/pull/2166))
|
||||
- **OpenAI-compatible API server** — `/v1/chat/completions` endpoint with `/api/jobs` cron management ([#1756](https://github.com/NousResearch/hermes-agent/pull/1756), [#2450](https://github.com/NousResearch/hermes-agent/pull/2450), [#2456](https://github.com/NousResearch/hermes-agent/pull/2456))
|
||||
|
||||
### Telegram Improvements
|
||||
- MarkdownV2 support — strikethrough, spoiler, blockquotes, escape parentheses/braces/backslashes/backticks ([#2199](https://github.com/NousResearch/hermes-agent/pull/2199), [#2200](https://github.com/NousResearch/hermes-agent/pull/2200) by @llbn, [#2386](https://github.com/NousResearch/hermes-agent/pull/2386))
|
||||
- Auto-detect HTML tags and use `parse_mode=HTML` ([#1709](https://github.com/NousResearch/hermes-agent/pull/1709))
|
||||
- Telegram group vision support + thread-based sessions ([#2153](https://github.com/NousResearch/hermes-agent/pull/2153))
|
||||
- Auto-reconnect polling after network interruption ([#2517](https://github.com/NousResearch/hermes-agent/pull/2517))
|
||||
- Aggregate split text messages before dispatching ([#1674](https://github.com/NousResearch/hermes-agent/pull/1674))
|
||||
- Fix: streaming config bridge, not-modified, flood control ([#1782](https://github.com/NousResearch/hermes-agent/pull/1782), [#1783](https://github.com/NousResearch/hermes-agent/pull/1783))
|
||||
- Fix: edited_message event crashes ([#2074](https://github.com/NousResearch/hermes-agent/pull/2074))
|
||||
- Fix: retry 409 polling conflicts before giving up ([#2312](https://github.com/NousResearch/hermes-agent/pull/2312))
|
||||
- Fix: topic delivery via `platform:chat_id:thread_id` format ([#2455](https://github.com/NousResearch/hermes-agent/pull/2455))
|
||||
|
||||
### Discord Improvements
|
||||
- Document caching and text-file injection ([#2503](https://github.com/NousResearch/hermes-agent/pull/2503))
|
||||
- Persistent typing indicator for DMs ([#2468](https://github.com/NousResearch/hermes-agent/pull/2468))
|
||||
- Discord DM vision — inline images + attachment analysis ([#2186](https://github.com/NousResearch/hermes-agent/pull/2186))
|
||||
- Persist thread participation across gateway restarts ([#1661](https://github.com/NousResearch/hermes-agent/pull/1661))
|
||||
- Fix: gateway crash on non-ASCII guild names ([#2302](https://github.com/NousResearch/hermes-agent/pull/2302))
|
||||
- Fix: thread permission errors ([#2073](https://github.com/NousResearch/hermes-agent/pull/2073))
|
||||
- Fix: slash event routing in threads ([#2460](https://github.com/NousResearch/hermes-agent/pull/2460))
|
||||
- Fix: remove bugged followup messages + `/ask` command ([#1836](https://github.com/NousResearch/hermes-agent/pull/1836))
|
||||
- Fix: graceful WebSocket reconnection ([#2127](https://github.com/NousResearch/hermes-agent/pull/2127))
|
||||
- Fix: voice channel TTS when streaming enabled ([#2322](https://github.com/NousResearch/hermes-agent/pull/2322))
|
||||
|
||||
### WhatsApp & Other Adapters
|
||||
- WhatsApp: outbound `send_message` routing ([#1769](https://github.com/NousResearch/hermes-agent/pull/1769) by @sai-samarth), LID format self-chat ([#1667](https://github.com/NousResearch/hermes-agent/pull/1667)), `reply_prefix` config fix ([#1923](https://github.com/NousResearch/hermes-agent/pull/1923)), restart on bridge child exit ([#2334](https://github.com/NousResearch/hermes-agent/pull/2334)), image/bridge improvements ([#2181](https://github.com/NousResearch/hermes-agent/pull/2181))
|
||||
- Matrix: correct `reply_to_message_id` parameter ([#1895](https://github.com/NousResearch/hermes-agent/pull/1895)), bare media types fix ([#1736](https://github.com/NousResearch/hermes-agent/pull/1736))
|
||||
- Mattermost: MIME types for media attachments ([#2329](https://github.com/NousResearch/hermes-agent/pull/2329))
|
||||
|
||||
### Gateway Core
|
||||
- **Auto-reconnect** failed platforms with exponential backoff ([#2584](https://github.com/NousResearch/hermes-agent/pull/2584))
|
||||
- **Notify users when session auto-resets** ([#2519](https://github.com/NousResearch/hermes-agent/pull/2519))
|
||||
- **Reply-to message context** for out-of-session replies ([#1662](https://github.com/NousResearch/hermes-agent/pull/1662))
|
||||
- **Ignore unauthorized DMs** config option ([#1919](https://github.com/NousResearch/hermes-agent/pull/1919))
|
||||
- Fix: `/reset` in thread-mode resets global session instead of thread ([#2254](https://github.com/NousResearch/hermes-agent/pull/2254))
|
||||
- Fix: deliver MEDIA: files after streaming responses ([#2382](https://github.com/NousResearch/hermes-agent/pull/2382))
|
||||
- Fix: cap interrupt recursion depth to prevent resource exhaustion ([#1659](https://github.com/NousResearch/hermes-agent/pull/1659))
|
||||
- Fix: detect stopped processes and release stale locks on `--replace` ([#2406](https://github.com/NousResearch/hermes-agent/pull/2406), [#1908](https://github.com/NousResearch/hermes-agent/pull/1908))
|
||||
- Fix: PID-based wait with force-kill for gateway restart ([#1902](https://github.com/NousResearch/hermes-agent/pull/1902))
|
||||
- Fix: prevent `--replace` mode from killing the caller process ([#2185](https://github.com/NousResearch/hermes-agent/pull/2185))
|
||||
- Fix: `/model` shows active fallback model instead of config default ([#1660](https://github.com/NousResearch/hermes-agent/pull/1660))
|
||||
- Fix: `/title` command fails when session doesn't exist in SQLite yet ([#2379](https://github.com/NousResearch/hermes-agent/pull/2379) by @ten-jampa)
|
||||
- Fix: process `/queue`'d messages after agent completion ([#2469](https://github.com/NousResearch/hermes-agent/pull/2469))
|
||||
- Fix: strip orphaned `tool_results` + let `/reset` bypass running agent ([#2180](https://github.com/NousResearch/hermes-agent/pull/2180))
|
||||
- Fix: prevent agents from starting gateway outside systemd management ([#2617](https://github.com/NousResearch/hermes-agent/pull/2617))
|
||||
- Fix: prevent systemd restart storm on gateway connection failure ([#2327](https://github.com/NousResearch/hermes-agent/pull/2327))
|
||||
- Fix: include resolved node path in systemd unit ([#1767](https://github.com/NousResearch/hermes-agent/pull/1767) by @sai-samarth)
|
||||
- Fix: send error details to user in gateway outer exception handler ([#1966](https://github.com/NousResearch/hermes-agent/pull/1966))
|
||||
- Fix: improve error handling for 429 usage limits and 500 context overflow ([#1839](https://github.com/NousResearch/hermes-agent/pull/1839))
|
||||
- Fix: add all missing platform allowlist env vars to startup warning check ([#2628](https://github.com/NousResearch/hermes-agent/pull/2628))
|
||||
- Fix: media delivery fails for file paths containing spaces ([#2621](https://github.com/NousResearch/hermes-agent/pull/2621))
|
||||
- Fix: duplicate session-key collision in multi-platform gateway ([#2171](https://github.com/NousResearch/hermes-agent/pull/2171))
|
||||
- Fix: Matrix and Mattermost never report as connected ([#1711](https://github.com/NousResearch/hermes-agent/pull/1711))
|
||||
- Fix: PII redaction config never read — missing yaml import ([#1701](https://github.com/NousResearch/hermes-agent/pull/1701))
|
||||
- Fix: NameError on skill slash commands ([#1697](https://github.com/NousResearch/hermes-agent/pull/1697))
|
||||
- Fix: persist watcher metadata in checkpoint for crash recovery ([#1706](https://github.com/NousResearch/hermes-agent/pull/1706))
|
||||
- Fix: pass `message_thread_id` in send_image_file, send_document, send_video ([#2339](https://github.com/NousResearch/hermes-agent/pull/2339))
|
||||
- Fix: media-group aggregation on rapid successive photo messages ([#2160](https://github.com/NousResearch/hermes-agent/pull/2160))
|
||||
|
||||
---
|
||||
|
||||
## 🔧 Tool System
|
||||
|
||||
### MCP Enhancements
|
||||
- **MCP server management CLI** + OAuth 2.1 PKCE auth ([#2465](https://github.com/NousResearch/hermes-agent/pull/2465))
|
||||
- **Expose MCP servers as standalone toolsets** ([#1907](https://github.com/NousResearch/hermes-agent/pull/1907))
|
||||
- **Interactive MCP tool configuration** in `hermes tools` ([#1694](https://github.com/NousResearch/hermes-agent/pull/1694))
|
||||
- Fix: MCP-OAuth port mismatch, path traversal, and shared handler state ([#2552](https://github.com/NousResearch/hermes-agent/pull/2552))
|
||||
- Fix: preserve MCP tool registrations across session resets ([#2124](https://github.com/NousResearch/hermes-agent/pull/2124))
|
||||
- Fix: concurrent file access crash + duplicate MCP registration ([#2154](https://github.com/NousResearch/hermes-agent/pull/2154))
|
||||
- Fix: normalise MCP schemas + expand session list columns ([#2102](https://github.com/NousResearch/hermes-agent/pull/2102))
|
||||
- Fix: `tool_choice` `mcp_` prefix handling ([#1775](https://github.com/NousResearch/hermes-agent/pull/1775))
|
||||
|
||||
### Web Tool Backends
|
||||
- **Tavily** as web search/extract/crawl backend ([#1731](https://github.com/NousResearch/hermes-agent/pull/1731))
|
||||
- **Parallel** as alternative web search/extract backend ([#1696](https://github.com/NousResearch/hermes-agent/pull/1696))
|
||||
- **Configurable web backend** — Firecrawl/BeautifulSoup/Playwright selection ([#2256](https://github.com/NousResearch/hermes-agent/pull/2256))
|
||||
- Fix: whitespace-only env vars bypass web backend detection ([#2341](https://github.com/NousResearch/hermes-agent/pull/2341))
|
||||
|
||||
### New Tools
|
||||
- **IMAP email** reading and sending ([#2173](https://github.com/NousResearch/hermes-agent/pull/2173))
|
||||
- **STT (speech-to-text)** tool using Whisper API ([#2072](https://github.com/NousResearch/hermes-agent/pull/2072))
|
||||
- **Route-aware pricing estimates** ([#1695](https://github.com/NousResearch/hermes-agent/pull/1695))
|
||||
|
||||
### Tool Improvements
|
||||
- TTS: `base_url` support for OpenAI TTS provider ([#2064](https://github.com/NousResearch/hermes-agent/pull/2064) by @hanai)
|
||||
- Vision: configurable timeout, tilde expansion in file paths, DM vision with multi-image and base64 fallback ([#2480](https://github.com/NousResearch/hermes-agent/pull/2480), [#2585](https://github.com/NousResearch/hermes-agent/pull/2585), [#2211](https://github.com/NousResearch/hermes-agent/pull/2211))
|
||||
- Browser: race condition fix in session creation ([#1721](https://github.com/NousResearch/hermes-agent/pull/1721)), TypeError on unexpected LLM params ([#1735](https://github.com/NousResearch/hermes-agent/pull/1735))
|
||||
- File tools: strip ANSI escape codes from write_file and patch content ([#2532](https://github.com/NousResearch/hermes-agent/pull/2532)), include pagination args in repeated search key ([#1824](https://github.com/NousResearch/hermes-agent/pull/1824) by @cutepawss), improve fuzzy matching accuracy + position calculation refactor ([#2096](https://github.com/NousResearch/hermes-agent/pull/2096), [#1681](https://github.com/NousResearch/hermes-agent/pull/1681))
|
||||
- Code execution: resource leak and double socket close fix ([#2381](https://github.com/NousResearch/hermes-agent/pull/2381))
|
||||
- Delegate: thread safety for concurrent subagent delegation ([#1672](https://github.com/NousResearch/hermes-agent/pull/1672)), preserve parent agent's tool list after delegation ([#1778](https://github.com/NousResearch/hermes-agent/pull/1778))
|
||||
- Fix: make concurrent tool batching path-aware for file mutations ([#1914](https://github.com/NousResearch/hermes-agent/pull/1914))
|
||||
- Fix: chunk long messages in `send_message_tool` before platform dispatch ([#1646](https://github.com/NousResearch/hermes-agent/pull/1646))
|
||||
- Fix: add missing 'messaging' toolset ([#1718](https://github.com/NousResearch/hermes-agent/pull/1718))
|
||||
- Fix: prevent unavailable tool names from leaking into model schemas ([#2072](https://github.com/NousResearch/hermes-agent/pull/2072))
|
||||
- Fix: pass visited set by reference to prevent diamond dependency duplication ([#2311](https://github.com/NousResearch/hermes-agent/pull/2311))
|
||||
- Fix: Daytona sandbox lookup migrated from `find_one` to `get/list` ([#2063](https://github.com/NousResearch/hermes-agent/pull/2063) by @rovle)
|
||||
|
||||
---
|
||||
|
||||
## 🧩 Skills Ecosystem
|
||||
|
||||
### Skills System Improvements
|
||||
- **Agent-created skills** — Caution-level findings allowed, dangerous skills ask instead of block ([#1840](https://github.com/NousResearch/hermes-agent/pull/1840), [#2446](https://github.com/NousResearch/hermes-agent/pull/2446))
|
||||
- **`--yes` flag** to bypass confirmation in `/skills install` and uninstall ([#1647](https://github.com/NousResearch/hermes-agent/pull/1647))
|
||||
- **Disabled skills respected** across banner, system prompt, and slash commands ([#1897](https://github.com/NousResearch/hermes-agent/pull/1897))
|
||||
- Fix: skills custom_tools import crash + sandbox file_tools integration ([#2239](https://github.com/NousResearch/hermes-agent/pull/2239))
|
||||
- Fix: agent-created skills with pip requirements crash on install ([#2145](https://github.com/NousResearch/hermes-agent/pull/2145))
|
||||
- Fix: race condition in `Skills.__init__` when `hub.yaml` missing ([#2242](https://github.com/NousResearch/hermes-agent/pull/2242))
|
||||
- Fix: validate skill metadata before install and block duplicates ([#2241](https://github.com/NousResearch/hermes-agent/pull/2241))
|
||||
- Fix: skills hub inspect/resolve — 4 bugs in inspect, redirects, discovery, tap list ([#2447](https://github.com/NousResearch/hermes-agent/pull/2447))
|
||||
- Fix: agent-created skills keep working after session reset ([#2121](https://github.com/NousResearch/hermes-agent/pull/2121))
|
||||
|
||||
### New Skills
|
||||
- **OCR-and-documents** — PDF/DOCX/XLS/PPTX/image OCR with optional GPU ([#2236](https://github.com/NousResearch/hermes-agent/pull/2236), [#2461](https://github.com/NousResearch/hermes-agent/pull/2461))
|
||||
- **Huggingface-hub** bundled skill ([#1921](https://github.com/NousResearch/hermes-agent/pull/1921))
|
||||
- **Sherlock OSINT** username search ([#1671](https://github.com/NousResearch/hermes-agent/pull/1671))
|
||||
- **Meme-generation** — Image generator with Pillow ([#2344](https://github.com/NousResearch/hermes-agent/pull/2344))
|
||||
- **Bioinformatics** gateway skill — index to 400+ bio skills ([#2387](https://github.com/NousResearch/hermes-agent/pull/2387))
|
||||
- **Inference.sh** skill (terminal-based) ([#1686](https://github.com/NousResearch/hermes-agent/pull/1686))
|
||||
- **Base blockchain** optional skill ([#1643](https://github.com/NousResearch/hermes-agent/pull/1643))
|
||||
- **3D-model-viewer** optional skill ([#2226](https://github.com/NousResearch/hermes-agent/pull/2226))
|
||||
- **FastMCP** optional skill ([#2113](https://github.com/NousResearch/hermes-agent/pull/2113))
|
||||
- **Hermes-agent-setup** skill ([#1905](https://github.com/NousResearch/hermes-agent/pull/1905))
|
||||
|
||||
---
|
||||
|
||||
## 🔌 Plugin System Enhancements
|
||||
|
||||
- **TUI extension hooks** — Build custom CLIs on top of Hermes ([#2333](https://github.com/NousResearch/hermes-agent/pull/2333))
|
||||
- **`hermes plugins install/remove/list`** commands ([#2337](https://github.com/NousResearch/hermes-agent/pull/2337))
|
||||
- **Slash command registration** for plugins ([#2359](https://github.com/NousResearch/hermes-agent/pull/2359))
|
||||
- **`session:end` lifecycle event** hook ([#1725](https://github.com/NousResearch/hermes-agent/pull/1725))
|
||||
- Fix: require opt-in for project plugin discovery ([#2215](https://github.com/NousResearch/hermes-agent/pull/2215))
|
||||
|
||||
---
|
||||
|
||||
## 🔒 Security & Reliability
|
||||
|
||||
### Security
|
||||
- **SSRF protection** for vision_tools and web_tools ([#2679](https://github.com/NousResearch/hermes-agent/pull/2679))
|
||||
- **Shell injection prevention** in `_expand_path` via `~user` path suffix ([#2685](https://github.com/NousResearch/hermes-agent/pull/2685))
|
||||
- **Block untrusted browser-origin** API server access ([#2451](https://github.com/NousResearch/hermes-agent/pull/2451))
|
||||
- **Block sandbox backend creds** from subprocess env ([#1658](https://github.com/NousResearch/hermes-agent/pull/1658))
|
||||
- **Block @ references** from reading secrets outside workspace ([#2601](https://github.com/NousResearch/hermes-agent/pull/2601) by @Gutslabs)
|
||||
- **Malicious code pattern pre-exec scanner** for terminal_tool ([#2245](https://github.com/NousResearch/hermes-agent/pull/2245))
|
||||
- **Harden terminal safety** and sandbox file writes ([#1653](https://github.com/NousResearch/hermes-agent/pull/1653))
|
||||
- **PKCE verifier leak** fix + OAuth refresh Content-Type ([#1775](https://github.com/NousResearch/hermes-agent/pull/1775))
|
||||
- **Eliminate SQL string formatting** in `execute()` calls ([#2061](https://github.com/NousResearch/hermes-agent/pull/2061) by @dusterbloom)
|
||||
- **Harden jobs API** — input limits, field whitelist, startup check ([#2456](https://github.com/NousResearch/hermes-agent/pull/2456))
|
||||
|
||||
### Reliability
|
||||
- Thread locks on 4 SessionDB methods ([#1704](https://github.com/NousResearch/hermes-agent/pull/1704))
|
||||
- File locking for concurrent memory writes ([#1726](https://github.com/NousResearch/hermes-agent/pull/1726))
|
||||
- Handle OpenRouter errors gracefully ([#2112](https://github.com/NousResearch/hermes-agent/pull/2112))
|
||||
- Guard print() calls against OSError ([#1668](https://github.com/NousResearch/hermes-agent/pull/1668))
|
||||
- Safely handle non-string inputs in redacting formatter ([#2392](https://github.com/NousResearch/hermes-agent/pull/2392), [#1700](https://github.com/NousResearch/hermes-agent/pull/1700))
|
||||
- ACP: preserve session provider on model switch, persist sessions to disk ([#2380](https://github.com/NousResearch/hermes-agent/pull/2380), [#2071](https://github.com/NousResearch/hermes-agent/pull/2071))
|
||||
- API server: persist ResponseStore to SQLite across restarts ([#2472](https://github.com/NousResearch/hermes-agent/pull/2472))
|
||||
- Fix: `fetch_nous_models` always TypeError from positional args ([#1699](https://github.com/NousResearch/hermes-agent/pull/1699))
|
||||
- Fix: resolve merge conflict markers in cli.py breaking startup ([#2347](https://github.com/NousResearch/hermes-agent/pull/2347))
|
||||
- Fix: `minisweagent_path.py` missing from wheel ([#2098](https://github.com/NousResearch/hermes-agent/pull/2098) by @JiwaniZakir)
|
||||
|
||||
### Cron System
|
||||
- **`[SILENT]` response** — cron agents can suppress delivery ([#1833](https://github.com/NousResearch/hermes-agent/pull/1833))
|
||||
- **Scale missed-job grace window** with schedule frequency ([#2449](https://github.com/NousResearch/hermes-agent/pull/2449))
|
||||
- **Recover recent one-shot jobs** ([#1918](https://github.com/NousResearch/hermes-agent/pull/1918))
|
||||
- Fix: normalize `repeat<=0` to None — jobs deleted after first run when LLM passes -1 ([#2612](https://github.com/NousResearch/hermes-agent/pull/2612) by @Mibayy)
|
||||
- Fix: Matrix added to scheduler delivery platform_map ([#2167](https://github.com/NousResearch/hermes-agent/pull/2167) by @buntingszn)
|
||||
- Fix: naive ISO timestamps without timezone — jobs fire at wrong time ([#1729](https://github.com/NousResearch/hermes-agent/pull/1729))
|
||||
- Fix: `get_due_jobs` reads `jobs.json` twice — race condition ([#1716](https://github.com/NousResearch/hermes-agent/pull/1716))
|
||||
- Fix: silent jobs return empty response for delivery skip ([#2442](https://github.com/NousResearch/hermes-agent/pull/2442))
|
||||
- Fix: stop injecting cron outputs into gateway session history ([#2313](https://github.com/NousResearch/hermes-agent/pull/2313))
|
||||
- Fix: close abandoned coroutine when `asyncio.run()` raises RuntimeError ([#2317](https://github.com/NousResearch/hermes-agent/pull/2317))
|
||||
|
||||
---
|
||||
|
||||
## 🧪 Testing
|
||||
|
||||
- Resolve all consistently failing tests ([#2488](https://github.com/NousResearch/hermes-agent/pull/2488))
|
||||
- Replace `FakePath` with `monkeypatch` for Python 3.12 compat ([#2444](https://github.com/NousResearch/hermes-agent/pull/2444))
|
||||
- Align Hermes setup and full-suite expectations ([#1710](https://github.com/NousResearch/hermes-agent/pull/1710))
|
||||
|
||||
---
|
||||
|
||||
## 📚 Documentation
|
||||
|
||||
- Comprehensive docs update for recent features ([#1693](https://github.com/NousResearch/hermes-agent/pull/1693), [#2183](https://github.com/NousResearch/hermes-agent/pull/2183))
|
||||
- Alibaba Cloud and DingTalk setup guides ([#1687](https://github.com/NousResearch/hermes-agent/pull/1687), [#1692](https://github.com/NousResearch/hermes-agent/pull/1692))
|
||||
- Detailed skills documentation ([#2244](https://github.com/NousResearch/hermes-agent/pull/2244))
|
||||
- Honcho self-hosted / Docker configuration ([#2475](https://github.com/NousResearch/hermes-agent/pull/2475))
|
||||
- Context length detection FAQ and quickstart references ([#2179](https://github.com/NousResearch/hermes-agent/pull/2179))
|
||||
- Fix docs inconsistencies across reference and user guides ([#1995](https://github.com/NousResearch/hermes-agent/pull/1995))
|
||||
- Fix MCP install commands — use uv, not bare pip ([#1909](https://github.com/NousResearch/hermes-agent/pull/1909))
|
||||
- Replace ASCII diagrams with Mermaid/lists ([#2402](https://github.com/NousResearch/hermes-agent/pull/2402))
|
||||
- Gemini OAuth provider implementation plan ([#2467](https://github.com/NousResearch/hermes-agent/pull/2467))
|
||||
- Discord Server Members Intent marked as required ([#2330](https://github.com/NousResearch/hermes-agent/pull/2330))
|
||||
- Fix MDX build error in api-server.md ([#1787](https://github.com/NousResearch/hermes-agent/pull/1787))
|
||||
- Align venv path to match installer ([#2114](https://github.com/NousResearch/hermes-agent/pull/2114))
|
||||
- New skills added to hub index ([#2281](https://github.com/NousResearch/hermes-agent/pull/2281))
|
||||
|
||||
---
|
||||
|
||||
## 👥 Contributors
|
||||
|
||||
### Core
|
||||
- **@teknium1** (Teknium) — 280 PRs
|
||||
|
||||
### Community Contributors
|
||||
- **@mchzimm** (to_the_max) — GitHub Copilot provider integration ([#1879](https://github.com/NousResearch/hermes-agent/pull/1879))
|
||||
- **@jquesnelle** (Jeffrey Quesnelle) — Per-thread persistent event loops fix ([#2214](https://github.com/NousResearch/hermes-agent/pull/2214))
|
||||
- **@llbn** (lbn) — Telegram MarkdownV2 strikethrough, spoiler, blockquotes, and escape fixes ([#2199](https://github.com/NousResearch/hermes-agent/pull/2199), [#2200](https://github.com/NousResearch/hermes-agent/pull/2200))
|
||||
- **@dusterbloom** — SQL injection prevention + local server context window querying ([#2061](https://github.com/NousResearch/hermes-agent/pull/2061), [#2091](https://github.com/NousResearch/hermes-agent/pull/2091))
|
||||
- **@0xbyt4** — Anthropic tool_calls None guard + OpenCode-Go provider config fix ([#2209](https://github.com/NousResearch/hermes-agent/pull/2209), [#2393](https://github.com/NousResearch/hermes-agent/pull/2393))
|
||||
- **@sai-samarth** (Saisamarth) — WhatsApp send_message routing + systemd node path ([#1769](https://github.com/NousResearch/hermes-agent/pull/1769), [#1767](https://github.com/NousResearch/hermes-agent/pull/1767))
|
||||
- **@Gutslabs** (Guts) — Block @ references from reading secrets ([#2601](https://github.com/NousResearch/hermes-agent/pull/2601))
|
||||
- **@Mibayy** (Mibay) — Cron job repeat normalization ([#2612](https://github.com/NousResearch/hermes-agent/pull/2612))
|
||||
- **@ten-jampa** (Tenzin Jampa) — Gateway /title command fix ([#2379](https://github.com/NousResearch/hermes-agent/pull/2379))
|
||||
- **@cutepawss** (lila) — File tools search pagination fix ([#1824](https://github.com/NousResearch/hermes-agent/pull/1824))
|
||||
- **@hanai** (Hanai) — OpenAI TTS base_url support ([#2064](https://github.com/NousResearch/hermes-agent/pull/2064))
|
||||
- **@rovle** (Lovre Pešut) — Daytona sandbox API migration ([#2063](https://github.com/NousResearch/hermes-agent/pull/2063))
|
||||
- **@buntingszn** (bunting szn) — Matrix cron delivery support ([#2167](https://github.com/NousResearch/hermes-agent/pull/2167))
|
||||
- **@InB4DevOps** — Token counter reset on new session ([#2101](https://github.com/NousResearch/hermes-agent/pull/2101))
|
||||
- **@JiwaniZakir** (Zakir Jiwani) — Missing file in wheel fix ([#2098](https://github.com/NousResearch/hermes-agent/pull/2098))
|
||||
- **@ygd58** (buray) — Delegate tool parent tool names fix ([#2083](https://github.com/NousResearch/hermes-agent/pull/2083))
|
||||
|
||||
---
|
||||
|
||||
**Full Changelog**: [v2026.3.17...v2026.3.23](https://github.com/NousResearch/hermes-agent/compare/v2026.3.17...v2026.3.23)
|
||||
@@ -18,6 +18,7 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from hermes_constants import get_hermes_home
|
||||
|
||||
|
||||
def _setup_logging() -> None:
|
||||
@@ -44,7 +45,7 @@ def _load_env() -> None:
|
||||
"""Load .env from HERMES_HOME (default ``~/.hermes``)."""
|
||||
from hermes_cli.env_loader import load_hermes_dotenv
|
||||
|
||||
hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
hermes_home = get_hermes_home()
|
||||
loaded = load_hermes_dotenv(hermes_home=hermes_home)
|
||||
if loaded:
|
||||
for env_file in loaded:
|
||||
|
||||
@@ -10,7 +10,7 @@ thread while the event loop lives on the main thread).
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from collections import defaultdict, deque
|
||||
from collections import deque
|
||||
from typing import Any, Callable, Deque, Dict
|
||||
|
||||
import acp
|
||||
|
||||
@@ -5,14 +5,11 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
from concurrent.futures import TimeoutError as FutureTimeout
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Callable
|
||||
|
||||
from acp.schema import (
|
||||
AllowedOutcome,
|
||||
DeniedOutcome,
|
||||
PermissionOption,
|
||||
RequestPermissionRequest,
|
||||
SelectedPermissionOutcome,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -383,11 +383,11 @@ class HermesACPAgent(acp.Agent):
|
||||
|
||||
new_model = args.strip()
|
||||
target_provider = None
|
||||
current_provider = getattr(state.agent, "provider", None) or "openrouter"
|
||||
|
||||
# Auto-detect provider for the requested model
|
||||
try:
|
||||
from hermes_cli.models import parse_model_input, detect_provider_for_model
|
||||
current_provider = getattr(state.agent, "provider", None) or "openrouter"
|
||||
target_provider, new_model = parse_model_input(new_model, current_provider)
|
||||
if target_provider == current_provider:
|
||||
detected = detect_provider_for_model(new_model, current_provider)
|
||||
@@ -401,9 +401,10 @@ class HermesACPAgent(acp.Agent):
|
||||
session_id=state.session_id,
|
||||
cwd=state.cwd,
|
||||
model=new_model,
|
||||
requested_provider=target_provider or current_provider,
|
||||
)
|
||||
self.session_manager.save_session(state.session_id)
|
||||
provider_label = target_provider or getattr(state.agent, "provider", "auto")
|
||||
provider_label = getattr(state.agent, "provider", None) or target_provider or current_provider
|
||||
logger.info("Session %s: model switched to %s", state.session_id, new_model)
|
||||
return f"Model switched to: {new_model}\nProvider: {provider_label}"
|
||||
|
||||
@@ -475,10 +476,16 @@ class HermesACPAgent(acp.Agent):
|
||||
state = self.session_manager.get_session(session_id)
|
||||
if state:
|
||||
state.model = model_id
|
||||
current_provider = getattr(state.agent, "provider", None)
|
||||
current_base_url = getattr(state.agent, "base_url", None)
|
||||
current_api_mode = getattr(state.agent, "api_mode", None)
|
||||
state.agent = self.session_manager._make_agent(
|
||||
session_id=session_id,
|
||||
cwd=state.cwd,
|
||||
model=model_id,
|
||||
requested_provider=current_provider,
|
||||
base_url=current_base_url,
|
||||
api_mode=current_api_mode,
|
||||
)
|
||||
self.session_manager.save_session(session_id)
|
||||
logger.info("Session %s: model switched to %s", session_id, model_id)
|
||||
|
||||
+39
-9
@@ -8,6 +8,8 @@ history.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
@@ -251,7 +253,7 @@ class SessionManager:
|
||||
import os
|
||||
from pathlib import Path
|
||||
from hermes_state import SessionDB
|
||||
hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
hermes_home = get_hermes_home()
|
||||
self._db_instance = SessionDB(db_path=hermes_home / "state.db")
|
||||
return self._db_instance
|
||||
except Exception:
|
||||
@@ -270,7 +272,17 @@ class SessionManager:
|
||||
|
||||
# Ensure model is a plain string (not a MagicMock or other proxy).
|
||||
model_str = str(state.model) if state.model else None
|
||||
cwd_json = json.dumps({"cwd": state.cwd})
|
||||
session_meta = {"cwd": state.cwd}
|
||||
provider = getattr(state.agent, "provider", None)
|
||||
base_url = getattr(state.agent, "base_url", None)
|
||||
api_mode = getattr(state.agent, "api_mode", None)
|
||||
if isinstance(provider, str) and provider.strip():
|
||||
session_meta["provider"] = provider.strip()
|
||||
if isinstance(base_url, str) and base_url.strip():
|
||||
session_meta["base_url"] = base_url.strip()
|
||||
if isinstance(api_mode, str) and api_mode.strip():
|
||||
session_meta["api_mode"] = api_mode.strip()
|
||||
cwd_json = json.dumps(session_meta)
|
||||
|
||||
try:
|
||||
# Ensure the session record exists.
|
||||
@@ -331,10 +343,18 @@ class SessionManager:
|
||||
|
||||
# Extract cwd from model_config.
|
||||
cwd = "."
|
||||
requested_provider = row.get("billing_provider")
|
||||
restored_base_url = row.get("billing_base_url")
|
||||
restored_api_mode = None
|
||||
mc = row.get("model_config")
|
||||
if mc:
|
||||
try:
|
||||
cwd = json.loads(mc).get("cwd", ".")
|
||||
meta = json.loads(mc)
|
||||
if isinstance(meta, dict):
|
||||
cwd = meta.get("cwd", ".")
|
||||
requested_provider = meta.get("provider") or requested_provider
|
||||
restored_base_url = meta.get("base_url") or restored_base_url
|
||||
restored_api_mode = meta.get("api_mode") or restored_api_mode
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
@@ -348,7 +368,14 @@ class SessionManager:
|
||||
history = []
|
||||
|
||||
try:
|
||||
agent = self._make_agent(session_id=session_id, cwd=cwd, model=model)
|
||||
agent = self._make_agent(
|
||||
session_id=session_id,
|
||||
cwd=cwd,
|
||||
model=model,
|
||||
requested_provider=requested_provider,
|
||||
base_url=restored_base_url,
|
||||
api_mode=restored_api_mode,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to recreate agent for ACP session %s", session_id, exc_info=True)
|
||||
return None
|
||||
@@ -386,6 +413,9 @@ class SessionManager:
|
||||
session_id: str,
|
||||
cwd: str,
|
||||
model: str | None = None,
|
||||
requested_provider: str | None = None,
|
||||
base_url: str | None = None,
|
||||
api_mode: str | None = None,
|
||||
):
|
||||
if self._agent_factory is not None:
|
||||
return self._agent_factory()
|
||||
@@ -397,10 +427,10 @@ class SessionManager:
|
||||
config = load_config()
|
||||
model_cfg = config.get("model")
|
||||
default_model = "anthropic/claude-opus-4.6"
|
||||
requested_provider = None
|
||||
config_provider = None
|
||||
if isinstance(model_cfg, dict):
|
||||
default_model = str(model_cfg.get("default") or default_model)
|
||||
requested_provider = model_cfg.get("provider")
|
||||
config_provider = model_cfg.get("provider")
|
||||
elif isinstance(model_cfg, str) and model_cfg.strip():
|
||||
default_model = model_cfg.strip()
|
||||
|
||||
@@ -413,12 +443,12 @@ class SessionManager:
|
||||
}
|
||||
|
||||
try:
|
||||
runtime = resolve_runtime_provider(requested=requested_provider)
|
||||
runtime = resolve_runtime_provider(requested=requested_provider or config_provider)
|
||||
kwargs.update(
|
||||
{
|
||||
"provider": runtime.get("provider"),
|
||||
"api_mode": runtime.get("api_mode"),
|
||||
"base_url": runtime.get("base_url"),
|
||||
"api_mode": api_mode or runtime.get("api_mode"),
|
||||
"base_url": base_url or runtime.get("base_url"),
|
||||
"api_key": runtime.get("api_key"),
|
||||
"command": runtime.get("command"),
|
||||
"args": list(runtime.get("args") or []),
|
||||
|
||||
+71
-256
@@ -14,6 +14,8 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
@@ -57,6 +59,7 @@ _OAUTH_ONLY_BETAS = [
|
||||
# The version must stay reasonably current — Anthropic rejects OAuth requests
|
||||
# when the spoofed user-agent version is too far behind the actual release.
|
||||
_CLAUDE_CODE_VERSION_FALLBACK = "2.1.74"
|
||||
_claude_code_version_cache: Optional[str] = None
|
||||
|
||||
|
||||
def _detect_claude_code_version() -> str:
|
||||
@@ -84,11 +87,18 @@ def _detect_claude_code_version() -> str:
|
||||
return _CLAUDE_CODE_VERSION_FALLBACK
|
||||
|
||||
|
||||
_CLAUDE_CODE_VERSION = _detect_claude_code_version()
|
||||
_CLAUDE_CODE_SYSTEM_PREFIX = "You are Claude Code, Anthropic's official CLI for Claude."
|
||||
_MCP_TOOL_PREFIX = "mcp_"
|
||||
|
||||
|
||||
def _get_claude_code_version() -> str:
|
||||
"""Lazily detect the installed Claude Code version when OAuth headers need it."""
|
||||
global _claude_code_version_cache
|
||||
if _claude_code_version_cache is None:
|
||||
_claude_code_version_cache = _detect_claude_code_version()
|
||||
return _claude_code_version_cache
|
||||
|
||||
|
||||
def _is_oauth_token(key: str) -> bool:
|
||||
"""Check if the key is an OAuth/setup token (not a regular Console API key).
|
||||
|
||||
@@ -130,7 +140,7 @@ def build_anthropic_client(api_key: str, base_url: str = None):
|
||||
kwargs["auth_token"] = api_key
|
||||
kwargs["default_headers"] = {
|
||||
"anthropic-beta": ",".join(all_betas),
|
||||
"user-agent": f"claude-cli/{_CLAUDE_CODE_VERSION} (external, cli)",
|
||||
"user-agent": f"claude-cli/{_get_claude_code_version()} (external, cli)",
|
||||
"x-app": "cli",
|
||||
}
|
||||
else:
|
||||
@@ -208,9 +218,12 @@ def _refresh_oauth_token(creds: Dict[str, Any]) -> Optional[str]:
|
||||
Only works for credentials that have a refresh token (from claude /login
|
||||
or claude setup-token with OAuth flow).
|
||||
|
||||
Tries the new platform.claude.com endpoint first (Claude Code >=2.1.81),
|
||||
then falls back to console.anthropic.com for older tokens.
|
||||
|
||||
Returns the new access token, or None if refresh fails.
|
||||
"""
|
||||
import urllib.parse
|
||||
import time
|
||||
import urllib.request
|
||||
|
||||
refresh_token = creds.get("refreshToken", "")
|
||||
@@ -221,38 +234,42 @@ def _refresh_oauth_token(creds: Dict[str, Any]) -> Optional[str]:
|
||||
# Client ID used by Claude Code's OAuth flow
|
||||
CLIENT_ID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||
|
||||
data = urllib.parse.urlencode({
|
||||
# Anthropic migrated OAuth from console.anthropic.com to platform.claude.com
|
||||
# (Claude Code v2.1.81+). Try new endpoint first, fall back to old.
|
||||
token_endpoints = [
|
||||
"https://platform.claude.com/v1/oauth/token",
|
||||
"https://console.anthropic.com/v1/oauth/token",
|
||||
]
|
||||
|
||||
payload = json.dumps({
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": CLIENT_ID,
|
||||
}).encode()
|
||||
|
||||
req = urllib.request.Request(
|
||||
"https://console.anthropic.com/v1/oauth/token",
|
||||
data=data,
|
||||
headers={
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"User-Agent": f"claude-cli/{_CLAUDE_CODE_VERSION} (external, cli)",
|
||||
},
|
||||
method="POST",
|
||||
)
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": f"claude-cli/{_get_claude_code_version()} (external, cli)",
|
||||
}
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
result = json.loads(resp.read().decode())
|
||||
new_access = result.get("access_token", "")
|
||||
new_refresh = result.get("refresh_token", refresh_token)
|
||||
expires_in = result.get("expires_in", 3600) # seconds
|
||||
for endpoint in token_endpoints:
|
||||
req = urllib.request.Request(
|
||||
endpoint, data=payload, headers=headers, method="POST",
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
result = json.loads(resp.read().decode())
|
||||
new_access = result.get("access_token", "")
|
||||
new_refresh = result.get("refresh_token", refresh_token)
|
||||
expires_in = result.get("expires_in", 3600)
|
||||
|
||||
if new_access:
|
||||
import time
|
||||
new_expires_ms = int(time.time() * 1000) + (expires_in * 1000)
|
||||
# Write refreshed credentials back to ~/.claude/.credentials.json
|
||||
_write_claude_code_credentials(new_access, new_refresh, new_expires_ms)
|
||||
logger.debug("Successfully refreshed Claude Code OAuth token")
|
||||
return new_access
|
||||
except Exception as e:
|
||||
logger.debug("Failed to refresh Claude Code token: %s", e)
|
||||
if new_access:
|
||||
new_expires_ms = int(time.time() * 1000) + (expires_in * 1000)
|
||||
_write_claude_code_credentials(new_access, new_refresh, new_expires_ms)
|
||||
logger.debug("Refreshed Claude Code OAuth token via %s", endpoint)
|
||||
return new_access
|
||||
except Exception as e:
|
||||
logger.debug("Token refresh failed at %s: %s", endpoint, e)
|
||||
|
||||
return None
|
||||
|
||||
@@ -376,24 +393,12 @@ def resolve_anthropic_token() -> Optional[str]:
|
||||
return preferred
|
||||
return cc_token
|
||||
|
||||
# 3. Hermes-managed OAuth credentials (~/.hermes/.anthropic_oauth.json)
|
||||
hermes_creds = read_hermes_oauth_credentials()
|
||||
if hermes_creds:
|
||||
if is_claude_code_token_valid(hermes_creds):
|
||||
logger.debug("Using Hermes-managed OAuth credentials")
|
||||
return hermes_creds["accessToken"]
|
||||
# Expired — try refresh
|
||||
logger.debug("Hermes OAuth token expired — attempting refresh")
|
||||
refreshed = refresh_hermes_oauth_token()
|
||||
if refreshed:
|
||||
return refreshed
|
||||
|
||||
# 4. Claude Code credential file
|
||||
# 3. Claude Code credential file
|
||||
resolved_claude_token = _resolve_claude_code_token_from_credentials(creds)
|
||||
if resolved_claude_token:
|
||||
return resolved_claude_token
|
||||
|
||||
# 5. Regular API key, or a legacy OAuth token saved in ANTHROPIC_API_KEY.
|
||||
# 4. Regular API key, or a legacy OAuth token saved in ANTHROPIC_API_KEY.
|
||||
# This remains as a compatibility fallback for pre-migration Hermes configs.
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY", "").strip()
|
||||
if api_key:
|
||||
@@ -442,213 +447,10 @@ def run_oauth_setup_token() -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
# ── Hermes-native PKCE OAuth flow ────────────────────────────────────────
|
||||
# Mirrors the flow used by Claude Code, pi-ai, and OpenCode.
|
||||
# Stores credentials in ~/.hermes/.anthropic_oauth.json (our own file).
|
||||
|
||||
_OAUTH_CLIENT_ID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||
_OAUTH_TOKEN_URL = "https://console.anthropic.com/v1/oauth/token"
|
||||
_OAUTH_REDIRECT_URI = "https://console.anthropic.com/oauth/code/callback"
|
||||
_OAUTH_SCOPES = "org:create_api_key user:profile user:inference"
|
||||
_HERMES_OAUTH_FILE = Path(os.getenv("HERMES_HOME", str(Path.home() / ".hermes"))) / ".anthropic_oauth.json"
|
||||
|
||||
|
||||
def _generate_pkce() -> tuple:
|
||||
"""Generate PKCE code_verifier and code_challenge (S256)."""
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
|
||||
verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).rstrip(b"=").decode()
|
||||
challenge = base64.urlsafe_b64encode(
|
||||
hashlib.sha256(verifier.encode()).digest()
|
||||
).rstrip(b"=").decode()
|
||||
return verifier, challenge
|
||||
|
||||
|
||||
def run_hermes_oauth_login() -> Optional[str]:
|
||||
"""Run Hermes-native OAuth PKCE flow for Claude Pro/Max subscription.
|
||||
|
||||
Opens a browser to claude.ai for authorization, prompts for the code,
|
||||
exchanges it for tokens, and stores them in ~/.hermes/.anthropic_oauth.json.
|
||||
|
||||
Returns the access token on success, None on failure.
|
||||
"""
|
||||
import time
|
||||
import webbrowser
|
||||
|
||||
verifier, challenge = _generate_pkce()
|
||||
|
||||
# Build authorization URL
|
||||
params = {
|
||||
"code": "true",
|
||||
"client_id": _OAUTH_CLIENT_ID,
|
||||
"response_type": "code",
|
||||
"redirect_uri": _OAUTH_REDIRECT_URI,
|
||||
"scope": _OAUTH_SCOPES,
|
||||
"code_challenge": challenge,
|
||||
"code_challenge_method": "S256",
|
||||
"state": verifier,
|
||||
}
|
||||
from urllib.parse import urlencode
|
||||
auth_url = f"https://claude.ai/oauth/authorize?{urlencode(params)}"
|
||||
|
||||
print()
|
||||
print("Authorize Hermes with your Claude Pro/Max subscription.")
|
||||
print()
|
||||
print("╭─ Claude Pro/Max Authorization ────────────────────╮")
|
||||
print("│ │")
|
||||
print("│ Open this link in your browser: │")
|
||||
print("╰───────────────────────────────────────────────────╯")
|
||||
print()
|
||||
print(f" {auth_url}")
|
||||
print()
|
||||
|
||||
# Try to open browser automatically (works on desktop, silently fails on headless/SSH)
|
||||
try:
|
||||
webbrowser.open(auth_url)
|
||||
print(" (Browser opened automatically)")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
print()
|
||||
print("After authorizing, you'll see a code. Paste it below.")
|
||||
print()
|
||||
try:
|
||||
auth_code = input("Authorization code: ").strip()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
return None
|
||||
|
||||
if not auth_code:
|
||||
print("No code entered.")
|
||||
return None
|
||||
|
||||
# Split code#state format
|
||||
splits = auth_code.split("#")
|
||||
code = splits[0]
|
||||
state = splits[1] if len(splits) > 1 else ""
|
||||
|
||||
# Exchange code for tokens
|
||||
try:
|
||||
import urllib.request
|
||||
exchange_data = json.dumps({
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": _OAUTH_CLIENT_ID,
|
||||
"code": code,
|
||||
"state": state,
|
||||
"redirect_uri": _OAUTH_REDIRECT_URI,
|
||||
"code_verifier": verifier,
|
||||
}).encode()
|
||||
|
||||
req = urllib.request.Request(
|
||||
_OAUTH_TOKEN_URL,
|
||||
data=exchange_data,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": f"claude-cli/{_CLAUDE_CODE_VERSION} (external, cli)",
|
||||
},
|
||||
method="POST",
|
||||
)
|
||||
|
||||
with urllib.request.urlopen(req, timeout=15) as resp:
|
||||
result = json.loads(resp.read().decode())
|
||||
except Exception as e:
|
||||
print(f"Token exchange failed: {e}")
|
||||
return None
|
||||
|
||||
access_token = result.get("access_token", "")
|
||||
refresh_token = result.get("refresh_token", "")
|
||||
expires_in = result.get("expires_in", 3600)
|
||||
|
||||
if not access_token:
|
||||
print("No access token in response.")
|
||||
return None
|
||||
|
||||
# Store credentials
|
||||
expires_at_ms = int(time.time() * 1000) + (expires_in * 1000)
|
||||
_save_hermes_oauth_credentials(access_token, refresh_token, expires_at_ms)
|
||||
|
||||
# Also write to Claude Code's credential file for backward compat
|
||||
_write_claude_code_credentials(access_token, refresh_token, expires_at_ms)
|
||||
|
||||
print("Authentication successful!")
|
||||
return access_token
|
||||
|
||||
|
||||
def _save_hermes_oauth_credentials(access_token: str, refresh_token: str, expires_at_ms: int) -> None:
|
||||
"""Save OAuth credentials to ~/.hermes/.anthropic_oauth.json."""
|
||||
data = {
|
||||
"accessToken": access_token,
|
||||
"refreshToken": refresh_token,
|
||||
"expiresAt": expires_at_ms,
|
||||
}
|
||||
try:
|
||||
_HERMES_OAUTH_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
_HERMES_OAUTH_FILE.write_text(json.dumps(data, indent=2), encoding="utf-8")
|
||||
_HERMES_OAUTH_FILE.chmod(0o600)
|
||||
except (OSError, IOError) as e:
|
||||
logger.debug("Failed to save Hermes OAuth credentials: %s", e)
|
||||
|
||||
|
||||
def read_hermes_oauth_credentials() -> Optional[Dict[str, Any]]:
|
||||
"""Read Hermes-managed OAuth credentials from ~/.hermes/.anthropic_oauth.json."""
|
||||
if _HERMES_OAUTH_FILE.exists():
|
||||
try:
|
||||
data = json.loads(_HERMES_OAUTH_FILE.read_text(encoding="utf-8"))
|
||||
if data.get("accessToken"):
|
||||
return data
|
||||
except (json.JSONDecodeError, OSError, IOError) as e:
|
||||
logger.debug("Failed to read Hermes OAuth credentials: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def refresh_hermes_oauth_token() -> Optional[str]:
|
||||
"""Refresh the Hermes-managed OAuth token using the stored refresh token.
|
||||
|
||||
Returns the new access token, or None if refresh fails.
|
||||
"""
|
||||
import time
|
||||
import urllib.request
|
||||
|
||||
creds = read_hermes_oauth_credentials()
|
||||
if not creds or not creds.get("refreshToken"):
|
||||
return None
|
||||
|
||||
try:
|
||||
data = json.dumps({
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": creds["refreshToken"],
|
||||
"client_id": _OAUTH_CLIENT_ID,
|
||||
}).encode()
|
||||
|
||||
req = urllib.request.Request(
|
||||
_OAUTH_TOKEN_URL,
|
||||
data=data,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": f"claude-cli/{_CLAUDE_CODE_VERSION} (external, cli)",
|
||||
},
|
||||
method="POST",
|
||||
)
|
||||
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
result = json.loads(resp.read().decode())
|
||||
|
||||
new_access = result.get("access_token", "")
|
||||
new_refresh = result.get("refresh_token", creds["refreshToken"])
|
||||
expires_in = result.get("expires_in", 3600)
|
||||
|
||||
if new_access:
|
||||
new_expires_ms = int(time.time() * 1000) + (expires_in * 1000)
|
||||
_save_hermes_oauth_credentials(new_access, new_refresh, new_expires_ms)
|
||||
# Also update Claude Code's credential file
|
||||
_write_claude_code_credentials(new_access, new_refresh, new_expires_ms)
|
||||
logger.debug("Successfully refreshed Hermes OAuth token")
|
||||
return new_access
|
||||
except Exception as e:
|
||||
logger.debug("Failed to refresh Hermes OAuth token: %s", e)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -656,19 +458,21 @@ def refresh_hermes_oauth_token() -> Optional[str]:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def normalize_model_name(model: str) -> str:
|
||||
def normalize_model_name(model: str, preserve_dots: bool = False) -> str:
|
||||
"""Normalize a model name for the Anthropic API.
|
||||
|
||||
- Strips 'anthropic/' prefix (OpenRouter format, case-insensitive)
|
||||
- Converts dots to hyphens in version numbers (OpenRouter uses dots,
|
||||
Anthropic uses hyphens: claude-opus-4.6 → claude-opus-4-6)
|
||||
Anthropic uses hyphens: claude-opus-4.6 → claude-opus-4-6), unless
|
||||
preserve_dots is True (e.g. for Alibaba/DashScope: qwen3.5-plus).
|
||||
"""
|
||||
lower = model.lower()
|
||||
if lower.startswith("anthropic/"):
|
||||
model = model[len("anthropic/"):]
|
||||
# OpenRouter uses dots for version separators (claude-opus-4.6),
|
||||
# Anthropic uses hyphens (claude-opus-4-6). Convert dots to hyphens.
|
||||
model = model.replace(".", "-")
|
||||
if not preserve_dots:
|
||||
# OpenRouter uses dots for version separators (claude-opus-4.6),
|
||||
# Anthropic uses hyphens (claude-opus-4-6). Convert dots to hyphens.
|
||||
model = model.replace(".", "-")
|
||||
return model
|
||||
|
||||
|
||||
@@ -910,14 +714,21 @@ def convert_messages_to_anthropic(
|
||||
result.append({"role": "user", "content": [tool_result]})
|
||||
continue
|
||||
|
||||
# Regular user message
|
||||
# Regular user message — validate non-empty content (Anthropic rejects empty)
|
||||
if isinstance(content, list):
|
||||
converted_blocks = _convert_content_to_anthropic(content)
|
||||
result.append({
|
||||
"role": "user",
|
||||
"content": converted_blocks or [{"type": "text", "text": ""}],
|
||||
})
|
||||
# Check if all text blocks are empty
|
||||
if not converted_blocks or all(
|
||||
b.get("text", "").strip() == ""
|
||||
for b in converted_blocks
|
||||
if isinstance(b, dict) and b.get("type") == "text"
|
||||
):
|
||||
converted_blocks = [{"type": "text", "text": "(empty message)"}]
|
||||
result.append({"role": "user", "content": converted_blocks})
|
||||
else:
|
||||
# Validate string content is non-empty
|
||||
if not content or (isinstance(content, str) and not content.strip()):
|
||||
content = "(empty message)"
|
||||
result.append({"role": "user", "content": content})
|
||||
|
||||
# Strip orphaned tool_use blocks (no matching tool_result follows)
|
||||
@@ -1006,16 +817,20 @@ def build_anthropic_kwargs(
|
||||
reasoning_config: Optional[Dict[str, Any]],
|
||||
tool_choice: Optional[str] = None,
|
||||
is_oauth: bool = False,
|
||||
preserve_dots: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build kwargs for anthropic.messages.create().
|
||||
|
||||
When *is_oauth* is True, applies Claude Code compatibility transforms:
|
||||
system prompt prefix, tool name prefixing, and prompt sanitization.
|
||||
|
||||
When *preserve_dots* is True, model name dots are not converted to hyphens
|
||||
(for Alibaba/DashScope anthropic-compatible endpoints: qwen3.5-plus).
|
||||
"""
|
||||
system, anthropic_messages = convert_messages_to_anthropic(messages)
|
||||
anthropic_tools = convert_tools_to_anthropic(tools) if tools else []
|
||||
|
||||
model = normalize_model_name(model)
|
||||
model = normalize_model_name(model, preserve_dots=preserve_dots)
|
||||
effective_max_tokens = max_tokens or 16384
|
||||
|
||||
# ── OAuth: Claude Code identity ──────────────────────────────────
|
||||
|
||||
+197
-30
@@ -40,7 +40,8 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from pathlib import Path
|
||||
import time
|
||||
from pathlib import Path # noqa: F401 — used by test mocks
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
@@ -81,7 +82,7 @@ auxiliary_is_nous: bool = False
|
||||
|
||||
# Default auxiliary models per provider
|
||||
_OPENROUTER_MODEL = "google/gemini-3-flash-preview"
|
||||
_NOUS_MODEL = "gemini-3-flash"
|
||||
_NOUS_MODEL = "google/gemini-3-flash-preview"
|
||||
_NOUS_DEFAULT_BASE_URL = "https://inference-api.nousresearch.com/v1"
|
||||
_ANTHROPIC_DEFAULT_BASE_URL = "https://api.anthropic.com"
|
||||
_AUTH_JSON_PATH = get_hermes_home() / "auth.json"
|
||||
@@ -325,9 +326,10 @@ class AsyncCodexAuxiliaryClient:
|
||||
class _AnthropicCompletionsAdapter:
|
||||
"""OpenAI-client-compatible adapter for Anthropic Messages API."""
|
||||
|
||||
def __init__(self, real_client: Any, model: str):
|
||||
def __init__(self, real_client: Any, model: str, is_oauth: bool = False):
|
||||
self._client = real_client
|
||||
self._model = model
|
||||
self._is_oauth = is_oauth
|
||||
|
||||
def create(self, **kwargs) -> Any:
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs, normalize_anthropic_response
|
||||
@@ -356,6 +358,7 @@ class _AnthropicCompletionsAdapter:
|
||||
max_tokens=max_tokens,
|
||||
reasoning_config=None,
|
||||
tool_choice=normalized_tool_choice,
|
||||
is_oauth=self._is_oauth,
|
||||
)
|
||||
if temperature is not None:
|
||||
anthropic_kwargs["temperature"] = temperature
|
||||
@@ -394,9 +397,9 @@ class _AnthropicChatShim:
|
||||
class AnthropicAuxiliaryClient:
|
||||
"""OpenAI-client-compatible wrapper over a native Anthropic client."""
|
||||
|
||||
def __init__(self, real_client: Any, model: str, api_key: str, base_url: str):
|
||||
def __init__(self, real_client: Any, model: str, api_key: str, base_url: str, is_oauth: bool = False):
|
||||
self._real_client = real_client
|
||||
adapter = _AnthropicCompletionsAdapter(real_client, model)
|
||||
adapter = _AnthropicCompletionsAdapter(real_client, model, is_oauth=is_oauth)
|
||||
self.chat = _AnthropicChatShim(adapter)
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
@@ -463,15 +466,30 @@ def _nous_base_url() -> str:
|
||||
|
||||
|
||||
def _read_codex_access_token() -> Optional[str]:
|
||||
"""Read a valid Codex OAuth access token from Hermes auth store (~/.hermes/auth.json)."""
|
||||
"""Read a valid, non-expired Codex OAuth access token from Hermes auth store."""
|
||||
try:
|
||||
from hermes_cli.auth import _read_codex_tokens
|
||||
data = _read_codex_tokens()
|
||||
tokens = data.get("tokens", {})
|
||||
access_token = tokens.get("access_token")
|
||||
if isinstance(access_token, str) and access_token.strip():
|
||||
return access_token.strip()
|
||||
return None
|
||||
if not isinstance(access_token, str) or not access_token.strip():
|
||||
return None
|
||||
|
||||
# Check JWT expiry — expired tokens block the auto chain and
|
||||
# prevent fallback to working providers (e.g. Anthropic).
|
||||
try:
|
||||
import base64
|
||||
payload = access_token.split(".")[1]
|
||||
payload += "=" * (-len(payload) % 4)
|
||||
claims = json.loads(base64.urlsafe_b64decode(payload))
|
||||
exp = claims.get("exp", 0)
|
||||
if exp and time.time() > exp:
|
||||
logger.debug("Codex access token expired (exp=%s), skipping", exp)
|
||||
return None
|
||||
except Exception:
|
||||
pass # Non-JWT token or decode error — use as-is
|
||||
|
||||
return access_token.strip()
|
||||
except Exception as exc:
|
||||
logger.debug("Could not read Codex auth for auxiliary client: %s", exc)
|
||||
return None
|
||||
@@ -654,23 +672,35 @@ def _try_anthropic() -> Tuple[Optional[Any], Optional[str]]:
|
||||
if not token:
|
||||
return None, None
|
||||
|
||||
# Allow base URL override from config.yaml model.base_url
|
||||
# Allow base URL override from config.yaml model.base_url, but only
|
||||
# when the configured provider is anthropic — otherwise a non-Anthropic
|
||||
# base_url (e.g. Codex endpoint) would leak into Anthropic requests.
|
||||
base_url = _ANTHROPIC_DEFAULT_BASE_URL
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
cfg = load_config()
|
||||
model_cfg = cfg.get("model")
|
||||
if isinstance(model_cfg, dict):
|
||||
cfg_base_url = (model_cfg.get("base_url") or "").strip().rstrip("/")
|
||||
if cfg_base_url:
|
||||
base_url = cfg_base_url
|
||||
cfg_provider = str(model_cfg.get("provider") or "").strip().lower()
|
||||
if cfg_provider == "anthropic":
|
||||
cfg_base_url = (model_cfg.get("base_url") or "").strip().rstrip("/")
|
||||
if cfg_base_url:
|
||||
base_url = cfg_base_url
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from agent.anthropic_adapter import _is_oauth_token
|
||||
is_oauth = _is_oauth_token(token)
|
||||
model = _API_KEY_PROVIDER_AUX_MODELS.get("anthropic", "claude-haiku-4-5-20251001")
|
||||
logger.debug("Auxiliary client: Anthropic native (%s) at %s", model, base_url)
|
||||
real_client = build_anthropic_client(token, base_url)
|
||||
return AnthropicAuxiliaryClient(real_client, model, token, base_url), model
|
||||
logger.debug("Auxiliary client: Anthropic native (%s) at %s (oauth=%s)", model, base_url, is_oauth)
|
||||
try:
|
||||
real_client = build_anthropic_client(token, base_url)
|
||||
except ImportError:
|
||||
# The anthropic_adapter module imports fine but the SDK itself is
|
||||
# missing — build_anthropic_client raises ImportError at call time
|
||||
# when _anthropic_sdk is None. Treat as unavailable.
|
||||
return None, None
|
||||
return AnthropicAuxiliaryClient(real_client, model, token, base_url, is_oauth=is_oauth), model
|
||||
|
||||
|
||||
def _resolve_forced_provider(forced: str) -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
@@ -1107,7 +1137,13 @@ def resolve_vision_provider_client(
|
||||
return "custom", client, final_model
|
||||
|
||||
if requested == "auto":
|
||||
for candidate in get_available_vision_backends():
|
||||
ordered = list(_VISION_AUTO_PROVIDER_ORDER)
|
||||
preferred = _preferred_main_vision_provider()
|
||||
if preferred in ordered:
|
||||
ordered.remove(preferred)
|
||||
ordered.insert(0, preferred)
|
||||
|
||||
for candidate in ordered:
|
||||
sync_client, default_model = _resolve_strict_vision_backend(candidate)
|
||||
if sync_client is not None:
|
||||
return _finalize(candidate, sync_client, default_model)
|
||||
@@ -1180,6 +1216,105 @@ _client_cache: Dict[tuple, tuple] = {}
|
||||
_client_cache_lock = threading.Lock()
|
||||
|
||||
|
||||
def neuter_async_httpx_del() -> None:
|
||||
"""Monkey-patch ``AsyncHttpxClientWrapper.__del__`` to be a no-op.
|
||||
|
||||
The OpenAI SDK's ``AsyncHttpxClientWrapper.__del__`` schedules
|
||||
``self.aclose()`` via ``asyncio.get_running_loop().create_task()``.
|
||||
When an ``AsyncOpenAI`` client is garbage-collected while
|
||||
prompt_toolkit's event loop is running (the common CLI idle state),
|
||||
the ``aclose()`` task runs on prompt_toolkit's loop but the
|
||||
underlying TCP transport is bound to a *different* loop (the worker
|
||||
thread's loop that the client was originally created on). If that
|
||||
loop is closed or its thread is dead, the transport's
|
||||
``self._loop.call_soon()`` raises ``RuntimeError("Event loop is
|
||||
closed")``, which prompt_toolkit surfaces as "Unhandled exception
|
||||
in event loop ... Press ENTER to continue...".
|
||||
|
||||
Neutering ``__del__`` is safe because:
|
||||
- Cached clients are explicitly cleaned via ``_force_close_async_httpx``
|
||||
on stale-loop detection and ``shutdown_cached_clients`` on exit.
|
||||
- Uncached clients' TCP connections are cleaned up by the OS when the
|
||||
process exits.
|
||||
- The OpenAI SDK itself marks this as a TODO (``# TODO(someday):
|
||||
support non asyncio runtimes here``).
|
||||
|
||||
Call this once at CLI startup, before any ``AsyncOpenAI`` clients are
|
||||
created.
|
||||
"""
|
||||
try:
|
||||
from openai._base_client import AsyncHttpxClientWrapper
|
||||
AsyncHttpxClientWrapper.__del__ = lambda self: None # type: ignore[assignment]
|
||||
except (ImportError, AttributeError):
|
||||
pass # Graceful degradation if the SDK changes its internals
|
||||
|
||||
|
||||
def _force_close_async_httpx(client: Any) -> None:
|
||||
"""Mark the httpx AsyncClient inside an AsyncOpenAI client as closed.
|
||||
|
||||
This prevents ``AsyncHttpxClientWrapper.__del__`` from scheduling
|
||||
``aclose()`` on a (potentially closed) event loop, which causes
|
||||
``RuntimeError: Event loop is closed`` → prompt_toolkit's
|
||||
"Press ENTER to continue..." handler.
|
||||
|
||||
We intentionally do NOT run the full async close path — the
|
||||
connections will be dropped by the OS when the process exits.
|
||||
"""
|
||||
try:
|
||||
from httpx._client import ClientState
|
||||
inner = getattr(client, "_client", None)
|
||||
if inner is not None and not getattr(inner, "is_closed", True):
|
||||
inner._state = ClientState.CLOSED
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def shutdown_cached_clients() -> None:
|
||||
"""Close all cached clients (sync and async) to prevent event-loop errors.
|
||||
|
||||
Call this during CLI shutdown, *before* the event loop is closed, to
|
||||
avoid ``AsyncHttpxClientWrapper.__del__`` raising on a dead loop.
|
||||
"""
|
||||
import inspect
|
||||
|
||||
with _client_cache_lock:
|
||||
for key, entry in list(_client_cache.items()):
|
||||
client = entry[0]
|
||||
if client is None:
|
||||
continue
|
||||
# Mark any async httpx transport as closed first (prevents __del__
|
||||
# from scheduling aclose() on a dead event loop).
|
||||
_force_close_async_httpx(client)
|
||||
# Sync clients: close the httpx connection pool cleanly.
|
||||
# Async clients: skip — we already neutered __del__ above.
|
||||
try:
|
||||
close_fn = getattr(client, "close", None)
|
||||
if close_fn and not inspect.iscoroutinefunction(close_fn):
|
||||
close_fn()
|
||||
except Exception:
|
||||
pass
|
||||
_client_cache.clear()
|
||||
|
||||
|
||||
def cleanup_stale_async_clients() -> None:
|
||||
"""Force-close cached async clients whose event loop is closed.
|
||||
|
||||
Call this after each agent turn to proactively clean up stale clients
|
||||
before GC can trigger ``AsyncHttpxClientWrapper.__del__`` on them.
|
||||
This is defense-in-depth — the primary fix is ``neuter_async_httpx_del``
|
||||
which disables ``__del__`` entirely.
|
||||
"""
|
||||
with _client_cache_lock:
|
||||
stale_keys = []
|
||||
for key, entry in _client_cache.items():
|
||||
client, _default, cached_loop = entry
|
||||
if cached_loop is not None and cached_loop.is_closed():
|
||||
_force_close_async_httpx(client)
|
||||
stale_keys.append(key)
|
||||
for key in stale_keys:
|
||||
del _client_cache[key]
|
||||
|
||||
|
||||
def _get_cached_client(
|
||||
provider: str,
|
||||
model: str = None,
|
||||
@@ -1187,17 +1322,38 @@ def _get_cached_client(
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
) -> Tuple[Optional[Any], Optional[str]]:
|
||||
"""Get or create a cached client for the given provider."""
|
||||
cache_key = (provider, async_mode, base_url or "", api_key or "")
|
||||
"""Get or create a cached client for the given provider.
|
||||
|
||||
Async clients (AsyncOpenAI) use httpx.AsyncClient internally, which
|
||||
binds to the event loop that was current when the client was created.
|
||||
Using such a client on a *different* loop causes deadlocks or
|
||||
RuntimeError. To prevent cross-loop issues (especially in gateway
|
||||
mode where _run_async() may spawn fresh loops in worker threads), the
|
||||
cache key for async clients includes the current event loop's identity
|
||||
so each loop gets its own client instance.
|
||||
"""
|
||||
# Include loop identity for async clients to prevent cross-loop reuse.
|
||||
# httpx.AsyncClient (inside AsyncOpenAI) is bound to the loop where it
|
||||
# was created — reusing it on a different loop causes deadlocks (#2681).
|
||||
loop_id = 0
|
||||
current_loop = None
|
||||
if async_mode:
|
||||
try:
|
||||
import asyncio as _aio
|
||||
current_loop = _aio.get_event_loop()
|
||||
loop_id = id(current_loop)
|
||||
except RuntimeError:
|
||||
pass
|
||||
cache_key = (provider, async_mode, base_url or "", api_key or "", loop_id)
|
||||
with _client_cache_lock:
|
||||
if cache_key in _client_cache:
|
||||
cached_client, cached_default, cached_loop = _client_cache[cache_key]
|
||||
if async_mode:
|
||||
# Async clients are bound to the event loop that created them.
|
||||
# A cached async client whose loop has been closed will raise
|
||||
# "Event loop is closed" when httpx tries to clean up its
|
||||
# transport. Discard the stale client and create a fresh one.
|
||||
if cached_loop is not None and cached_loop.is_closed():
|
||||
_force_close_async_httpx(cached_client)
|
||||
del _client_cache[cache_key]
|
||||
else:
|
||||
return cached_client, model or cached_default
|
||||
@@ -1214,13 +1370,7 @@ def _get_cached_client(
|
||||
if client is not None:
|
||||
# For async clients, remember which loop they were created on so we
|
||||
# can detect stale entries later.
|
||||
bound_loop = None
|
||||
if async_mode:
|
||||
try:
|
||||
import asyncio as _aio
|
||||
bound_loop = _aio.get_event_loop()
|
||||
except RuntimeError:
|
||||
pass
|
||||
bound_loop = current_loop
|
||||
with _client_cache_lock:
|
||||
if cache_key not in _client_cache:
|
||||
_client_cache[cache_key] = (client, default_model, bound_loop)
|
||||
@@ -1427,8 +1577,18 @@ def call_llm(
|
||||
api_key=resolved_api_key,
|
||||
)
|
||||
if client is None:
|
||||
# Fallback: try openrouter
|
||||
if resolved_provider != "openrouter" and not resolved_base_url:
|
||||
# When the user explicitly chose a non-OpenRouter provider but no
|
||||
# credentials were found, fail fast instead of silently routing
|
||||
# through OpenRouter (which causes confusing 404s).
|
||||
_explicit = (resolved_provider or "").strip().lower()
|
||||
if _explicit and _explicit not in ("auto", "openrouter", "custom"):
|
||||
raise RuntimeError(
|
||||
f"Provider '{_explicit}' is set in config.yaml but no API key "
|
||||
f"was found. Set the {_explicit.upper()}_API_KEY environment "
|
||||
f"variable, or switch to a different provider with `hermes model`."
|
||||
)
|
||||
# For auto/custom, fall back to OpenRouter
|
||||
if not resolved_base_url:
|
||||
logger.warning("Provider %s unavailable, falling back to openrouter",
|
||||
resolved_provider)
|
||||
client, final_model = _get_cached_client(
|
||||
@@ -1510,7 +1670,14 @@ async def async_call_llm(
|
||||
api_key=resolved_api_key,
|
||||
)
|
||||
if client is None:
|
||||
if resolved_provider != "openrouter" and not resolved_base_url:
|
||||
_explicit = (resolved_provider or "").strip().lower()
|
||||
if _explicit and _explicit not in ("auto", "openrouter", "custom"):
|
||||
raise RuntimeError(
|
||||
f"Provider '{_explicit}' is set in config.yaml but no API key "
|
||||
f"was found. Set the {_explicit.upper()}_API_KEY environment "
|
||||
f"variable, or switch to a different provider with `hermes model`."
|
||||
)
|
||||
if not resolved_base_url:
|
||||
logger.warning("Provider %s unavailable, falling back to openrouter",
|
||||
resolved_provider)
|
||||
client, final_model = _get_cached_client(
|
||||
|
||||
+326
-46
@@ -1,12 +1,19 @@
|
||||
"""Automatic context window compression for long conversations.
|
||||
|
||||
Self-contained class with its own OpenAI client for summarization.
|
||||
Uses Gemini Flash (cheap/fast) to summarize middle turns while
|
||||
Uses auxiliary model (cheap/fast) to summarize middle turns while
|
||||
protecting head and tail context.
|
||||
|
||||
Improvements over v1:
|
||||
- Structured summary template (Goal, Progress, Decisions, Files, Next Steps)
|
||||
- Iterative summary updates (preserves info across multiple compactions)
|
||||
- Token-budget tail protection instead of fixed message count
|
||||
- Tool output pruning before LLM summarization (cheap pre-pass)
|
||||
- Scaled summary budget (proportional to compressed content)
|
||||
- Richer tool call/result detail in summarizer input
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from agent.auxiliary_client import call_llm
|
||||
@@ -27,12 +34,29 @@ SUMMARY_PREFIX = (
|
||||
)
|
||||
LEGACY_SUMMARY_PREFIX = "[CONTEXT SUMMARY]:"
|
||||
|
||||
# Minimum tokens for the summary output
|
||||
_MIN_SUMMARY_TOKENS = 2000
|
||||
# Proportion of compressed content to allocate for summary
|
||||
_SUMMARY_RATIO = 0.20
|
||||
# Absolute ceiling for summary tokens (even on very large context windows)
|
||||
_SUMMARY_TOKENS_CEILING = 12_000
|
||||
|
||||
# Placeholder used when pruning old tool results
|
||||
_PRUNED_TOOL_PLACEHOLDER = "[Old tool output cleared to save context space]"
|
||||
|
||||
# Chars per token rough estimate
|
||||
_CHARS_PER_TOKEN = 4
|
||||
|
||||
|
||||
class ContextCompressor:
|
||||
"""Compresses conversation context when approaching the model's context limit.
|
||||
|
||||
Algorithm: protect first N + last N turns, summarize everything in between.
|
||||
Token tracking uses actual counts from API responses for accuracy.
|
||||
Algorithm:
|
||||
1. Prune old tool results (cheap, no LLM call)
|
||||
2. Protect head messages (system prompt + first exchange)
|
||||
3. Protect tail messages by token budget (most recent ~20K tokens)
|
||||
4. Summarize middle turns with structured LLM prompt
|
||||
5. On subsequent compactions, iteratively update the previous summary
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -40,8 +64,8 @@ class ContextCompressor:
|
||||
model: str,
|
||||
threshold_percent: float = 0.50,
|
||||
protect_first_n: int = 3,
|
||||
protect_last_n: int = 4,
|
||||
summary_target_tokens: int = 2500,
|
||||
protect_last_n: int = 20,
|
||||
summary_target_ratio: float = 0.20,
|
||||
quiet_mode: bool = False,
|
||||
summary_model_override: str = None,
|
||||
base_url: str = "",
|
||||
@@ -56,7 +80,7 @@ class ContextCompressor:
|
||||
self.threshold_percent = threshold_percent
|
||||
self.protect_first_n = protect_first_n
|
||||
self.protect_last_n = protect_last_n
|
||||
self.summary_target_tokens = summary_target_tokens
|
||||
self.summary_target_ratio = max(0.10, min(summary_target_ratio, 0.80))
|
||||
self.quiet_mode = quiet_mode
|
||||
|
||||
self.context_length = get_model_context_length(
|
||||
@@ -66,6 +90,24 @@ class ContextCompressor:
|
||||
)
|
||||
self.threshold_tokens = int(self.context_length * threshold_percent)
|
||||
self.compression_count = 0
|
||||
|
||||
# Derive token budgets: ratio is relative to the threshold, not total context
|
||||
target_tokens = int(self.threshold_tokens * self.summary_target_ratio)
|
||||
self.tail_token_budget = target_tokens
|
||||
self.max_summary_tokens = min(
|
||||
int(self.context_length * 0.05), _SUMMARY_TOKENS_CEILING,
|
||||
)
|
||||
|
||||
if not quiet_mode:
|
||||
logger.info(
|
||||
"Context compressor initialized: model=%s context_length=%d "
|
||||
"threshold=%d (%.0f%%) target_ratio=%.0f%% tail_budget=%d "
|
||||
"provider=%s base_url=%s",
|
||||
model, self.context_length, self.threshold_tokens,
|
||||
threshold_percent * 100, self.summary_target_ratio * 100,
|
||||
self.tail_token_budget,
|
||||
provider or "none", base_url or "none",
|
||||
)
|
||||
self._context_probed = False # True after a step-down from context error
|
||||
|
||||
self.last_prompt_tokens = 0
|
||||
@@ -74,6 +116,9 @@ class ContextCompressor:
|
||||
|
||||
self.summary_model = summary_model_override or ""
|
||||
|
||||
# Stores the previous compaction summary for iterative updates
|
||||
self._previous_summary: Optional[str] = None
|
||||
|
||||
def update_from_response(self, usage: Dict[str, Any]):
|
||||
"""Update tracked token usage from API response."""
|
||||
self.last_prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
@@ -100,53 +145,209 @@ class ContextCompressor:
|
||||
"compression_count": self.compression_count,
|
||||
}
|
||||
|
||||
def _generate_summary(self, turns_to_summarize: List[Dict[str, Any]]) -> Optional[str]:
|
||||
"""Generate a concise summary of conversation turns.
|
||||
# ------------------------------------------------------------------
|
||||
# Tool output pruning (cheap pre-pass, no LLM call)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
Tries the auxiliary model first, then falls back to the user's main
|
||||
model. Returns None if all attempts fail — the caller should drop
|
||||
def _prune_old_tool_results(
|
||||
self, messages: List[Dict[str, Any]], protect_tail_count: int,
|
||||
) -> tuple[List[Dict[str, Any]], int]:
|
||||
"""Replace old tool result contents with a short placeholder.
|
||||
|
||||
Walks backward from the end, protecting the most recent
|
||||
``protect_tail_count`` messages. Older tool results get their
|
||||
content replaced with a placeholder string.
|
||||
|
||||
Returns (pruned_messages, pruned_count).
|
||||
"""
|
||||
if not messages:
|
||||
return messages, 0
|
||||
|
||||
result = [m.copy() for m in messages]
|
||||
pruned = 0
|
||||
prune_boundary = len(result) - protect_tail_count
|
||||
|
||||
for i in range(prune_boundary):
|
||||
msg = result[i]
|
||||
if msg.get("role") != "tool":
|
||||
continue
|
||||
content = msg.get("content", "")
|
||||
if not content or content == _PRUNED_TOOL_PLACEHOLDER:
|
||||
continue
|
||||
# Only prune if the content is substantial (>200 chars)
|
||||
if len(content) > 200:
|
||||
result[i] = {**msg, "content": _PRUNED_TOOL_PLACEHOLDER}
|
||||
pruned += 1
|
||||
|
||||
return result, pruned
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Summarization
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _compute_summary_budget(self, turns_to_summarize: List[Dict[str, Any]]) -> int:
|
||||
"""Scale summary token budget with the amount of content being compressed.
|
||||
|
||||
The maximum scales with the model's context window (5% of context,
|
||||
capped at ``_SUMMARY_TOKENS_CEILING``) so large-context models get
|
||||
richer summaries instead of being hard-capped at 8K tokens.
|
||||
"""
|
||||
content_tokens = estimate_messages_tokens_rough(turns_to_summarize)
|
||||
budget = int(content_tokens * _SUMMARY_RATIO)
|
||||
return max(_MIN_SUMMARY_TOKENS, min(budget, self.max_summary_tokens))
|
||||
|
||||
def _serialize_for_summary(self, turns: List[Dict[str, Any]]) -> str:
|
||||
"""Serialize conversation turns into labeled text for the summarizer.
|
||||
|
||||
Includes tool call arguments and result content (up to 3000 chars
|
||||
per message) so the summarizer can preserve specific details like
|
||||
file paths, commands, and outputs.
|
||||
"""
|
||||
parts = []
|
||||
for msg in turns:
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content") or ""
|
||||
|
||||
# Tool results: keep more content than before (3000 chars)
|
||||
if role == "tool":
|
||||
tool_id = msg.get("tool_call_id", "")
|
||||
if len(content) > 3000:
|
||||
content = content[:2000] + "\n...[truncated]...\n" + content[-800:]
|
||||
parts.append(f"[TOOL RESULT {tool_id}]: {content}")
|
||||
continue
|
||||
|
||||
# Assistant messages: include tool call names AND arguments
|
||||
if role == "assistant":
|
||||
if len(content) > 3000:
|
||||
content = content[:2000] + "\n...[truncated]...\n" + content[-800:]
|
||||
tool_calls = msg.get("tool_calls", [])
|
||||
if tool_calls:
|
||||
tc_parts = []
|
||||
for tc in tool_calls:
|
||||
if isinstance(tc, dict):
|
||||
fn = tc.get("function", {})
|
||||
name = fn.get("name", "?")
|
||||
args = fn.get("arguments", "")
|
||||
# Truncate long arguments but keep enough for context
|
||||
if len(args) > 500:
|
||||
args = args[:400] + "..."
|
||||
tc_parts.append(f" {name}({args})")
|
||||
else:
|
||||
fn = getattr(tc, "function", None)
|
||||
name = getattr(fn, "name", "?") if fn else "?"
|
||||
tc_parts.append(f" {name}(...)")
|
||||
content += "\n[Tool calls:\n" + "\n".join(tc_parts) + "\n]"
|
||||
parts.append(f"[ASSISTANT]: {content}")
|
||||
continue
|
||||
|
||||
# User and other roles
|
||||
if len(content) > 3000:
|
||||
content = content[:2000] + "\n...[truncated]...\n" + content[-800:]
|
||||
parts.append(f"[{role.upper()}]: {content}")
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def _generate_summary(self, turns_to_summarize: List[Dict[str, Any]]) -> Optional[str]:
|
||||
"""Generate a structured summary of conversation turns.
|
||||
|
||||
Uses a structured template (Goal, Progress, Decisions, Files, Next Steps)
|
||||
inspired by Pi-mono and OpenCode. When a previous summary exists,
|
||||
generates an iterative update instead of summarizing from scratch.
|
||||
|
||||
Returns None if all attempts fail — the caller should drop
|
||||
the middle turns without a summary rather than inject a useless
|
||||
placeholder.
|
||||
"""
|
||||
parts = []
|
||||
for msg in turns_to_summarize:
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content") or ""
|
||||
if len(content) > 2000:
|
||||
content = content[:1000] + "\n...[truncated]...\n" + content[-500:]
|
||||
tool_calls = msg.get("tool_calls", [])
|
||||
if tool_calls:
|
||||
tool_names = [tc.get("function", {}).get("name", "?") for tc in tool_calls if isinstance(tc, dict)]
|
||||
content += f"\n[Tool calls: {', '.join(tool_names)}]"
|
||||
parts.append(f"[{role.upper()}]: {content}")
|
||||
summary_budget = self._compute_summary_budget(turns_to_summarize)
|
||||
content_to_summarize = self._serialize_for_summary(turns_to_summarize)
|
||||
|
||||
content_to_summarize = "\n\n".join(parts)
|
||||
prompt = f"""Create a concise handoff summary for a later assistant that will continue this conversation after earlier turns are compacted.
|
||||
if self._previous_summary:
|
||||
# Iterative update: preserve existing info, add new progress
|
||||
prompt = f"""You are updating a context compaction summary. A previous compaction produced the summary below. New conversation turns have occurred since then and need to be incorporated.
|
||||
|
||||
Describe:
|
||||
1. What actions were taken (tool calls, searches, file operations)
|
||||
2. Key information or results obtained
|
||||
3. Important decisions, constraints, or user preferences
|
||||
4. Relevant data, file names, outputs, or next steps needed to continue
|
||||
PREVIOUS SUMMARY:
|
||||
{self._previous_summary}
|
||||
|
||||
Keep it factual, concise, and focused on helping the next assistant resume without repeating work. Target ~{self.summary_target_tokens} tokens.
|
||||
NEW TURNS TO INCORPORATE:
|
||||
{content_to_summarize}
|
||||
|
||||
Update the summary using this exact structure. PRESERVE all existing information that is still relevant. ADD new progress. Move items from "In Progress" to "Done" when completed. Remove information only if it is clearly obsolete.
|
||||
|
||||
## Goal
|
||||
[What the user is trying to accomplish — preserve from previous summary, update if goal evolved]
|
||||
|
||||
## Constraints & Preferences
|
||||
[User preferences, coding style, constraints, important decisions — accumulate across compactions]
|
||||
|
||||
## Progress
|
||||
### Done
|
||||
[Completed work — include specific file paths, commands run, results obtained]
|
||||
### In Progress
|
||||
[Work currently underway]
|
||||
### Blocked
|
||||
[Any blockers or issues encountered]
|
||||
|
||||
## Key Decisions
|
||||
[Important technical decisions and why they were made]
|
||||
|
||||
## Relevant Files
|
||||
[Files read, modified, or created — with brief note on each. Accumulate across compactions.]
|
||||
|
||||
## Next Steps
|
||||
[What needs to happen next to continue the work]
|
||||
|
||||
## Critical Context
|
||||
[Any specific values, error messages, configuration details, or data that would be lost without explicit preservation]
|
||||
|
||||
Target ~{summary_budget} tokens. Be specific — include file paths, command outputs, error messages, and concrete values rather than vague descriptions.
|
||||
|
||||
Write only the summary body. Do not include any preamble or prefix."""
|
||||
else:
|
||||
# First compaction: summarize from scratch
|
||||
prompt = f"""Create a structured handoff summary for a later assistant that will continue this conversation after earlier turns are compacted.
|
||||
|
||||
---
|
||||
TURNS TO SUMMARIZE:
|
||||
{content_to_summarize}
|
||||
---
|
||||
|
||||
Write only the summary body. Do not include any preamble or prefix; the system will add the handoff wrapper."""
|
||||
Use this exact structure:
|
||||
|
||||
## Goal
|
||||
[What the user is trying to accomplish]
|
||||
|
||||
## Constraints & Preferences
|
||||
[User preferences, coding style, constraints, important decisions]
|
||||
|
||||
## Progress
|
||||
### Done
|
||||
[Completed work — include specific file paths, commands run, results obtained]
|
||||
### In Progress
|
||||
[Work currently underway]
|
||||
### Blocked
|
||||
[Any blockers or issues encountered]
|
||||
|
||||
## Key Decisions
|
||||
[Important technical decisions and why they were made]
|
||||
|
||||
## Relevant Files
|
||||
[Files read, modified, or created — with brief note on each]
|
||||
|
||||
## Next Steps
|
||||
[What needs to happen next to continue the work]
|
||||
|
||||
## Critical Context
|
||||
[Any specific values, error messages, configuration details, or data that would be lost without explicit preservation]
|
||||
|
||||
Target ~{summary_budget} tokens. Be specific — include file paths, command outputs, error messages, and concrete values rather than vague descriptions. The goal is to prevent the next assistant from repeating work or losing important details.
|
||||
|
||||
Write only the summary body. Do not include any preamble or prefix."""
|
||||
|
||||
# Use the centralized LLM router — handles provider resolution,
|
||||
# auth, and fallback internally.
|
||||
try:
|
||||
call_kwargs = {
|
||||
"task": "compression",
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": 0.3,
|
||||
"max_tokens": self.summary_target_tokens * 2,
|
||||
"timeout": 30.0,
|
||||
"max_tokens": summary_budget * 2,
|
||||
"timeout": 45.0,
|
||||
}
|
||||
if self.summary_model:
|
||||
call_kwargs["model"] = self.summary_model
|
||||
@@ -156,6 +357,8 @@ Write only the summary body. Do not include any preamble or prefix; the system w
|
||||
if not isinstance(content, str):
|
||||
content = str(content) if content else ""
|
||||
summary = content.strip()
|
||||
# Store for iterative updates on next compaction
|
||||
self._previous_summary = summary
|
||||
return self._with_summary_prefix(summary)
|
||||
except RuntimeError:
|
||||
logging.warning("Context compression: no provider available for "
|
||||
@@ -280,10 +483,75 @@ Write only the summary body. Do not include any preamble or prefix; the system w
|
||||
idx = check
|
||||
return idx
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tail protection by token budget
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _find_tail_cut_by_tokens(
|
||||
self, messages: List[Dict[str, Any]], head_end: int,
|
||||
token_budget: int | None = None,
|
||||
) -> int:
|
||||
"""Walk backward from the end of messages, accumulating tokens until
|
||||
the budget is reached. Returns the index where the tail starts.
|
||||
|
||||
``token_budget`` defaults to ``self.tail_token_budget`` which is
|
||||
derived from ``summary_target_ratio * context_length``, so it
|
||||
scales automatically with the model's context window.
|
||||
|
||||
Never cuts inside a tool_call/result group. Falls back to the old
|
||||
``protect_last_n`` if the budget would protect fewer messages.
|
||||
"""
|
||||
if token_budget is None:
|
||||
token_budget = self.tail_token_budget
|
||||
n = len(messages)
|
||||
min_tail = self.protect_last_n
|
||||
accumulated = 0
|
||||
cut_idx = n # start from beyond the end
|
||||
|
||||
for i in range(n - 1, head_end - 1, -1):
|
||||
msg = messages[i]
|
||||
content = msg.get("content") or ""
|
||||
msg_tokens = len(content) // _CHARS_PER_TOKEN + 10 # +10 for role/metadata
|
||||
# Include tool call arguments in estimate
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
if isinstance(tc, dict):
|
||||
args = tc.get("function", {}).get("arguments", "")
|
||||
msg_tokens += len(args) // _CHARS_PER_TOKEN
|
||||
if accumulated + msg_tokens > token_budget and (n - i) >= min_tail:
|
||||
break
|
||||
accumulated += msg_tokens
|
||||
cut_idx = i
|
||||
|
||||
# Ensure we protect at least protect_last_n messages
|
||||
fallback_cut = n - min_tail
|
||||
if cut_idx > fallback_cut:
|
||||
cut_idx = fallback_cut
|
||||
|
||||
# If the token budget would protect everything (small conversations),
|
||||
# fall back to the fixed protect_last_n approach so compression can
|
||||
# still remove middle turns.
|
||||
if cut_idx <= head_end:
|
||||
cut_idx = fallback_cut
|
||||
|
||||
# Align to avoid splitting tool groups
|
||||
cut_idx = self._align_boundary_backward(messages, cut_idx)
|
||||
|
||||
return max(cut_idx, head_end + 1)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Main compression entry point
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def compress(self, messages: List[Dict[str, Any]], current_tokens: int = None) -> List[Dict[str, Any]]:
|
||||
"""Compress conversation messages by summarizing middle turns.
|
||||
|
||||
Keeps first N + last N turns, summarizes everything in between.
|
||||
Algorithm:
|
||||
1. Prune old tool results (cheap pre-pass, no LLM call)
|
||||
2. Protect head messages (system prompt + first exchange)
|
||||
3. Find tail boundary by token budget (~20K tokens of recent context)
|
||||
4. Summarize middle turns with structured LLM prompt
|
||||
5. On re-compression, iteratively update the previous summary
|
||||
|
||||
After compression, orphaned tool_call / tool_result pairs are cleaned
|
||||
up so the API never receives mismatched IDs.
|
||||
"""
|
||||
@@ -297,19 +565,26 @@ Write only the summary body. Do not include any preamble or prefix; the system w
|
||||
)
|
||||
return messages
|
||||
|
||||
compress_start = self.protect_first_n
|
||||
compress_end = n_messages - self.protect_last_n
|
||||
if compress_start >= compress_end:
|
||||
return messages
|
||||
display_tokens = current_tokens if current_tokens else self.last_prompt_tokens or estimate_messages_tokens_rough(messages)
|
||||
|
||||
# Adjust boundaries to avoid splitting tool_call/result groups.
|
||||
# Phase 1: Prune old tool results (cheap, no LLM call)
|
||||
messages, pruned_count = self._prune_old_tool_results(
|
||||
messages, protect_tail_count=self.protect_last_n * 3,
|
||||
)
|
||||
if pruned_count and not self.quiet_mode:
|
||||
logger.info("Pre-compression: pruned %d old tool result(s)", pruned_count)
|
||||
|
||||
# Phase 2: Determine boundaries
|
||||
compress_start = self.protect_first_n
|
||||
compress_start = self._align_boundary_forward(messages, compress_start)
|
||||
compress_end = self._align_boundary_backward(messages, compress_end)
|
||||
|
||||
# Use token-budget tail protection instead of fixed message count
|
||||
compress_end = self._find_tail_cut_by_tokens(messages, compress_start)
|
||||
|
||||
if compress_start >= compress_end:
|
||||
return messages
|
||||
|
||||
turns_to_summarize = messages[compress_start:compress_end]
|
||||
display_tokens = current_tokens if current_tokens else self.last_prompt_tokens or estimate_messages_tokens_rough(messages)
|
||||
|
||||
if not self.quiet_mode:
|
||||
logger.info(
|
||||
@@ -323,15 +598,20 @@ Write only the summary body. Do not include any preamble or prefix; the system w
|
||||
self.threshold_percent * 100,
|
||||
self.threshold_tokens,
|
||||
)
|
||||
tail_msgs = n_messages - compress_end
|
||||
logger.info(
|
||||
"Summarizing turns %d-%d (%d turns)",
|
||||
"Summarizing turns %d-%d (%d turns), protecting %d head + %d tail messages",
|
||||
compress_start + 1,
|
||||
compress_end,
|
||||
len(turns_to_summarize),
|
||||
compress_start,
|
||||
tail_msgs,
|
||||
)
|
||||
|
||||
# Phase 3: Generate structured summary
|
||||
summary = self._generate_summary(turns_to_summarize)
|
||||
|
||||
# Phase 4: Assemble compressed message list
|
||||
compressed = []
|
||||
for i in range(compress_start):
|
||||
msg = messages[i].copy()
|
||||
|
||||
@@ -0,0 +1,485 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
from agent.model_metadata import estimate_tokens_rough
|
||||
|
||||
REFERENCE_PATTERN = re.compile(
|
||||
r"(?<![\w/])@(?:(?P<simple>diff|staged)\b|(?P<kind>file|folder|git|url):(?P<value>\S+))"
|
||||
)
|
||||
TRAILING_PUNCTUATION = ",.;!?"
|
||||
_SENSITIVE_HOME_DIRS = (".ssh", ".aws", ".gnupg", ".kube")
|
||||
_SENSITIVE_HERMES_DIRS = (Path("skills") / ".hub",)
|
||||
_SENSITIVE_HOME_FILES = (
|
||||
Path(".ssh") / "authorized_keys",
|
||||
Path(".ssh") / "id_rsa",
|
||||
Path(".ssh") / "id_ed25519",
|
||||
Path(".ssh") / "config",
|
||||
Path(".bashrc"),
|
||||
Path(".zshrc"),
|
||||
Path(".profile"),
|
||||
Path(".bash_profile"),
|
||||
Path(".zprofile"),
|
||||
Path(".netrc"),
|
||||
Path(".pgpass"),
|
||||
Path(".npmrc"),
|
||||
Path(".pypirc"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ContextReference:
|
||||
raw: str
|
||||
kind: str
|
||||
target: str
|
||||
start: int
|
||||
end: int
|
||||
line_start: int | None = None
|
||||
line_end: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextReferenceResult:
|
||||
message: str
|
||||
original_message: str
|
||||
references: list[ContextReference] = field(default_factory=list)
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
injected_tokens: int = 0
|
||||
expanded: bool = False
|
||||
blocked: bool = False
|
||||
|
||||
|
||||
def parse_context_references(message: str) -> list[ContextReference]:
|
||||
refs: list[ContextReference] = []
|
||||
if not message:
|
||||
return refs
|
||||
|
||||
for match in REFERENCE_PATTERN.finditer(message):
|
||||
simple = match.group("simple")
|
||||
if simple:
|
||||
refs.append(
|
||||
ContextReference(
|
||||
raw=match.group(0),
|
||||
kind=simple,
|
||||
target="",
|
||||
start=match.start(),
|
||||
end=match.end(),
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
kind = match.group("kind")
|
||||
value = _strip_trailing_punctuation(match.group("value") or "")
|
||||
line_start = None
|
||||
line_end = None
|
||||
target = value
|
||||
|
||||
if kind == "file":
|
||||
range_match = re.match(r"^(?P<path>.+?):(?P<start>\d+)(?:-(?P<end>\d+))?$", value)
|
||||
if range_match:
|
||||
target = range_match.group("path")
|
||||
line_start = int(range_match.group("start"))
|
||||
line_end = int(range_match.group("end") or range_match.group("start"))
|
||||
|
||||
refs.append(
|
||||
ContextReference(
|
||||
raw=match.group(0),
|
||||
kind=kind,
|
||||
target=target,
|
||||
start=match.start(),
|
||||
end=match.end(),
|
||||
line_start=line_start,
|
||||
line_end=line_end,
|
||||
)
|
||||
)
|
||||
|
||||
return refs
|
||||
|
||||
|
||||
def preprocess_context_references(
|
||||
message: str,
|
||||
*,
|
||||
cwd: str | Path,
|
||||
context_length: int,
|
||||
url_fetcher: Callable[[str], str | Awaitable[str]] | None = None,
|
||||
allowed_root: str | Path | None = None,
|
||||
) -> ContextReferenceResult:
|
||||
coro = preprocess_context_references_async(
|
||||
message,
|
||||
cwd=cwd,
|
||||
context_length=context_length,
|
||||
url_fetcher=url_fetcher,
|
||||
allowed_root=allowed_root,
|
||||
)
|
||||
# Safe for both CLI (no loop) and gateway (loop already running).
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
if loop and loop.is_running():
|
||||
import concurrent.futures
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
return pool.submit(asyncio.run, coro).result()
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
async def preprocess_context_references_async(
|
||||
message: str,
|
||||
*,
|
||||
cwd: str | Path,
|
||||
context_length: int,
|
||||
url_fetcher: Callable[[str], str | Awaitable[str]] | None = None,
|
||||
allowed_root: str | Path | None = None,
|
||||
) -> ContextReferenceResult:
|
||||
refs = parse_context_references(message)
|
||||
if not refs:
|
||||
return ContextReferenceResult(message=message, original_message=message)
|
||||
|
||||
cwd_path = Path(cwd).expanduser().resolve()
|
||||
# Default to the current working directory so @ references cannot escape
|
||||
# the active workspace unless a caller explicitly widens the root.
|
||||
allowed_root_path = (
|
||||
Path(allowed_root).expanduser().resolve() if allowed_root is not None else cwd_path
|
||||
)
|
||||
warnings: list[str] = []
|
||||
blocks: list[str] = []
|
||||
injected_tokens = 0
|
||||
|
||||
for ref in refs:
|
||||
warning, block = await _expand_reference(
|
||||
ref,
|
||||
cwd_path,
|
||||
url_fetcher=url_fetcher,
|
||||
allowed_root=allowed_root_path,
|
||||
)
|
||||
if warning:
|
||||
warnings.append(warning)
|
||||
if block:
|
||||
blocks.append(block)
|
||||
injected_tokens += estimate_tokens_rough(block)
|
||||
|
||||
hard_limit = max(1, int(context_length * 0.50))
|
||||
soft_limit = max(1, int(context_length * 0.25))
|
||||
if injected_tokens > hard_limit:
|
||||
warnings.append(
|
||||
f"@ context injection refused: {injected_tokens} tokens exceeds the 50% hard limit ({hard_limit})."
|
||||
)
|
||||
return ContextReferenceResult(
|
||||
message=message,
|
||||
original_message=message,
|
||||
references=refs,
|
||||
warnings=warnings,
|
||||
injected_tokens=injected_tokens,
|
||||
expanded=False,
|
||||
blocked=True,
|
||||
)
|
||||
|
||||
if injected_tokens > soft_limit:
|
||||
warnings.append(
|
||||
f"@ context injection warning: {injected_tokens} tokens exceeds the 25% soft limit ({soft_limit})."
|
||||
)
|
||||
|
||||
stripped = _remove_reference_tokens(message, refs)
|
||||
final = stripped
|
||||
if warnings:
|
||||
final = f"{final}\n\n--- Context Warnings ---\n" + "\n".join(f"- {warning}" for warning in warnings)
|
||||
if blocks:
|
||||
final = f"{final}\n\n--- Attached Context ---\n\n" + "\n\n".join(blocks)
|
||||
|
||||
return ContextReferenceResult(
|
||||
message=final.strip(),
|
||||
original_message=message,
|
||||
references=refs,
|
||||
warnings=warnings,
|
||||
injected_tokens=injected_tokens,
|
||||
expanded=bool(blocks or warnings),
|
||||
blocked=False,
|
||||
)
|
||||
|
||||
|
||||
async def _expand_reference(
|
||||
ref: ContextReference,
|
||||
cwd: Path,
|
||||
*,
|
||||
url_fetcher: Callable[[str], str | Awaitable[str]] | None = None,
|
||||
allowed_root: Path | None = None,
|
||||
) -> tuple[str | None, str | None]:
|
||||
try:
|
||||
if ref.kind == "file":
|
||||
return _expand_file_reference(ref, cwd, allowed_root=allowed_root)
|
||||
if ref.kind == "folder":
|
||||
return _expand_folder_reference(ref, cwd, allowed_root=allowed_root)
|
||||
if ref.kind == "diff":
|
||||
return _expand_git_reference(ref, cwd, ["diff"], "git diff")
|
||||
if ref.kind == "staged":
|
||||
return _expand_git_reference(ref, cwd, ["diff", "--staged"], "git diff --staged")
|
||||
if ref.kind == "git":
|
||||
count = max(1, min(int(ref.target or "1"), 10))
|
||||
return _expand_git_reference(ref, cwd, ["log", f"-{count}", "-p"], f"git log -{count} -p")
|
||||
if ref.kind == "url":
|
||||
content = await _fetch_url_content(ref.target, url_fetcher=url_fetcher)
|
||||
if not content:
|
||||
return f"{ref.raw}: no content extracted", None
|
||||
return None, f"🌐 {ref.raw} ({estimate_tokens_rough(content)} tokens)\n{content}"
|
||||
except Exception as exc:
|
||||
return f"{ref.raw}: {exc}", None
|
||||
|
||||
return f"{ref.raw}: unsupported reference type", None
|
||||
|
||||
|
||||
def _expand_file_reference(
|
||||
ref: ContextReference,
|
||||
cwd: Path,
|
||||
*,
|
||||
allowed_root: Path | None = None,
|
||||
) -> tuple[str | None, str | None]:
|
||||
path = _resolve_path(cwd, ref.target, allowed_root=allowed_root)
|
||||
_ensure_reference_path_allowed(path)
|
||||
if not path.exists():
|
||||
return f"{ref.raw}: file not found", None
|
||||
if not path.is_file():
|
||||
return f"{ref.raw}: path is not a file", None
|
||||
if _is_binary_file(path):
|
||||
return f"{ref.raw}: binary files are not supported", None
|
||||
|
||||
text = path.read_text(encoding="utf-8")
|
||||
if ref.line_start is not None:
|
||||
lines = text.splitlines()
|
||||
start_idx = max(ref.line_start - 1, 0)
|
||||
end_idx = min(ref.line_end or ref.line_start, len(lines))
|
||||
text = "\n".join(lines[start_idx:end_idx])
|
||||
|
||||
lang = _code_fence_language(path)
|
||||
label = ref.raw
|
||||
return None, f"📄 {label} ({estimate_tokens_rough(text)} tokens)\n```{lang}\n{text}\n```"
|
||||
|
||||
|
||||
def _expand_folder_reference(
|
||||
ref: ContextReference,
|
||||
cwd: Path,
|
||||
*,
|
||||
allowed_root: Path | None = None,
|
||||
) -> tuple[str | None, str | None]:
|
||||
path = _resolve_path(cwd, ref.target, allowed_root=allowed_root)
|
||||
_ensure_reference_path_allowed(path)
|
||||
if not path.exists():
|
||||
return f"{ref.raw}: folder not found", None
|
||||
if not path.is_dir():
|
||||
return f"{ref.raw}: path is not a folder", None
|
||||
|
||||
listing = _build_folder_listing(path, cwd)
|
||||
return None, f"📁 {ref.raw} ({estimate_tokens_rough(listing)} tokens)\n{listing}"
|
||||
|
||||
|
||||
def _expand_git_reference(
|
||||
ref: ContextReference,
|
||||
cwd: Path,
|
||||
args: list[str],
|
||||
label: str,
|
||||
) -> tuple[str | None, str | None]:
|
||||
result = subprocess.run(
|
||||
["git", *args],
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
stderr = (result.stderr or "").strip() or "git command failed"
|
||||
return f"{ref.raw}: {stderr}", None
|
||||
content = result.stdout.strip()
|
||||
if not content:
|
||||
content = "(no output)"
|
||||
return None, f"🧾 {label} ({estimate_tokens_rough(content)} tokens)\n```diff\n{content}\n```"
|
||||
|
||||
|
||||
async def _fetch_url_content(
|
||||
url: str,
|
||||
*,
|
||||
url_fetcher: Callable[[str], str | Awaitable[str]] | None = None,
|
||||
) -> str:
|
||||
fetcher = url_fetcher or _default_url_fetcher
|
||||
content = fetcher(url)
|
||||
if inspect.isawaitable(content):
|
||||
content = await content
|
||||
return str(content or "").strip()
|
||||
|
||||
|
||||
async def _default_url_fetcher(url: str) -> str:
|
||||
from tools.web_tools import web_extract_tool
|
||||
|
||||
raw = await web_extract_tool([url], format="markdown", use_llm_processing=True)
|
||||
payload = json.loads(raw)
|
||||
docs = payload.get("data", {}).get("documents", [])
|
||||
if not docs:
|
||||
return ""
|
||||
doc = docs[0]
|
||||
return str(doc.get("content") or doc.get("raw_content") or "").strip()
|
||||
|
||||
|
||||
def _resolve_path(cwd: Path, target: str, *, allowed_root: Path | None = None) -> Path:
|
||||
path = Path(os.path.expanduser(target))
|
||||
if not path.is_absolute():
|
||||
path = cwd / path
|
||||
resolved = path.resolve()
|
||||
if allowed_root is not None:
|
||||
try:
|
||||
resolved.relative_to(allowed_root)
|
||||
except ValueError as exc:
|
||||
raise ValueError("path is outside the allowed workspace") from exc
|
||||
return resolved
|
||||
|
||||
|
||||
def _ensure_reference_path_allowed(path: Path) -> None:
|
||||
home = Path(os.path.expanduser("~")).resolve()
|
||||
hermes_home = Path(
|
||||
os.getenv("HERMES_HOME", str(home / ".hermes"))
|
||||
).expanduser().resolve()
|
||||
|
||||
blocked_exact = {home / rel for rel in _SENSITIVE_HOME_FILES}
|
||||
blocked_exact.add(hermes_home / ".env")
|
||||
blocked_dirs = [home / rel for rel in _SENSITIVE_HOME_DIRS]
|
||||
blocked_dirs.extend(hermes_home / rel for rel in _SENSITIVE_HERMES_DIRS)
|
||||
|
||||
if path in blocked_exact:
|
||||
raise ValueError("path is a sensitive credential file and cannot be attached")
|
||||
|
||||
for blocked_dir in blocked_dirs:
|
||||
try:
|
||||
path.relative_to(blocked_dir)
|
||||
except ValueError:
|
||||
continue
|
||||
raise ValueError("path is a sensitive credential or internal Hermes path and cannot be attached")
|
||||
|
||||
|
||||
def _strip_trailing_punctuation(value: str) -> str:
|
||||
stripped = value.rstrip(TRAILING_PUNCTUATION)
|
||||
while stripped.endswith((")", "]", "}")):
|
||||
closer = stripped[-1]
|
||||
opener = {")": "(", "]": "[", "}": "{"}[closer]
|
||||
if stripped.count(closer) > stripped.count(opener):
|
||||
stripped = stripped[:-1]
|
||||
continue
|
||||
break
|
||||
return stripped
|
||||
|
||||
|
||||
def _remove_reference_tokens(message: str, refs: list[ContextReference]) -> str:
|
||||
pieces: list[str] = []
|
||||
cursor = 0
|
||||
for ref in refs:
|
||||
pieces.append(message[cursor:ref.start])
|
||||
cursor = ref.end
|
||||
pieces.append(message[cursor:])
|
||||
text = "".join(pieces)
|
||||
text = re.sub(r"\s{2,}", " ", text)
|
||||
text = re.sub(r"\s+([,.;:!?])", r"\1", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def _is_binary_file(path: Path) -> bool:
|
||||
mime, _ = mimetypes.guess_type(path.name)
|
||||
if mime and not mime.startswith("text/") and not any(
|
||||
path.name.endswith(ext) for ext in (".py", ".md", ".txt", ".json", ".yaml", ".yml", ".toml", ".js", ".ts")
|
||||
):
|
||||
return True
|
||||
chunk = path.read_bytes()[:4096]
|
||||
return b"\x00" in chunk
|
||||
|
||||
|
||||
def _build_folder_listing(path: Path, cwd: Path, limit: int = 200) -> str:
|
||||
lines = [f"{path.relative_to(cwd)}/"]
|
||||
entries = _iter_visible_entries(path, cwd, limit=limit)
|
||||
for entry in entries:
|
||||
rel = entry.relative_to(cwd)
|
||||
indent = " " * max(len(rel.parts) - len(path.relative_to(cwd).parts) - 1, 0)
|
||||
if entry.is_dir():
|
||||
lines.append(f"{indent}- {entry.name}/")
|
||||
else:
|
||||
meta = _file_metadata(entry)
|
||||
lines.append(f"{indent}- {entry.name} ({meta})")
|
||||
if len(entries) >= limit:
|
||||
lines.append("- ...")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _iter_visible_entries(path: Path, cwd: Path, limit: int) -> list[Path]:
|
||||
rg_entries = _rg_files(path, cwd, limit=limit)
|
||||
if rg_entries is not None:
|
||||
output: list[Path] = []
|
||||
seen_dirs: set[Path] = set()
|
||||
for rel in rg_entries:
|
||||
full = cwd / rel
|
||||
for parent in full.parents:
|
||||
if parent == cwd or parent in seen_dirs or path not in {parent, *parent.parents}:
|
||||
continue
|
||||
seen_dirs.add(parent)
|
||||
output.append(parent)
|
||||
output.append(full)
|
||||
return sorted({p for p in output if p.exists()}, key=lambda p: (not p.is_dir(), str(p)))
|
||||
|
||||
output = []
|
||||
for root, dirs, files in os.walk(path):
|
||||
dirs[:] = sorted(d for d in dirs if not d.startswith(".") and d != "__pycache__")
|
||||
files = sorted(f for f in files if not f.startswith("."))
|
||||
root_path = Path(root)
|
||||
for d in dirs:
|
||||
output.append(root_path / d)
|
||||
if len(output) >= limit:
|
||||
return output
|
||||
for f in files:
|
||||
output.append(root_path / f)
|
||||
if len(output) >= limit:
|
||||
return output
|
||||
return output
|
||||
|
||||
|
||||
def _rg_files(path: Path, cwd: Path, limit: int) -> list[Path] | None:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["rg", "--files", str(path.relative_to(cwd))],
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
if result.returncode != 0:
|
||||
return None
|
||||
files = [Path(line.strip()) for line in result.stdout.splitlines() if line.strip()]
|
||||
return files[:limit]
|
||||
|
||||
|
||||
def _file_metadata(path: Path) -> str:
|
||||
if _is_binary_file(path):
|
||||
return f"{path.stat().st_size} bytes"
|
||||
try:
|
||||
line_count = path.read_text(encoding="utf-8").count("\n") + 1
|
||||
except Exception:
|
||||
return f"{path.stat().st_size} bytes"
|
||||
return f"{line_count} lines"
|
||||
|
||||
|
||||
def _code_fence_language(path: Path) -> str:
|
||||
mapping = {
|
||||
".py": "python",
|
||||
".js": "javascript",
|
||||
".ts": "typescript",
|
||||
".tsx": "tsx",
|
||||
".jsx": "jsx",
|
||||
".json": "json",
|
||||
".md": "markdown",
|
||||
".sh": "bash",
|
||||
".yml": "yaml",
|
||||
".yaml": "yaml",
|
||||
".toml": "toml",
|
||||
}
|
||||
return mapping.get(path.suffix.lower(), "")
|
||||
+62
-40
@@ -231,7 +231,7 @@ class KawaiiSpinner:
|
||||
"analyzing", "computing", "synthesizing", "formulating", "brainstorming",
|
||||
]
|
||||
|
||||
def __init__(self, message: str = "", spinner_type: str = 'dots'):
|
||||
def __init__(self, message: str = "", spinner_type: str = 'dots', print_fn=None):
|
||||
self.message = message
|
||||
self.spinner_frames = self.SPINNERS.get(spinner_type, self.SPINNERS['dots'])
|
||||
self.running = False
|
||||
@@ -239,13 +239,26 @@ class KawaiiSpinner:
|
||||
self.frame_idx = 0
|
||||
self.start_time = None
|
||||
self.last_line_len = 0
|
||||
self._last_flush_time = 0.0 # Rate-limit flushes for patch_stdout compat
|
||||
# Optional callable to route all output through (e.g. a no-op for silent
|
||||
# background agents). When set, bypasses self._out entirely so that
|
||||
# agents with _print_fn overridden remain fully silent.
|
||||
self._print_fn = print_fn
|
||||
# Capture stdout NOW, before any redirect_stdout(devnull) from
|
||||
# child agents can replace sys.stdout with a black hole.
|
||||
self._out = sys.stdout
|
||||
|
||||
def _write(self, text: str, end: str = '\n', flush: bool = False):
|
||||
"""Write to the stdout captured at spinner creation time."""
|
||||
"""Write to the stdout captured at spinner creation time.
|
||||
|
||||
If a print_fn was supplied at construction, all output is routed through
|
||||
it instead — allowing callers to silence the spinner with a no-op lambda.
|
||||
"""
|
||||
if self._print_fn is not None:
|
||||
try:
|
||||
self._print_fn(text)
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
try:
|
||||
self._out.write(text + end)
|
||||
if flush:
|
||||
@@ -253,16 +266,50 @@ class KawaiiSpinner:
|
||||
except (ValueError, OSError):
|
||||
pass
|
||||
|
||||
@property
|
||||
def _is_tty(self) -> bool:
|
||||
"""Check if output is a real terminal, safe against closed streams."""
|
||||
try:
|
||||
return hasattr(self._out, 'isatty') and self._out.isatty()
|
||||
except (ValueError, OSError):
|
||||
return False
|
||||
|
||||
def _is_patch_stdout_proxy(self) -> bool:
|
||||
"""Return True when stdout is prompt_toolkit's StdoutProxy.
|
||||
|
||||
patch_stdout wraps sys.stdout in a StdoutProxy that queues writes and
|
||||
injects newlines around each flush(). The \\r overwrite never lands on
|
||||
the correct line — each spinner frame ends up on its own line.
|
||||
|
||||
The CLI already drives a TUI widget (_spinner_text) for spinner display,
|
||||
so KawaiiSpinner's \\r-based animation is redundant under StdoutProxy.
|
||||
"""
|
||||
out = self._out
|
||||
# StdoutProxy has a 'raw' attribute (bool) that plain file objects lack.
|
||||
if hasattr(out, 'raw') and type(out).__name__ == 'StdoutProxy':
|
||||
return True
|
||||
return False
|
||||
|
||||
def _animate(self):
|
||||
# When stdout is not a real terminal (e.g. Docker, systemd, pipe),
|
||||
# skip the animation entirely — it creates massive log bloat.
|
||||
# Just log the start once and let stop() log the completion.
|
||||
if not hasattr(self._out, 'isatty') or not self._out.isatty():
|
||||
if not self._is_tty:
|
||||
self._write(f" [tool] {self.message}", flush=True)
|
||||
while self.running:
|
||||
time.sleep(0.5)
|
||||
return
|
||||
|
||||
# When running inside prompt_toolkit's patch_stdout context the CLI
|
||||
# renders spinner state via a dedicated TUI widget (_spinner_text).
|
||||
# Driving a \r-based animation here too causes visual overdraw: the
|
||||
# StdoutProxy injects newlines around each flush, so every frame lands
|
||||
# on a new line and overwrites the status bar.
|
||||
if self._is_patch_stdout_proxy():
|
||||
while self.running:
|
||||
time.sleep(0.1)
|
||||
return
|
||||
|
||||
# Cache skin wings at start (avoid per-frame imports)
|
||||
skin = _get_skin()
|
||||
wings = skin.get_spinner_wings() if skin else []
|
||||
@@ -279,18 +326,7 @@ class KawaiiSpinner:
|
||||
else:
|
||||
line = f" {frame} {self.message} ({elapsed:.1f}s)"
|
||||
pad = max(self.last_line_len - len(line), 0)
|
||||
# Rate-limit flush() calls to avoid spinner spam under
|
||||
# prompt_toolkit's patch_stdout. Each flush() pushes a queue
|
||||
# item that may trigger a separate run_in_terminal() call; if
|
||||
# items are processed one-at-a-time the \r overwrite is lost
|
||||
# and every frame appears on its own line. By flushing at
|
||||
# most every 0.4s we guarantee multiple \r-frames are batched
|
||||
# into a single write, so the terminal collapses them correctly.
|
||||
now = time.time()
|
||||
should_flush = (now - self._last_flush_time) >= 0.4
|
||||
self._write(f"\r{line}{' ' * pad}", end='', flush=should_flush)
|
||||
if should_flush:
|
||||
self._last_flush_time = now
|
||||
self._write(f"\r{line}{' ' * pad}", end='', flush=True)
|
||||
self.last_line_len = len(line)
|
||||
self.frame_idx += 1
|
||||
time.sleep(0.12)
|
||||
@@ -329,7 +365,7 @@ class KawaiiSpinner:
|
||||
if self.thread:
|
||||
self.thread.join(timeout=0.5)
|
||||
|
||||
is_tty = hasattr(self._out, 'isatty') and self._out.isatty()
|
||||
is_tty = self._is_tty
|
||||
if is_tty:
|
||||
# Clear the spinner line with spaces instead of \033[K to avoid
|
||||
# garbled escape codes when prompt_toolkit's patch_stdout is active.
|
||||
@@ -657,10 +693,6 @@ def format_context_pressure(
|
||||
The bar and percentage show progress toward the compaction threshold,
|
||||
NOT the raw context window. 100% = compaction fires.
|
||||
|
||||
Uses ANSI colors:
|
||||
- cyan at ~60% to compaction = informational
|
||||
- bold yellow at ~85% to compaction = warning
|
||||
|
||||
Args:
|
||||
compaction_progress: How close to compaction (0.0–1.0, 1.0 = fires).
|
||||
threshold_tokens: Compaction threshold in tokens.
|
||||
@@ -674,18 +706,12 @@ def format_context_pressure(
|
||||
threshold_k = f"{threshold_tokens // 1000}k" if threshold_tokens >= 1000 else str(threshold_tokens)
|
||||
threshold_pct_int = int(threshold_percent * 100)
|
||||
|
||||
# Tier styling
|
||||
if compaction_progress >= 0.85:
|
||||
color = f"{_BOLD}{_YELLOW}"
|
||||
icon = "⚠"
|
||||
if compression_enabled:
|
||||
hint = "compaction imminent"
|
||||
else:
|
||||
hint = "no auto-compaction"
|
||||
color = f"{_BOLD}{_YELLOW}"
|
||||
icon = "⚠"
|
||||
if compression_enabled:
|
||||
hint = "compaction approaching"
|
||||
else:
|
||||
color = _CYAN
|
||||
icon = "◐"
|
||||
hint = "approaching compaction"
|
||||
hint = "no auto-compaction"
|
||||
|
||||
return (
|
||||
f" {color}{icon} context {bar} {pct_int}% to compaction{_ANSI_RESET}"
|
||||
@@ -709,14 +735,10 @@ def format_context_pressure_gateway(
|
||||
|
||||
threshold_pct_int = int(threshold_percent * 100)
|
||||
|
||||
if compaction_progress >= 0.85:
|
||||
icon = "⚠️"
|
||||
if compression_enabled:
|
||||
hint = f"Context compaction is imminent (threshold: {threshold_pct_int}% of window)."
|
||||
else:
|
||||
hint = "Auto-compaction is disabled — context may be truncated."
|
||||
icon = "⚠️"
|
||||
if compression_enabled:
|
||||
hint = f"Context compaction approaching (threshold: {threshold_pct_int}% of window)."
|
||||
else:
|
||||
icon = "ℹ️"
|
||||
hint = f"Compaction threshold is at {threshold_pct_int}% of context window."
|
||||
hint = "Auto-compaction is disabled — context may be truncated."
|
||||
|
||||
return f"{icon} Context: {bar} {pct_int}% to compaction\n{hint}"
|
||||
|
||||
+1
-1
@@ -666,7 +666,7 @@ class InsightsEngine:
|
||||
cost_cell = " N/A"
|
||||
lines.append(f" {model_name:<30} {m['sessions']:>8} {m['total_tokens']:>12,} {cost_cell}")
|
||||
if o.get("models_without_pricing"):
|
||||
lines.append(f" * Cost N/A for custom/self-hosted models")
|
||||
lines.append(" * Cost N/A for custom/self-hosted models")
|
||||
lines.append("")
|
||||
|
||||
# Platform breakdown
|
||||
|
||||
+42
-7
@@ -164,6 +164,8 @@ _URL_TO_PROVIDER: Dict[str, str] = {
|
||||
"openrouter.ai": "openrouter",
|
||||
"inference-api.nousresearch.com": "nous",
|
||||
"api.deepseek.com": "deepseek",
|
||||
"api.githubcopilot.com": "copilot",
|
||||
"models.github.ai": "copilot",
|
||||
}
|
||||
|
||||
|
||||
@@ -260,9 +262,11 @@ def detect_local_server_type(base_url: str) -> Optional[str]:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
# llama.cpp exposes /props
|
||||
# llama.cpp exposes /v1/props (older builds used /props without the /v1 prefix)
|
||||
try:
|
||||
r = client.get(f"{server_url}/props")
|
||||
r = client.get(f"{server_url}/v1/props")
|
||||
if r.status_code != 200:
|
||||
r = client.get(f"{server_url}/props") # fallback for older builds
|
||||
if r.status_code == 200 and "default_generation_settings" in r.text:
|
||||
return "llamacpp"
|
||||
except Exception:
|
||||
@@ -455,8 +459,11 @@ def fetch_endpoint_model_metadata(
|
||||
)
|
||||
if is_llamacpp:
|
||||
try:
|
||||
props_url = candidate.rstrip("/").replace("/v1", "") + "/props"
|
||||
props_resp = requests.get(props_url, headers=headers, timeout=5)
|
||||
# Try /v1/props first (current llama.cpp); fall back to /props for older builds
|
||||
base = candidate.rstrip("/").replace("/v1", "")
|
||||
props_resp = requests.get(base + "/v1/props", headers=headers, timeout=5)
|
||||
if not props_resp.ok:
|
||||
props_resp = requests.get(base + "/props", headers=headers, timeout=5)
|
||||
if props_resp.ok:
|
||||
props = props_resp.json()
|
||||
gen_settings = props.get("default_generation_settings", {})
|
||||
@@ -783,8 +790,12 @@ def get_model_context_length(
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
# 2. Active endpoint metadata for explicit custom routes
|
||||
if _is_custom_endpoint(base_url):
|
||||
# 2. Active endpoint metadata for truly custom/unknown endpoints.
|
||||
# Known providers (Copilot, OpenAI, Anthropic, etc.) skip this — their
|
||||
# /models endpoint may report a provider-imposed limit (e.g. Copilot
|
||||
# returns 128k) instead of the model's full context (400k). models.dev
|
||||
# has the correct per-provider values and is checked at step 5+.
|
||||
if _is_custom_endpoint(base_url) and not _is_known_provider_base_url(base_url):
|
||||
endpoint_metadata = fetch_endpoint_model_metadata(base_url, api_key=api_key)
|
||||
matched = endpoint_metadata.get(model)
|
||||
if not matched:
|
||||
@@ -855,10 +866,11 @@ def get_model_context_length(
|
||||
# Only check `default_model in model` (is the key a substring of the input).
|
||||
# The reverse (`model in default_model`) causes shorter names like
|
||||
# "claude-sonnet-4" to incorrectly match "claude-sonnet-4-6" and return 1M.
|
||||
model_lower = model.lower()
|
||||
for default_model, length in sorted(
|
||||
DEFAULT_CONTEXT_LENGTHS.items(), key=lambda x: len(x[0]), reverse=True
|
||||
):
|
||||
if default_model in model:
|
||||
if default_model in model_lower:
|
||||
return length
|
||||
|
||||
# 9. Query local server as last resort
|
||||
@@ -883,3 +895,26 @@ def estimate_messages_tokens_rough(messages: List[Dict[str, Any]]) -> int:
|
||||
"""Rough token estimate for a message list (pre-flight only)."""
|
||||
total_chars = sum(len(str(msg)) for msg in messages)
|
||||
return total_chars // 4
|
||||
|
||||
|
||||
def estimate_request_tokens_rough(
|
||||
messages: List[Dict[str, Any]],
|
||||
*,
|
||||
system_prompt: str = "",
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> int:
|
||||
"""Rough token estimate for a full chat-completions request.
|
||||
|
||||
Includes the major payload buckets Hermes sends to providers:
|
||||
system prompt, conversation messages, and tool schemas. With 50+
|
||||
tools enabled, schemas alone can add 20-30K tokens — a significant
|
||||
blind spot when only counting messages.
|
||||
"""
|
||||
total_chars = 0
|
||||
if system_prompt:
|
||||
total_chars += len(system_prompt)
|
||||
if messages:
|
||||
total_chars += sum(len(str(msg)) for msg in messages)
|
||||
if tools:
|
||||
total_chars += len(str(tools))
|
||||
return total_chars // 4
|
||||
|
||||
+268
-133
@@ -4,12 +4,27 @@ All functions are stateless. AIAgent._build_system_prompt() calls these to
|
||||
assemble pieces, then combines them with memory and ephemeral prompts.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
from typing import Optional
|
||||
|
||||
from agent.skill_utils import (
|
||||
extract_skill_conditions,
|
||||
extract_skill_description,
|
||||
get_disabled_skill_names,
|
||||
iter_skill_index_files,
|
||||
parse_frontmatter,
|
||||
skill_matches_platform,
|
||||
)
|
||||
from utils import atomic_json_write
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -228,6 +243,111 @@ CONTEXT_TRUNCATE_HEAD_RATIO = 0.7
|
||||
CONTEXT_TRUNCATE_TAIL_RATIO = 0.2
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Skills prompt cache
|
||||
# =========================================================================
|
||||
|
||||
_SKILLS_PROMPT_CACHE_MAX = 8
|
||||
_SKILLS_PROMPT_CACHE: OrderedDict[tuple, str] = OrderedDict()
|
||||
_SKILLS_PROMPT_CACHE_LOCK = threading.Lock()
|
||||
_SKILLS_SNAPSHOT_VERSION = 1
|
||||
|
||||
|
||||
def _skills_prompt_snapshot_path() -> Path:
|
||||
return get_hermes_home() / ".skills_prompt_snapshot.json"
|
||||
|
||||
|
||||
def clear_skills_system_prompt_cache(*, clear_snapshot: bool = False) -> None:
|
||||
"""Drop the in-process skills prompt cache (and optionally the disk snapshot)."""
|
||||
with _SKILLS_PROMPT_CACHE_LOCK:
|
||||
_SKILLS_PROMPT_CACHE.clear()
|
||||
if clear_snapshot:
|
||||
try:
|
||||
_skills_prompt_snapshot_path().unlink(missing_ok=True)
|
||||
except OSError as e:
|
||||
logger.debug("Could not remove skills prompt snapshot: %s", e)
|
||||
|
||||
|
||||
def _build_skills_manifest(skills_dir: Path) -> dict[str, list[int]]:
|
||||
"""Build an mtime/size manifest of all SKILL.md and DESCRIPTION.md files."""
|
||||
manifest: dict[str, list[int]] = {}
|
||||
for filename in ("SKILL.md", "DESCRIPTION.md"):
|
||||
for path in iter_skill_index_files(skills_dir, filename):
|
||||
try:
|
||||
st = path.stat()
|
||||
except OSError:
|
||||
continue
|
||||
manifest[str(path.relative_to(skills_dir))] = [st.st_mtime_ns, st.st_size]
|
||||
return manifest
|
||||
|
||||
|
||||
def _load_skills_snapshot(skills_dir: Path) -> Optional[dict]:
|
||||
"""Load the disk snapshot if it exists and its manifest still matches."""
|
||||
snapshot_path = _skills_prompt_snapshot_path()
|
||||
if not snapshot_path.exists():
|
||||
return None
|
||||
try:
|
||||
snapshot = json.loads(snapshot_path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
return None
|
||||
if not isinstance(snapshot, dict):
|
||||
return None
|
||||
if snapshot.get("version") != _SKILLS_SNAPSHOT_VERSION:
|
||||
return None
|
||||
if snapshot.get("manifest") != _build_skills_manifest(skills_dir):
|
||||
return None
|
||||
return snapshot
|
||||
|
||||
|
||||
def _write_skills_snapshot(
|
||||
skills_dir: Path,
|
||||
manifest: dict[str, list[int]],
|
||||
skill_entries: list[dict],
|
||||
category_descriptions: dict[str, str],
|
||||
) -> None:
|
||||
"""Persist skill metadata to disk for fast cold-start reuse."""
|
||||
payload = {
|
||||
"version": _SKILLS_SNAPSHOT_VERSION,
|
||||
"manifest": manifest,
|
||||
"skills": skill_entries,
|
||||
"category_descriptions": category_descriptions,
|
||||
}
|
||||
try:
|
||||
atomic_json_write(_skills_prompt_snapshot_path(), payload)
|
||||
except Exception as e:
|
||||
logger.debug("Could not write skills prompt snapshot: %s", e)
|
||||
|
||||
|
||||
def _build_snapshot_entry(
|
||||
skill_file: Path,
|
||||
skills_dir: Path,
|
||||
frontmatter: dict,
|
||||
description: str,
|
||||
) -> dict:
|
||||
"""Build a serialisable metadata dict for one skill."""
|
||||
rel_path = skill_file.relative_to(skills_dir)
|
||||
parts = rel_path.parts
|
||||
if len(parts) >= 2:
|
||||
skill_name = parts[-2]
|
||||
category = "/".join(parts[:-2]) if len(parts) > 2 else parts[0]
|
||||
else:
|
||||
category = "general"
|
||||
skill_name = skill_file.parent.name
|
||||
|
||||
platforms = frontmatter.get("platforms") or []
|
||||
if isinstance(platforms, str):
|
||||
platforms = [platforms]
|
||||
|
||||
return {
|
||||
"skill_name": skill_name,
|
||||
"category": category,
|
||||
"frontmatter_name": str(frontmatter.get("name", skill_name)),
|
||||
"description": description,
|
||||
"platforms": [str(p).strip() for p in platforms if str(p).strip()],
|
||||
"conditions": extract_skill_conditions(frontmatter),
|
||||
}
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Skills index
|
||||
# =========================================================================
|
||||
@@ -239,22 +359,13 @@ def _parse_skill_file(skill_file: Path) -> tuple[bool, dict, str]:
|
||||
(True, {}, "") to err on the side of showing the skill.
|
||||
"""
|
||||
try:
|
||||
from tools.skills_tool import _parse_frontmatter, skill_matches_platform
|
||||
|
||||
raw = skill_file.read_text(encoding="utf-8")[:2000]
|
||||
frontmatter, _ = _parse_frontmatter(raw)
|
||||
frontmatter, _ = parse_frontmatter(raw)
|
||||
|
||||
if not skill_matches_platform(frontmatter):
|
||||
return False, {}, ""
|
||||
return False, frontmatter, ""
|
||||
|
||||
desc = ""
|
||||
raw_desc = frontmatter.get("description", "")
|
||||
if raw_desc:
|
||||
desc = str(raw_desc).strip().strip("'\"")
|
||||
if len(desc) > 60:
|
||||
desc = desc[:57] + "..."
|
||||
|
||||
return True, frontmatter, desc
|
||||
return True, frontmatter, extract_skill_description(frontmatter)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to parse skill file %s: %s", skill_file, e)
|
||||
return True, {}, ""
|
||||
@@ -263,16 +374,9 @@ def _parse_skill_file(skill_file: Path) -> tuple[bool, dict, str]:
|
||||
def _read_skill_conditions(skill_file: Path) -> dict:
|
||||
"""Extract conditional activation fields from SKILL.md frontmatter."""
|
||||
try:
|
||||
from tools.skills_tool import _parse_frontmatter
|
||||
raw = skill_file.read_text(encoding="utf-8")[:2000]
|
||||
frontmatter, _ = _parse_frontmatter(raw)
|
||||
hermes = frontmatter.get("metadata", {}).get("hermes", {})
|
||||
return {
|
||||
"fallback_for_toolsets": hermes.get("fallback_for_toolsets", []),
|
||||
"requires_toolsets": hermes.get("requires_toolsets", []),
|
||||
"fallback_for_tools": hermes.get("fallback_for_tools", []),
|
||||
"requires_tools": hermes.get("requires_tools", []),
|
||||
}
|
||||
frontmatter, _ = parse_frontmatter(raw)
|
||||
return extract_skill_conditions(frontmatter)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to read skill conditions from %s: %s", skill_file, e)
|
||||
return {}
|
||||
@@ -315,102 +419,153 @@ def build_skills_system_prompt(
|
||||
) -> str:
|
||||
"""Build a compact skill index for the system prompt.
|
||||
|
||||
Scans ~/.hermes/skills/ for SKILL.md files grouped by category.
|
||||
Includes per-skill descriptions from frontmatter so the model can
|
||||
match skills by meaning, not just name.
|
||||
Filters out skills incompatible with the current OS platform.
|
||||
Two-layer cache:
|
||||
1. In-process LRU dict keyed by (skills_dir, tools, toolsets)
|
||||
2. Disk snapshot (``.skills_prompt_snapshot.json``) validated by
|
||||
mtime/size manifest — survives process restarts
|
||||
|
||||
Falls back to a full filesystem scan when both layers miss.
|
||||
"""
|
||||
hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
hermes_home = get_hermes_home()
|
||||
skills_dir = hermes_home / "skills"
|
||||
|
||||
if not skills_dir.exists():
|
||||
return ""
|
||||
|
||||
# Collect skills with descriptions, grouped by category.
|
||||
# Each entry: (skill_name, description)
|
||||
# Supports sub-categories: skills/mlops/training/axolotl/SKILL.md
|
||||
# -> category "mlops/training", skill "axolotl"
|
||||
# Load disabled skill names once for the entire scan
|
||||
try:
|
||||
from tools.skills_tool import _get_disabled_skill_names
|
||||
disabled = _get_disabled_skill_names()
|
||||
except Exception:
|
||||
disabled = set()
|
||||
# ── Layer 1: in-process LRU cache ─────────────────────────────────
|
||||
cache_key = (
|
||||
str(skills_dir.resolve()),
|
||||
tuple(sorted(str(t) for t in (available_tools or set()))),
|
||||
tuple(sorted(str(ts) for ts in (available_toolsets or set()))),
|
||||
)
|
||||
with _SKILLS_PROMPT_CACHE_LOCK:
|
||||
cached = _SKILLS_PROMPT_CACHE.get(cache_key)
|
||||
if cached is not None:
|
||||
_SKILLS_PROMPT_CACHE.move_to_end(cache_key)
|
||||
return cached
|
||||
|
||||
disabled = get_disabled_skill_names()
|
||||
|
||||
# ── Layer 2: disk snapshot ────────────────────────────────────────
|
||||
snapshot = _load_skills_snapshot(skills_dir)
|
||||
|
||||
skills_by_category: dict[str, list[tuple[str, str]]] = {}
|
||||
for skill_file in skills_dir.rglob("SKILL.md"):
|
||||
is_compatible, frontmatter, desc = _parse_skill_file(skill_file)
|
||||
if not is_compatible:
|
||||
continue
|
||||
rel_path = skill_file.relative_to(skills_dir)
|
||||
parts = rel_path.parts
|
||||
if len(parts) >= 2:
|
||||
skill_name = parts[-2]
|
||||
category = "/".join(parts[:-2]) if len(parts) > 2 else parts[0]
|
||||
else:
|
||||
category = "general"
|
||||
skill_name = skill_file.parent.name
|
||||
# Respect user's disabled skills config
|
||||
fm_name = frontmatter.get("name", skill_name)
|
||||
if fm_name in disabled or skill_name in disabled:
|
||||
continue
|
||||
# Skip skills whose conditional activation rules exclude them
|
||||
conditions = _read_skill_conditions(skill_file)
|
||||
if not _skill_should_show(conditions, available_tools, available_toolsets):
|
||||
continue
|
||||
skills_by_category.setdefault(category, []).append((skill_name, desc))
|
||||
category_descriptions: dict[str, str] = {}
|
||||
|
||||
if not skills_by_category:
|
||||
return ""
|
||||
if snapshot is not None:
|
||||
# Fast path: use pre-parsed metadata from disk
|
||||
for entry in snapshot.get("skills", []):
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
skill_name = entry.get("skill_name") or ""
|
||||
category = entry.get("category") or "general"
|
||||
frontmatter_name = entry.get("frontmatter_name") or skill_name
|
||||
platforms = entry.get("platforms") or []
|
||||
if not skill_matches_platform({"platforms": platforms}):
|
||||
continue
|
||||
if frontmatter_name in disabled or skill_name in disabled:
|
||||
continue
|
||||
if not _skill_should_show(
|
||||
entry.get("conditions") or {},
|
||||
available_tools,
|
||||
available_toolsets,
|
||||
):
|
||||
continue
|
||||
skills_by_category.setdefault(category, []).append(
|
||||
(skill_name, entry.get("description", ""))
|
||||
)
|
||||
category_descriptions = {
|
||||
str(k): str(v)
|
||||
for k, v in (snapshot.get("category_descriptions") or {}).items()
|
||||
}
|
||||
else:
|
||||
# Cold path: full filesystem scan + write snapshot for next time
|
||||
skill_entries: list[dict] = []
|
||||
for skill_file in iter_skill_index_files(skills_dir, "SKILL.md"):
|
||||
is_compatible, frontmatter, desc = _parse_skill_file(skill_file)
|
||||
entry = _build_snapshot_entry(skill_file, skills_dir, frontmatter, desc)
|
||||
skill_entries.append(entry)
|
||||
if not is_compatible:
|
||||
continue
|
||||
skill_name = entry["skill_name"]
|
||||
if entry["frontmatter_name"] in disabled or skill_name in disabled:
|
||||
continue
|
||||
if not _skill_should_show(
|
||||
extract_skill_conditions(frontmatter),
|
||||
available_tools,
|
||||
available_toolsets,
|
||||
):
|
||||
continue
|
||||
skills_by_category.setdefault(entry["category"], []).append(
|
||||
(skill_name, entry["description"])
|
||||
)
|
||||
|
||||
# Read category-level descriptions from DESCRIPTION.md
|
||||
# Checks both the exact category path and parent directories
|
||||
category_descriptions = {}
|
||||
for category in skills_by_category:
|
||||
cat_path = Path(category)
|
||||
desc_file = skills_dir / cat_path / "DESCRIPTION.md"
|
||||
if desc_file.exists():
|
||||
# Read category-level DESCRIPTION.md files
|
||||
for desc_file in iter_skill_index_files(skills_dir, "DESCRIPTION.md"):
|
||||
try:
|
||||
content = desc_file.read_text(encoding="utf-8")
|
||||
match = re.search(r"^---\s*\n.*?description:\s*(.+?)\s*\n.*?^---", content, re.MULTILINE | re.DOTALL)
|
||||
if match:
|
||||
category_descriptions[category] = match.group(1).strip()
|
||||
fm, _ = parse_frontmatter(content)
|
||||
cat_desc = fm.get("description")
|
||||
if not cat_desc:
|
||||
continue
|
||||
rel = desc_file.relative_to(skills_dir)
|
||||
cat = "/".join(rel.parts[:-1]) if len(rel.parts) > 1 else "general"
|
||||
category_descriptions[cat] = str(cat_desc).strip().strip("'\"")
|
||||
except Exception as e:
|
||||
logger.debug("Could not read skill description %s: %s", desc_file, e)
|
||||
|
||||
index_lines = []
|
||||
for category in sorted(skills_by_category.keys()):
|
||||
cat_desc = category_descriptions.get(category, "")
|
||||
if cat_desc:
|
||||
index_lines.append(f" {category}: {cat_desc}")
|
||||
else:
|
||||
index_lines.append(f" {category}:")
|
||||
# Deduplicate and sort skills within each category
|
||||
seen = set()
|
||||
for name, desc in sorted(skills_by_category[category], key=lambda x: x[0]):
|
||||
if name in seen:
|
||||
continue
|
||||
seen.add(name)
|
||||
if desc:
|
||||
index_lines.append(f" - {name}: {desc}")
|
||||
else:
|
||||
index_lines.append(f" - {name}")
|
||||
_write_skills_snapshot(
|
||||
skills_dir,
|
||||
_build_skills_manifest(skills_dir),
|
||||
skill_entries,
|
||||
category_descriptions,
|
||||
)
|
||||
|
||||
return (
|
||||
"## Skills (mandatory)\n"
|
||||
"Before replying, scan the skills below. If one clearly matches your task, "
|
||||
"load it with skill_view(name) and follow its instructions. "
|
||||
"If a skill has issues, fix it with skill_manage(action='patch').\n"
|
||||
"After difficult/iterative tasks, offer to save as a skill. "
|
||||
"If a skill you loaded was missing steps, had wrong commands, or needed "
|
||||
"pitfalls you discovered, update it before finishing.\n"
|
||||
"\n"
|
||||
"<available_skills>\n"
|
||||
+ "\n".join(index_lines) + "\n"
|
||||
"</available_skills>\n"
|
||||
"\n"
|
||||
"If none match, proceed normally without loading a skill."
|
||||
)
|
||||
if not skills_by_category:
|
||||
result = ""
|
||||
else:
|
||||
index_lines = []
|
||||
for category in sorted(skills_by_category.keys()):
|
||||
cat_desc = category_descriptions.get(category, "")
|
||||
if cat_desc:
|
||||
index_lines.append(f" {category}: {cat_desc}")
|
||||
else:
|
||||
index_lines.append(f" {category}:")
|
||||
# Deduplicate and sort skills within each category
|
||||
seen = set()
|
||||
for name, desc in sorted(skills_by_category[category], key=lambda x: x[0]):
|
||||
if name in seen:
|
||||
continue
|
||||
seen.add(name)
|
||||
if desc:
|
||||
index_lines.append(f" - {name}: {desc}")
|
||||
else:
|
||||
index_lines.append(f" - {name}")
|
||||
|
||||
result = (
|
||||
"## Skills (mandatory)\n"
|
||||
"Before replying, scan the skills below. If one clearly matches your task, "
|
||||
"load it with skill_view(name) and follow its instructions. "
|
||||
"If a skill has issues, fix it with skill_manage(action='patch').\n"
|
||||
"After difficult/iterative tasks, offer to save as a skill. "
|
||||
"If a skill you loaded was missing steps, had wrong commands, or needed "
|
||||
"pitfalls you discovered, update it before finishing.\n"
|
||||
"\n"
|
||||
"<available_skills>\n"
|
||||
+ "\n".join(index_lines) + "\n"
|
||||
"</available_skills>\n"
|
||||
"\n"
|
||||
"If none match, proceed normally without loading a skill."
|
||||
)
|
||||
|
||||
# ── Store in LRU cache ────────────────────────────────────────────
|
||||
with _SKILLS_PROMPT_CACHE_LOCK:
|
||||
_SKILLS_PROMPT_CACHE[cache_key] = result
|
||||
_SKILLS_PROMPT_CACHE.move_to_end(cache_key)
|
||||
while len(_SKILLS_PROMPT_CACHE) > _SKILLS_PROMPT_CACHE_MAX:
|
||||
_SKILLS_PROMPT_CACHE.popitem(last=False)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# =========================================================================
|
||||
@@ -442,7 +597,7 @@ def load_soul_md() -> Optional[str]:
|
||||
except Exception as e:
|
||||
logger.debug("Could not ensure HERMES_HOME before loading SOUL.md: %s", e)
|
||||
|
||||
soul_path = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) / "SOUL.md"
|
||||
soul_path = get_hermes_home() / "SOUL.md"
|
||||
if not soul_path.exists():
|
||||
return None
|
||||
try:
|
||||
@@ -481,39 +636,19 @@ def _load_hermes_md(cwd_path: Path) -> str:
|
||||
|
||||
|
||||
def _load_agents_md(cwd_path: Path) -> str:
|
||||
"""AGENTS.md — hierarchical, recursive directory walk."""
|
||||
top_level_agents = None
|
||||
"""AGENTS.md — top-level only (no recursive walk)."""
|
||||
for name in ["AGENTS.md", "agents.md"]:
|
||||
candidate = cwd_path / name
|
||||
if candidate.exists():
|
||||
top_level_agents = candidate
|
||||
break
|
||||
|
||||
if not top_level_agents:
|
||||
return ""
|
||||
|
||||
agents_files = []
|
||||
for root, dirs, files in os.walk(cwd_path):
|
||||
dirs[:] = [d for d in dirs if not d.startswith('.') and d not in ('node_modules', '__pycache__', 'venv', '.venv')]
|
||||
for f in files:
|
||||
if f.lower() == "agents.md":
|
||||
agents_files.append(Path(root) / f)
|
||||
agents_files.sort(key=lambda p: len(p.parts))
|
||||
|
||||
total_content = ""
|
||||
for agents_path in agents_files:
|
||||
try:
|
||||
content = agents_path.read_text(encoding="utf-8").strip()
|
||||
if content:
|
||||
rel_path = agents_path.relative_to(cwd_path)
|
||||
content = _scan_context_content(content, str(rel_path))
|
||||
total_content += f"## {rel_path}\n\n{content}\n\n"
|
||||
except Exception as e:
|
||||
logger.debug("Could not read %s: %s", agents_path, e)
|
||||
|
||||
if not total_content:
|
||||
return ""
|
||||
return _truncate_content(total_content, "AGENTS.md")
|
||||
try:
|
||||
content = candidate.read_text(encoding="utf-8").strip()
|
||||
if content:
|
||||
content = _scan_context_content(content, name)
|
||||
result = f"## {name}\n\n{content}"
|
||||
return _truncate_content(result, "AGENTS.md")
|
||||
except Exception as e:
|
||||
logger.debug("Could not read %s: %s", candidate, e)
|
||||
return ""
|
||||
|
||||
|
||||
def _load_claude_md(cwd_path: Path) -> str:
|
||||
@@ -567,7 +702,7 @@ def build_context_files_prompt(cwd: Optional[str] = None, skip_soul: bool = Fals
|
||||
|
||||
Priority (first found wins — only ONE project context type is loaded):
|
||||
1. .hermes.md / HERMES.md (walk to git root)
|
||||
2. AGENTS.md / agents.md (recursive directory walk)
|
||||
2. AGENTS.md / agents.md (cwd only)
|
||||
3. CLAUDE.md / claude.md (cwd only)
|
||||
4. .cursorrules / .cursor/rules/*.mdc (cwd only)
|
||||
|
||||
|
||||
@@ -12,13 +12,14 @@ import copy
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
def _apply_cache_marker(msg: dict, cache_marker: dict) -> None:
|
||||
def _apply_cache_marker(msg: dict, cache_marker: dict, native_anthropic: bool = False) -> None:
|
||||
"""Add cache_control to a single message, handling all format variations."""
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content")
|
||||
|
||||
if role == "tool":
|
||||
msg["cache_control"] = cache_marker
|
||||
if native_anthropic:
|
||||
msg["cache_control"] = cache_marker
|
||||
return
|
||||
|
||||
if content is None or content == "":
|
||||
@@ -40,6 +41,7 @@ def _apply_cache_marker(msg: dict, cache_marker: dict) -> None:
|
||||
def apply_anthropic_cache_control(
|
||||
api_messages: List[Dict[str, Any]],
|
||||
cache_ttl: str = "5m",
|
||||
native_anthropic: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Apply system_and_3 caching strategy to messages for Anthropic models.
|
||||
|
||||
@@ -59,12 +61,12 @@ def apply_anthropic_cache_control(
|
||||
breakpoints_used = 0
|
||||
|
||||
if messages[0].get("role") == "system":
|
||||
_apply_cache_marker(messages[0], marker)
|
||||
_apply_cache_marker(messages[0], marker, native_anthropic=native_anthropic)
|
||||
breakpoints_used += 1
|
||||
|
||||
remaining = 4 - breakpoints_used
|
||||
non_sys = [i for i in range(len(messages)) if messages[i].get("role") != "system"]
|
||||
for idx in non_sys[-remaining:]:
|
||||
_apply_cache_marker(messages[idx], marker)
|
||||
_apply_cache_marker(messages[idx], marker, native_anthropic=native_anthropic)
|
||||
|
||||
return messages
|
||||
|
||||
@@ -100,6 +100,10 @@ def redact_sensitive_text(text: str) -> str:
|
||||
Safe to call on any string -- non-matching text passes through unchanged.
|
||||
Disabled when security.redact_secrets is false in config.yaml.
|
||||
"""
|
||||
if text is None:
|
||||
return None
|
||||
if not isinstance(text, str):
|
||||
text = str(text)
|
||||
if not text:
|
||||
return text
|
||||
if os.getenv("HERMES_REDACT_SECRETS", "").lower() in ("0", "false", "no", "off"):
|
||||
|
||||
@@ -0,0 +1,203 @@
|
||||
"""Lightweight skill metadata utilities shared by prompt_builder and skills_tool.
|
||||
|
||||
This module intentionally avoids importing the tool registry, CLI config, or any
|
||||
heavy dependency chain. It is safe to import at module level without triggering
|
||||
tool registration or provider resolution.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Platform mapping ──────────────────────────────────────────────────────
|
||||
|
||||
PLATFORM_MAP = {
|
||||
"macos": "darwin",
|
||||
"linux": "linux",
|
||||
"windows": "win32",
|
||||
}
|
||||
|
||||
EXCLUDED_SKILL_DIRS = frozenset((".git", ".github", ".hub"))
|
||||
|
||||
# ── Lazy YAML loader ─────────────────────────────────────────────────────
|
||||
|
||||
_yaml_load_fn = None
|
||||
|
||||
|
||||
def yaml_load(content: str):
|
||||
"""Parse YAML with lazy import and CSafeLoader preference."""
|
||||
global _yaml_load_fn
|
||||
if _yaml_load_fn is None:
|
||||
import yaml
|
||||
|
||||
loader = getattr(yaml, "CSafeLoader", None) or yaml.SafeLoader
|
||||
|
||||
def _load(value: str):
|
||||
return yaml.load(value, Loader=loader)
|
||||
|
||||
_yaml_load_fn = _load
|
||||
return _yaml_load_fn(content)
|
||||
|
||||
|
||||
# ── Frontmatter parsing ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def parse_frontmatter(content: str) -> Tuple[Dict[str, Any], str]:
|
||||
"""Parse YAML frontmatter from a markdown string.
|
||||
|
||||
Uses yaml with CSafeLoader for full YAML support (nested metadata, lists)
|
||||
with a fallback to simple key:value splitting for robustness.
|
||||
|
||||
Returns:
|
||||
(frontmatter_dict, remaining_body)
|
||||
"""
|
||||
frontmatter: Dict[str, Any] = {}
|
||||
body = content
|
||||
|
||||
if not content.startswith("---"):
|
||||
return frontmatter, body
|
||||
|
||||
end_match = re.search(r"\n---\s*\n", content[3:])
|
||||
if not end_match:
|
||||
return frontmatter, body
|
||||
|
||||
yaml_content = content[3 : end_match.start() + 3]
|
||||
body = content[end_match.end() + 3 :]
|
||||
|
||||
try:
|
||||
parsed = yaml_load(yaml_content)
|
||||
if isinstance(parsed, dict):
|
||||
frontmatter = parsed
|
||||
except Exception:
|
||||
# Fallback: simple key:value parsing for malformed YAML
|
||||
for line in yaml_content.strip().split("\n"):
|
||||
if ":" not in line:
|
||||
continue
|
||||
key, value = line.split(":", 1)
|
||||
frontmatter[key.strip()] = value.strip()
|
||||
|
||||
return frontmatter, body
|
||||
|
||||
|
||||
# ── Platform matching ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def skill_matches_platform(frontmatter: Dict[str, Any]) -> bool:
|
||||
"""Return True when the skill is compatible with the current OS.
|
||||
|
||||
Skills declare platform requirements via a top-level ``platforms`` list
|
||||
in their YAML frontmatter::
|
||||
|
||||
platforms: [macos] # macOS only
|
||||
platforms: [macos, linux] # macOS and Linux
|
||||
|
||||
If the field is absent or empty the skill is compatible with **all**
|
||||
platforms (backward-compatible default).
|
||||
"""
|
||||
platforms = frontmatter.get("platforms")
|
||||
if not platforms:
|
||||
return True
|
||||
if not isinstance(platforms, list):
|
||||
platforms = [platforms]
|
||||
current = sys.platform
|
||||
for platform in platforms:
|
||||
normalized = str(platform).lower().strip()
|
||||
mapped = PLATFORM_MAP.get(normalized, normalized)
|
||||
if current.startswith(mapped):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# ── Disabled skills ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_disabled_skill_names() -> Set[str]:
|
||||
"""Read disabled skill names from config.yaml.
|
||||
|
||||
Resolves platform from ``HERMES_PLATFORM`` env var, falls back to
|
||||
the global disabled list. Reads the config file directly (no CLI
|
||||
config imports) to stay lightweight.
|
||||
"""
|
||||
config_path = get_hermes_home() / "config.yaml"
|
||||
if not config_path.exists():
|
||||
return set()
|
||||
try:
|
||||
parsed = yaml_load(config_path.read_text(encoding="utf-8"))
|
||||
except Exception as e:
|
||||
logger.debug("Could not read skill config %s: %s", config_path, e)
|
||||
return set()
|
||||
if not isinstance(parsed, dict):
|
||||
return set()
|
||||
|
||||
skills_cfg = parsed.get("skills")
|
||||
if not isinstance(skills_cfg, dict):
|
||||
return set()
|
||||
|
||||
resolved_platform = os.getenv("HERMES_PLATFORM")
|
||||
if resolved_platform:
|
||||
platform_disabled = (skills_cfg.get("platform_disabled") or {}).get(
|
||||
resolved_platform
|
||||
)
|
||||
if platform_disabled is not None:
|
||||
return _normalize_string_set(platform_disabled)
|
||||
return _normalize_string_set(skills_cfg.get("disabled"))
|
||||
|
||||
|
||||
def _normalize_string_set(values) -> Set[str]:
|
||||
if values is None:
|
||||
return set()
|
||||
if isinstance(values, str):
|
||||
values = [values]
|
||||
return {str(v).strip() for v in values if str(v).strip()}
|
||||
|
||||
|
||||
# ── Condition extraction ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def extract_skill_conditions(frontmatter: Dict[str, Any]) -> Dict[str, List]:
|
||||
"""Extract conditional activation fields from parsed frontmatter."""
|
||||
hermes = (frontmatter.get("metadata") or {}).get("hermes") or {}
|
||||
return {
|
||||
"fallback_for_toolsets": hermes.get("fallback_for_toolsets", []),
|
||||
"requires_toolsets": hermes.get("requires_toolsets", []),
|
||||
"fallback_for_tools": hermes.get("fallback_for_tools", []),
|
||||
"requires_tools": hermes.get("requires_tools", []),
|
||||
}
|
||||
|
||||
|
||||
# ── Description extraction ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def extract_skill_description(frontmatter: Dict[str, Any]) -> str:
|
||||
"""Extract a truncated description from parsed frontmatter."""
|
||||
raw_desc = frontmatter.get("description", "")
|
||||
if not raw_desc:
|
||||
return ""
|
||||
desc = str(raw_desc).strip().strip("'\"")
|
||||
if len(desc) > 60:
|
||||
return desc[:57] + "..."
|
||||
return desc
|
||||
|
||||
|
||||
# ── File iteration ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def iter_skill_index_files(skills_dir: Path, filename: str):
|
||||
"""Walk skills_dir yielding sorted paths matching *filename*.
|
||||
|
||||
Excludes ``.git``, ``.github``, ``.hub`` directories.
|
||||
"""
|
||||
matches = []
|
||||
for root, dirs, files in os.walk(skills_dir):
|
||||
dirs[:] = [d for d in dirs if d not in EXCLUDED_SKILL_DIRS]
|
||||
if filename in files:
|
||||
matches.append(Path(root) / filename)
|
||||
for path in sorted(matches, key=lambda p: str(p.relative_to(skills_dir))):
|
||||
yield path
|
||||
@@ -649,7 +649,8 @@ def format_token_count_compact(value: int) -> str:
|
||||
text = f"{scaled:.1f}"
|
||||
else:
|
||||
text = f"{scaled:.0f}"
|
||||
text = text.rstrip("0").rstrip(".")
|
||||
if "." in text:
|
||||
text = text.rstrip("0").rstrip(".")
|
||||
return f"{sign}{text}{suffix}"
|
||||
|
||||
return f"{value:,}"
|
||||
|
||||
@@ -128,6 +128,7 @@ def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, i
|
||||
# Track tool calls from assistant messages
|
||||
if msg["role"] == "assistant" and "tool_calls" in msg and msg["tool_calls"]:
|
||||
for tool_call in msg["tool_calls"]:
|
||||
if not tool_call or not isinstance(tool_call, dict): continue
|
||||
tool_name = tool_call["function"]["name"]
|
||||
tool_call_id = tool_call["id"]
|
||||
|
||||
|
||||
+26
-5
@@ -232,19 +232,34 @@ browser:
|
||||
# 1. Tracks actual token usage from API responses (not estimates)
|
||||
# 2. When prompt_tokens >= threshold% of model's context_length, triggers compression
|
||||
# 3. Protects first 3 turns (system prompt, initial request, first response)
|
||||
# 4. Protects last 4 turns (recent context is most relevant)
|
||||
# 4. Protects last N turns (default 20 messages = ~10 full turns of recent context)
|
||||
# 5. Summarizes middle turns using a fast/cheap model
|
||||
# 6. Inserts summary as a user message, continues conversation seamlessly
|
||||
#
|
||||
# Post-compression tail budget is target_ratio × threshold × context_length:
|
||||
# 200K context, threshold 0.50, ratio 0.20 → 20K tokens of recent tail preserved
|
||||
# 1M context, threshold 0.50, ratio 0.20 → 100K tokens of recent tail preserved
|
||||
#
|
||||
compression:
|
||||
# Enable automatic context compression (default: true)
|
||||
# Set to false if you prefer to manage context manually or want errors on overflow
|
||||
enabled: true
|
||||
|
||||
# Trigger compression at this % of model's context limit (default: 0.85 = 85%)
|
||||
# Trigger compression at this % of model's context limit (default: 0.50 = 50%)
|
||||
# Lower values = more aggressive compression, higher values = compress later
|
||||
threshold: 0.85
|
||||
threshold: 0.50
|
||||
|
||||
# Fraction of the threshold to preserve as recent tail (default: 0.20 = 20%)
|
||||
# e.g. 20% of 50% threshold = 10% of total context kept as recent messages.
|
||||
# Summary output is separately capped at 12K tokens (Gemini output limit).
|
||||
# Range: 0.10 - 0.80
|
||||
target_ratio: 0.20
|
||||
|
||||
# Number of most-recent messages to always preserve (default: 20 ≈ 10 full turns)
|
||||
# Higher values keep more recent conversation intact at the cost of more aggressive
|
||||
# compression of older turns.
|
||||
protect_last_n: 20
|
||||
|
||||
# Model to use for generating summaries (fast/cheap recommended)
|
||||
# This model compresses the middle turns into a concise summary.
|
||||
# IMPORTANT: it receives the full middle section of the conversation, so it
|
||||
@@ -673,6 +688,12 @@ display:
|
||||
# Toggle at runtime with /verbose in the CLI
|
||||
tool_progress: all
|
||||
|
||||
# What Enter does when Hermes is already busy in the CLI.
|
||||
# interrupt: Interrupt the current run and redirect Hermes (default)
|
||||
# queue: Queue your message for the next turn
|
||||
# Ctrl+C always interrupts regardless of this setting.
|
||||
busy_input_mode: interrupt
|
||||
|
||||
# Background process notifications (gateway/messaging only).
|
||||
# Controls how chatty the process watcher is when you use
|
||||
# terminal(background=true, check_interval=...) from Telegram/Discord/etc.
|
||||
@@ -696,8 +717,8 @@ display:
|
||||
# Stream tokens to the terminal as they arrive instead of waiting for the
|
||||
# full response. The response box opens on first token and text appears
|
||||
# line-by-line. Tool calls are still captured silently.
|
||||
# Disabled by default — enable to try the streaming UX.
|
||||
streaming: false
|
||||
# Stream tokens to the terminal in real-time. Disable to wait for full responses.
|
||||
streaming: true
|
||||
|
||||
# ───────────────────────────────────────────────────────────────────────────
|
||||
# Skin / Theme
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
+73
-6
@@ -14,6 +14,7 @@ import re
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from hermes_constants import get_hermes_home
|
||||
from typing import Optional, Dict, List, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -30,7 +31,7 @@ except ImportError:
|
||||
# Configuration
|
||||
# =============================================================================
|
||||
|
||||
HERMES_DIR = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
HERMES_DIR = get_hermes_home()
|
||||
CRON_DIR = HERMES_DIR / "cron"
|
||||
JOBS_FILE = CRON_DIR / "jobs.json"
|
||||
OUTPUT_DIR = CRON_DIR / "output"
|
||||
@@ -248,6 +249,38 @@ def _recoverable_oneshot_run_at(
|
||||
return None
|
||||
|
||||
|
||||
def _compute_grace_seconds(schedule: dict) -> int:
|
||||
"""Compute how late a job can be and still catch up instead of fast-forwarding.
|
||||
|
||||
Uses half the schedule period, clamped between 120 seconds and 2 hours.
|
||||
This ensures daily jobs can catch up if missed by up to 2 hours,
|
||||
while frequent jobs (every 5-10 min) still fast-forward quickly.
|
||||
"""
|
||||
MIN_GRACE = 120
|
||||
MAX_GRACE = 7200 # 2 hours
|
||||
|
||||
kind = schedule.get("kind")
|
||||
|
||||
if kind == "interval":
|
||||
period_seconds = schedule.get("minutes", 1) * 60
|
||||
grace = period_seconds // 2
|
||||
return max(MIN_GRACE, min(grace, MAX_GRACE))
|
||||
|
||||
if kind == "cron" and HAS_CRONITER:
|
||||
try:
|
||||
now = _hermes_now()
|
||||
cron = croniter(schedule["expr"], now)
|
||||
first = cron.get_next(datetime)
|
||||
second = cron.get_next(datetime)
|
||||
period_seconds = int((second - first).total_seconds())
|
||||
grace = period_seconds // 2
|
||||
return max(MIN_GRACE, min(grace, MAX_GRACE))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return MIN_GRACE
|
||||
|
||||
|
||||
def compute_next_run(schedule: Dict[str, Any], last_run_at: Optional[str] = None) -> Optional[str]:
|
||||
"""
|
||||
Compute the next run time for a schedule.
|
||||
@@ -351,6 +384,10 @@ def create_job(
|
||||
"""
|
||||
parsed_schedule = parse_schedule(schedule)
|
||||
|
||||
# Normalize repeat: treat 0 or negative values as None (infinite)
|
||||
if repeat is not None and repeat <= 0:
|
||||
repeat = None
|
||||
|
||||
# Auto-set repeat=1 for one-shot schedules if not specified
|
||||
if parsed_schedule["kind"] == "once" and repeat is None:
|
||||
repeat = 1
|
||||
@@ -539,7 +576,7 @@ def mark_job_run(job_id: str, success: bool, error: Optional[str] = None):
|
||||
# Check if we've hit the repeat limit
|
||||
times = job["repeat"].get("times")
|
||||
completed = job["repeat"]["completed"]
|
||||
if times is not None and completed >= times:
|
||||
if times is not None and times > 0 and completed >= times:
|
||||
# Remove the job (limit reached)
|
||||
jobs.pop(i)
|
||||
save_jobs(jobs)
|
||||
@@ -561,6 +598,34 @@ def mark_job_run(job_id: str, success: bool, error: Optional[str] = None):
|
||||
save_jobs(jobs)
|
||||
|
||||
|
||||
def advance_next_run(job_id: str) -> bool:
|
||||
"""Preemptively advance next_run_at for a recurring job before execution.
|
||||
|
||||
Call this BEFORE run_job() so that if the process crashes mid-execution,
|
||||
the job won't re-fire on the next gateway restart. This converts the
|
||||
scheduler from at-least-once to at-most-once for recurring jobs — missing
|
||||
one run is far better than firing dozens of times in a crash loop.
|
||||
|
||||
One-shot jobs are left unchanged so they can still retry on restart.
|
||||
|
||||
Returns True if next_run_at was advanced, False otherwise.
|
||||
"""
|
||||
jobs = load_jobs()
|
||||
for job in jobs:
|
||||
if job["id"] == job_id:
|
||||
kind = job.get("schedule", {}).get("kind")
|
||||
if kind not in ("cron", "interval"):
|
||||
return False
|
||||
now = _hermes_now().isoformat()
|
||||
new_next = compute_next_run(job["schedule"], now)
|
||||
if new_next and new_next != job.get("next_run_at"):
|
||||
job["next_run_at"] = new_next
|
||||
save_jobs(jobs)
|
||||
return True
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def get_due_jobs() -> List[Dict[str, Any]]:
|
||||
"""Get all jobs that are due to run now.
|
||||
|
||||
@@ -610,16 +675,18 @@ def get_due_jobs() -> List[Dict[str, Any]]:
|
||||
# For recurring jobs, check if the scheduled time is stale
|
||||
# (gateway was down and missed the window). Fast-forward to
|
||||
# the next future occurrence instead of firing a stale run.
|
||||
if kind in ("cron", "interval") and (now - next_run_dt).total_seconds() > 120:
|
||||
# More than 2 minutes late — this is a missed run, not a current one.
|
||||
# Recompute next_run_at to the next future occurrence.
|
||||
grace = _compute_grace_seconds(schedule)
|
||||
if kind in ("cron", "interval") and (now - next_run_dt).total_seconds() > grace:
|
||||
# Job is past its catch-up grace window — this is a stale missed run.
|
||||
# Grace scales with schedule period: daily=2h, hourly=30m, 10min=5m.
|
||||
new_next = compute_next_run(schedule, now.isoformat())
|
||||
if new_next:
|
||||
logger.info(
|
||||
"Job '%s' missed its scheduled time (%s). "
|
||||
"Job '%s' missed its scheduled time (%s, grace=%ds). "
|
||||
"Fast-forwarding to next run: %s",
|
||||
job.get("name", job["id"]),
|
||||
next_run,
|
||||
grace,
|
||||
new_next,
|
||||
)
|
||||
# Update the job in storage
|
||||
|
||||
+48
-28
@@ -24,8 +24,8 @@ except ImportError:
|
||||
import msvcrt
|
||||
except ImportError:
|
||||
msvcrt = None
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from hermes_constants import get_hermes_home
|
||||
from typing import Optional
|
||||
|
||||
from hermes_time import now as _hermes_now
|
||||
@@ -35,7 +35,7 @@ logger = logging.getLogger(__name__)
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from cron.jobs import get_due_jobs, mark_job_run, save_job_output
|
||||
from cron.jobs import get_due_jobs, mark_job_run, save_job_output, advance_next_run
|
||||
|
||||
# Sentinel: when a cron agent has nothing new to report, it can start its
|
||||
# response with this marker to suppress delivery. Output is still saved
|
||||
@@ -43,7 +43,7 @@ from cron.jobs import get_due_jobs, mark_job_run, save_job_output
|
||||
SILENT_MARKER = "[SILENT]"
|
||||
|
||||
# Resolve Hermes home directory (respects HERMES_HOME override)
|
||||
_hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
_hermes_home = get_hermes_home()
|
||||
|
||||
# File-based lock prevents concurrent ticks from gateway + daemon + systemd timer
|
||||
_LOCK_DIR = _hermes_home / "cron"
|
||||
@@ -80,11 +80,16 @@ def _resolve_delivery_target(job: dict) -> Optional[dict]:
|
||||
}
|
||||
|
||||
if ":" in deliver:
|
||||
platform_name, chat_id = deliver.split(":", 1)
|
||||
platform_name, rest = deliver.split(":", 1)
|
||||
# Check for thread_id suffix (e.g. "telegram:-1003724596514:17")
|
||||
if ":" in rest:
|
||||
chat_id, thread_id = rest.split(":", 1)
|
||||
else:
|
||||
chat_id, thread_id = rest, None
|
||||
return {
|
||||
"platform": platform_name,
|
||||
"chat_id": chat_id,
|
||||
"thread_id": None,
|
||||
"thread_id": thread_id,
|
||||
}
|
||||
|
||||
platform_name = deliver
|
||||
@@ -159,15 +164,29 @@ def _deliver_result(job: dict, content: str) -> None:
|
||||
logger.warning("Job '%s': platform '%s' not configured/enabled", job["id"], platform_name)
|
||||
return
|
||||
|
||||
# Wrap the content so the user knows this is a cron delivery and that
|
||||
# the interactive agent has no visibility into it.
|
||||
task_name = job.get("name", job["id"])
|
||||
wrapped = (
|
||||
f"Cronjob Response: {task_name}\n"
|
||||
f"-------------\n\n"
|
||||
f"{content}\n\n"
|
||||
f"Note: The agent cannot see this message, and therefore cannot respond to it."
|
||||
)
|
||||
|
||||
# Run the async send in a fresh event loop (safe from any thread)
|
||||
coro = _send_to_platform(platform, pconfig, chat_id, wrapped, thread_id=thread_id)
|
||||
try:
|
||||
result = asyncio.run(_send_to_platform(platform, pconfig, chat_id, content, thread_id=thread_id))
|
||||
result = asyncio.run(coro)
|
||||
except RuntimeError:
|
||||
# asyncio.run() fails if there's already a running loop in this thread;
|
||||
# spin up a new thread to avoid that.
|
||||
# asyncio.run() checks for a running loop before awaiting the coroutine;
|
||||
# when it raises, the original coro was never started — close it to
|
||||
# prevent "coroutine was never awaited" RuntimeWarning, then retry in a
|
||||
# fresh thread that has no running loop.
|
||||
coro.close()
|
||||
import concurrent.futures
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
future = pool.submit(asyncio.run, _send_to_platform(platform, pconfig, chat_id, content, thread_id=thread_id))
|
||||
future = pool.submit(asyncio.run, _send_to_platform(platform, pconfig, chat_id, wrapped, thread_id=thread_id))
|
||||
result = future.result(timeout=30)
|
||||
except Exception as e:
|
||||
logger.error("Job '%s': delivery to %s:%s failed: %s", job["id"], platform_name, chat_id, e)
|
||||
@@ -177,12 +196,6 @@ def _deliver_result(job: dict, content: str) -> None:
|
||||
logger.error("Job '%s': delivery error: %s", job["id"], result["error"])
|
||||
else:
|
||||
logger.info("Job '%s': delivered to %s:%s", job["id"], platform_name, chat_id)
|
||||
# Mirror the delivered content into the target's gateway session
|
||||
try:
|
||||
from gateway.mirror import mirror_to_session
|
||||
mirror_to_session(platform_name, chat_id, content, source_label="cron", thread_id=thread_id)
|
||||
except Exception as e:
|
||||
logger.warning("Job '%s': mirror_to_session failed: %s", job["id"], e)
|
||||
|
||||
|
||||
def _build_job_prompt(job: dict) -> str:
|
||||
@@ -267,6 +280,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
job_name = job["name"]
|
||||
prompt = _build_job_prompt(job)
|
||||
origin = _resolve_origin(job)
|
||||
_cron_session_id = f"cron_{job_id}_{_hermes_now().strftime('%Y%m%d_%H%M%S')}"
|
||||
|
||||
logger.info("Running job '%s' (ID: %s)", job_name, job_id)
|
||||
logger.info("Prompt: %s", prompt[:100])
|
||||
@@ -314,16 +328,11 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
logger.warning("Job '%s': failed to load config.yaml, using defaults: %s", job_id, e)
|
||||
|
||||
# Reasoning config from env or config.yaml
|
||||
reasoning_config = None
|
||||
from hermes_constants import parse_reasoning_effort
|
||||
effort = os.getenv("HERMES_REASONING_EFFORT", "")
|
||||
if not effort:
|
||||
effort = str(_cfg.get("agent", {}).get("reasoning_effort", "")).strip()
|
||||
if effort and effort.lower() != "none":
|
||||
valid = ("xhigh", "high", "medium", "low", "minimal")
|
||||
if effort.lower() in valid:
|
||||
reasoning_config = {"enabled": True, "effort": effort.lower()}
|
||||
elif effort.lower() == "none":
|
||||
reasoning_config = {"enabled": False}
|
||||
reasoning_config = parse_reasoning_effort(effort)
|
||||
|
||||
# Prefill messages from env or config.yaml
|
||||
prefill_messages = None
|
||||
@@ -398,15 +407,16 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
disabled_toolsets=["cronjob", "messaging", "clarify"],
|
||||
quiet_mode=True,
|
||||
platform="cron",
|
||||
session_id=f"cron_{job_id}_{_hermes_now().strftime('%Y%m%d_%H%M%S')}",
|
||||
session_id=_cron_session_id,
|
||||
session_db=_session_db,
|
||||
)
|
||||
|
||||
result = agent.run_conversation(prompt)
|
||||
|
||||
final_response = result.get("final_response", "")
|
||||
if not final_response:
|
||||
final_response = "(No response generated)"
|
||||
final_response = result.get("final_response", "") or ""
|
||||
# Use a separate variable for log display; keep final_response clean
|
||||
# for delivery logic (empty response = no delivery).
|
||||
logged_response = final_response if final_response else "(No response generated)"
|
||||
|
||||
output = f"""# Cron Job: {job_name}
|
||||
|
||||
@@ -420,7 +430,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
|
||||
## Response
|
||||
|
||||
{final_response}
|
||||
{logged_response}
|
||||
"""
|
||||
|
||||
logger.info("Job '%s' completed successfully", job_name)
|
||||
@@ -462,9 +472,13 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
):
|
||||
os.environ.pop(key, None)
|
||||
if _session_db:
|
||||
try:
|
||||
_session_db.end_session(_cron_session_id, "cron_complete")
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
logger.debug("Job '%s': failed to end session: %s", job_id, e)
|
||||
try:
|
||||
_session_db.close()
|
||||
except Exception as e:
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
logger.debug("Job '%s': failed to close SQLite session store: %s", job_id, e)
|
||||
|
||||
|
||||
@@ -510,6 +524,12 @@ def tick(verbose: bool = True) -> int:
|
||||
executed = 0
|
||||
for job in due_jobs:
|
||||
try:
|
||||
# For recurring jobs (cron/interval), advance next_run_at to the
|
||||
# next future occurrence BEFORE execution. This way, if the
|
||||
# process crashes mid-run, the job won't re-fire on restart.
|
||||
# One-shot jobs are left alone so they can retry on restart.
|
||||
advance_next_run(job["id"])
|
||||
|
||||
success, output, final_response, error = run_job(job)
|
||||
|
||||
output_file = save_job_output(job["id"], output)
|
||||
|
||||
@@ -101,7 +101,7 @@ Available methods:
|
||||
|
||||
### Patches (`patches.py`)
|
||||
|
||||
**Problem**: Some hermes-agent tools use `asyncio.run()` internally (e.g., mini-swe-agent's Modal backend via SWE-ReX). This crashes when called from inside Atropos's event loop because `asyncio.run()` cannot be nested.
|
||||
**Problem**: Some hermes-agent tools use `asyncio.run()` internally (e.g., the Modal backend via SWE-ReX). This crashes when called from inside Atropos's event loop because `asyncio.run()` cannot be nested.
|
||||
|
||||
**Solution**: `patches.py` monkey-patches `SwerexModalEnvironment` to use a dedicated background thread (`_AsyncWorker`) with its own event loop. The calling code sees the same sync interface, but internally the async work happens on a separate thread that doesn't conflict with Atropos's loop.
|
||||
|
||||
|
||||
+96
-63
@@ -18,12 +18,12 @@ import logging
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
|
||||
from model_tools import handle_function_call
|
||||
|
||||
# Thread pool for running sync tool calls that internally use asyncio.run()
|
||||
# (e.g., mini-swe-agent's modal/docker/daytona backends). Running them in a separate
|
||||
# (e.g., the Modal/Docker/Daytona terminal backends). Running them in a separate
|
||||
# thread gives them a clean event loop so they don't deadlock inside Atropos's loop.
|
||||
# Size must be large enough for concurrent eval tasks (e.g., 89 TB2 tasks all
|
||||
# making tool calls). Too small = thread pool starvation, tasks queue for minutes.
|
||||
@@ -138,6 +138,7 @@ class HermesAgentLoop:
|
||||
temperature: float = 1.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
early_stop_check: Optional[Callable[[List[Dict[str, Any]]], bool]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the agent loop.
|
||||
@@ -154,6 +155,9 @@ class HermesAgentLoop:
|
||||
extra_body: Extra parameters passed to the OpenAI client's create() call.
|
||||
Used for OpenRouter provider preferences, transforms, etc.
|
||||
e.g. {"provider": {"ignore": ["DeepInfra"]}}
|
||||
early_stop_check: Optional callback that inspects messages after each tool
|
||||
turn. If it returns True, the loop ends with finished_naturally=True.
|
||||
Used for environment-level completion signals (e.g., flag accepted).
|
||||
"""
|
||||
self.server = server
|
||||
self.tool_schemas = tool_schemas
|
||||
@@ -163,6 +167,7 @@ class HermesAgentLoop:
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.extra_body = extra_body
|
||||
self.early_stop_check = early_stop_check
|
||||
|
||||
async def run(self, messages: List[Dict[str, Any]]) -> AgentResult:
|
||||
"""
|
||||
@@ -346,78 +351,89 @@ class HermesAgentLoop:
|
||||
tool_name, turn + 1,
|
||||
)
|
||||
else:
|
||||
# Parse arguments and dispatch
|
||||
# Parse arguments
|
||||
try:
|
||||
args = json.loads(tool_args_raw)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
except json.JSONDecodeError as e:
|
||||
args = None
|
||||
tool_result = json.dumps(
|
||||
{"error": f"Invalid JSON in tool arguments: {e}. Please retry with valid JSON."}
|
||||
)
|
||||
tool_errors.append(ToolError(
|
||||
turn=turn + 1, tool_name=tool_name,
|
||||
arguments=tool_args_raw[:200],
|
||||
error=f"Invalid JSON: {e}",
|
||||
tool_result=tool_result,
|
||||
))
|
||||
logger.warning(
|
||||
"Invalid JSON in tool call arguments for '%s': %s",
|
||||
tool_name, tool_args_raw[:200],
|
||||
)
|
||||
|
||||
try:
|
||||
if tool_name == "terminal":
|
||||
backend = os.getenv("TERMINAL_ENV", "local")
|
||||
cmd_preview = args.get("command", "")[:80]
|
||||
logger.info(
|
||||
"[%s] $ %s", self.task_id[:8], cmd_preview,
|
||||
)
|
||||
# Dispatch tool only if arguments parsed successfully
|
||||
if args is not None:
|
||||
try:
|
||||
if tool_name == "terminal":
|
||||
backend = os.getenv("TERMINAL_ENV", "local")
|
||||
cmd_preview = args.get("command", "")[:80]
|
||||
logger.info(
|
||||
"[%s] $ %s", self.task_id[:8], cmd_preview,
|
||||
)
|
||||
|
||||
tool_submit_time = _time.monotonic()
|
||||
tool_submit_time = _time.monotonic()
|
||||
|
||||
# Todo tool -- handle locally (needs per-loop TodoStore)
|
||||
if tool_name == "todo":
|
||||
tool_result = _todo_tool(
|
||||
todos=args.get("todos"),
|
||||
merge=args.get("merge", False),
|
||||
store=_todo_store,
|
||||
)
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
elif tool_name == "memory":
|
||||
tool_result = json.dumps({"error": "Memory is not available in RL environments."})
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
elif tool_name == "session_search":
|
||||
tool_result = json.dumps({"error": "Session search is not available in RL environments."})
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
else:
|
||||
# Run tool calls in a thread pool so backends that
|
||||
# use asyncio.run() internally (modal, docker, daytona) get
|
||||
# a clean event loop instead of deadlocking.
|
||||
loop = asyncio.get_event_loop()
|
||||
# Capture current tool_name/args for the lambda
|
||||
_tn, _ta, _tid = tool_name, args, self.task_id
|
||||
tool_result = await loop.run_in_executor(
|
||||
_tool_executor,
|
||||
lambda: handle_function_call(
|
||||
_tn, _ta, task_id=_tid,
|
||||
user_task=_user_task,
|
||||
),
|
||||
)
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
# Todo tool -- handle locally (needs per-loop TodoStore)
|
||||
if tool_name == "todo":
|
||||
tool_result = _todo_tool(
|
||||
todos=args.get("todos"),
|
||||
merge=args.get("merge", False),
|
||||
store=_todo_store,
|
||||
)
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
elif tool_name == "memory":
|
||||
tool_result = json.dumps({"error": "Memory is not available in RL environments."})
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
elif tool_name == "session_search":
|
||||
tool_result = json.dumps({"error": "Session search is not available in RL environments."})
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
else:
|
||||
# Run tool calls in a thread pool so backends that
|
||||
# use asyncio.run() internally (modal, docker, daytona) get
|
||||
# a clean event loop instead of deadlocking.
|
||||
loop = asyncio.get_event_loop()
|
||||
# Capture current tool_name/args for the lambda
|
||||
_tn, _ta, _tid = tool_name, args, self.task_id
|
||||
tool_result = await loop.run_in_executor(
|
||||
_tool_executor,
|
||||
lambda: handle_function_call(
|
||||
_tn, _ta, task_id=_tid,
|
||||
user_task=_user_task,
|
||||
),
|
||||
)
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
|
||||
# Log slow tools and thread pool stats for debugging
|
||||
pool_active = _tool_executor._work_queue.qsize()
|
||||
if tool_elapsed > 30:
|
||||
logger.warning(
|
||||
"[%s] turn %d: %s took %.1fs (pool queue=%d)",
|
||||
self.task_id[:8], turn + 1, tool_name,
|
||||
tool_elapsed, pool_active,
|
||||
# Log slow tools and thread pool stats for debugging
|
||||
pool_active = _tool_executor._work_queue.qsize()
|
||||
if tool_elapsed > 30:
|
||||
logger.warning(
|
||||
"[%s] turn %d: %s took %.1fs (pool queue=%d)",
|
||||
self.task_id[:8], turn + 1, tool_name,
|
||||
tool_elapsed, pool_active,
|
||||
)
|
||||
except Exception as e:
|
||||
tool_result = json.dumps(
|
||||
{"error": f"Tool execution failed: {type(e).__name__}: {str(e)}"}
|
||||
)
|
||||
tool_errors.append(ToolError(
|
||||
turn=turn + 1, tool_name=tool_name,
|
||||
arguments=tool_args_raw[:200],
|
||||
error=f"{type(e).__name__}: {str(e)}",
|
||||
tool_result=tool_result,
|
||||
))
|
||||
logger.error(
|
||||
"Tool '%s' execution failed on turn %d: %s",
|
||||
tool_name, turn + 1, e,
|
||||
)
|
||||
except Exception as e:
|
||||
tool_result = json.dumps(
|
||||
{"error": f"Tool execution failed: {type(e).__name__}: {str(e)}"}
|
||||
)
|
||||
tool_errors.append(ToolError(
|
||||
turn=turn + 1, tool_name=tool_name,
|
||||
arguments=tool_args_raw[:200],
|
||||
error=f"{type(e).__name__}: {str(e)}",
|
||||
tool_result=tool_result,
|
||||
))
|
||||
logger.error(
|
||||
"Tool '%s' execution failed on turn %d: %s",
|
||||
tool_name, turn + 1, e,
|
||||
)
|
||||
|
||||
# Also check if the tool returned an error in its JSON result
|
||||
try:
|
||||
@@ -445,6 +461,23 @@ class HermesAgentLoop:
|
||||
}
|
||||
)
|
||||
|
||||
# Check if environment signals early stop (e.g., flag accepted)
|
||||
if self.early_stop_check and self.early_stop_check(messages):
|
||||
turn_elapsed = _time.monotonic() - turn_start
|
||||
logger.info(
|
||||
"[%s] turn %d: early stop triggered after %d tools (%.1fs)",
|
||||
self.task_id[:8], turn + 1,
|
||||
len(assistant_msg.tool_calls), turn_elapsed,
|
||||
)
|
||||
return AgentResult(
|
||||
messages=messages,
|
||||
managed_state=self._get_managed_state(),
|
||||
turns_used=turn + 1,
|
||||
finished_naturally=True,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
)
|
||||
|
||||
turn_elapsed = _time.monotonic() - turn_start
|
||||
logger.info(
|
||||
"[%s] turn %d: api=%.1fs, %d tools, turn_total=%.1fs",
|
||||
|
||||
@@ -176,6 +176,22 @@ class HermesAgentEnvConfig(BaseEnvConfig):
|
||||
"transforms, and other provider-specific settings.",
|
||||
)
|
||||
|
||||
# --- Security guards ---
|
||||
disable_command_guards: bool = Field(
|
||||
default=False,
|
||||
description="Disable terminal command security guards (dangerous command "
|
||||
"detection, tirith scanning, approval prompts). Enable this for RL "
|
||||
"environment runs where the agent operates inside isolated containers "
|
||||
"and needs unrestricted command execution (e.g., pwn.college challenges "
|
||||
"that require inline Python, raw sockets, binary exploitation, etc.).",
|
||||
)
|
||||
disable_secret_redaction: bool = Field(
|
||||
default=False,
|
||||
description="Disable secret/password redaction in tool output. Enable this "
|
||||
"for RL environments where the agent needs to read source code containing "
|
||||
"password fields (e.g. Flask apps in web-security challenges).",
|
||||
)
|
||||
|
||||
|
||||
class HermesAgentBaseEnv(BaseEnv):
|
||||
"""
|
||||
@@ -218,6 +234,15 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
os.environ["TERMINAL_ENV"] = config.terminal_backend
|
||||
os.environ["TERMINAL_TIMEOUT"] = str(config.terminal_timeout)
|
||||
os.environ["TERMINAL_LIFETIME_SECONDS"] = str(config.terminal_lifetime)
|
||||
|
||||
# Disable command security guards for RL environments that need
|
||||
# unrestricted execution (agent runs inside isolated containers).
|
||||
if config.disable_command_guards:
|
||||
os.environ["HERMES_YOLO_MODE"] = "1"
|
||||
print("🔓 Command guards disabled (disable_command_guards=true)")
|
||||
if config.disable_secret_redaction:
|
||||
os.environ["HERMES_REDACT_SECRETS"] = "false"
|
||||
print("🔓 Secret redaction disabled (disable_secret_redaction=true)")
|
||||
print(
|
||||
f"🖥️ Terminal: backend={config.terminal_backend}, "
|
||||
f"timeout={config.terminal_timeout}s, lifetime={config.terminal_lifetime}s"
|
||||
|
||||
+12
-174
@@ -2,203 +2,41 @@
|
||||
Monkey patches for making hermes-agent tools work inside async frameworks (Atropos).
|
||||
|
||||
Problem:
|
||||
Some tools use asyncio.run() internally (e.g., mini-swe-agent's Modal backend,
|
||||
Some tools use asyncio.run() internally (e.g., Modal backend via SWE-ReX,
|
||||
web_extract). This crashes when called from inside Atropos's event loop because
|
||||
asyncio.run() can't be nested.
|
||||
|
||||
Solution:
|
||||
Replace the problematic methods with versions that use a dedicated background
|
||||
thread with its own event loop. The calling code sees the same sync interface --
|
||||
call a function, get a result -- but internally the async work happens on a
|
||||
separate thread that doesn't conflict with Atropos's loop.
|
||||
The Modal environment (tools/environments/modal.py) now uses a dedicated
|
||||
_AsyncWorker thread internally, making it safe for both CLI and Atropos use.
|
||||
No monkey-patching is required.
|
||||
|
||||
These patches are safe for normal CLI use too: when there's no running event
|
||||
loop, the behavior is identical (the background thread approach works regardless).
|
||||
|
||||
What gets patched:
|
||||
- SwerexModalEnvironment.__init__ -- creates Modal deployment on a background thread
|
||||
- SwerexModalEnvironment.execute -- runs commands on the same background thread
|
||||
- SwerexModalEnvironment.stop -- stops deployment on the background thread
|
||||
This module is kept for backward compatibility — apply_patches() is now a no-op.
|
||||
|
||||
Usage:
|
||||
Call apply_patches() once at import time (done automatically by hermes_base_env.py).
|
||||
This is idempotent -- calling it multiple times is safe.
|
||||
This is idempotent — calling it multiple times is safe.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_patches_applied = False
|
||||
|
||||
|
||||
class _AsyncWorker:
|
||||
"""
|
||||
A dedicated background thread with its own event loop.
|
||||
|
||||
Allows sync code to submit async coroutines and block for results,
|
||||
even when called from inside another running event loop. Used to
|
||||
bridge sync tool interfaces with async backends (Modal, SWE-ReX).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._loop: asyncio.AbstractEventLoop = None
|
||||
self._thread: threading.Thread = None
|
||||
self._started = threading.Event()
|
||||
|
||||
def start(self):
|
||||
"""Start the background event loop thread."""
|
||||
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self._thread.start()
|
||||
self._started.wait(timeout=30)
|
||||
|
||||
def _run_loop(self):
|
||||
"""Background thread entry point -- runs the event loop forever."""
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
self._started.set()
|
||||
self._loop.run_forever()
|
||||
|
||||
def run_coroutine(self, coro, timeout=600):
|
||||
"""
|
||||
Submit a coroutine to the background loop and block until it completes.
|
||||
|
||||
Safe to call from any thread, including threads that already have
|
||||
a running event loop.
|
||||
"""
|
||||
if self._loop is None or self._loop.is_closed():
|
||||
raise RuntimeError("AsyncWorker loop is not running")
|
||||
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
||||
return future.result(timeout=timeout)
|
||||
|
||||
def stop(self):
|
||||
"""Stop the background event loop and join the thread."""
|
||||
if self._loop and self._loop.is_running():
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
if self._thread:
|
||||
self._thread.join(timeout=10)
|
||||
|
||||
|
||||
def _patch_swerex_modal():
|
||||
"""
|
||||
Monkey patch SwerexModalEnvironment to use a background thread event loop
|
||||
instead of asyncio.run(). This makes it safe to call from inside Atropos's
|
||||
async event loop.
|
||||
|
||||
The patched methods have the exact same interface and behavior -- the only
|
||||
difference is HOW the async work is executed internally.
|
||||
"""
|
||||
try:
|
||||
from minisweagent.environments.extra.swerex_modal import (
|
||||
SwerexModalEnvironment,
|
||||
SwerexModalEnvironmentConfig,
|
||||
)
|
||||
from swerex.deployment.modal import ModalDeployment
|
||||
from swerex.runtime.abstract import Command as RexCommand
|
||||
except ImportError:
|
||||
# mini-swe-agent or swe-rex not installed -- nothing to patch
|
||||
logger.debug("mini-swe-agent Modal backend not available, skipping patch")
|
||||
return
|
||||
|
||||
# Save original methods so we can refer to config handling
|
||||
_original_init = SwerexModalEnvironment.__init__
|
||||
|
||||
def _patched_init(self, **kwargs):
|
||||
"""Patched __init__: creates Modal deployment on a background thread."""
|
||||
self.config = SwerexModalEnvironmentConfig(**kwargs)
|
||||
|
||||
# Start a dedicated event loop thread for all Modal async operations
|
||||
self._worker = _AsyncWorker()
|
||||
self._worker.start()
|
||||
|
||||
# Pre-build a modal.Image with pip fix for Modal's legacy image builder.
|
||||
# Modal requires `python -m pip` to work during image build, but some
|
||||
# task images (e.g., TBLite's broken-python) have intentionally broken pip.
|
||||
# Fix: remove stale pip dist-info and reinstall via ensurepip before Modal
|
||||
# tries to use it. This is a no-op for images where pip already works.
|
||||
import modal as _modal
|
||||
image_spec = self.config.image
|
||||
if isinstance(image_spec, str):
|
||||
image_spec = _modal.Image.from_registry(
|
||||
image_spec,
|
||||
setup_dockerfile_commands=[
|
||||
"RUN rm -rf /usr/local/lib/python*/site-packages/pip* 2>/dev/null; "
|
||||
"python -m ensurepip --upgrade --default-pip 2>/dev/null || true",
|
||||
],
|
||||
)
|
||||
|
||||
# Create AND start the deployment entirely on the worker's loop/thread
|
||||
# so all gRPC channels and async state are bound to that loop
|
||||
async def _create_and_start():
|
||||
deployment = ModalDeployment(
|
||||
image=image_spec,
|
||||
startup_timeout=self.config.startup_timeout,
|
||||
runtime_timeout=self.config.runtime_timeout,
|
||||
deployment_timeout=self.config.deployment_timeout,
|
||||
install_pipx=self.config.install_pipx,
|
||||
modal_sandbox_kwargs=self.config.modal_sandbox_kwargs,
|
||||
)
|
||||
await deployment.start()
|
||||
return deployment
|
||||
|
||||
self.deployment = self._worker.run_coroutine(_create_and_start())
|
||||
|
||||
def _patched_execute(self, command: str, cwd: str = "", *, timeout: int | None = None) -> dict[str, Any]:
|
||||
"""Patched execute: runs commands on the background thread's loop."""
|
||||
async def _do_execute():
|
||||
return await self.deployment.runtime.execute(
|
||||
RexCommand(
|
||||
command=command,
|
||||
shell=True,
|
||||
check=False,
|
||||
cwd=cwd or self.config.cwd,
|
||||
timeout=timeout or self.config.timeout,
|
||||
merge_output_streams=True,
|
||||
env=self.config.env if self.config.env else None,
|
||||
)
|
||||
)
|
||||
|
||||
output = self._worker.run_coroutine(_do_execute())
|
||||
return {
|
||||
"output": output.stdout,
|
||||
"returncode": output.exit_code,
|
||||
}
|
||||
|
||||
def _patched_stop(self):
|
||||
"""Patched stop: stops deployment on the background thread, then stops the thread."""
|
||||
try:
|
||||
self._worker.run_coroutine(
|
||||
asyncio.wait_for(self.deployment.stop(), timeout=10),
|
||||
timeout=15,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
self._worker.stop()
|
||||
|
||||
# Apply the patches
|
||||
SwerexModalEnvironment.__init__ = _patched_init
|
||||
SwerexModalEnvironment.execute = _patched_execute
|
||||
SwerexModalEnvironment.stop = _patched_stop
|
||||
|
||||
logger.debug("Patched SwerexModalEnvironment for async-safe operation")
|
||||
|
||||
|
||||
def apply_patches():
|
||||
"""
|
||||
Apply all monkey patches needed for Atropos compatibility.
|
||||
"""Apply all monkey patches needed for Atropos compatibility.
|
||||
|
||||
Safe to call multiple times -- patches are only applied once.
|
||||
Safe for normal CLI use -- patched code works identically when
|
||||
there is no running event loop.
|
||||
Now a no-op — Modal async safety is built directly into ModalEnvironment.
|
||||
Safe to call multiple times.
|
||||
"""
|
||||
global _patches_applied
|
||||
if _patches_applied:
|
||||
return
|
||||
|
||||
_patch_swerex_modal()
|
||||
# Modal async-safety is now built into tools/environments/modal.py
|
||||
# via the _AsyncWorker class. No monkey-patching needed.
|
||||
logger.debug("apply_patches() called — no patches needed (async safety is built-in)")
|
||||
|
||||
_patches_applied = True
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
from .pwncollege_env import PwnCollegeEnv, PwnCollegeEnvConfig
|
||||
@@ -0,0 +1,47 @@
|
||||
# PwnCollege Training Environment
|
||||
#
|
||||
# Usage:
|
||||
# python environments/pwncollege_env/pwncollege_env.py serve \
|
||||
# --config environments/pwncollege_env/default.yaml
|
||||
#
|
||||
# python environments/pwncollege_env/pwncollege_env.py process \
|
||||
# --config environments/pwncollege_env/default.yaml \
|
||||
# --env.data_path_to_save_groups sft_data.jsonl
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal", "file", "pwncollege"]
|
||||
max_agent_turns: 20
|
||||
max_token_length: 16384
|
||||
agent_temperature: 0.7
|
||||
terminal_backend: "ssh"
|
||||
|
||||
# Dojo connection
|
||||
base_url: "http://100.120.55.25:8080"
|
||||
ssh_host: "100.120.55.25"
|
||||
ssh_port: 2222
|
||||
ssh_key: "environments/pwncollege_env/keys/rl_test_key"
|
||||
|
||||
# Training: challenge selection
|
||||
# challenge: "hello/hello" # Single challenge (training fallback)
|
||||
# dojo_filter: "linux-luminarium" # Filter training set by dojo
|
||||
# module_filter: "hello" # Filter training set by module
|
||||
|
||||
# Eval settings (null = all)
|
||||
eval_dojo: null
|
||||
eval_module: null
|
||||
eval_exclude_dojos: ["archive"]
|
||||
eval_concurrency: 16
|
||||
|
||||
# Atropos settings
|
||||
data_dir_to_save_evals: "eval_output/pwncollege"
|
||||
use_wandb: false
|
||||
wandb_name: "pwncollege"
|
||||
ensure_scores_are_not_same: false
|
||||
tokenizer_name: "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
|
||||
openai:
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
model_name: "anthropic/claude-sonnet-4.5"
|
||||
server_type: "openai"
|
||||
health_check: false
|
||||
# api_key: set OPENROUTER_API_KEY in .env or shell
|
||||
@@ -0,0 +1,74 @@
|
||||
env:
|
||||
group_size: 4
|
||||
max_num_workers: -1
|
||||
max_eval_workers: 16
|
||||
max_num_workers_per_node: 8
|
||||
steps_per_eval: 100
|
||||
max_token_length: 16384
|
||||
eval_handling: STOP_TRAIN
|
||||
eval_limit_ratio: 0.5
|
||||
inference_weight: 1.0
|
||||
batch_size: -1
|
||||
max_batches_offpolicy: 3
|
||||
tokenizer_name: NousResearch/Hermes-3-Llama-3.1-8B
|
||||
use_wandb: false
|
||||
rollout_server_url: http://localhost:8000
|
||||
total_steps: 1000
|
||||
wandb_name: pwncollege-intro-cybersec-flash
|
||||
num_rollouts_to_keep: 32
|
||||
num_rollouts_per_group_for_logging: 1
|
||||
ensure_scores_are_not_same: false
|
||||
data_path_to_save_groups: null
|
||||
data_dir_to_save_evals: environments/pwncollege_env/eval_runs/intro_cybersec_flash
|
||||
min_items_sent_before_logging: 2
|
||||
include_messages: false
|
||||
min_batch_allocation: null
|
||||
worker_timeout: 600.0
|
||||
thinking_mode: false
|
||||
reasoning_effort: null
|
||||
max_reasoning_tokens: null
|
||||
custom_thinking_prompt: null
|
||||
enabled_toolsets:
|
||||
- terminal
|
||||
- file
|
||||
- pwncollege
|
||||
disabled_toolsets: null
|
||||
distribution: null
|
||||
max_agent_turns: 80
|
||||
agent_temperature: 0.7
|
||||
terminal_backend: ssh
|
||||
terminal_timeout: 120
|
||||
terminal_lifetime: 3600
|
||||
disable_command_guards: true
|
||||
dataset_name: null
|
||||
dataset_split: train
|
||||
prompt_field: prompt
|
||||
tool_pool_size: 128
|
||||
tool_call_parser: hermes
|
||||
extra_body: null
|
||||
base_url: http://100.120.55.25:8080
|
||||
ssh_host: 100.120.55.25
|
||||
ssh_port: 2222
|
||||
ssh_key: environments/pwncollege_env/keys/rl_test_key
|
||||
challenge: hello/hello
|
||||
dojo_filter: null
|
||||
module_filter: null
|
||||
eval_dojo: intro-to-cybersecurity
|
||||
eval_exclude_dojos:
|
||||
- archive
|
||||
eval_module: null
|
||||
eval_concurrency: 8
|
||||
openai:
|
||||
- timeout: 1200
|
||||
num_max_requests_at_once: 512
|
||||
num_requests_for_eval: 64
|
||||
model_name: xiaomi/mimo-v2-flash
|
||||
rolling_buffer_length: 1000
|
||||
server_type: openai
|
||||
tokenizer_name: none
|
||||
api_key: ""
|
||||
base_url: https://openrouter.ai/api/v1
|
||||
n_kwarg_is_ignored: false
|
||||
health_check: false
|
||||
slurm: false
|
||||
testing: false
|
||||
@@ -0,0 +1,73 @@
|
||||
env:
|
||||
group_size: 4
|
||||
max_num_workers: -1
|
||||
max_eval_workers: 16
|
||||
max_num_workers_per_node: 8
|
||||
steps_per_eval: 100
|
||||
max_token_length: 16384
|
||||
eval_handling: STOP_TRAIN
|
||||
eval_limit_ratio: 0.5
|
||||
inference_weight: 1.0
|
||||
batch_size: -1
|
||||
max_batches_offpolicy: 3
|
||||
tokenizer_name: NousResearch/Hermes-3-Llama-3.1-8B
|
||||
use_wandb: false
|
||||
rollout_server_url: http://localhost:8000
|
||||
total_steps: 1000
|
||||
wandb_name: pwncollege
|
||||
num_rollouts_to_keep: 32
|
||||
num_rollouts_per_group_for_logging: 1
|
||||
ensure_scores_are_not_same: false
|
||||
data_path_to_save_groups: null
|
||||
data_dir_to_save_evals: eval_output/pwncollege
|
||||
min_items_sent_before_logging: 2
|
||||
include_messages: false
|
||||
min_batch_allocation: null
|
||||
worker_timeout: 600.0
|
||||
thinking_mode: false
|
||||
reasoning_effort: null
|
||||
max_reasoning_tokens: null
|
||||
custom_thinking_prompt: null
|
||||
enabled_toolsets:
|
||||
- terminal
|
||||
- file
|
||||
- pwncollege
|
||||
disabled_toolsets: null
|
||||
distribution: null
|
||||
max_agent_turns: 50
|
||||
agent_temperature: 0.7
|
||||
terminal_backend: ssh
|
||||
terminal_timeout: 120
|
||||
terminal_lifetime: 3600
|
||||
dataset_name: null
|
||||
dataset_split: train
|
||||
prompt_field: prompt
|
||||
tool_pool_size: 128
|
||||
tool_call_parser: hermes
|
||||
extra_body: null
|
||||
base_url: http://100.120.55.25:8080
|
||||
ssh_host: 100.120.55.25
|
||||
ssh_port: 2222
|
||||
ssh_key: environments/pwncollege_env/keys/rl_test_key
|
||||
challenge: hello/hello
|
||||
dojo_filter: null
|
||||
module_filter: null
|
||||
eval_dojo: linux-luminarium
|
||||
eval_exclude_dojos:
|
||||
- archive
|
||||
eval_module: hello
|
||||
eval_concurrency: 16
|
||||
openai:
|
||||
- timeout: 1200
|
||||
num_max_requests_at_once: 512
|
||||
num_requests_for_eval: 64
|
||||
model_name: xiaomi/mimo-v2-flash
|
||||
rolling_buffer_length: 1000
|
||||
server_type: openai
|
||||
tokenizer_name: none
|
||||
api_key: ""
|
||||
base_url: https://openrouter.ai/api/v1
|
||||
n_kwarg_is_ignored: false
|
||||
health_check: false
|
||||
slurm: false
|
||||
testing: false
|
||||
@@ -0,0 +1,3 @@
|
||||
# SSH private keys -- never commit
|
||||
*
|
||||
!.gitignore
|
||||
@@ -0,0 +1,54 @@
|
||||
env:
|
||||
# Breadth: total items to process (>= 842 challenges in dojo)
|
||||
total_steps: 850
|
||||
# Depth: completions per item (1 = max coverage speed)
|
||||
group_size: 1
|
||||
# Concurrency: match dojo max_instances (16 slots)
|
||||
eval_concurrency: 16
|
||||
|
||||
max_agent_turns: 30
|
||||
max_token_length: 16384
|
||||
agent_temperature: 0.7
|
||||
enabled_toolsets:
|
||||
- terminal
|
||||
- file
|
||||
- pwncollege
|
||||
terminal_backend: ssh
|
||||
terminal_timeout: 120
|
||||
terminal_lifetime: 3600
|
||||
disable_command_guards: true
|
||||
disable_secret_redaction: true
|
||||
tool_call_parser: hermes
|
||||
|
||||
# Dojo connection
|
||||
base_url: http://100.120.55.25:8080
|
||||
ssh_host: 100.120.55.25
|
||||
ssh_port: 2222
|
||||
ssh_key: ""
|
||||
|
||||
# No filters = all challenges
|
||||
challenge: hello/hello
|
||||
dojo_filter: null
|
||||
module_filter: null
|
||||
|
||||
# Output
|
||||
data_path_to_save_groups: environments/pwncollege_env/eval_runs/process_all.jsonl
|
||||
data_dir_to_save_evals: environments/pwncollege_env/eval_runs/process_all
|
||||
ensure_scores_are_not_same: false
|
||||
include_messages: true
|
||||
use_wandb: false
|
||||
wandb_name: pwncollege-process
|
||||
|
||||
# Not needed for process mode but required by config schema
|
||||
tokenizer_name: NousResearch/Hermes-3-Llama-3.1-8B
|
||||
rollout_server_url: http://localhost:8000
|
||||
openai:
|
||||
- timeout: 1200
|
||||
num_max_requests_at_once: 128
|
||||
model_name: xiaomi/mimo-v2-flash
|
||||
server_type: openai
|
||||
api_key: ""
|
||||
base_url: https://openrouter.ai/api/v1
|
||||
health_check: false
|
||||
slurm: false
|
||||
testing: false
|
||||
@@ -0,0 +1,35 @@
|
||||
"""Prompt templates for pwn.college environments."""
|
||||
|
||||
SYSTEM_PROMPT = """You are a CTF player solving security challenges on pwn.college.
|
||||
|
||||
Your goal is to find and capture the flag. Flags follow the format: `pwn.college{...}`
|
||||
|
||||
When you find a flag, use the submit_flag tool to verify it.
|
||||
|
||||
## How pwn.college Challenges Work
|
||||
|
||||
- `/challenge/run` or `/challenge/bin/<name>` — The challenge binary that checks your work
|
||||
- The flag is printed ONLY when you run the challenge binary after correctly completing the task
|
||||
- Do NOT search for a pre-existing flag file — you must earn it by solving the challenge
|
||||
|
||||
## Tips
|
||||
|
||||
- If a challenge needs a long-running process (e.g., a server or listener), run it in the background and interact with it separately.
|
||||
- For binary exploitation, check protections with `checksec` and use `python3` with `pwntools` if available.
|
||||
- Read error messages carefully — they often hint at what the challenge expects.
|
||||
- Sometimes `/challenge/run` drops you into an interactive shell inside a network namespace. To interact with it, use `tmux`:
|
||||
```
|
||||
tmux new-session -d -s challenge '/challenge/run'
|
||||
tmux send-keys -t challenge 'your-command-here' Enter
|
||||
tmux capture-pane -t challenge -p
|
||||
```
|
||||
"""
|
||||
|
||||
USER_PROMPT_TEMPLATE = """Solve this pwn.college challenge.
|
||||
|
||||
Module: {module_name}
|
||||
Challenge: {challenge_name}
|
||||
|
||||
## Challenge Description
|
||||
|
||||
{challenge_description}"""
|
||||
@@ -0,0 +1,852 @@
|
||||
"""
|
||||
PwnCollege Training Environment for Hermes-Agent + Atropos
|
||||
|
||||
Uses hermes-agent's tool system and HermesAgentLoop for the agent,
|
||||
with pwn.college SDK + SSH for challenge container management.
|
||||
|
||||
Usage:
|
||||
python environments/pwncollege_env/pwncollege_env.py serve \
|
||||
--config environments/pwncollege_env/default.yaml
|
||||
|
||||
python environments/pwncollege_env/pwncollege_env.py process \
|
||||
--config environments/pwncollege_env/default.yaml \
|
||||
--env.data_path_to_save_groups sft_data.jsonl
|
||||
|
||||
python environments/pwncollege_env/pwncollege_env.py evaluate \
|
||||
--config environments/pwncollege_env/default.yaml
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
import httpx
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
# Ensure repo root is on sys.path
|
||||
_repo_root = Path(__file__).resolve().parent.parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
_env_path = _repo_root / ".env"
|
||||
if _env_path.exists():
|
||||
load_dotenv(dotenv_path=_env_path)
|
||||
|
||||
from environments.patches import apply_patches
|
||||
|
||||
apply_patches()
|
||||
|
||||
from atroposlib.envs.base import APIServerConfig, ScoredDataItem
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
from environments.agent_loop import AgentResult, HermesAgentLoop
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
|
||||
# Import submit_flag_tool to trigger registry.register() at module load
|
||||
from environments.pwncollege_env import submit_flag_tool # noqa: F401
|
||||
from environments.pwncollege_env.prompts import SYSTEM_PROMPT, USER_PROMPT_TEMPLATE
|
||||
from environments.pwncollege_env.sdk import (
|
||||
DojoRLClient, DojoRLSyncClient, RLChallenge, RLInstance,
|
||||
)
|
||||
from environments.pwncollege_env.submit_flag_tool import (
|
||||
clear_flag_context,
|
||||
register_flag_context,
|
||||
)
|
||||
from environments.tool_context import ToolContext
|
||||
from tools.terminal_tool import (
|
||||
cleanup_vm,
|
||||
clear_task_env_overrides,
|
||||
register_task_env_overrides,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PwnCollegeEnvConfig(HermesAgentEnvConfig):
|
||||
"""Configuration for PwnCollege environment."""
|
||||
|
||||
# Dojo connection
|
||||
base_url: str = Field(
|
||||
default="http://100.120.55.25:8080",
|
||||
description="Dojo API base URL",
|
||||
)
|
||||
ssh_host: str = Field(
|
||||
default="100.120.55.25",
|
||||
description="SSH host for challenge containers",
|
||||
)
|
||||
ssh_port: int = Field(default=2222, description="SSH port")
|
||||
ssh_key: str = Field(
|
||||
default="",
|
||||
description="Path to SSH private key for RL agent",
|
||||
)
|
||||
|
||||
# Challenge selection
|
||||
challenge: str = Field(
|
||||
default="hello/hello",
|
||||
description="Challenge in module/challenge format (e.g., 'hello/hello', 'paths/root')",
|
||||
)
|
||||
dojo_filter: Optional[str] = Field(default=None, description="Filter by dojo ID")
|
||||
module_filter: Optional[str] = Field(
|
||||
default=None, description="Filter by module ID"
|
||||
)
|
||||
include_challenges: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="Specific challenge keys to include in training "
|
||||
"(format: module_id/challenge_id). Overrides dojo/module "
|
||||
"filters. Use for retry runs.",
|
||||
)
|
||||
|
||||
# Eval settings
|
||||
eval_dojo: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Dojo to evaluate on (None = all dojos)",
|
||||
)
|
||||
eval_exclude_dojos: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Dojos to exclude from evaluation",
|
||||
)
|
||||
eval_module: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Module to evaluate on (None = all modules)",
|
||||
)
|
||||
eval_exclude_modules: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Modules to exclude from evaluation",
|
||||
)
|
||||
eval_challenges: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="Specific challenges to evaluate (format: module_id/challenge_id). Overrides dojo/module filters.",
|
||||
)
|
||||
eval_concurrency: int = Field(
|
||||
default=4,
|
||||
description="Max concurrent eval episodes (limited by dojo slots)",
|
||||
)
|
||||
|
||||
|
||||
class PwnCollegeEnv(HermesAgentBaseEnv):
|
||||
"""PwnCollege training environment.
|
||||
|
||||
Lifecycle per rollout:
|
||||
1. Create dojo instance (SDK) → get slot + ssh_user
|
||||
2. Register SSH overrides so terminal tool routes to that instance
|
||||
3. Register flag context so submit_flag tool can verify flags
|
||||
4. Run hermes-agent loop (terminal + file + submit_flag tools)
|
||||
5. Score: did agent submit the correct flag?
|
||||
6. Cleanup: destroy instance, clear overrides
|
||||
"""
|
||||
|
||||
name = "pwncollege"
|
||||
env_config_cls = PwnCollegeEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PwnCollegeEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm: bool = False,
|
||||
testing: bool = False,
|
||||
):
|
||||
# Set global SSH env vars before super().__init__ triggers terminal validation.
|
||||
# Per-task overrides (ssh_user) are registered before each rollout.
|
||||
os.environ.setdefault("TERMINAL_SSH_HOST", config.ssh_host)
|
||||
os.environ.setdefault("TERMINAL_SSH_USER", "rl_0")
|
||||
os.environ.setdefault("TERMINAL_SSH_KEY", config.ssh_key)
|
||||
|
||||
# Patch api_key from env var before super().__init__ bakes it into openai.AsyncClient
|
||||
api_key = os.getenv("OPENROUTER_API_KEY", "")
|
||||
if api_key:
|
||||
for sc in server_configs:
|
||||
if not sc.api_key:
|
||||
sc.api_key = api_key
|
||||
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.config: PwnCollegeEnvConfig = config
|
||||
|
||||
self.train: list[RLChallenge] = []
|
||||
self.iter = 0
|
||||
self.solve_rate_buffer: list[float] = []
|
||||
self._active_slots: set[int] = set()
|
||||
|
||||
# SDK clients — async for setup/lifecycle, sync for submit_flag handler
|
||||
self.client: Optional[DojoRLClient] = None
|
||||
self.sync_client: Optional[DojoRLSyncClient] = None
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[PwnCollegeEnvConfig, List[APIServerConfig]]:
|
||||
env_config = PwnCollegeEnvConfig(
|
||||
enabled_toolsets=["terminal", "file", "pwncollege"],
|
||||
max_agent_turns=20,
|
||||
max_token_length=16384,
|
||||
agent_temperature=0.7,
|
||||
terminal_backend="ssh",
|
||||
system_prompt=SYSTEM_PROMPT,
|
||||
use_wandb=True,
|
||||
wandb_name="pwncollege",
|
||||
ensure_scores_are_not_same=False,
|
||||
)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model_name="anthropic/claude-sonnet-4.5",
|
||||
server_type="openai",
|
||||
api_key=os.getenv("OPENROUTER_API_KEY", ""),
|
||||
health_check=False,
|
||||
),
|
||||
]
|
||||
return env_config, server_configs
|
||||
|
||||
def _cleanup_instances(self):
|
||||
"""Destroy all running dojo instances. Called on exit/signal."""
|
||||
if not self.sync_client:
|
||||
return
|
||||
try:
|
||||
n = self.sync_client.destroy_all()
|
||||
if n:
|
||||
logger.info("Cleaned up %d dojo instance(s)", n)
|
||||
except Exception as e:
|
||||
logger.warning("Instance cleanup failed: %s", e)
|
||||
|
||||
if hasattr(self, "_auto_ssh_key_dir"):
|
||||
import shutil
|
||||
shutil.rmtree(self._auto_ssh_key_dir, ignore_errors=True)
|
||||
|
||||
def _signal_handler(self, signum, frame):
|
||||
"""Handle SIGINT/SIGTERM: clean up instances, then re-raise."""
|
||||
logger.info("Signal %d received, cleaning up dojo instances...", signum)
|
||||
self._cleanup_instances()
|
||||
signal.signal(signum, signal.SIG_DFL)
|
||||
os.kill(os.getpid(), signum)
|
||||
|
||||
async def _ensure_ssh_key(self):
|
||||
"""Auto-generate and register an SSH key if none configured."""
|
||||
if self.config.ssh_key and Path(self.config.ssh_key).exists():
|
||||
return
|
||||
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
key_dir = Path(tempfile.mkdtemp(prefix="hermes-ssh-"))
|
||||
key_path = key_dir / "id_ed25519"
|
||||
|
||||
subprocess.run(
|
||||
["ssh-keygen", "-t", "ed25519", "-f", str(key_path), "-N", "", "-q"],
|
||||
check=True,
|
||||
)
|
||||
|
||||
pub_key = key_path.with_suffix(".pub").read_text().strip()
|
||||
registered = await self.client.register_ssh_key(pub_key)
|
||||
if not registered:
|
||||
raise RuntimeError("Failed to register SSH key with dojo")
|
||||
|
||||
self.config.ssh_key = str(key_path)
|
||||
os.environ["TERMINAL_SSH_KEY"] = str(key_path)
|
||||
self._auto_ssh_key_dir = key_dir
|
||||
|
||||
logger.info("Auto-generated SSH key and registered with dojo")
|
||||
|
||||
async def setup(self):
|
||||
"""Load challenges from dojo and initialize SDK clients."""
|
||||
self.client = DojoRLClient(self.config.base_url)
|
||||
self.sync_client = DojoRLSyncClient(self.config.base_url)
|
||||
|
||||
await self._ensure_ssh_key()
|
||||
|
||||
atexit.register(self._cleanup_instances)
|
||||
signal.signal(signal.SIGINT, self._signal_handler)
|
||||
signal.signal(signal.SIGTERM, self._signal_handler)
|
||||
|
||||
# Fetch challenges
|
||||
challenges = await self.client.list_challenges()
|
||||
logger.info("Fetched %d challenges from dojo", len(challenges))
|
||||
|
||||
# Apply filters
|
||||
if self.config.include_challenges:
|
||||
# Explicit include list overrides all other filters
|
||||
include_set = set(self.config.include_challenges)
|
||||
for c in challenges:
|
||||
if c.challenge_key in include_set:
|
||||
self.train.append(c)
|
||||
else:
|
||||
for c in challenges:
|
||||
if (self.config.dojo_filter
|
||||
and c.dojo_id != self.config.dojo_filter):
|
||||
continue
|
||||
if (self.config.module_filter
|
||||
and c.module_id != self.config.module_filter):
|
||||
continue
|
||||
self.train.append(c)
|
||||
|
||||
# If a specific challenge is set and no filters matched, use it directly
|
||||
if not self.train and self.config.challenge:
|
||||
parts = self.config.challenge.split("/")
|
||||
self.train.append(
|
||||
RLChallenge(
|
||||
id=parts[-1],
|
||||
module_id=parts[0],
|
||||
dojo_id="unknown",
|
||||
name=self.config.challenge,
|
||||
description="",
|
||||
)
|
||||
)
|
||||
|
||||
if not self.train:
|
||||
raise RuntimeError(
|
||||
f"No challenges matched filters (dojo_filter={self.config.dojo_filter}, "
|
||||
f"module_filter={self.config.module_filter}, challenge={self.config.challenge}). "
|
||||
f"Total available: {len(challenges)}"
|
||||
)
|
||||
|
||||
logger.info("Training on %d challenges", len(self.train))
|
||||
|
||||
async def get_next_item(self) -> RLChallenge:
|
||||
"""Return next challenge item (round-robin)."""
|
||||
item = self.train[self.iter % len(self.train)]
|
||||
self.iter += 1
|
||||
return item
|
||||
|
||||
def _get_challenge_key(self, item: RLChallenge) -> str:
|
||||
"""Extract the challenge key from a challenge."""
|
||||
return item.challenge_key or f"{item.module_id or ''}/{item.id}"
|
||||
|
||||
def format_prompt(self, item: RLChallenge) -> str:
|
||||
"""Build user prompt from challenge metadata."""
|
||||
challenge_key = self._get_challenge_key(item)
|
||||
return USER_PROMPT_TEMPLATE.format(
|
||||
module_name=item.module_id or "unknown",
|
||||
challenge_name=item.name or item.id,
|
||||
challenge_description=item.description or f"Solve the challenge: {challenge_key}",
|
||||
)
|
||||
|
||||
async def _acquire_instance(
|
||||
self, challenge_key: str, *, pool_slot: Optional[int] = None,
|
||||
) -> Optional[RLInstance]:
|
||||
"""Acquire a dojo instance for a challenge.
|
||||
|
||||
If *pool_slot* is given (process mode), try to reset the slot.
|
||||
If the slot is dead on the dojo, destroy it and create a fresh
|
||||
one. The returned instance may have a different slot ID than
|
||||
*pool_slot* — callers must use ``inst.slot`` going forward.
|
||||
|
||||
If *pool_slot* is ``None`` (evaluate / serve modes), create a
|
||||
new instance with transient-error retries.
|
||||
"""
|
||||
if pool_slot is not None:
|
||||
# Pool mode: try reset first (fast path)
|
||||
try:
|
||||
return await self.client.reset_instance(
|
||||
pool_slot, challenge=challenge_key,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"reset_instance(%d, %s) failed: %s — "
|
||||
"destroying and creating fresh slot",
|
||||
pool_slot, challenge_key, str(e)[:80],
|
||||
)
|
||||
try:
|
||||
await self.client.destroy_instance(pool_slot)
|
||||
except Exception:
|
||||
pass
|
||||
# Fall through to create mode
|
||||
|
||||
# Create mode: new instance with transient-error retries
|
||||
max_retries = 10 if pool_slot is not None else 5
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return await self.client.create_instance(
|
||||
challenge_key,
|
||||
)
|
||||
except Exception as e:
|
||||
err_str = str(e)
|
||||
is_transient = (
|
||||
isinstance(e, httpx.HTTPStatusError)
|
||||
and e.response.status_code >= 500
|
||||
or isinstance(e, (
|
||||
httpx.ReadTimeout,
|
||||
httpx.ConnectTimeout,
|
||||
httpx.ConnectError,
|
||||
))
|
||||
or "No available slots" in err_str
|
||||
)
|
||||
if is_transient and attempt < max_retries - 1:
|
||||
wait = min(2 ** (attempt + 1), 60)
|
||||
logger.warning(
|
||||
"Transient error creating instance "
|
||||
"for %s (attempt %d/%d): %s, "
|
||||
"retrying in %ds",
|
||||
challenge_key, attempt + 1,
|
||||
max_retries, err_str[:80], wait,
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
else:
|
||||
logger.error(
|
||||
"Failed to create instance for %s "
|
||||
"after %d attempts: %s",
|
||||
challenge_key, attempt + 1, e,
|
||||
)
|
||||
return None
|
||||
return None
|
||||
|
||||
async def collect_trajectory(
|
||||
self, item: Item, *, pool_instance: Optional[RLInstance] = None,
|
||||
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
|
||||
"""Run a single rollout with dojo instance lifecycle.
|
||||
|
||||
Wraps the agent loop with:
|
||||
1. Dojo instance creation (SSH-accessible challenge container)
|
||||
2. SSH override registration (routes terminal tool to the instance)
|
||||
3. Flag context registration (enables submit_flag tool)
|
||||
4. Cleanup on completion
|
||||
|
||||
When *pool_instance* is provided (process mode), that
|
||||
pre-acquired instance is used directly and NOT destroyed on
|
||||
completion — the caller manages its lifecycle.
|
||||
"""
|
||||
task_id = str(uuid.uuid4())
|
||||
challenge_key = self._get_challenge_key(item)
|
||||
owns_slot = pool_instance is None
|
||||
|
||||
if pool_instance is not None:
|
||||
inst = pool_instance
|
||||
else:
|
||||
inst = await self._acquire_instance(challenge_key)
|
||||
if inst is None:
|
||||
return None, []
|
||||
|
||||
slot = inst.slot
|
||||
self._active_slots.add(slot)
|
||||
register_task_env_overrides(
|
||||
task_id,
|
||||
{
|
||||
"ssh_user": inst.ssh_user,
|
||||
"ssh_host": self.config.ssh_host,
|
||||
"ssh_port": self.config.ssh_port,
|
||||
"ssh_key": self.config.ssh_key,
|
||||
},
|
||||
)
|
||||
register_flag_context(task_id, self.sync_client, slot)
|
||||
|
||||
try:
|
||||
# Resolve tools (includes submit_flag via "pwncollege" toolset)
|
||||
if self._current_group_tools is None:
|
||||
tools, valid_names = self._resolve_tools_for_group()
|
||||
else:
|
||||
tools, valid_names = self._current_group_tools
|
||||
|
||||
messages: List[Dict[str, Any]] = []
|
||||
if self.config.system_prompt:
|
||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
messages.append({"role": "user", "content": self.format_prompt(item)})
|
||||
|
||||
agent = HermesAgentLoop(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
# Skip reward if agent produced no output
|
||||
only_system_and_user = all(
|
||||
msg.get("role") in ("system", "user") for msg in result.messages
|
||||
)
|
||||
if result.turns_used == 0 or only_system_and_user:
|
||||
logger.warning("Agent produced no output for %s", challenge_key)
|
||||
reward = 0.0
|
||||
else:
|
||||
ctx = ToolContext(task_id)
|
||||
try:
|
||||
reward = await self.compute_reward(item, result, ctx)
|
||||
finally:
|
||||
ctx.cleanup()
|
||||
|
||||
# Track tool errors
|
||||
if result.tool_errors:
|
||||
for err in result.tool_errors:
|
||||
self._tool_error_buffer.append({
|
||||
"turn": err.turn,
|
||||
"tool": err.tool_name,
|
||||
"args": err.arguments[:150],
|
||||
"error": err.error[:300],
|
||||
"result": err.tool_result[:300],
|
||||
})
|
||||
|
||||
# Build scored item (Phase 1: placeholder tokens)
|
||||
full_text = "\n".join(
|
||||
msg.get("content", "") for msg in result.messages if msg.get("content")
|
||||
)
|
||||
if self.tokenizer:
|
||||
tokens = self.tokenizer.encode(full_text, add_special_tokens=True)
|
||||
else:
|
||||
tokens = list(range(min(len(full_text) // 4, 128)))
|
||||
|
||||
scored_item = {
|
||||
"tokens": tokens,
|
||||
"masks": [-100] + tokens[1:],
|
||||
"scores": reward,
|
||||
"messages": result.messages,
|
||||
}
|
||||
return scored_item, []
|
||||
|
||||
finally:
|
||||
clear_flag_context(task_id)
|
||||
clear_task_env_overrides(task_id)
|
||||
cleanup_vm(task_id)
|
||||
if owns_slot:
|
||||
# Evaluate/serve mode: we created it, we destroy it
|
||||
try:
|
||||
await self.client.destroy_instance(slot)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to destroy instance slot %d: %s", slot, e)
|
||||
# Pool mode: caller is responsible for the slot lifecycle
|
||||
self._active_slots.discard(slot)
|
||||
|
||||
async def compute_reward(
|
||||
self, item: Item, result: AgentResult, ctx: ToolContext
|
||||
) -> float:
|
||||
"""Score the rollout: 1.0 if flag was correctly submitted, 0.0 otherwise.
|
||||
|
||||
Checks two signals:
|
||||
1. Did submit_flag return {"success": true}?
|
||||
2. Fallback: extract pwn.college{...} from terminal output and verify via SDK.
|
||||
"""
|
||||
# Check submit_flag tool results in the conversation
|
||||
for msg in result.messages:
|
||||
if msg.get("role") == "tool":
|
||||
try:
|
||||
data = json.loads(msg.get("content", ""))
|
||||
if isinstance(data, dict) and data.get("success") is True:
|
||||
self.solve_rate_buffer.append(1.0)
|
||||
return 1.0
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# Fallback: scan for flag pattern in all messages
|
||||
for msg in result.messages:
|
||||
content = msg.get("content", "")
|
||||
if not content:
|
||||
continue
|
||||
flag_match = re.search(r"pwn\.college\{[^}]+\}", content)
|
||||
if flag_match:
|
||||
# We can't verify here since instance is being torn down,
|
||||
# but the flag pattern presence suggests partial progress
|
||||
self.solve_rate_buffer.append(0.0)
|
||||
return 0.0
|
||||
|
||||
self.solve_rate_buffer.append(0.0)
|
||||
return 0.0
|
||||
|
||||
async def process_manager(self):
|
||||
"""Override: process items concurrently with pre-allocated slot pool.
|
||||
|
||||
Uses a pool of dojo instances (asyncio.Queue) instead of a semaphore.
|
||||
Each task waits for a real dojo slot to become available, resets it
|
||||
to the target challenge, and returns it to the pool on completion.
|
||||
This guarantees zero silent drops from slot contention.
|
||||
"""
|
||||
from atroposlib.frontend.jsonl2html import generate_html
|
||||
|
||||
await self.setup()
|
||||
|
||||
if self.config.use_wandb:
|
||||
import random
|
||||
import string
|
||||
from datetime import datetime
|
||||
|
||||
import wandb
|
||||
|
||||
random_id = "".join(random.choices(string.ascii_lowercase, k=6))
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
wandb.init(
|
||||
project=self.wandb_project,
|
||||
name=f"{self.name}-{current_date}-{random_id}",
|
||||
group=self.wandb_group,
|
||||
config=self.config.model_dump(),
|
||||
)
|
||||
|
||||
self.config.group_size = self.group_size_to_process
|
||||
items = self.train[:self.n_groups_to_process]
|
||||
|
||||
total = len(items)
|
||||
concurrency = self.config.eval_concurrency
|
||||
completed = 0
|
||||
|
||||
# --- Pre-allocate slot pool ---
|
||||
# Use the first challenge as a throwaway target; each task will
|
||||
# reset_instance to its own challenge before running.
|
||||
first_key = self._get_challenge_key(items[0]) if items else "hello/hello"
|
||||
slot_pool: asyncio.Queue[int] = asyncio.Queue()
|
||||
pool_size = 0
|
||||
|
||||
logger.info("Pre-allocating %d dojo slots...", concurrency)
|
||||
for i in range(concurrency):
|
||||
try:
|
||||
inst = await self.client.create_instance(first_key)
|
||||
slot_pool.put_nowait(inst.slot)
|
||||
pool_size += 1
|
||||
except Exception as e:
|
||||
# Dojo has a hard slot cap; once full, stop trying
|
||||
logger.info(
|
||||
"Pre-allocated %d/%d slots (dojo full: %s)",
|
||||
i, concurrency, e,
|
||||
)
|
||||
break
|
||||
|
||||
if pool_size == 0:
|
||||
raise RuntimeError("Could not allocate any dojo slots")
|
||||
|
||||
logger.info(
|
||||
"Processing %d items (pool_size=%d, group_size=%d)",
|
||||
total, pool_size, self.group_size_to_process,
|
||||
)
|
||||
|
||||
# Resolve tools once before launching concurrent tasks
|
||||
self._current_group_tools = self._resolve_tools_for_group()
|
||||
|
||||
async def process_one(item):
|
||||
nonlocal completed
|
||||
challenge_key = self._get_challenge_key(item)
|
||||
|
||||
# Wait for a real slot (blocks until one is returned)
|
||||
original_slot = await slot_pool.get()
|
||||
# _acquire_instance may create a new slot if the original
|
||||
# died on the dojo, so we track the actual slot to return.
|
||||
actual_slot: int | None = original_slot
|
||||
|
||||
try:
|
||||
# Acquire instance (reset or create)
|
||||
inst = await self._acquire_instance(
|
||||
challenge_key, pool_slot=original_slot,
|
||||
)
|
||||
if inst is None:
|
||||
logger.warning(
|
||||
"Could not acquire instance for %s",
|
||||
challenge_key,
|
||||
)
|
||||
actual_slot = None # don't poison pool
|
||||
return
|
||||
actual_slot = inst.slot
|
||||
if actual_slot != original_slot:
|
||||
logger.info(
|
||||
"Slot %d replaced with %d for %s",
|
||||
original_slot, actual_slot,
|
||||
challenge_key,
|
||||
)
|
||||
|
||||
# Run the trajectory with the acquired instance
|
||||
scored, _ = await self.collect_trajectory(
|
||||
item, pool_instance=inst,
|
||||
)
|
||||
if scored is None:
|
||||
logger.warning(
|
||||
"No scored data for %s (slot %d)",
|
||||
challenge_key, actual_slot,
|
||||
)
|
||||
return
|
||||
|
||||
# Wrap in ScoredDataGroup for postprocessing
|
||||
to_postprocess = {
|
||||
"tokens": [scored["tokens"]],
|
||||
"masks": [scored["masks"]],
|
||||
"scores": [scored["scores"]],
|
||||
"advantages": [],
|
||||
"ref_logprobs": [],
|
||||
"messages": [scored.get("messages", [])],
|
||||
"group_overrides": {},
|
||||
"overrides": [],
|
||||
"images": [],
|
||||
}
|
||||
processed = await self.postprocess_histories(
|
||||
to_postprocess,
|
||||
)
|
||||
await self.handle_send_to_api(
|
||||
processed, item,
|
||||
do_send_to_api=False,
|
||||
abort_on_any_max_length_exceeded=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to process %s: %s", challenge_key, e,
|
||||
)
|
||||
finally:
|
||||
completed += 1
|
||||
logger.info(
|
||||
"Processed %d/%d (%s)",
|
||||
completed, total, challenge_key,
|
||||
)
|
||||
# Return the actual slot to pool (may differ from
|
||||
# original_slot if reset failed and a new one was
|
||||
# created). None means acquisition failed entirely.
|
||||
if actual_slot is not None:
|
||||
slot_pool.put_nowait(actual_slot)
|
||||
|
||||
await asyncio.gather(*[process_one(item) for item in items])
|
||||
|
||||
logger.info("Completed processing %d items", completed)
|
||||
|
||||
# Cleanup: destroy all pooled slots
|
||||
while not slot_pool.empty():
|
||||
slot = slot_pool.get_nowait()
|
||||
try:
|
||||
await self.client.destroy_instance(slot)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to destroy pool slot %d: %s", slot, e)
|
||||
|
||||
if self.jsonl_writer is not None:
|
||||
self.jsonl_writer.close()
|
||||
|
||||
if self.config.data_path_to_save_groups:
|
||||
generate_html(self.config.data_path_to_save_groups)
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""Run evaluation on a dojo/module and report solve rate.
|
||||
|
||||
Fetches challenges matching eval_dojo/eval_module, runs each through
|
||||
the agent loop with concurrency control, and logs results.
|
||||
"""
|
||||
import time
|
||||
|
||||
if not self.client:
|
||||
logger.error("SDK client not initialized. Call setup() first.")
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Fetch and filter eval challenges
|
||||
all_challenges = await self.client.list_challenges()
|
||||
if self.config.eval_challenges:
|
||||
challenge_set = set(self.config.eval_challenges)
|
||||
eval_challenges = [c for c in all_challenges if c.challenge_key in challenge_set]
|
||||
else:
|
||||
eval_challenges = [
|
||||
c for c in all_challenges
|
||||
if (self.config.eval_dojo is None or c.dojo_id == self.config.eval_dojo)
|
||||
and (self.config.eval_module is None or c.module_id == self.config.eval_module)
|
||||
and c.dojo_id not in self.config.eval_exclude_dojos
|
||||
and c.module_id not in self.config.eval_exclude_modules
|
||||
]
|
||||
|
||||
if not eval_challenges:
|
||||
logger.warning(
|
||||
"No challenges found for eval_dojo=%s eval_module=%s",
|
||||
self.config.eval_dojo, self.config.eval_module,
|
||||
)
|
||||
return
|
||||
|
||||
print(
|
||||
f"Evaluating {len(eval_challenges)} challenges from "
|
||||
f"{self.config.eval_dojo or '*'}/{self.config.eval_module or '*'} "
|
||||
f"(concurrency={self.config.eval_concurrency})",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
semaphore = asyncio.Semaphore(self.config.eval_concurrency)
|
||||
completed = 0
|
||||
total = len(eval_challenges)
|
||||
|
||||
async def eval_one(challenge: RLChallenge) -> dict:
|
||||
nonlocal completed
|
||||
challenge_key = self._get_challenge_key(challenge)
|
||||
async with semaphore:
|
||||
try:
|
||||
scored, _ = await self.collect_trajectory(challenge)
|
||||
solved = scored is not None and scored.get("scores", 0.0) >= 1.0
|
||||
completed += 1
|
||||
status = "PASS" if solved else "FAIL"
|
||||
reward = scored.get("scores", 0.0) if scored else 0.0
|
||||
print(
|
||||
f" [{completed}/{total}] [{status}] {challenge_key} "
|
||||
f"(reward={reward:.1f})",
|
||||
flush=True,
|
||||
)
|
||||
result = {
|
||||
"challenge": challenge_key,
|
||||
"name": challenge.name,
|
||||
"solved": solved,
|
||||
"reward": reward,
|
||||
}
|
||||
# Stream-write sample with full conversation for HTML viewer
|
||||
self.log_eval_sample({
|
||||
"score": reward,
|
||||
"challenge": challenge_key,
|
||||
"solved": solved,
|
||||
"messages": scored.get("messages", []) if scored else [],
|
||||
})
|
||||
return result
|
||||
except Exception as e:
|
||||
completed += 1
|
||||
print(
|
||||
f" [{completed}/{total}] [ERR ] {challenge_key}: {e}",
|
||||
flush=True,
|
||||
)
|
||||
self.log_eval_sample({
|
||||
"score": 0.0,
|
||||
"challenge": challenge_key,
|
||||
"solved": False,
|
||||
"messages": [{"role": "system", "content": f"Error: {e}"}],
|
||||
})
|
||||
return {
|
||||
"challenge": challenge_key,
|
||||
"name": challenge.name,
|
||||
"solved": False,
|
||||
"reward": 0.0,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
tasks = [eval_one(c) for c in eval_challenges]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
# Aggregate
|
||||
n = len(results)
|
||||
solved = sum(1 for r in results if r["solved"])
|
||||
solve_rate = solved / n if n else 0.0
|
||||
|
||||
print("=" * 60, flush=True)
|
||||
print(
|
||||
f"Eval: {solved}/{n} solved ({solve_rate * 100:.1f}%) "
|
||||
f"in {end_time - start_time:.1f}s",
|
||||
flush=True,
|
||||
)
|
||||
print("=" * 60, flush=True)
|
||||
|
||||
eval_metrics = {
|
||||
"eval/solve_rate": solve_rate,
|
||||
"eval/solved": solved,
|
||||
"eval/total": n,
|
||||
}
|
||||
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log solve rate metrics to wandb."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
if self.solve_rate_buffer:
|
||||
n = len(self.solve_rate_buffer)
|
||||
wandb_metrics["train/solve_rate"] = sum(self.solve_rate_buffer) / n
|
||||
wandb_metrics["train/num_rollouts"] = n
|
||||
self.solve_rate_buffer = []
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
PwnCollegeEnv.cli()
|
||||
@@ -0,0 +1,468 @@
|
||||
"""SDK for pwncollege dojo"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _extract_csrf_nonce(html: str) -> str | None:
|
||||
match = re.search(r"'csrfNonce': \"([^\"]+)\"", html)
|
||||
return match.group(1) if match else None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLInstance:
|
||||
slot: int
|
||||
ssh_user: str
|
||||
challenge_id: str
|
||||
module_id: str
|
||||
dojo_id: str
|
||||
flag: str | None = None
|
||||
created_at: float | None = None
|
||||
status: str | None = None
|
||||
|
||||
@property
|
||||
def challenge_key(self) -> str:
|
||||
return f"{self.module_id}/{self.challenge_id}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLResource:
|
||||
type: str
|
||||
name: str
|
||||
content: str | None = None
|
||||
video: str | None = None
|
||||
slides: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLChallenge:
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
module_id: str | None = None
|
||||
module_name: str | None = None
|
||||
module_description: str | None = None
|
||||
dojo_id: str | None = None
|
||||
dojo_name: str | None = None
|
||||
dojo_description: str | None = None
|
||||
resources: list[RLResource] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def challenge_key(self) -> str | None:
|
||||
if self.module_id:
|
||||
return f"{self.module_id}/{self.id}"
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLStatus:
|
||||
enabled: bool
|
||||
max_instances: int
|
||||
running: int
|
||||
instances: list[RLInstance]
|
||||
|
||||
|
||||
class DojoRLClient:
|
||||
"""Client for the dojo RL API. No auth required."""
|
||||
|
||||
def __init__(self, base_url: str, timeout: float = 120.0):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.client = httpx.AsyncClient(
|
||||
base_url=self.base_url,
|
||||
timeout=timeout,
|
||||
follow_redirects=True,
|
||||
)
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
await self.close()
|
||||
|
||||
async def close(self):
|
||||
await self.client.aclose()
|
||||
|
||||
def _rl_url(self, path: str) -> str:
|
||||
return f"/pwncollege_api/v1/rl{path}"
|
||||
|
||||
async def _get(self, path: str) -> dict[str, Any]:
|
||||
resp = await self.client.get(self._rl_url(path))
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
async def _post(self, path: str, json: dict | None = None) -> dict[str, Any]:
|
||||
resp = await self.client.post(self._rl_url(path), json=json or {})
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
async def _delete(self, path: str) -> dict[str, Any]:
|
||||
resp = await self.client.delete(self._rl_url(path))
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
# ── Response Parsing ──────────────────────────────────────────────────────
|
||||
# The API uses different field names in create/reset vs get/list responses.
|
||||
# These parsers normalize everything into RLInstance.
|
||||
|
||||
@staticmethod
|
||||
def _parse_create_response(data: dict[str, Any]) -> RLInstance:
|
||||
return RLInstance(
|
||||
slot=data["slot"],
|
||||
ssh_user=data["ssh_user"],
|
||||
challenge_id=data["challenge"],
|
||||
module_id=data["module"],
|
||||
dojo_id=data["dojo"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_instance_detail(data: dict[str, Any]) -> RLInstance:
|
||||
created_at = data.get("created_at")
|
||||
return RLInstance(
|
||||
slot=data["slot"],
|
||||
ssh_user=data.get("ssh_user", f"rl_{data['slot']}"),
|
||||
challenge_id=data["challenge_id"],
|
||||
module_id=data["module_id"],
|
||||
dojo_id=data["dojo_id"],
|
||||
flag=data.get("flag"),
|
||||
created_at=float(created_at) if created_at else None,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_instance_listing(data: dict[str, Any]) -> RLInstance:
|
||||
created_at = data.get("created_at")
|
||||
return RLInstance(
|
||||
slot=data["slot"],
|
||||
ssh_user=f"rl_{data['slot']}",
|
||||
challenge_id=data["challenge_id"],
|
||||
module_id=data["module_id"],
|
||||
dojo_id=data["dojo_id"],
|
||||
created_at=float(created_at) if created_at else None,
|
||||
status=data.get("status"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_challenge(data: dict[str, Any]) -> RLChallenge:
|
||||
resources = [
|
||||
RLResource(
|
||||
type=r["type"],
|
||||
name=r["name"],
|
||||
content=r.get("content"),
|
||||
video=r.get("video"),
|
||||
slides=r.get("slides"),
|
||||
)
|
||||
for r in data.get("resources", [])
|
||||
]
|
||||
return RLChallenge(
|
||||
id=data["id"],
|
||||
name=data["name"],
|
||||
description=data["description"],
|
||||
module_id=data.get("module_id"),
|
||||
module_name=data.get("module_name"),
|
||||
module_description=data.get("module_description"),
|
||||
dojo_id=data.get("dojo_id"),
|
||||
dojo_name=data.get("dojo_name"),
|
||||
dojo_description=data.get("dojo_description"),
|
||||
resources=resources,
|
||||
)
|
||||
|
||||
# ── RL Instance Lifecycle ─────────────────────────────────────────────────
|
||||
|
||||
async def status(self) -> RLStatus:
|
||||
result = await self._get("/status")
|
||||
instances = [
|
||||
self._parse_instance_listing(inst) for inst in result.get("instances", [])
|
||||
]
|
||||
return RLStatus(
|
||||
enabled=result["enabled"],
|
||||
max_instances=result["max_instances"],
|
||||
running=result["running"],
|
||||
instances=instances,
|
||||
)
|
||||
|
||||
async def create_instance(
|
||||
self, challenge: str, *, variant: int | None = None
|
||||
) -> RLInstance:
|
||||
data: dict[str, Any] = {"challenge": challenge}
|
||||
if variant is not None:
|
||||
data["variant"] = variant
|
||||
result = await self._post("/instances", json=data)
|
||||
if not result.get("success"):
|
||||
raise RuntimeError(f"Failed to create instance: {result.get('error')}")
|
||||
return self._parse_create_response(result)
|
||||
|
||||
async def get_instance(self, slot: int) -> RLInstance:
|
||||
result = await self._get(f"/instances/{slot}")
|
||||
if not result.get("success"):
|
||||
raise KeyError(f"No instance at slot {slot}")
|
||||
return self._parse_instance_detail(result)
|
||||
|
||||
async def list_instances(self) -> list[RLInstance]:
|
||||
result = await self._get("/instances")
|
||||
return [
|
||||
self._parse_instance_listing(inst) for inst in result.get("instances", [])
|
||||
]
|
||||
|
||||
async def destroy_instance(self, slot: int) -> None:
|
||||
result = await self._delete(f"/instances/{slot}")
|
||||
if not result.get("success"):
|
||||
raise RuntimeError(f"Failed to destroy instance: {result.get('error')}")
|
||||
|
||||
async def reset_instance(
|
||||
self, slot: int, *, challenge: str | None = None
|
||||
) -> RLInstance:
|
||||
data: dict[str, Any] = {}
|
||||
if challenge is not None:
|
||||
data["challenge"] = challenge
|
||||
result = await self._post(f"/instances/{slot}/reset", json=data)
|
||||
if not result.get("success"):
|
||||
raise RuntimeError(f"Failed to reset instance: {result.get('error')}")
|
||||
return self._parse_create_response(result)
|
||||
|
||||
async def check_flag(self, slot: int, flag: str) -> bool:
|
||||
result = await self._post(f"/instances/{slot}/check", json={"flag": flag})
|
||||
return result.get("correct", False)
|
||||
|
||||
async def get_flag(self, slot: int) -> str:
|
||||
instance = await self.get_instance(slot)
|
||||
if instance.flag is None:
|
||||
raise RuntimeError(f"No flag available for slot {slot}")
|
||||
return instance.flag
|
||||
|
||||
# ── SSH Key Management ────────────────────────────────────────────────────
|
||||
|
||||
async def register_ssh_key(self, public_key: str) -> bool:
|
||||
result = await self._post("/ssh_key", json={"public_key": public_key})
|
||||
return result.get("success", False)
|
||||
|
||||
async def get_ssh_key(self) -> dict[str, Any]:
|
||||
return await self._get("/ssh_key")
|
||||
|
||||
# ── Challenge Discovery ───────────────────────────────────────────────────
|
||||
|
||||
async def list_challenges(self) -> list[RLChallenge]:
|
||||
result = await self._get("/challenges")
|
||||
return [self._parse_challenge(ch) for ch in result.get("challenges", [])]
|
||||
|
||||
# ── Admin (requires auth) ─────────────────────────────────────────────────
|
||||
|
||||
async def admin_login(
|
||||
self, username: str = "admin", password: str = "admin"
|
||||
) -> None:
|
||||
resp = await self.client.get("/login")
|
||||
nonce = _extract_csrf_nonce(resp.text)
|
||||
if not nonce:
|
||||
raise RuntimeError("Could not extract CSRF nonce")
|
||||
self._admin_csrf = nonce
|
||||
resp = await self.client.post(
|
||||
"/login",
|
||||
data={"name": username, "password": password, "nonce": nonce},
|
||||
)
|
||||
if resp.status_code not in (200, 302):
|
||||
raise RuntimeError(f"Login failed: {resp.status_code}")
|
||||
resp = await self.client.get("/")
|
||||
self._admin_csrf = _extract_csrf_nonce(resp.text) or self._admin_csrf
|
||||
|
||||
async def load_dojo(self, repository: str) -> str:
|
||||
if not hasattr(self, "_admin_csrf"):
|
||||
raise RuntimeError("Must call admin_login() first")
|
||||
resp = await self.client.post(
|
||||
"/pwncollege_api/v1/dojos/create",
|
||||
json={
|
||||
"repository": repository,
|
||||
"public_key": f"public/{repository}",
|
||||
"private_key": f"private/{repository}",
|
||||
},
|
||||
headers={"CSRF-Token": self._admin_csrf},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
if not data.get("success", True):
|
||||
raise RuntimeError(f"Failed to load dojo: {data.get('error', data)}")
|
||||
return data.get("dojo", repository)
|
||||
|
||||
async def promote_dojo(self, dojo_id: str) -> None:
|
||||
if not hasattr(self, "_admin_csrf"):
|
||||
raise RuntimeError("Must call admin_login() first")
|
||||
resp = await self.client.post(
|
||||
f"/pwncollege_api/v1/dojos/{dojo_id}/promote",
|
||||
json={},
|
||||
headers={"CSRF-Token": self._admin_csrf},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
# ── Bulk Operations ───────────────────────────────────────────────────────
|
||||
|
||||
async def create_batch(self, challenge: str, count: int) -> list[RLInstance]:
|
||||
tasks = [self.create_instance(challenge) for _ in range(count)]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
async def destroy_all(self) -> int:
|
||||
instances = await self.list_instances()
|
||||
for inst in instances:
|
||||
await self.destroy_instance(inst.slot)
|
||||
return len(instances)
|
||||
|
||||
|
||||
class DojoRLSyncClient:
|
||||
"""Sync wrapper for DojoRLClient.
|
||||
|
||||
Runs all async operations on a dedicated background thread with its own
|
||||
event loop, so it's safe to call from any context — including from inside
|
||||
another running event loop (e.g., Atropos's loop or tool dispatch threads).
|
||||
"""
|
||||
|
||||
def __init__(self, base_url: str, timeout: float = 120.0):
|
||||
import threading
|
||||
|
||||
self._async = DojoRLClient(base_url, timeout)
|
||||
self._loop = asyncio.new_event_loop()
|
||||
self._thread = threading.Thread(
|
||||
target=self._loop.run_forever,
|
||||
daemon=True,
|
||||
)
|
||||
self._thread.start()
|
||||
|
||||
def _run(self, coro):
|
||||
return asyncio.run_coroutine_threadsafe(coro, self._loop).result()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.close()
|
||||
|
||||
def close(self):
|
||||
if not self._loop.is_running():
|
||||
return
|
||||
try:
|
||||
self._run(self._async.close())
|
||||
except Exception:
|
||||
pass
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
self._thread.join(timeout=5)
|
||||
|
||||
def status(self) -> RLStatus:
|
||||
return self._run(self._async.status())
|
||||
|
||||
def create_instance(
|
||||
self, challenge: str, *, variant: int | None = None
|
||||
) -> RLInstance:
|
||||
return self._run(self._async.create_instance(challenge, variant=variant))
|
||||
|
||||
def get_instance(self, slot: int) -> RLInstance:
|
||||
return self._run(self._async.get_instance(slot))
|
||||
|
||||
def list_instances(self) -> list[RLInstance]:
|
||||
return self._run(self._async.list_instances())
|
||||
|
||||
def destroy_instance(self, slot: int) -> None:
|
||||
return self._run(self._async.destroy_instance(slot))
|
||||
|
||||
def reset_instance(self, slot: int, *, challenge: str | None = None) -> RLInstance:
|
||||
return self._run(self._async.reset_instance(slot, challenge=challenge))
|
||||
|
||||
def check_flag(self, slot: int, flag: str) -> bool:
|
||||
return self._run(self._async.check_flag(slot, flag))
|
||||
|
||||
def get_flag(self, slot: int) -> str:
|
||||
return self._run(self._async.get_flag(slot))
|
||||
|
||||
def list_challenges(self) -> list[RLChallenge]:
|
||||
return self._run(self._async.list_challenges())
|
||||
|
||||
def register_ssh_key(self, public_key: str) -> bool:
|
||||
return self._run(self._async.register_ssh_key(public_key))
|
||||
|
||||
def get_ssh_key(self) -> dict[str, Any]:
|
||||
return self._run(self._async.get_ssh_key())
|
||||
|
||||
def admin_login(self, username: str = "admin", password: str = "admin") -> None:
|
||||
return self._run(self._async.admin_login(username, password))
|
||||
|
||||
def load_dojo(self, repository: str) -> str:
|
||||
return self._run(self._async.load_dojo(repository))
|
||||
|
||||
def promote_dojo(self, dojo_id: str) -> None:
|
||||
return self._run(self._async.promote_dojo(dojo_id))
|
||||
|
||||
def destroy_all(self) -> int:
|
||||
return self._run(self._async.destroy_all())
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpisodePool:
|
||||
"""Manages a pool of RL instances for parallel episode collection."""
|
||||
|
||||
client: DojoRLClient
|
||||
challenge: str
|
||||
pool_size: int = 32
|
||||
acquisition_timeout: float = 300.0
|
||||
|
||||
_available: asyncio.Queue[RLInstance] = field(
|
||||
default_factory=asyncio.Queue, init=False
|
||||
)
|
||||
_all_instances: dict[int, RLInstance] = field(default_factory=dict, init=False)
|
||||
_initialized: bool = field(default=False, init=False)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
if self._initialized:
|
||||
return
|
||||
for _ in range(self.pool_size):
|
||||
instance = await self.client.create_instance(self.challenge)
|
||||
full = await self.client.get_instance(instance.slot)
|
||||
self._all_instances[instance.slot] = full
|
||||
await self._available.put(full)
|
||||
self._initialized = True
|
||||
|
||||
@asynccontextmanager
|
||||
async def acquire(self):
|
||||
if not self._initialized:
|
||||
raise RuntimeError("EpisodePool not initialized")
|
||||
try:
|
||||
instance = await asyncio.wait_for(
|
||||
self._available.get(), timeout=self.acquisition_timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise RuntimeError(
|
||||
f"No instance available within {self.acquisition_timeout}s"
|
||||
)
|
||||
try:
|
||||
yield instance
|
||||
finally:
|
||||
try:
|
||||
reset = await self.client.reset_instance(
|
||||
instance.slot, challenge=self.challenge
|
||||
)
|
||||
full = await self.client.get_instance(reset.slot)
|
||||
self._all_instances[reset.slot] = full
|
||||
await self._available.put(full)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to reset instance slot %d, returning stale instance: %s",
|
||||
instance.slot,
|
||||
e,
|
||||
)
|
||||
await self._available.put(instance)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
errors = []
|
||||
for slot in list(self._all_instances.keys()):
|
||||
try:
|
||||
await self.client.destroy_instance(slot)
|
||||
except Exception as e:
|
||||
errors.append((slot, e))
|
||||
logger.warning("Failed to destroy instance slot %d: %s", slot, e)
|
||||
self._all_instances.clear()
|
||||
self._initialized = False
|
||||
if errors:
|
||||
logger.error(
|
||||
"EpisodePool shutdown: %d instance(s) failed to destroy", len(errors)
|
||||
)
|
||||
@@ -0,0 +1,74 @@
|
||||
env:
|
||||
group_size: 4
|
||||
max_num_workers: -1
|
||||
max_eval_workers: 16
|
||||
max_num_workers_per_node: 8
|
||||
steps_per_eval: 100
|
||||
max_token_length: 16384
|
||||
eval_handling: STOP_TRAIN
|
||||
eval_limit_ratio: 0.5
|
||||
inference_weight: 1.0
|
||||
batch_size: -1
|
||||
max_batches_offpolicy: 3
|
||||
tokenizer_name: NousResearch/Hermes-3-Llama-3.1-8B
|
||||
use_wandb: false
|
||||
rollout_server_url: http://localhost:8000
|
||||
total_steps: 1000
|
||||
wandb_name: pwncollege-smoke-hello
|
||||
num_rollouts_to_keep: 32
|
||||
num_rollouts_per_group_for_logging: 1
|
||||
ensure_scores_are_not_same: false
|
||||
data_path_to_save_groups: null
|
||||
data_dir_to_save_evals: environments/pwncollege_env/eval_runs/smoke_hello
|
||||
min_items_sent_before_logging: 2
|
||||
include_messages: false
|
||||
min_batch_allocation: null
|
||||
worker_timeout: 600.0
|
||||
thinking_mode: false
|
||||
reasoning_effort: null
|
||||
max_reasoning_tokens: null
|
||||
custom_thinking_prompt: null
|
||||
enabled_toolsets:
|
||||
- terminal
|
||||
- file
|
||||
- pwncollege
|
||||
disabled_toolsets: null
|
||||
distribution: null
|
||||
max_agent_turns: 20
|
||||
agent_temperature: 0.7
|
||||
terminal_backend: ssh
|
||||
terminal_timeout: 120
|
||||
terminal_lifetime: 3600
|
||||
disable_command_guards: true
|
||||
dataset_name: null
|
||||
dataset_split: train
|
||||
prompt_field: prompt
|
||||
tool_pool_size: 128
|
||||
tool_call_parser: hermes
|
||||
extra_body: null
|
||||
base_url: http://100.120.55.25:8080
|
||||
ssh_host: 100.120.55.25
|
||||
ssh_port: 2222
|
||||
ssh_key: environments/pwncollege_env/keys/rl_test_key
|
||||
challenge: hello/hello
|
||||
dojo_filter: null
|
||||
module_filter: null
|
||||
eval_dojo: linux-luminarium
|
||||
eval_exclude_dojos:
|
||||
- archive
|
||||
eval_module: hello
|
||||
eval_concurrency: 3
|
||||
openai:
|
||||
- timeout: 1200
|
||||
num_max_requests_at_once: 512
|
||||
num_requests_for_eval: 64
|
||||
model_name: xiaomi/mimo-v2-flash
|
||||
rolling_buffer_length: 1000
|
||||
server_type: openai
|
||||
tokenizer_name: none
|
||||
api_key: ""
|
||||
base_url: https://openrouter.ai/api/v1
|
||||
n_kwarg_is_ignored: false
|
||||
health_check: false
|
||||
slurm: false
|
||||
testing: false
|
||||
@@ -0,0 +1,513 @@
|
||||
"""
|
||||
Capability verification test for pwn-dojo RL infrastructure.
|
||||
|
||||
Verifies that RL containers are provisioned with the correct Linux capabilities,
|
||||
resource limits, and host configuration for each challenge type.
|
||||
|
||||
Usage:
|
||||
python environments/pwncollege_env/stress_test.py -y
|
||||
python environments/pwncollege_env/stress_test.py -y -o report.json --verbose
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
_repo_root = Path(__file__).resolve().parent.parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from environments.pwncollege_env.sdk import DojoRLClient
|
||||
|
||||
|
||||
@dataclass
|
||||
class SSHConfig:
|
||||
host: str
|
||||
port: int
|
||||
key: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class CheckResult:
|
||||
name: str
|
||||
passed: bool
|
||||
message: str
|
||||
duration: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestResult:
|
||||
name: str
|
||||
challenge: str
|
||||
checks: list[CheckResult] = field(default_factory=list)
|
||||
passed: bool = False
|
||||
skipped: bool = False
|
||||
error: str | None = None
|
||||
duration: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestCase:
|
||||
name: str
|
||||
challenge: str
|
||||
checks: list
|
||||
|
||||
|
||||
async def ssh_run(
|
||||
cfg: SSHConfig, user: str, command: str, timeout: float = 30.0
|
||||
) -> tuple[int, str]:
|
||||
"""Run a command over SSH via subprocess. Returns (returncode, output)."""
|
||||
cmd = [
|
||||
"ssh",
|
||||
"-o",
|
||||
"BatchMode=yes",
|
||||
"-o",
|
||||
"StrictHostKeyChecking=accept-new",
|
||||
"-o",
|
||||
"UserKnownHostsFile=/dev/null",
|
||||
"-o",
|
||||
"ConnectTimeout=10",
|
||||
"-o",
|
||||
"LogLevel=ERROR",
|
||||
"-p",
|
||||
str(cfg.port),
|
||||
"-i",
|
||||
cfg.key,
|
||||
f"{user}@{cfg.host}",
|
||||
command,
|
||||
]
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
)
|
||||
try:
|
||||
stdout, _ = await asyncio.wait_for(proc.communicate(), timeout=timeout)
|
||||
return proc.returncode, stdout.decode(errors="replace")
|
||||
except asyncio.TimeoutError:
|
||||
proc.kill()
|
||||
await proc.wait()
|
||||
return -1, f"[SSH timeout after {timeout}s]"
|
||||
|
||||
|
||||
async def wait_ssh_ready(cfg: SSHConfig, user: str, retries: int = 10) -> bool:
|
||||
for i in range(retries):
|
||||
rc, out = await ssh_run(cfg, user, "echo ready", timeout=10)
|
||||
if rc == 0 and "ready" in out:
|
||||
return True
|
||||
await asyncio.sleep(1)
|
||||
return False
|
||||
|
||||
|
||||
# ── Check functions ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def check_ssh_echo(cfg: SSHConfig, user: str) -> CheckResult:
|
||||
t0 = time.monotonic()
|
||||
rc, out = await ssh_run(cfg, user, "echo ok")
|
||||
dur = time.monotonic() - t0
|
||||
if rc == 0 and "ok" in out:
|
||||
return CheckResult("ssh_echo", True, "connected", dur)
|
||||
return CheckResult("ssh_echo", False, f"rc={rc}: {out.strip()[:100]}", dur)
|
||||
|
||||
|
||||
async def check_unshare_net(cfg: SSHConfig, user: str) -> CheckResult:
|
||||
t0 = time.monotonic()
|
||||
rc, out = await ssh_run(cfg, user, "unshare --net echo ok")
|
||||
dur = time.monotonic() - t0
|
||||
if rc == 0 and "ok" in out:
|
||||
return CheckResult("unshare_net", True, "namespace creation works", dur)
|
||||
return CheckResult("unshare_net", False, f"rc={rc}: {out.strip()[:120]}", dur)
|
||||
|
||||
|
||||
async def check_unshare_user(cfg: SSHConfig, user: str) -> CheckResult:
|
||||
t0 = time.monotonic()
|
||||
rc, out = await ssh_run(cfg, user, "unshare --user --map-root-user bash -c 'id'")
|
||||
dur = time.monotonic() - t0
|
||||
if rc == 0 and "uid=0" in out:
|
||||
return CheckResult("unshare_user", True, "user namespace works", dur)
|
||||
return CheckResult("unshare_user", False, f"rc={rc}: {out.strip()[:120]}", dur)
|
||||
|
||||
|
||||
async def check_capeff(cfg: SSHConfig, user: str) -> CheckResult:
|
||||
"""Check that the container init (PID 1) has SYS_ADMIN capability."""
|
||||
t0 = time.monotonic()
|
||||
rc, out = await ssh_run(cfg, user, "cat /proc/1/status")
|
||||
dur = time.monotonic() - t0
|
||||
if rc != 0:
|
||||
return CheckResult(
|
||||
"capeff", False, f"Cannot read /proc/1/status: {out.strip()[:80]}", dur
|
||||
)
|
||||
for line in out.splitlines():
|
||||
if line.startswith("CapEff:") or line.startswith("CapBnd:"):
|
||||
hex_val = line.split(":")[1].strip()
|
||||
try:
|
||||
val = int(hex_val, 16)
|
||||
has_sysadmin = bool(val & (1 << 21))
|
||||
if has_sysadmin:
|
||||
label = line.split(":")[0]
|
||||
return CheckResult(
|
||||
"capeff", True, f"{label}={hex_val} has SYS_ADMIN", dur
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
return CheckResult(
|
||||
"capeff", False, "SYS_ADMIN (bit 21) not found in capabilities", dur
|
||||
)
|
||||
|
||||
|
||||
async def check_hosts_resolution(cfg: SSHConfig, user: str) -> CheckResult:
|
||||
t0 = time.monotonic()
|
||||
rc, out = await ssh_run(cfg, user, "getent hosts challenge.localhost")
|
||||
dur = time.monotonic() - t0
|
||||
if rc == 0 and out.strip():
|
||||
return CheckResult(
|
||||
"hosts_resolution", True, f"resolves to {out.strip()[:40]}", dur
|
||||
)
|
||||
rc2, out2 = await ssh_run(cfg, user, "grep challenge.localhost /etc/hosts")
|
||||
dur = time.monotonic() - t0
|
||||
if rc2 == 0 and "challenge.localhost" in out2:
|
||||
return CheckResult(
|
||||
"hosts_resolution", True, "/etc/hosts has entry", dur
|
||||
)
|
||||
return CheckResult(
|
||||
"hosts_resolution", False, "challenge.localhost not resolvable", dur
|
||||
)
|
||||
|
||||
|
||||
async def check_pids_limit(cfg: SSHConfig, user: str) -> CheckResult:
|
||||
t0 = time.monotonic()
|
||||
rc, out = await ssh_run(
|
||||
cfg,
|
||||
user,
|
||||
"cat /sys/fs/cgroup/pids.max 2>/dev/null || cat /sys/fs/cgroup/pids/pids.max 2>/dev/null",
|
||||
)
|
||||
dur = time.monotonic() - t0
|
||||
val = out.strip()
|
||||
if val == "max":
|
||||
return CheckResult("pids_limit", True, "unlimited", dur)
|
||||
try:
|
||||
limit = int(val)
|
||||
if limit >= 1024:
|
||||
return CheckResult("pids_limit", True, f"pids_limit={limit}", dur)
|
||||
return CheckResult(
|
||||
"pids_limit", False, f"pids_limit={limit} (need >= 1024)", dur
|
||||
)
|
||||
except ValueError:
|
||||
return CheckResult("pids_limit", False, f"Cannot parse: {val[:60]}", dur)
|
||||
|
||||
|
||||
async def check_mem_limit(cfg: SSHConfig, user: str) -> CheckResult:
|
||||
t0 = time.monotonic()
|
||||
rc, out = await ssh_run(
|
||||
cfg,
|
||||
user,
|
||||
"cat /sys/fs/cgroup/memory.max 2>/dev/null || cat /sys/fs/cgroup/memory/memory.limit_in_bytes 2>/dev/null",
|
||||
)
|
||||
dur = time.monotonic() - t0
|
||||
val = out.strip()
|
||||
if val == "max":
|
||||
return CheckResult("mem_limit", True, "unlimited", dur)
|
||||
try:
|
||||
limit = int(val)
|
||||
limit_gb = limit / (1024**3)
|
||||
if (
|
||||
limit_gb >= 1.8
|
||||
): # 2GB for privileged RL containers (not 4GB to manage memory pressure)
|
||||
return CheckResult("mem_limit", True, f"mem={limit_gb:.1f}GB", dur)
|
||||
return CheckResult(
|
||||
"mem_limit", False, f"mem={limit_gb:.1f}GB (need >= 2GB)", dur
|
||||
)
|
||||
except ValueError:
|
||||
return CheckResult("mem_limit", False, f"Cannot parse: {val[:60]}", dur)
|
||||
|
||||
|
||||
async def check_challenge_run(cfg: SSHConfig, user: str) -> CheckResult:
|
||||
"""Run /challenge/run and verify no PermissionError."""
|
||||
t0 = time.monotonic()
|
||||
rc, out = await ssh_run(cfg, user, "/challenge/run < /dev/null", timeout=15)
|
||||
dur = time.monotonic() - t0
|
||||
if "PermissionError" in out or "Operation not permitted" in out:
|
||||
snippet = [l for l in out.splitlines() if "Permission" in l or "Operation" in l]
|
||||
return CheckResult(
|
||||
"challenge_run",
|
||||
False,
|
||||
snippet[0][:120] if snippet else "PermissionError",
|
||||
dur,
|
||||
)
|
||||
return CheckResult("challenge_run", True, f"No permission errors (rc={rc})", dur)
|
||||
|
||||
|
||||
# ── Test cases ───────────────────────────────────────────────────────────────
|
||||
|
||||
TEST_CASES = [
|
||||
TestCase("unprivileged_basic", "hello/hello", [check_ssh_echo]),
|
||||
TestCase(
|
||||
"privileged_caps",
|
||||
"intercepting-communication/udp-1",
|
||||
[check_ssh_echo, check_capeff],
|
||||
),
|
||||
TestCase(
|
||||
"privileged_challenge_run",
|
||||
"intercepting-communication/udp-1",
|
||||
[check_challenge_run],
|
||||
),
|
||||
TestCase(
|
||||
"web_challenge_hosts",
|
||||
"web-security/path-traversal-1",
|
||||
[check_ssh_echo, check_hosts_resolution],
|
||||
),
|
||||
TestCase(
|
||||
"resource_limits",
|
||||
"intercepting-communication/udp-1",
|
||||
[check_pids_limit, check_mem_limit],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ── Runner ───────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def run_tests(args) -> dict:
|
||||
cfg = SSHConfig(host=args.ssh_host, port=args.ssh_port, key=args.ssh_key)
|
||||
client = DojoRLClient(args.base_url)
|
||||
|
||||
status = await client.status()
|
||||
print(
|
||||
f"Server: {args.base_url} (RL={'enabled' if status.enabled else 'DISABLED'}, "
|
||||
f"{status.max_instances} max, {status.running} running)"
|
||||
)
|
||||
if status.running > 0:
|
||||
n = await client.destroy_all()
|
||||
print(f"Cleaned up {n} instance(s)")
|
||||
print()
|
||||
|
||||
results: list[TestResult] = []
|
||||
test_num = 0
|
||||
total = len(TEST_CASES) + (0 if args.skip_concurrent else 1)
|
||||
start_time = time.monotonic()
|
||||
|
||||
for tc in TEST_CASES:
|
||||
test_num += 1
|
||||
t0 = time.monotonic()
|
||||
tr = TestResult(name=tc.name, challenge=tc.challenge)
|
||||
print(f"[{test_num}/{total}] {tc.name} ({tc.challenge})")
|
||||
|
||||
try:
|
||||
inst = await client.create_instance(tc.challenge)
|
||||
except Exception as e:
|
||||
err = str(e)
|
||||
if "404" in err or "not found" in err.lower() or "Invalid" in err:
|
||||
tr.skipped = True
|
||||
tr.error = f"Challenge not available: {err[:80]}"
|
||||
print(f" SKIP {tr.error}")
|
||||
else:
|
||||
tr.error = f"create_instance failed: {err[:100]}"
|
||||
print(f" ERR {tr.error}")
|
||||
tr.duration = time.monotonic() - t0
|
||||
results.append(tr)
|
||||
print(f" --- {'SKIP' if tr.skipped else 'FAIL'} ({tr.duration:.1f}s)\n")
|
||||
continue
|
||||
|
||||
try:
|
||||
ready = await wait_ssh_ready(cfg, inst.ssh_user)
|
||||
if not ready:
|
||||
tr.error = "SSH not ready after 10 retries"
|
||||
tr.checks.append(
|
||||
CheckResult("ssh_ready", False, tr.error, time.monotonic() - t0)
|
||||
)
|
||||
print(f" FAIL ssh_ready: {tr.error}")
|
||||
else:
|
||||
for check_fn in tc.checks:
|
||||
cr = await check_fn(cfg, inst.ssh_user)
|
||||
tr.checks.append(cr)
|
||||
tag = "PASS" if cr.passed else "FAIL"
|
||||
extra = f" ({cr.message})" if args.verbose or not cr.passed else ""
|
||||
print(f" {tag} {cr.name:30s} {cr.duration:.1f}s{extra}")
|
||||
if not cr.passed:
|
||||
break
|
||||
finally:
|
||||
try:
|
||||
await client.destroy_instance(inst.slot)
|
||||
except Exception as e:
|
||||
print(f" WARN destroy failed: {e}")
|
||||
|
||||
tr.passed = all(c.passed for c in tr.checks) and not tr.error
|
||||
tr.duration = time.monotonic() - t0
|
||||
results.append(tr)
|
||||
print(f" --- {'PASS' if tr.passed else 'FAIL'} ({tr.duration:.1f}s)\n")
|
||||
|
||||
if not args.skip_concurrent:
|
||||
test_num += 1
|
||||
t0 = time.monotonic()
|
||||
tr = TestResult(name="concurrent_lifecycle", challenge="8x hello/hello")
|
||||
n_concurrent = min(8, status.max_instances)
|
||||
print(
|
||||
f"[{test_num}/{total}] concurrent_lifecycle ({n_concurrent}x hello/hello)"
|
||||
)
|
||||
|
||||
try:
|
||||
ct0 = time.monotonic()
|
||||
tasks = [client.create_instance("hello/hello") for _ in range(n_concurrent)]
|
||||
instances = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
create_dur = time.monotonic() - ct0
|
||||
|
||||
created = [i for i in instances if not isinstance(i, Exception)]
|
||||
errors = [i for i in instances if isinstance(i, Exception)]
|
||||
if errors:
|
||||
tr.checks.append(
|
||||
CheckResult(
|
||||
"create_all",
|
||||
False,
|
||||
f"{len(errors)}/{n_concurrent} failed: {errors[0]}",
|
||||
create_dur,
|
||||
)
|
||||
)
|
||||
else:
|
||||
tr.checks.append(
|
||||
CheckResult(
|
||||
"create_all", True, f"{n_concurrent} created", create_dur
|
||||
)
|
||||
)
|
||||
|
||||
if created:
|
||||
await asyncio.sleep(3)
|
||||
et0 = time.monotonic()
|
||||
echo_tasks = [
|
||||
ssh_run(cfg, i.ssh_user, "echo ok", timeout=15) for i in created
|
||||
]
|
||||
echo_results = await asyncio.gather(*echo_tasks, return_exceptions=True)
|
||||
echo_ok = sum(
|
||||
1
|
||||
for r in echo_results
|
||||
if not isinstance(r, Exception) and r[0] == 0
|
||||
)
|
||||
tr.checks.append(
|
||||
CheckResult(
|
||||
"ssh_echo_all",
|
||||
echo_ok == len(created),
|
||||
f"{echo_ok}/{len(created)} connected",
|
||||
time.monotonic() - et0,
|
||||
)
|
||||
)
|
||||
|
||||
dt0 = time.monotonic()
|
||||
destroyed = await client.destroy_all()
|
||||
tr.checks.append(
|
||||
CheckResult(
|
||||
"destroy_all",
|
||||
True,
|
||||
f"destroyed {destroyed}",
|
||||
time.monotonic() - dt0,
|
||||
)
|
||||
)
|
||||
|
||||
st = await client.status()
|
||||
live = sum(1 for i in st.instances if i.status == "running")
|
||||
tr.checks.append(
|
||||
CheckResult(
|
||||
"slot_cleanup",
|
||||
live == 0,
|
||||
f"running={live} (total listed={st.running})",
|
||||
0.0,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
tr.error = str(e)[:200]
|
||||
tr.checks.append(CheckResult("concurrent", False, str(e)[:100], 0.0))
|
||||
|
||||
tr.passed = all(c.passed for c in tr.checks) and not tr.error
|
||||
tr.duration = time.monotonic() - t0
|
||||
results.append(tr)
|
||||
for cr in tr.checks:
|
||||
tag = "PASS" if cr.passed else "FAIL"
|
||||
extra = f" ({cr.message})" if args.verbose or not cr.passed else ""
|
||||
print(f" {tag} {cr.name:30s} {cr.duration:.1f}s{extra}")
|
||||
print(f" --- {'PASS' if tr.passed else 'FAIL'} ({tr.duration:.1f}s)\n")
|
||||
|
||||
total_dur = time.monotonic() - start_time
|
||||
passed = sum(1 for r in results if r.passed)
|
||||
failed = sum(1 for r in results if not r.passed and not r.skipped)
|
||||
skipped = sum(1 for r in results if r.skipped)
|
||||
|
||||
print("=" * 50)
|
||||
parts = [f"{passed}/{len(results)} passed"]
|
||||
if failed:
|
||||
parts.append(f"{failed} failed")
|
||||
if skipped:
|
||||
parts.append(f"{skipped} skipped")
|
||||
print(f"RESULTS: {', '.join(parts)} in {total_dur:.0f}s")
|
||||
print("=" * 50)
|
||||
|
||||
return {
|
||||
"test": "capability_verification",
|
||||
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S%z"),
|
||||
"server": args.base_url,
|
||||
"summary": {
|
||||
"total": len(results),
|
||||
"passed": passed,
|
||||
"failed": failed,
|
||||
"skipped": skipped,
|
||||
"duration_seconds": round(total_dur, 1),
|
||||
},
|
||||
"tests": [
|
||||
{
|
||||
"name": r.name,
|
||||
"challenge": r.challenge,
|
||||
"passed": r.passed,
|
||||
"skipped": r.skipped,
|
||||
"error": r.error,
|
||||
"duration": round(r.duration, 1),
|
||||
"checks": [asdict(c) for c in r.checks],
|
||||
}
|
||||
for r in results
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Capability verification test for pwn-dojo RL infrastructure",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--base-url", default="http://100.120.55.25:8080")
|
||||
parser.add_argument("--ssh-host", default="100.120.55.25")
|
||||
parser.add_argument("--ssh-port", type=int, default=2222)
|
||||
parser.add_argument(
|
||||
"--ssh-key", default="environments/pwncollege_env/keys/rl_test_key"
|
||||
)
|
||||
parser.add_argument("--output", "-o", help="Write JSON report")
|
||||
parser.add_argument("--skip-concurrent", action="store_true")
|
||||
parser.add_argument("--verbose", "-v", action="store_true")
|
||||
parser.add_argument("--yes", "-y", action="store_true", help="Skip confirmation")
|
||||
args = parser.parse_args()
|
||||
|
||||
key = Path(args.ssh_key)
|
||||
if not key.exists():
|
||||
key = _repo_root / args.ssh_key
|
||||
if not key.exists():
|
||||
print(f"SSH key not found: {args.ssh_key}")
|
||||
sys.exit(1)
|
||||
args.ssh_key = str(key)
|
||||
|
||||
if not args.yes:
|
||||
print(f"Will test against {args.base_url}")
|
||||
if input("Continue? [y/N] ").lower() != "y":
|
||||
sys.exit(0)
|
||||
|
||||
report = asyncio.run(run_tests(args))
|
||||
|
||||
if args.output:
|
||||
with open(args.output, "w") as f:
|
||||
json.dump(report, f, indent=2)
|
||||
print(f"\nJSON report: {args.output}")
|
||||
|
||||
sys.exit(0 if report["summary"]["failed"] == 0 else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,102 @@
|
||||
"""submit_flag tool for pwn.college RL environments.
|
||||
|
||||
Registers a `submit_flag` tool in the hermes-agent tool registry under the
|
||||
"pwncollege" toolset. The handler checks flags against the dojo RL API using
|
||||
per-task context (SDK client + slot) stored in a module-level dict.
|
||||
|
||||
Usage in an environment:
|
||||
from environments.pwncollege_env.submit_flag_tool import (
|
||||
register_flag_context, clear_flag_context,
|
||||
)
|
||||
|
||||
# Before agent loop
|
||||
register_flag_context(task_id, sync_client, slot)
|
||||
|
||||
# After agent loop
|
||||
clear_flag_context(task_id)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Per-task context: task_id → {"client": DojoRLSyncClient, "slot": int}
|
||||
_task_flag_context: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
|
||||
def register_flag_context(task_id: str, sync_client: Any, slot: int) -> None:
|
||||
"""Register dojo client + slot for a rollout so submit_flag can verify flags."""
|
||||
_task_flag_context[task_id] = {"client": sync_client, "slot": slot}
|
||||
|
||||
|
||||
def clear_flag_context(task_id: str) -> None:
|
||||
"""Remove flag context after rollout completes."""
|
||||
_task_flag_context.pop(task_id, None)
|
||||
|
||||
|
||||
def _submit_flag_handler(args: dict, **kw) -> str:
|
||||
"""Handle submit_flag tool calls by checking the flag against the dojo API."""
|
||||
task_id = kw.get("task_id", "default")
|
||||
flag = args.get("flag", "")
|
||||
|
||||
if not flag:
|
||||
return json.dumps({"success": False, "message": "No flag provided."})
|
||||
|
||||
ctx = _task_flag_context.get(task_id)
|
||||
if not ctx:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": "No active challenge instance for this task.",
|
||||
})
|
||||
|
||||
try:
|
||||
correct = ctx["client"].check_flag(ctx["slot"], flag)
|
||||
except Exception as e:
|
||||
logger.error("Flag check failed for task %s: %s", task_id, e, exc_info=True)
|
||||
return json.dumps({"success": False, "message": f"Flag check error: {type(e).__name__}"})
|
||||
|
||||
if correct:
|
||||
return json.dumps({"success": True, "message": "Flag accepted! Challenge solved."})
|
||||
return json.dumps({"success": False, "message": "Incorrect flag."})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Register in hermes-agent tool registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SUBMIT_FLAG_SCHEMA = {
|
||||
"name": "submit_flag",
|
||||
"description": (
|
||||
"Submit a flag for verification. Use this when you find a flag "
|
||||
"(format: pwn.college{...}) to check if it is correct."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"flag": {
|
||||
"type": "string",
|
||||
"description": "The flag string (format: pwn.college{...}).",
|
||||
}
|
||||
},
|
||||
"required": ["flag"],
|
||||
},
|
||||
}
|
||||
|
||||
from tools.registry import registry
|
||||
from toolsets import create_custom_toolset
|
||||
|
||||
registry.register(
|
||||
name="submit_flag",
|
||||
toolset="pwncollege",
|
||||
schema=SUBMIT_FLAG_SCHEMA,
|
||||
handler=_submit_flag_handler,
|
||||
emoji="🚩",
|
||||
)
|
||||
|
||||
create_custom_toolset(
|
||||
name="pwncollege",
|
||||
description="PwnCollege CTF tools",
|
||||
tools=["submit_flag"],
|
||||
)
|
||||
@@ -10,7 +10,6 @@ The [TOOL_CALLS] token is the bot_token used by Mistral models.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
@@ -42,9 +41,6 @@ class MistralToolCallParser(ToolCallParser):
|
||||
# The [TOOL_CALLS] token -- may appear as different strings depending on tokenizer
|
||||
BOT_TOKEN = "[TOOL_CALLS]"
|
||||
|
||||
# Fallback regex for pre-v11 format when JSON parsing fails
|
||||
TOOL_CALL_REGEX = re.compile(r"\[?\s*(\{.*?\})\s*\]?", re.DOTALL)
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
if self.BOT_TOKEN not in text:
|
||||
return text, None
|
||||
@@ -71,6 +67,13 @@ class MistralToolCallParser(ToolCallParser):
|
||||
tool_name = raw[:brace_idx].strip()
|
||||
args_str = raw[brace_idx:]
|
||||
|
||||
# Validate and clean the JSON arguments
|
||||
try:
|
||||
parsed_args = json.loads(args_str)
|
||||
args_str = json.dumps(parsed_args, ensure_ascii=False)
|
||||
except json.JSONDecodeError:
|
||||
pass # Keep raw if parsing fails
|
||||
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=_generate_mistral_id(),
|
||||
@@ -100,13 +103,14 @@ class MistralToolCallParser(ToolCallParser):
|
||||
)
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# Fallback regex extraction
|
||||
match = self.TOOL_CALL_REGEX.findall(first_raw)
|
||||
if match:
|
||||
for raw_json in match:
|
||||
try:
|
||||
tc = json.loads(raw_json)
|
||||
args = tc.get("arguments", {})
|
||||
# Fallback: extract JSON objects using raw_decode
|
||||
decoder = json.JSONDecoder()
|
||||
idx = 0
|
||||
while idx < len(first_raw):
|
||||
try:
|
||||
obj, end_idx = decoder.raw_decode(first_raw, idx)
|
||||
if isinstance(obj, dict) and "name" in obj:
|
||||
args = obj.get("arguments", {})
|
||||
if isinstance(args, dict):
|
||||
args = json.dumps(args, ensure_ascii=False)
|
||||
tool_calls.append(
|
||||
@@ -114,12 +118,13 @@ class MistralToolCallParser(ToolCallParser):
|
||||
id=_generate_mistral_id(),
|
||||
type="function",
|
||||
function=Function(
|
||||
name=tc["name"], arguments=args
|
||||
name=obj["name"], arguments=args
|
||||
),
|
||||
)
|
||||
)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
continue
|
||||
idx = end_idx
|
||||
except json.JSONDecodeError:
|
||||
idx += 1
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
Generated
+181
@@ -0,0 +1,181 @@
|
||||
{
|
||||
"nodes": {
|
||||
"flake-parts": {
|
||||
"inputs": {
|
||||
"nixpkgs-lib": [
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1772408722,
|
||||
"narHash": "sha256-rHuJtdcOjK7rAHpHphUb1iCvgkU3GpfvicLMwwnfMT0=",
|
||||
"owner": "hercules-ci",
|
||||
"repo": "flake-parts",
|
||||
"rev": "f20dc5d9b8027381c474144ecabc9034d6a839a3",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "hercules-ci",
|
||||
"repo": "flake-parts",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1751274312,
|
||||
"narHash": "sha256-/bVBlRpECLVzjV19t5KMdMFWSwKLtb5RyXdjz3LJT+g=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "50ab793786d9de88ee30ec4e4c24fb4236fc2674",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-24.11",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"pyproject-build-systems": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"nixpkgs"
|
||||
],
|
||||
"pyproject-nix": "pyproject-nix",
|
||||
"uv2nix": "uv2nix"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1772555609,
|
||||
"narHash": "sha256-3BA3HnUvJSbHJAlJj6XSy0Jmu7RyP2gyB/0fL7XuEDo=",
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "build-system-pkgs",
|
||||
"rev": "c37f66a953535c394244888598947679af231863",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "build-system-pkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"pyproject-nix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"pyproject-build-systems",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1769936401,
|
||||
"narHash": "sha256-kwCOegKLZJM9v/e/7cqwg1p/YjjTAukKPqmxKnAZRgA=",
|
||||
"owner": "nix-community",
|
||||
"repo": "pyproject.nix",
|
||||
"rev": "b0d513eeeebed6d45b4f2e874f9afba2021f7812",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-community",
|
||||
"repo": "pyproject.nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"pyproject-nix_2": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1772865871,
|
||||
"narHash": "sha256-/ZTSg97aouL0SlPHaokA4r3iuH9QzHVuWPACD2CUCFY=",
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "pyproject.nix",
|
||||
"rev": "e537db02e72d553cea470976b9733581bcf5b3ed",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "pyproject.nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"pyproject-nix_3": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"uv2nix",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1771518446,
|
||||
"narHash": "sha256-nFJSfD89vWTu92KyuJWDoTQJuoDuddkJV3TlOl1cOic=",
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "pyproject.nix",
|
||||
"rev": "eb204c6b3335698dec6c7fc1da0ebc3c6df05937",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "pyproject.nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
"inputs": {
|
||||
"flake-parts": "flake-parts",
|
||||
"nixpkgs": "nixpkgs",
|
||||
"pyproject-build-systems": "pyproject-build-systems",
|
||||
"pyproject-nix": "pyproject-nix_2",
|
||||
"uv2nix": "uv2nix_2"
|
||||
}
|
||||
},
|
||||
"uv2nix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"pyproject-build-systems",
|
||||
"nixpkgs"
|
||||
],
|
||||
"pyproject-nix": [
|
||||
"pyproject-build-systems",
|
||||
"pyproject-nix"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1770770348,
|
||||
"narHash": "sha256-A2GzkmzdYvdgmMEu5yxW+xhossP+txrYb7RuzRaqhlg=",
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "uv2nix",
|
||||
"rev": "5d1b2cb4fe3158043fbafbbe2e46238abbc954b0",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "uv2nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"uv2nix_2": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"nixpkgs"
|
||||
],
|
||||
"pyproject-nix": "pyproject-nix_3"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1773039484,
|
||||
"narHash": "sha256-+boo33KYkJDw9KItpeEXXv8+65f7hHv/earxpcyzQ0I=",
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "uv2nix",
|
||||
"rev": "b68be7cfeacbed9a3fa38a2b5adc0cfb81d9bb1f",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "uv2nix",
|
||||
"type": "github"
|
||||
}
|
||||
}
|
||||
},
|
||||
"root": "root",
|
||||
"version": 7
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
{
|
||||
description = "Hermes Agent - AI agent framework by Nous Research";
|
||||
|
||||
inputs = {
|
||||
nixpkgs.url = "github:NixOS/nixpkgs/nixos-24.11";
|
||||
flake-parts = {
|
||||
url = "github:hercules-ci/flake-parts";
|
||||
inputs.nixpkgs-lib.follows = "nixpkgs";
|
||||
};
|
||||
pyproject-nix = {
|
||||
url = "github:pyproject-nix/pyproject.nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
uv2nix = {
|
||||
url = "github:pyproject-nix/uv2nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
pyproject-build-systems = {
|
||||
url = "github:pyproject-nix/build-system-pkgs";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
};
|
||||
|
||||
outputs = inputs:
|
||||
inputs.flake-parts.lib.mkFlake { inherit inputs; } {
|
||||
systems = [ "x86_64-linux" "aarch64-linux" "aarch64-darwin" ];
|
||||
|
||||
imports = [
|
||||
./nix/packages.nix
|
||||
./nix/nixosModules.nix
|
||||
./nix/checks.nix
|
||||
./nix/devShell.nix
|
||||
];
|
||||
};
|
||||
}
|
||||
@@ -9,7 +9,6 @@ action="list" and for resolving human-friendly channel names to numeric IDs.
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from hermes_cli.config import get_hermes_home
|
||||
@@ -90,7 +89,7 @@ def _build_discord(adapter) -> List[Dict[str, str]]:
|
||||
return channels
|
||||
|
||||
try:
|
||||
import discord as _discord
|
||||
import discord as _discord # noqa: F401 — SDK presence check
|
||||
except ImportError:
|
||||
return channels
|
||||
|
||||
@@ -119,7 +118,6 @@ def _build_slack(adapter) -> List[Dict[str, str]]:
|
||||
return _build_from_sessions("slack")
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
from tools.send_message_tool import _send_slack # noqa: F401
|
||||
# Use the Slack Web API directly if available
|
||||
except Exception:
|
||||
|
||||
+60
-4
@@ -101,12 +101,16 @@ class SessionResetPolicy:
|
||||
mode: str = "both" # "daily", "idle", "both", or "none"
|
||||
at_hour: int = 4 # Hour for daily reset (0-23, local time)
|
||||
idle_minutes: int = 1440 # Minutes of inactivity before reset (24 hours)
|
||||
notify: bool = True # Send a notification to the user when auto-reset occurs
|
||||
notify_exclude_platforms: tuple = ("api_server", "webhook") # Platforms that don't get reset notifications
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"mode": self.mode,
|
||||
"at_hour": self.at_hour,
|
||||
"idle_minutes": self.idle_minutes,
|
||||
"notify": self.notify,
|
||||
"notify_exclude_platforms": list(self.notify_exclude_platforms),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -115,10 +119,14 @@ class SessionResetPolicy:
|
||||
mode = data.get("mode")
|
||||
at_hour = data.get("at_hour")
|
||||
idle_minutes = data.get("idle_minutes")
|
||||
notify = data.get("notify")
|
||||
exclude = data.get("notify_exclude_platforms")
|
||||
return cls(
|
||||
mode=mode if mode is not None else "both",
|
||||
at_hour=at_hour if at_hour is not None else 4,
|
||||
idle_minutes=idle_minutes if idle_minutes is not None else 1440,
|
||||
notify=notify if notify is not None else True,
|
||||
notify_exclude_platforms=tuple(exclude) if exclude is not None else ("api_server", "webhook"),
|
||||
)
|
||||
|
||||
|
||||
@@ -130,6 +138,12 @@ class PlatformConfig:
|
||||
api_key: Optional[str] = None # API key if different from token
|
||||
home_channel: Optional[HomeChannel] = None
|
||||
|
||||
# Reply threading mode (Telegram/Slack)
|
||||
# - "off": Never thread replies to original message
|
||||
# - "first": Only first chunk threads to user's message (default)
|
||||
# - "all": All chunks in multi-part replies thread to user's message
|
||||
reply_to_mode: str = "first"
|
||||
|
||||
# Platform-specific settings
|
||||
extra: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@@ -137,6 +151,7 @@ class PlatformConfig:
|
||||
result = {
|
||||
"enabled": self.enabled,
|
||||
"extra": self.extra,
|
||||
"reply_to_mode": self.reply_to_mode,
|
||||
}
|
||||
if self.token:
|
||||
result["token"] = self.token
|
||||
@@ -157,6 +172,7 @@ class PlatformConfig:
|
||||
token=data.get("token"),
|
||||
api_key=data.get("api_key"),
|
||||
home_channel=home_channel,
|
||||
reply_to_mode=data.get("reply_to_mode", "first"),
|
||||
extra=data.get("extra", {}),
|
||||
)
|
||||
|
||||
@@ -455,11 +471,27 @@ def load_gateway_config() -> GatewayConfig:
|
||||
"pair",
|
||||
)
|
||||
|
||||
# Bridge per-platform settings from config.yaml into gw_data
|
||||
# Merge platforms section from config.yaml into gw_data so that
|
||||
# nested keys like platforms.webhook.extra.routes are loaded.
|
||||
yaml_platforms = yaml_cfg.get("platforms")
|
||||
platforms_data = gw_data.setdefault("platforms", {})
|
||||
if not isinstance(platforms_data, dict):
|
||||
platforms_data = {}
|
||||
gw_data["platforms"] = platforms_data
|
||||
if isinstance(yaml_platforms, dict):
|
||||
for plat_name, plat_block in yaml_platforms.items():
|
||||
if not isinstance(plat_block, dict):
|
||||
continue
|
||||
existing = platforms_data.get(plat_name, {})
|
||||
if not isinstance(existing, dict):
|
||||
existing = {}
|
||||
# Deep-merge extra dicts so gateway.json defaults survive
|
||||
merged_extra = {**existing.get("extra", {}), **plat_block.get("extra", {})}
|
||||
merged = {**existing, **plat_block}
|
||||
if merged_extra:
|
||||
merged["extra"] = merged_extra
|
||||
platforms_data[plat_name] = merged
|
||||
gw_data["platforms"] = platforms_data
|
||||
for plat in Platform:
|
||||
if plat == Platform.LOCAL:
|
||||
continue
|
||||
@@ -499,8 +531,13 @@ def load_gateway_config() -> GatewayConfig:
|
||||
os.environ["DISCORD_FREE_RESPONSE_CHANNELS"] = str(frc)
|
||||
if "auto_thread" in discord_cfg and not os.getenv("DISCORD_AUTO_THREAD"):
|
||||
os.environ["DISCORD_AUTO_THREAD"] = str(discord_cfg["auto_thread"]).lower()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to process config.yaml — falling back to .env / gateway.json values. "
|
||||
"Check %s for syntax errors. Error: %s",
|
||||
_home / "config.yaml",
|
||||
e,
|
||||
)
|
||||
|
||||
config = GatewayConfig.from_dict(gw_data)
|
||||
|
||||
@@ -557,6 +594,21 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
config.platforms[Platform.TELEGRAM].enabled = True
|
||||
config.platforms[Platform.TELEGRAM].token = telegram_token
|
||||
|
||||
# Reply threading mode for Telegram (off/first/all)
|
||||
telegram_reply_mode = os.getenv("TELEGRAM_REPLY_TO_MODE", "").lower()
|
||||
if telegram_reply_mode in ("off", "first", "all"):
|
||||
if Platform.TELEGRAM not in config.platforms:
|
||||
config.platforms[Platform.TELEGRAM] = PlatformConfig()
|
||||
config.platforms[Platform.TELEGRAM].reply_to_mode = telegram_reply_mode
|
||||
|
||||
telegram_fallback_ips = os.getenv("TELEGRAM_FALLBACK_IPS", "")
|
||||
if telegram_fallback_ips:
|
||||
if Platform.TELEGRAM not in config.platforms:
|
||||
config.platforms[Platform.TELEGRAM] = PlatformConfig()
|
||||
config.platforms[Platform.TELEGRAM].extra["fallback_ips"] = [
|
||||
ip.strip() for ip in telegram_fallback_ips.split(",") if ip.strip()
|
||||
]
|
||||
|
||||
telegram_home = os.getenv("TELEGRAM_HOME_CHANNEL")
|
||||
if telegram_home and Platform.TELEGRAM in config.platforms:
|
||||
config.platforms[Platform.TELEGRAM].home_channel = HomeChannel(
|
||||
@@ -722,6 +774,7 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
# API Server
|
||||
api_server_enabled = os.getenv("API_SERVER_ENABLED", "").lower() in ("true", "1", "yes")
|
||||
api_server_key = os.getenv("API_SERVER_KEY", "")
|
||||
api_server_cors_origins = os.getenv("API_SERVER_CORS_ORIGINS", "")
|
||||
api_server_port = os.getenv("API_SERVER_PORT")
|
||||
api_server_host = os.getenv("API_SERVER_HOST")
|
||||
if api_server_enabled or api_server_key:
|
||||
@@ -730,6 +783,10 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
config.platforms[Platform.API_SERVER].enabled = True
|
||||
if api_server_key:
|
||||
config.platforms[Platform.API_SERVER].extra["key"] = api_server_key
|
||||
if api_server_cors_origins:
|
||||
origins = [origin.strip() for origin in api_server_cors_origins.split(",") if origin.strip()]
|
||||
if origins:
|
||||
config.platforms[Platform.API_SERVER].extra["cors_origins"] = origins
|
||||
if api_server_port:
|
||||
try:
|
||||
config.platforms[Platform.API_SERVER].extra["port"] = int(api_server_port)
|
||||
@@ -770,4 +827,3 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ from pathlib import Path
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Any, Union
|
||||
from enum import Enum
|
||||
|
||||
from hermes_cli.config import get_hermes_home
|
||||
|
||||
|
||||
@@ -21,8 +21,6 @@ Errors in hooks are caught and logged but never block the main pipeline.
|
||||
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
@@ -12,7 +12,6 @@ the full SessionStore machinery.
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from hermes_cli.config import get_hermes_home
|
||||
|
||||
+542
-80
@@ -18,10 +18,10 @@ Requires:
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
@@ -45,6 +45,7 @@ logger = logging.getLogger(__name__)
|
||||
DEFAULT_HOST = "127.0.0.1"
|
||||
DEFAULT_PORT = 8642
|
||||
MAX_STORED_RESPONSES = 100
|
||||
MAX_REQUEST_BYTES = 1_000_000 # 1 MB default limit for POST bodies
|
||||
|
||||
|
||||
def check_api_server_requirements() -> bool:
|
||||
@@ -54,41 +55,109 @@ def check_api_server_requirements() -> bool:
|
||||
|
||||
class ResponseStore:
|
||||
"""
|
||||
In-memory LRU store for Responses API state.
|
||||
SQLite-backed LRU store for Responses API state.
|
||||
|
||||
Each stored response includes the full internal conversation history
|
||||
(with tool calls and results) so it can be reconstructed on subsequent
|
||||
requests via previous_response_id.
|
||||
|
||||
Persists across gateway restarts. Falls back to in-memory SQLite
|
||||
if the on-disk path is unavailable.
|
||||
"""
|
||||
|
||||
def __init__(self, max_size: int = MAX_STORED_RESPONSES):
|
||||
self._store: collections.OrderedDict[str, Dict[str, Any]] = collections.OrderedDict()
|
||||
def __init__(self, max_size: int = MAX_STORED_RESPONSES, db_path: str = None):
|
||||
self._max_size = max_size
|
||||
if db_path is None:
|
||||
try:
|
||||
from hermes_cli.config import get_hermes_home
|
||||
db_path = str(get_hermes_home() / "response_store.db")
|
||||
except Exception:
|
||||
db_path = ":memory:"
|
||||
try:
|
||||
self._conn = sqlite3.connect(db_path, check_same_thread=False)
|
||||
except Exception:
|
||||
self._conn = sqlite3.connect(":memory:", check_same_thread=False)
|
||||
self._conn.execute("PRAGMA journal_mode=WAL")
|
||||
self._conn.execute(
|
||||
"""CREATE TABLE IF NOT EXISTS responses (
|
||||
response_id TEXT PRIMARY KEY,
|
||||
data TEXT NOT NULL,
|
||||
accessed_at REAL NOT NULL
|
||||
)"""
|
||||
)
|
||||
self._conn.execute(
|
||||
"""CREATE TABLE IF NOT EXISTS conversations (
|
||||
name TEXT PRIMARY KEY,
|
||||
response_id TEXT NOT NULL
|
||||
)"""
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def get(self, response_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Retrieve a stored response by ID (moves to end for LRU)."""
|
||||
if response_id in self._store:
|
||||
self._store.move_to_end(response_id)
|
||||
return self._store[response_id]
|
||||
return None
|
||||
"""Retrieve a stored response by ID (updates access time for LRU)."""
|
||||
row = self._conn.execute(
|
||||
"SELECT data FROM responses WHERE response_id = ?", (response_id,)
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
import time
|
||||
self._conn.execute(
|
||||
"UPDATE responses SET accessed_at = ? WHERE response_id = ?",
|
||||
(time.time(), response_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
return json.loads(row[0])
|
||||
|
||||
def put(self, response_id: str, data: Dict[str, Any]) -> None:
|
||||
"""Store a response, evicting the oldest if at capacity."""
|
||||
if response_id in self._store:
|
||||
self._store.move_to_end(response_id)
|
||||
self._store[response_id] = data
|
||||
while len(self._store) > self._max_size:
|
||||
self._store.popitem(last=False)
|
||||
import time
|
||||
self._conn.execute(
|
||||
"INSERT OR REPLACE INTO responses (response_id, data, accessed_at) VALUES (?, ?, ?)",
|
||||
(response_id, json.dumps(data, default=str), time.time()),
|
||||
)
|
||||
# Evict oldest entries beyond max_size
|
||||
count = self._conn.execute("SELECT COUNT(*) FROM responses").fetchone()[0]
|
||||
if count > self._max_size:
|
||||
self._conn.execute(
|
||||
"DELETE FROM responses WHERE response_id IN "
|
||||
"(SELECT response_id FROM responses ORDER BY accessed_at ASC LIMIT ?)",
|
||||
(count - self._max_size,),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def delete(self, response_id: str) -> bool:
|
||||
"""Remove a response from the store. Returns True if found and deleted."""
|
||||
if response_id in self._store:
|
||||
del self._store[response_id]
|
||||
return True
|
||||
return False
|
||||
cursor = self._conn.execute(
|
||||
"DELETE FROM responses WHERE response_id = ?", (response_id,)
|
||||
)
|
||||
self._conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def get_conversation(self, name: str) -> Optional[str]:
|
||||
"""Get the latest response_id for a conversation name."""
|
||||
row = self._conn.execute(
|
||||
"SELECT response_id FROM conversations WHERE name = ?", (name,)
|
||||
).fetchone()
|
||||
return row[0] if row else None
|
||||
|
||||
def set_conversation(self, name: str, response_id: str) -> None:
|
||||
"""Map a conversation name to its latest response_id."""
|
||||
self._conn.execute(
|
||||
"INSERT OR REPLACE INTO conversations (name, response_id) VALUES (?, ?)",
|
||||
(name, response_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the database connection."""
|
||||
try:
|
||||
self._conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._store)
|
||||
row = self._conn.execute("SELECT COUNT(*) FROM responses").fetchone()
|
||||
return row[0] if row else 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -96,7 +165,6 @@ class ResponseStore:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CORS_HEADERS = {
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Methods": "GET, POST, DELETE, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Authorization, Content-Type",
|
||||
}
|
||||
@@ -105,16 +173,95 @@ _CORS_HEADERS = {
|
||||
if AIOHTTP_AVAILABLE:
|
||||
@web.middleware
|
||||
async def cors_middleware(request, handler):
|
||||
"""Add CORS headers to every response; handle OPTIONS preflight."""
|
||||
"""Add CORS headers for explicitly allowed origins; handle OPTIONS preflight."""
|
||||
adapter = request.app.get("api_server_adapter")
|
||||
origin = request.headers.get("Origin", "")
|
||||
cors_headers = None
|
||||
if adapter is not None:
|
||||
if not adapter._origin_allowed(origin):
|
||||
return web.Response(status=403)
|
||||
cors_headers = adapter._cors_headers_for_origin(origin)
|
||||
|
||||
if request.method == "OPTIONS":
|
||||
return web.Response(status=200, headers=_CORS_HEADERS)
|
||||
if cors_headers is None:
|
||||
return web.Response(status=403)
|
||||
return web.Response(status=200, headers=cors_headers)
|
||||
|
||||
response = await handler(request)
|
||||
response.headers.update(_CORS_HEADERS)
|
||||
if cors_headers is not None:
|
||||
response.headers.update(cors_headers)
|
||||
return response
|
||||
else:
|
||||
cors_middleware = None # type: ignore[assignment]
|
||||
|
||||
|
||||
def _openai_error(message: str, err_type: str = "invalid_request_error", param: str = None, code: str = None) -> Dict[str, Any]:
|
||||
"""OpenAI-style error envelope."""
|
||||
return {
|
||||
"error": {
|
||||
"message": message,
|
||||
"type": err_type,
|
||||
"param": param,
|
||||
"code": code,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if AIOHTTP_AVAILABLE:
|
||||
@web.middleware
|
||||
async def body_limit_middleware(request, handler):
|
||||
"""Reject overly large request bodies early based on Content-Length."""
|
||||
if request.method in ("POST", "PUT", "PATCH"):
|
||||
cl = request.headers.get("Content-Length")
|
||||
if cl is not None:
|
||||
try:
|
||||
if int(cl) > MAX_REQUEST_BYTES:
|
||||
return web.json_response(_openai_error("Request body too large.", code="body_too_large"), status=413)
|
||||
except ValueError:
|
||||
return web.json_response(_openai_error("Invalid Content-Length header.", code="invalid_content_length"), status=400)
|
||||
return await handler(request)
|
||||
else:
|
||||
body_limit_middleware = None # type: ignore[assignment]
|
||||
|
||||
|
||||
class _IdempotencyCache:
|
||||
"""In-memory idempotency cache with TTL and basic LRU semantics."""
|
||||
def __init__(self, max_items: int = 1000, ttl_seconds: int = 300):
|
||||
from collections import OrderedDict
|
||||
self._store = OrderedDict()
|
||||
self._ttl = ttl_seconds
|
||||
self._max = max_items
|
||||
|
||||
def _purge(self):
|
||||
import time as _t
|
||||
now = _t.time()
|
||||
expired = [k for k, v in self._store.items() if now - v["ts"] > self._ttl]
|
||||
for k in expired:
|
||||
self._store.pop(k, None)
|
||||
while len(self._store) > self._max:
|
||||
self._store.popitem(last=False)
|
||||
|
||||
async def get_or_set(self, key: str, fingerprint: str, compute_coro):
|
||||
self._purge()
|
||||
item = self._store.get(key)
|
||||
if item and item["fp"] == fingerprint:
|
||||
return item["resp"]
|
||||
resp = await compute_coro()
|
||||
import time as _t
|
||||
self._store[key] = {"resp": resp, "fp": fingerprint, "ts": _t.time()}
|
||||
self._purge()
|
||||
return resp
|
||||
|
||||
|
||||
_idem_cache = _IdempotencyCache()
|
||||
|
||||
|
||||
def _make_request_fingerprint(body: Dict[str, Any], keys: List[str]) -> str:
|
||||
from hashlib import sha256
|
||||
subset = {k: body.get(k) for k in keys}
|
||||
return sha256(repr(subset).encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
class APIServerAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
OpenAI-compatible HTTP API server adapter.
|
||||
@@ -129,12 +276,56 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
self._host: str = extra.get("host", os.getenv("API_SERVER_HOST", DEFAULT_HOST))
|
||||
self._port: int = int(extra.get("port", os.getenv("API_SERVER_PORT", str(DEFAULT_PORT))))
|
||||
self._api_key: str = extra.get("key", os.getenv("API_SERVER_KEY", ""))
|
||||
self._cors_origins: tuple[str, ...] = self._parse_cors_origins(
|
||||
extra.get("cors_origins", os.getenv("API_SERVER_CORS_ORIGINS", "")),
|
||||
)
|
||||
self._app: Optional["web.Application"] = None
|
||||
self._runner: Optional["web.AppRunner"] = None
|
||||
self._site: Optional["web.TCPSite"] = None
|
||||
self._response_store = ResponseStore()
|
||||
# Conversation name → latest response_id mapping
|
||||
self._conversations: Dict[str, str] = {}
|
||||
|
||||
@staticmethod
|
||||
def _parse_cors_origins(value: Any) -> tuple[str, ...]:
|
||||
"""Normalize configured CORS origins into a stable tuple."""
|
||||
if not value:
|
||||
return ()
|
||||
|
||||
if isinstance(value, str):
|
||||
items = value.split(",")
|
||||
elif isinstance(value, (list, tuple, set)):
|
||||
items = value
|
||||
else:
|
||||
items = [str(value)]
|
||||
|
||||
return tuple(str(item).strip() for item in items if str(item).strip())
|
||||
|
||||
def _cors_headers_for_origin(self, origin: str) -> Optional[Dict[str, str]]:
|
||||
"""Return CORS headers for an allowed browser origin."""
|
||||
if not origin or not self._cors_origins:
|
||||
return None
|
||||
|
||||
if "*" in self._cors_origins:
|
||||
headers = dict(_CORS_HEADERS)
|
||||
headers["Access-Control-Allow-Origin"] = "*"
|
||||
return headers
|
||||
|
||||
if origin not in self._cors_origins:
|
||||
return None
|
||||
|
||||
headers = dict(_CORS_HEADERS)
|
||||
headers["Access-Control-Allow-Origin"] = origin
|
||||
headers["Vary"] = "Origin"
|
||||
return headers
|
||||
|
||||
def _origin_allowed(self, origin: str) -> bool:
|
||||
"""Allow non-browser clients and explicitly configured browser origins."""
|
||||
if not origin:
|
||||
return True
|
||||
|
||||
if not self._cors_origins:
|
||||
return False
|
||||
|
||||
return "*" in self._cors_origins or origin in self._cors_origins
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Auth helper
|
||||
@@ -175,14 +366,20 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
Create an AIAgent instance using the gateway's runtime config.
|
||||
|
||||
Uses _resolve_runtime_agent_kwargs() to pick up model, api_key,
|
||||
base_url, etc. from config.yaml / env vars.
|
||||
base_url, etc. from config.yaml / env vars. Toolsets are resolved
|
||||
from config.yaml platform_toolsets.api_server (same as all other
|
||||
gateway platforms), falling back to the hermes-api-server default.
|
||||
"""
|
||||
from run_agent import AIAgent
|
||||
from gateway.run import _resolve_runtime_agent_kwargs, _resolve_gateway_model
|
||||
from gateway.run import _resolve_runtime_agent_kwargs, _resolve_gateway_model, _load_gateway_config
|
||||
from hermes_cli.tools_config import _get_platform_tools
|
||||
|
||||
runtime_kwargs = _resolve_runtime_agent_kwargs()
|
||||
model = _resolve_gateway_model()
|
||||
|
||||
user_config = _load_gateway_config()
|
||||
enabled_toolsets = sorted(_get_platform_tools(user_config, "api_server"))
|
||||
|
||||
max_iterations = int(os.getenv("HERMES_MAX_ITERATIONS", "90"))
|
||||
|
||||
agent = AIAgent(
|
||||
@@ -192,6 +389,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
quiet_mode=True,
|
||||
verbose_logging=False,
|
||||
ephemeral_system_prompt=ephemeral_system_prompt or None,
|
||||
enabled_toolsets=enabled_toolsets,
|
||||
session_id=session_id,
|
||||
platform="api_server",
|
||||
stream_delta_callback=stream_delta_callback,
|
||||
@@ -237,10 +435,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
try:
|
||||
body = await request.json()
|
||||
except (json.JSONDecodeError, Exception):
|
||||
return web.json_response(
|
||||
{"error": {"message": "Invalid JSON in request body", "type": "invalid_request_error"}},
|
||||
status=400,
|
||||
)
|
||||
return web.json_response(_openai_error("Invalid JSON in request body"), status=400)
|
||||
|
||||
messages = body.get("messages")
|
||||
if not messages or not isinstance(messages, list):
|
||||
@@ -290,7 +485,15 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
_stream_q: _q.Queue = _q.Queue()
|
||||
|
||||
def _on_delta(delta):
|
||||
_stream_q.put(delta)
|
||||
# Filter out None — the agent fires stream_delta_callback(None)
|
||||
# to signal the CLI display to close its response box before
|
||||
# tool execution, but the SSE writer uses None as end-of-stream
|
||||
# sentinel. Forwarding it would prematurely close the HTTP
|
||||
# response, causing Open WebUI (and similar frontends) to miss
|
||||
# the final answer after tool calls. The SSE loop detects
|
||||
# completion via agent_task.done() instead.
|
||||
if delta is not None:
|
||||
_stream_q.put(delta)
|
||||
|
||||
# Start agent in background
|
||||
agent_task = asyncio.ensure_future(self._run_agent(
|
||||
@@ -305,20 +508,35 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
request, completion_id, model_name, created, _stream_q, agent_task
|
||||
)
|
||||
|
||||
# Non-streaming: run the agent and return full response
|
||||
try:
|
||||
result, usage = await self._run_agent(
|
||||
# Non-streaming: run the agent (with optional Idempotency-Key)
|
||||
async def _compute_completion():
|
||||
return await self._run_agent(
|
||||
user_message=user_message,
|
||||
conversation_history=history,
|
||||
ephemeral_system_prompt=system_prompt,
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error running agent for chat completions: %s", e, exc_info=True)
|
||||
return web.json_response(
|
||||
{"error": {"message": f"Internal server error: {e}", "type": "server_error"}},
|
||||
status=500,
|
||||
)
|
||||
|
||||
idempotency_key = request.headers.get("Idempotency-Key")
|
||||
if idempotency_key:
|
||||
fp = _make_request_fingerprint(body, keys=["model", "messages", "tools", "tool_choice", "stream"])
|
||||
try:
|
||||
result, usage = await _idem_cache.get_or_set(idempotency_key, fp, _compute_completion)
|
||||
except Exception as e:
|
||||
logger.error("Error running agent for chat completions: %s", e, exc_info=True)
|
||||
return web.json_response(
|
||||
_openai_error(f"Internal server error: {e}", err_type="server_error"),
|
||||
status=500,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
result, usage = await _compute_completion()
|
||||
except Exception as e:
|
||||
logger.error("Error running agent for chat completions: %s", e, exc_info=True)
|
||||
return web.json_response(
|
||||
_openai_error(f"Internal server error: {e}", err_type="server_error"),
|
||||
status=500,
|
||||
)
|
||||
|
||||
final_response = result.get("final_response", "")
|
||||
if not final_response:
|
||||
@@ -444,10 +662,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
|
||||
raw_input = body.get("input")
|
||||
if raw_input is None:
|
||||
return web.json_response(
|
||||
{"error": {"message": "Missing 'input' field", "type": "invalid_request_error"}},
|
||||
status=400,
|
||||
)
|
||||
return web.json_response(_openai_error("Missing 'input' field"), status=400)
|
||||
|
||||
instructions = body.get("instructions")
|
||||
previous_response_id = body.get("previous_response_id")
|
||||
@@ -456,14 +671,11 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
|
||||
# conversation and previous_response_id are mutually exclusive
|
||||
if conversation and previous_response_id:
|
||||
return web.json_response(
|
||||
{"error": {"message": "Cannot use both 'conversation' and 'previous_response_id'", "type": "invalid_request_error"}},
|
||||
status=400,
|
||||
)
|
||||
return web.json_response(_openai_error("Cannot use both 'conversation' and 'previous_response_id'"), status=400)
|
||||
|
||||
# Resolve conversation name to latest response_id
|
||||
if conversation:
|
||||
previous_response_id = self._conversations.get(conversation)
|
||||
previous_response_id = self._response_store.get_conversation(conversation)
|
||||
# No error if conversation doesn't exist yet — it's a new conversation
|
||||
|
||||
# Normalize input to message list
|
||||
@@ -490,20 +702,14 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
content = "\n".join(text_parts)
|
||||
input_messages.append({"role": role, "content": content})
|
||||
else:
|
||||
return web.json_response(
|
||||
{"error": {"message": "'input' must be a string or array", "type": "invalid_request_error"}},
|
||||
status=400,
|
||||
)
|
||||
return web.json_response(_openai_error("'input' must be a string or array"), status=400)
|
||||
|
||||
# Reconstruct conversation history from previous_response_id
|
||||
conversation_history: List[Dict[str, str]] = []
|
||||
if previous_response_id:
|
||||
stored = self._response_store.get(previous_response_id)
|
||||
if stored is None:
|
||||
return web.json_response(
|
||||
{"error": {"message": f"Previous response not found: {previous_response_id}", "type": "invalid_request_error"}},
|
||||
status=404,
|
||||
)
|
||||
return web.json_response(_openai_error(f"Previous response not found: {previous_response_id}"), status=404)
|
||||
conversation_history = list(stored.get("conversation_history", []))
|
||||
# If no instructions provided, carry forward from previous
|
||||
if instructions is None:
|
||||
@@ -516,30 +722,46 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
# Last input message is the user_message
|
||||
user_message = input_messages[-1].get("content", "") if input_messages else ""
|
||||
if not user_message:
|
||||
return web.json_response(
|
||||
{"error": {"message": "No user message found in input", "type": "invalid_request_error"}},
|
||||
status=400,
|
||||
)
|
||||
return web.json_response(_openai_error("No user message found in input"), status=400)
|
||||
|
||||
# Truncation support
|
||||
if body.get("truncation") == "auto" and len(conversation_history) > 100:
|
||||
conversation_history = conversation_history[-100:]
|
||||
|
||||
# Run the agent
|
||||
# Run the agent (with Idempotency-Key support)
|
||||
session_id = str(uuid.uuid4())
|
||||
try:
|
||||
result, usage = await self._run_agent(
|
||||
|
||||
async def _compute_response():
|
||||
return await self._run_agent(
|
||||
user_message=user_message,
|
||||
conversation_history=conversation_history,
|
||||
ephemeral_system_prompt=instructions,
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error running agent for responses: %s", e, exc_info=True)
|
||||
return web.json_response(
|
||||
{"error": {"message": f"Internal server error: {e}", "type": "server_error"}},
|
||||
status=500,
|
||||
|
||||
idempotency_key = request.headers.get("Idempotency-Key")
|
||||
if idempotency_key:
|
||||
fp = _make_request_fingerprint(
|
||||
body,
|
||||
keys=["input", "instructions", "previous_response_id", "conversation", "model", "tools"],
|
||||
)
|
||||
try:
|
||||
result, usage = await _idem_cache.get_or_set(idempotency_key, fp, _compute_response)
|
||||
except Exception as e:
|
||||
logger.error("Error running agent for responses: %s", e, exc_info=True)
|
||||
return web.json_response(
|
||||
_openai_error(f"Internal server error: {e}", err_type="server_error"),
|
||||
status=500,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
result, usage = await _compute_response()
|
||||
except Exception as e:
|
||||
logger.error("Error running agent for responses: %s", e, exc_info=True)
|
||||
return web.json_response(
|
||||
_openai_error(f"Internal server error: {e}", err_type="server_error"),
|
||||
status=500,
|
||||
)
|
||||
|
||||
final_response = result.get("final_response", "")
|
||||
if not final_response:
|
||||
@@ -586,7 +808,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
# Update conversation mapping so the next request with the same
|
||||
# conversation name automatically chains to this response
|
||||
if conversation:
|
||||
self._conversations[conversation] = response_id
|
||||
self._response_store.set_conversation(conversation, response_id)
|
||||
|
||||
return web.json_response(response_data)
|
||||
|
||||
@@ -603,10 +825,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
response_id = request.match_info["response_id"]
|
||||
stored = self._response_store.get(response_id)
|
||||
if stored is None:
|
||||
return web.json_response(
|
||||
{"error": {"message": f"Response not found: {response_id}", "type": "invalid_request_error"}},
|
||||
status=404,
|
||||
)
|
||||
return web.json_response(_openai_error(f"Response not found: {response_id}"), status=404)
|
||||
|
||||
return web.json_response(stored["response"])
|
||||
|
||||
@@ -619,10 +838,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
response_id = request.match_info["response_id"]
|
||||
deleted = self._response_store.delete(response_id)
|
||||
if not deleted:
|
||||
return web.json_response(
|
||||
{"error": {"message": f"Response not found: {response_id}", "type": "invalid_request_error"}},
|
||||
status=404,
|
||||
)
|
||||
return web.json_response(_openai_error(f"Response not found: {response_id}"), status=404)
|
||||
|
||||
return web.json_response({
|
||||
"id": response_id,
|
||||
@@ -630,6 +846,241 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
"deleted": True,
|
||||
})
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Cron jobs API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
# Check cron module availability once (not per-request)
|
||||
_CRON_AVAILABLE = False
|
||||
try:
|
||||
from cron.jobs import (
|
||||
list_jobs as _cron_list,
|
||||
get_job as _cron_get,
|
||||
create_job as _cron_create,
|
||||
update_job as _cron_update,
|
||||
remove_job as _cron_remove,
|
||||
pause_job as _cron_pause,
|
||||
resume_job as _cron_resume,
|
||||
trigger_job as _cron_trigger,
|
||||
)
|
||||
_CRON_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
_JOB_ID_RE = __import__("re").compile(r"[a-f0-9]{12}")
|
||||
# Allowed fields for update — prevents clients injecting arbitrary keys
|
||||
_UPDATE_ALLOWED_FIELDS = {"name", "schedule", "prompt", "deliver", "skills", "skill", "repeat", "enabled"}
|
||||
_MAX_NAME_LENGTH = 200
|
||||
_MAX_PROMPT_LENGTH = 5000
|
||||
|
||||
def _check_jobs_available(self) -> Optional["web.Response"]:
|
||||
"""Return error response if cron module isn't available."""
|
||||
if not self._CRON_AVAILABLE:
|
||||
return web.json_response(
|
||||
{"error": "Cron module not available"}, status=501,
|
||||
)
|
||||
return None
|
||||
|
||||
def _check_job_id(self, request: "web.Request") -> tuple:
|
||||
"""Validate and extract job_id. Returns (job_id, error_response)."""
|
||||
job_id = request.match_info["job_id"]
|
||||
if not self._JOB_ID_RE.fullmatch(job_id):
|
||||
return job_id, web.json_response(
|
||||
{"error": "Invalid job ID format"}, status=400,
|
||||
)
|
||||
return job_id, None
|
||||
|
||||
async def _handle_list_jobs(self, request: "web.Request") -> "web.Response":
|
||||
"""GET /api/jobs — list all cron jobs."""
|
||||
auth_err = self._check_auth(request)
|
||||
if auth_err:
|
||||
return auth_err
|
||||
cron_err = self._check_jobs_available()
|
||||
if cron_err:
|
||||
return cron_err
|
||||
try:
|
||||
include_disabled = request.query.get("include_disabled", "").lower() in ("true", "1")
|
||||
jobs = self._cron_list(include_disabled=include_disabled)
|
||||
return web.json_response({"jobs": jobs})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_create_job(self, request: "web.Request") -> "web.Response":
|
||||
"""POST /api/jobs — create a new cron job."""
|
||||
auth_err = self._check_auth(request)
|
||||
if auth_err:
|
||||
return auth_err
|
||||
cron_err = self._check_jobs_available()
|
||||
if cron_err:
|
||||
return cron_err
|
||||
try:
|
||||
body = await request.json()
|
||||
name = (body.get("name") or "").strip()
|
||||
schedule = (body.get("schedule") or "").strip()
|
||||
prompt = body.get("prompt", "")
|
||||
deliver = body.get("deliver", "local")
|
||||
skills = body.get("skills")
|
||||
repeat = body.get("repeat")
|
||||
|
||||
if not name:
|
||||
return web.json_response({"error": "Name is required"}, status=400)
|
||||
if len(name) > self._MAX_NAME_LENGTH:
|
||||
return web.json_response(
|
||||
{"error": f"Name must be ≤ {self._MAX_NAME_LENGTH} characters"}, status=400,
|
||||
)
|
||||
if not schedule:
|
||||
return web.json_response({"error": "Schedule is required"}, status=400)
|
||||
if len(prompt) > self._MAX_PROMPT_LENGTH:
|
||||
return web.json_response(
|
||||
{"error": f"Prompt must be ≤ {self._MAX_PROMPT_LENGTH} characters"}, status=400,
|
||||
)
|
||||
if repeat is not None and (not isinstance(repeat, int) or repeat < 1):
|
||||
return web.json_response({"error": "Repeat must be a positive integer"}, status=400)
|
||||
|
||||
kwargs = {
|
||||
"prompt": prompt,
|
||||
"schedule": schedule,
|
||||
"name": name,
|
||||
"deliver": deliver,
|
||||
}
|
||||
if skills:
|
||||
kwargs["skills"] = skills
|
||||
if repeat is not None:
|
||||
kwargs["repeat"] = repeat
|
||||
|
||||
job = self._cron_create(**kwargs)
|
||||
return web.json_response({"job": job})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_get_job(self, request: "web.Request") -> "web.Response":
|
||||
"""GET /api/jobs/{job_id} — get a single cron job."""
|
||||
auth_err = self._check_auth(request)
|
||||
if auth_err:
|
||||
return auth_err
|
||||
cron_err = self._check_jobs_available()
|
||||
if cron_err:
|
||||
return cron_err
|
||||
job_id, id_err = self._check_job_id(request)
|
||||
if id_err:
|
||||
return id_err
|
||||
try:
|
||||
job = self._cron_get(job_id)
|
||||
if not job:
|
||||
return web.json_response({"error": "Job not found"}, status=404)
|
||||
return web.json_response({"job": job})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_update_job(self, request: "web.Request") -> "web.Response":
|
||||
"""PATCH /api/jobs/{job_id} — update a cron job."""
|
||||
auth_err = self._check_auth(request)
|
||||
if auth_err:
|
||||
return auth_err
|
||||
cron_err = self._check_jobs_available()
|
||||
if cron_err:
|
||||
return cron_err
|
||||
job_id, id_err = self._check_job_id(request)
|
||||
if id_err:
|
||||
return id_err
|
||||
try:
|
||||
body = await request.json()
|
||||
# Whitelist allowed fields to prevent arbitrary key injection
|
||||
sanitized = {k: v for k, v in body.items() if k in self._UPDATE_ALLOWED_FIELDS}
|
||||
if not sanitized:
|
||||
return web.json_response({"error": "No valid fields to update"}, status=400)
|
||||
# Validate lengths if present
|
||||
if "name" in sanitized and len(sanitized["name"]) > self._MAX_NAME_LENGTH:
|
||||
return web.json_response(
|
||||
{"error": f"Name must be ≤ {self._MAX_NAME_LENGTH} characters"}, status=400,
|
||||
)
|
||||
if "prompt" in sanitized and len(sanitized["prompt"]) > self._MAX_PROMPT_LENGTH:
|
||||
return web.json_response(
|
||||
{"error": f"Prompt must be ≤ {self._MAX_PROMPT_LENGTH} characters"}, status=400,
|
||||
)
|
||||
job = self._cron_update(job_id, sanitized)
|
||||
if not job:
|
||||
return web.json_response({"error": "Job not found"}, status=404)
|
||||
return web.json_response({"job": job})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_delete_job(self, request: "web.Request") -> "web.Response":
|
||||
"""DELETE /api/jobs/{job_id} — delete a cron job."""
|
||||
auth_err = self._check_auth(request)
|
||||
if auth_err:
|
||||
return auth_err
|
||||
cron_err = self._check_jobs_available()
|
||||
if cron_err:
|
||||
return cron_err
|
||||
job_id, id_err = self._check_job_id(request)
|
||||
if id_err:
|
||||
return id_err
|
||||
try:
|
||||
success = self._cron_remove(job_id)
|
||||
if not success:
|
||||
return web.json_response({"error": "Job not found"}, status=404)
|
||||
return web.json_response({"ok": True})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_pause_job(self, request: "web.Request") -> "web.Response":
|
||||
"""POST /api/jobs/{job_id}/pause — pause a cron job."""
|
||||
auth_err = self._check_auth(request)
|
||||
if auth_err:
|
||||
return auth_err
|
||||
cron_err = self._check_jobs_available()
|
||||
if cron_err:
|
||||
return cron_err
|
||||
job_id, id_err = self._check_job_id(request)
|
||||
if id_err:
|
||||
return id_err
|
||||
try:
|
||||
job = self._cron_pause(job_id)
|
||||
if not job:
|
||||
return web.json_response({"error": "Job not found"}, status=404)
|
||||
return web.json_response({"job": job})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_resume_job(self, request: "web.Request") -> "web.Response":
|
||||
"""POST /api/jobs/{job_id}/resume — resume a paused cron job."""
|
||||
auth_err = self._check_auth(request)
|
||||
if auth_err:
|
||||
return auth_err
|
||||
cron_err = self._check_jobs_available()
|
||||
if cron_err:
|
||||
return cron_err
|
||||
job_id, id_err = self._check_job_id(request)
|
||||
if id_err:
|
||||
return id_err
|
||||
try:
|
||||
job = self._cron_resume(job_id)
|
||||
if not job:
|
||||
return web.json_response({"error": "Job not found"}, status=404)
|
||||
return web.json_response({"job": job})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_run_job(self, request: "web.Request") -> "web.Response":
|
||||
"""POST /api/jobs/{job_id}/run — trigger immediate execution."""
|
||||
auth_err = self._check_auth(request)
|
||||
if auth_err:
|
||||
return auth_err
|
||||
cron_err = self._check_jobs_available()
|
||||
if cron_err:
|
||||
return cron_err
|
||||
job_id, id_err = self._check_job_id(request)
|
||||
if id_err:
|
||||
return id_err
|
||||
try:
|
||||
job = self._cron_trigger(job_id)
|
||||
if not job:
|
||||
return web.json_response({"error": "Job not found"}, status=404)
|
||||
return web.json_response({"job": job})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Output extraction helper
|
||||
# ------------------------------------------------------------------
|
||||
@@ -732,13 +1183,24 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
return False
|
||||
|
||||
try:
|
||||
self._app = web.Application(middlewares=[cors_middleware])
|
||||
mws = [mw for mw in (cors_middleware, body_limit_middleware) if mw is not None]
|
||||
self._app = web.Application(middlewares=mws)
|
||||
self._app["api_server_adapter"] = self
|
||||
self._app.router.add_get("/health", self._handle_health)
|
||||
self._app.router.add_get("/v1/models", self._handle_models)
|
||||
self._app.router.add_post("/v1/chat/completions", self._handle_chat_completions)
|
||||
self._app.router.add_post("/v1/responses", self._handle_responses)
|
||||
self._app.router.add_get("/v1/responses/{response_id}", self._handle_get_response)
|
||||
self._app.router.add_delete("/v1/responses/{response_id}", self._handle_delete_response)
|
||||
# Cron jobs management API
|
||||
self._app.router.add_get("/api/jobs", self._handle_list_jobs)
|
||||
self._app.router.add_post("/api/jobs", self._handle_create_job)
|
||||
self._app.router.add_get("/api/jobs/{job_id}", self._handle_get_job)
|
||||
self._app.router.add_patch("/api/jobs/{job_id}", self._handle_update_job)
|
||||
self._app.router.add_delete("/api/jobs/{job_id}", self._handle_delete_job)
|
||||
self._app.router.add_post("/api/jobs/{job_id}/pause", self._handle_pause_job)
|
||||
self._app.router.add_post("/api/jobs/{job_id}/resume", self._handle_resume_job)
|
||||
self._app.router.add_post("/api/jobs/{job_id}/run", self._handle_run_job)
|
||||
|
||||
self._runner = web.AppRunner(self._app)
|
||||
await self._runner.setup()
|
||||
|
||||
+165
-26
@@ -8,6 +8,7 @@ and implement the required methods.
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -71,31 +72,51 @@ def cache_image_from_bytes(data: bytes, ext: str = ".jpg") -> str:
|
||||
return str(filepath)
|
||||
|
||||
|
||||
async def cache_image_from_url(url: str, ext: str = ".jpg") -> str:
|
||||
async def cache_image_from_url(url: str, ext: str = ".jpg", retries: int = 2) -> str:
|
||||
"""
|
||||
Download an image from a URL and save it to the local cache.
|
||||
|
||||
Uses httpx for async download with a reasonable timeout.
|
||||
Retries on transient failures (timeouts, 429, 5xx) with exponential
|
||||
backoff so a single slow CDN response doesn't lose the media.
|
||||
|
||||
Args:
|
||||
url: The HTTP/HTTPS URL to download from.
|
||||
ext: File extension including the dot (e.g. ".jpg", ".png").
|
||||
retries: Number of retry attempts on transient failures.
|
||||
|
||||
Returns:
|
||||
Absolute path to the cached image file as a string.
|
||||
"""
|
||||
import asyncio
|
||||
import httpx
|
||||
import logging as _logging
|
||||
_log = _logging.getLogger(__name__)
|
||||
|
||||
last_exc = None
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={
|
||||
"User-Agent": "Mozilla/5.0 (compatible; HermesAgent/1.0)",
|
||||
"Accept": "image/*,*/*;q=0.8",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return cache_image_from_bytes(response.content, ext)
|
||||
for attempt in range(retries + 1):
|
||||
try:
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={
|
||||
"User-Agent": "Mozilla/5.0 (compatible; HermesAgent/1.0)",
|
||||
"Accept": "image/*,*/*;q=0.8",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return cache_image_from_bytes(response.content, ext)
|
||||
except (httpx.TimeoutException, httpx.HTTPStatusError) as exc:
|
||||
last_exc = exc
|
||||
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code < 429:
|
||||
raise
|
||||
if attempt < retries:
|
||||
wait = 1.5 * (attempt + 1)
|
||||
_log.debug("Media cache retry %d/%d for %s (%.1fs): %s",
|
||||
attempt + 1, retries, url[:80], wait, exc)
|
||||
await asyncio.sleep(wait)
|
||||
continue
|
||||
raise
|
||||
raise last_exc
|
||||
|
||||
|
||||
def cleanup_image_cache(max_age_hours: int = 24) -> int:
|
||||
@@ -296,6 +317,9 @@ class MessageEvent:
|
||||
reply_to_message_id: Optional[str] = None
|
||||
reply_to_text: Optional[str] = None # Text of the replied-to message (for context injection)
|
||||
|
||||
# Auto-loaded skill for topic/channel bindings (e.g., Telegram DM Topics)
|
||||
auto_skill: Optional[str] = None
|
||||
|
||||
# Timestamps
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
|
||||
@@ -326,6 +350,24 @@ class SendResult:
|
||||
message_id: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
raw_response: Any = None
|
||||
retryable: bool = False # True for transient errors (network, timeout) — base will retry automatically
|
||||
|
||||
|
||||
# Error substrings that indicate a transient network failure worth retrying
|
||||
_RETRYABLE_ERROR_PATTERNS = (
|
||||
"connecterror",
|
||||
"connectionerror",
|
||||
"connectionreset",
|
||||
"connectionrefused",
|
||||
"timeout",
|
||||
"timed out",
|
||||
"network",
|
||||
"broken pipe",
|
||||
"remotedisconnected",
|
||||
"eoferror",
|
||||
"readtimeout",
|
||||
"writetimeout",
|
||||
)
|
||||
|
||||
|
||||
# Type for message handlers
|
||||
@@ -504,6 +546,14 @@ class BasePlatformAdapter(ABC):
|
||||
metadata: optional dict with platform-specific context (e.g. thread_id for Slack).
|
||||
"""
|
||||
pass
|
||||
|
||||
async def stop_typing(self, chat_id: str) -> None:
|
||||
"""Stop a persistent typing indicator (if the platform uses one).
|
||||
|
||||
Override in subclasses that start background typing loops.
|
||||
Default is a no-op for platforms with one-shot typing indicators.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
@@ -713,7 +763,7 @@ class BasePlatformAdapter(ABC):
|
||||
# Extract MEDIA:<path> tags, allowing optional whitespace after the colon
|
||||
# and quoted/backticked paths for LLM-formatted outputs.
|
||||
media_pattern = re.compile(
|
||||
r'''[`"']?MEDIA:\s*(?P<path>`[^`\n]+`|"[^"\n]+"|'[^'\n]+'|\S+)[`"']?'''
|
||||
r'''[`"']?MEDIA:\s*(?P<path>`[^`\n]+`|"[^"\n]+"|'[^'\n]+'|(?:~/|/)\S+(?:[^\S\n]+\S+)*?\.(?:png|jpe?g|gif|webp|mp4|mov|avi|mkv|webm|ogg|opus|mp3|wav|m4a)(?=[\s`"',;:)\]}]|$)|\S+)[`"']?'''
|
||||
)
|
||||
for match in media_pattern.finditer(content):
|
||||
path = match.group("path").strip()
|
||||
@@ -811,7 +861,102 @@ class BasePlatformAdapter(ABC):
|
||||
await asyncio.sleep(interval)
|
||||
except asyncio.CancelledError:
|
||||
pass # Normal cancellation when handler completes
|
||||
finally:
|
||||
# Ensure the underlying platform typing loop is stopped.
|
||||
# _keep_typing may have called send_typing() after an outer
|
||||
# stop_typing() cleared the task dict, recreating the loop.
|
||||
# Cancelling _keep_typing alone won't clean that up.
|
||||
if hasattr(self, "stop_typing"):
|
||||
try:
|
||||
await self.stop_typing(chat_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _is_retryable_error(error: Optional[str]) -> bool:
|
||||
"""Return True if the error string looks like a transient network failure."""
|
||||
if not error:
|
||||
return False
|
||||
lowered = error.lower()
|
||||
return any(pat in lowered for pat in _RETRYABLE_ERROR_PATTERNS)
|
||||
|
||||
async def _send_with_retry(
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Any = None,
|
||||
max_retries: int = 2,
|
||||
base_delay: float = 2.0,
|
||||
) -> "SendResult":
|
||||
"""
|
||||
Send a message with automatic retry for transient network errors.
|
||||
|
||||
On permanent failures (e.g. formatting / permission errors) falls back
|
||||
to a plain-text version before giving up. If all attempts fail due to
|
||||
network errors, sends the user a brief delivery-failure notice so they
|
||||
know to retry rather than waiting indefinitely.
|
||||
"""
|
||||
|
||||
result = await self.send(
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
reply_to=reply_to,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
if result.success:
|
||||
return result
|
||||
|
||||
error_str = result.error or ""
|
||||
is_network = result.retryable or self._is_retryable_error(error_str)
|
||||
|
||||
if is_network:
|
||||
# Retry with exponential backoff for transient errors
|
||||
for attempt in range(1, max_retries + 1):
|
||||
delay = base_delay * (2 ** (attempt - 1)) + random.uniform(0, 1)
|
||||
logger.warning(
|
||||
"[%s] Send failed (attempt %d/%d, retrying in %.1fs): %s",
|
||||
self.name, attempt, max_retries, delay, error_str,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
result = await self.send(
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
reply_to=reply_to,
|
||||
metadata=metadata,
|
||||
)
|
||||
if result.success:
|
||||
logger.info("[%s] Send succeeded on retry %d", self.name, attempt)
|
||||
return result
|
||||
error_str = result.error or ""
|
||||
if not (result.retryable or self._is_retryable_error(error_str)):
|
||||
break # error switched to non-transient — fall through to plain-text fallback
|
||||
else:
|
||||
# All retries exhausted (loop completed without break) — notify user
|
||||
logger.error("[%s] Failed to deliver response after %d retries: %s", self.name, max_retries, error_str)
|
||||
notice = (
|
||||
"\u26a0\ufe0f Message delivery failed after multiple attempts. "
|
||||
"Please try again \u2014 your request was processed but the response could not be sent."
|
||||
)
|
||||
try:
|
||||
await self.send(chat_id=chat_id, content=notice, reply_to=reply_to, metadata=metadata)
|
||||
except Exception as notify_err:
|
||||
logger.debug("[%s] Could not send delivery-failure notice: %s", self.name, notify_err)
|
||||
return result
|
||||
|
||||
# Non-network / post-retry formatting failure: try plain text as fallback
|
||||
logger.warning("[%s] Send failed: %s — trying plain-text fallback", self.name, error_str)
|
||||
fallback_result = await self.send(
|
||||
chat_id=chat_id,
|
||||
content=f"(Response formatting failed, plain text:)\n\n{content[:3500]}",
|
||||
reply_to=reply_to,
|
||||
metadata=metadata,
|
||||
)
|
||||
if not fallback_result.success:
|
||||
logger.error("[%s] Fallback send also failed: %s", self.name, fallback_result.error)
|
||||
return fallback_result
|
||||
|
||||
async def handle_message(self, event: MessageEvent) -> None:
|
||||
"""
|
||||
Process an incoming message.
|
||||
@@ -961,26 +1106,13 @@ class BasePlatformAdapter(ABC):
|
||||
# Send the text portion
|
||||
if text_content:
|
||||
logger.info("[%s] Sending response (%d chars) to %s", self.name, len(text_content), event.source.chat_id)
|
||||
result = await self.send(
|
||||
result = await self._send_with_retry(
|
||||
chat_id=event.source.chat_id,
|
||||
content=text_content,
|
||||
reply_to=event.message_id,
|
||||
metadata=_thread_metadata,
|
||||
)
|
||||
|
||||
# Log send failures (don't raise - user already saw tool progress)
|
||||
if not result.success:
|
||||
print(f"[{self.name}] Failed to send response: {result.error}")
|
||||
# Try sending without markdown as fallback
|
||||
fallback_result = await self.send(
|
||||
chat_id=event.source.chat_id,
|
||||
content=f"(Response formatting failed, plain text:)\n\n{text_content[:3500]}",
|
||||
reply_to=event.message_id,
|
||||
metadata=_thread_metadata,
|
||||
)
|
||||
if not fallback_result.success:
|
||||
print(f"[{self.name}] Fallback send also failed: {fallback_result.error}")
|
||||
|
||||
# Human-like pacing delay between text and media
|
||||
human_delay = self._get_human_delay()
|
||||
|
||||
@@ -1122,6 +1254,13 @@ class BasePlatformAdapter(ABC):
|
||||
await typing_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
# Also cancel any platform-level persistent typing tasks (e.g. Discord)
|
||||
# that may have been recreated by _keep_typing after the last stop_typing()
|
||||
try:
|
||||
if hasattr(self, "stop_typing"):
|
||||
await self.stop_typing(event.source.chat_id)
|
||||
except Exception:
|
||||
pass
|
||||
# Clean up session tracking
|
||||
if session_key in self._active_sessions:
|
||||
del self._active_sessions[session_key]
|
||||
|
||||
+142
-15
@@ -20,7 +20,7 @@ import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Any
|
||||
from typing import Callable, Dict, Optional, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -43,6 +43,8 @@ from pathlib import Path as _Path
|
||||
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
import re
|
||||
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
@@ -50,6 +52,8 @@ from gateway.platforms.base import (
|
||||
SendResult,
|
||||
cache_image_from_url,
|
||||
cache_audio_from_url,
|
||||
cache_document_from_bytes,
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
)
|
||||
|
||||
|
||||
@@ -439,6 +443,10 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
# in those threads don't require @mention. Persisted to disk so the
|
||||
# set survives gateway restarts.
|
||||
self._bot_participated_threads: set = self._load_participated_threads()
|
||||
# Persistent typing indicator loops per channel (DMs don't reliably
|
||||
# show the standard typing gateway event for bots)
|
||||
self._typing_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._bot_task: Optional[asyncio.Task] = None
|
||||
# Cap to prevent unbounded growth (Discord threads get archived).
|
||||
self._MAX_TRACKED_THREADS = 500
|
||||
|
||||
@@ -524,6 +532,11 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if message.author == self._client.user:
|
||||
return
|
||||
|
||||
# Ignore Discord system messages (thread renames, pins, member joins, etc.)
|
||||
# Allow both default and reply types — replies have a distinct MessageType.
|
||||
if message.type not in (discord.MessageType.default, discord.MessageType.reply):
|
||||
return
|
||||
|
||||
# Bot message filtering (DISCORD_ALLOW_BOTS):
|
||||
# "none" — ignore all other bots (default)
|
||||
# "mentions" — accept bot messages only when they @mention us
|
||||
@@ -576,7 +589,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
self._register_slash_commands()
|
||||
|
||||
# Start the bot in background
|
||||
asyncio.create_task(self._client.start(self.config.token))
|
||||
self._bot_task = asyncio.create_task(self._client.start(self.config.token))
|
||||
|
||||
# Wait for ready
|
||||
await asyncio.wait_for(self._ready_event.wait(), timeout=30)
|
||||
@@ -1239,14 +1252,48 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
return await super().send_document(chat_id, file_path, caption, file_name, reply_to, metadata=metadata)
|
||||
|
||||
async def send_typing(self, chat_id: str, metadata=None) -> None:
|
||||
"""Send typing indicator."""
|
||||
if self._client:
|
||||
"""Start a persistent typing indicator for a channel.
|
||||
|
||||
Discord's TYPING_START gateway event is unreliable in DMs for bots.
|
||||
Instead, start a background loop that hits the typing endpoint every
|
||||
8 seconds (typing indicator lasts ~10s). The loop is cancelled when
|
||||
stop_typing() is called (after the response is sent).
|
||||
"""
|
||||
if not self._client:
|
||||
return
|
||||
# Don't start a duplicate loop
|
||||
if chat_id in self._typing_tasks:
|
||||
return
|
||||
|
||||
async def _typing_loop() -> None:
|
||||
try:
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if channel:
|
||||
await channel.typing()
|
||||
except Exception:
|
||||
pass # Ignore typing indicator failures
|
||||
while True:
|
||||
try:
|
||||
route = discord.http.Route(
|
||||
"POST", "/channels/{channel_id}/typing",
|
||||
channel_id=chat_id,
|
||||
)
|
||||
await self._client.http.request(route)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug("Discord typing indicator failed for %s: %s", chat_id, e)
|
||||
return
|
||||
await asyncio.sleep(8)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self._typing_tasks[chat_id] = asyncio.create_task(_typing_loop())
|
||||
|
||||
async def stop_typing(self, chat_id: str) -> None:
|
||||
"""Stop the persistent typing indicator for a channel."""
|
||||
task = self._typing_tasks.pop(chat_id, None)
|
||||
if task:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
"""Get information about a Discord channel."""
|
||||
@@ -1500,7 +1547,17 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
def _build_slash_event(self, interaction: discord.Interaction, text: str) -> MessageEvent:
|
||||
"""Build a MessageEvent from a Discord slash command interaction."""
|
||||
is_dm = isinstance(interaction.channel, discord.DMChannel)
|
||||
chat_type = "dm" if is_dm else "group"
|
||||
is_thread = isinstance(interaction.channel, discord.Thread)
|
||||
thread_id = None
|
||||
|
||||
if is_dm:
|
||||
chat_type = "dm"
|
||||
elif is_thread:
|
||||
chat_type = "thread"
|
||||
thread_id = str(interaction.channel_id)
|
||||
else:
|
||||
chat_type = "group"
|
||||
|
||||
chat_name = ""
|
||||
if not is_dm and hasattr(interaction.channel, "name"):
|
||||
chat_name = interaction.channel.name
|
||||
@@ -1516,6 +1573,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
chat_type=chat_type,
|
||||
user_id=str(interaction.user.id),
|
||||
user_name=interaction.user.display_name,
|
||||
thread_id=thread_id,
|
||||
chat_topic=chat_topic,
|
||||
)
|
||||
|
||||
@@ -1902,7 +1960,12 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
elif att.content_type.startswith("audio/"):
|
||||
msg_type = MessageType.AUDIO
|
||||
else:
|
||||
msg_type = MessageType.DOCUMENT
|
||||
doc_ext = ""
|
||||
if att.filename:
|
||||
_, doc_ext = os.path.splitext(att.filename)
|
||||
doc_ext = doc_ext.lower()
|
||||
if doc_ext in SUPPORTED_DOCUMENT_TYPES:
|
||||
msg_type = MessageType.DOCUMENT
|
||||
break
|
||||
|
||||
# When auto-threading kicked in, route responses to the new thread
|
||||
@@ -1939,6 +2002,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
# vision tool can access them reliably (Discord CDN URLs can expire).
|
||||
media_urls = []
|
||||
media_types = []
|
||||
pending_text_injection: Optional[str] = None
|
||||
for att in message.attachments:
|
||||
content_type = att.content_type or "unknown"
|
||||
if content_type.startswith("image/"):
|
||||
@@ -1970,12 +2034,75 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
media_urls.append(att.url)
|
||||
media_types.append(content_type)
|
||||
else:
|
||||
# Other attachments: keep the original URL
|
||||
media_urls.append(att.url)
|
||||
media_types.append(content_type)
|
||||
# Document attachments: download, cache, and optionally inject text
|
||||
ext = ""
|
||||
if att.filename:
|
||||
_, ext = os.path.splitext(att.filename)
|
||||
ext = ext.lower()
|
||||
if not ext and content_type:
|
||||
mime_to_ext = {v: k for k, v in SUPPORTED_DOCUMENT_TYPES.items()}
|
||||
ext = mime_to_ext.get(content_type, "")
|
||||
if ext not in SUPPORTED_DOCUMENT_TYPES:
|
||||
logger.warning(
|
||||
"[Discord] Unsupported document type '%s' (%s), skipping",
|
||||
ext or "unknown", content_type,
|
||||
)
|
||||
else:
|
||||
MAX_DOC_BYTES = 20 * 1024 * 1024
|
||||
if att.size and att.size > MAX_DOC_BYTES:
|
||||
logger.warning(
|
||||
"[Discord] Document too large (%s bytes), skipping: %s",
|
||||
att.size, att.filename,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
import aiohttp
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
att.url,
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"HTTP {resp.status}")
|
||||
raw_bytes = await resp.read()
|
||||
cached_path = cache_document_from_bytes(
|
||||
raw_bytes, att.filename or f"document{ext}"
|
||||
)
|
||||
doc_mime = SUPPORTED_DOCUMENT_TYPES[ext]
|
||||
media_urls.append(cached_path)
|
||||
media_types.append(doc_mime)
|
||||
logger.info("[Discord] Cached user document: %s", cached_path)
|
||||
# Inject text content for .txt/.md files (capped at 100 KB)
|
||||
MAX_TEXT_INJECT_BYTES = 100 * 1024
|
||||
if ext in (".md", ".txt") and len(raw_bytes) <= MAX_TEXT_INJECT_BYTES:
|
||||
try:
|
||||
text_content = raw_bytes.decode("utf-8")
|
||||
display_name = att.filename or f"document{ext}"
|
||||
display_name = re.sub(r'[^\w.\- ]', '_', display_name)
|
||||
injection = f"[Content of {display_name}]:\n{text_content}"
|
||||
if pending_text_injection:
|
||||
pending_text_injection = f"{pending_text_injection}\n\n{injection}"
|
||||
else:
|
||||
pending_text_injection = injection
|
||||
except UnicodeDecodeError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"[Discord] Failed to cache document %s: %s",
|
||||
att.filename, e, exc_info=True,
|
||||
)
|
||||
|
||||
event_text = message.content
|
||||
if pending_text_injection:
|
||||
event_text = f"{pending_text_injection}\n\n{event_text}" if event_text else pending_text_injection
|
||||
|
||||
# Defense-in-depth: prevent empty user messages from entering session
|
||||
# (can happen when user sends @mention-only with no other text)
|
||||
if not event_text or not event_text.strip():
|
||||
event_text = "(The user sent a message with no text content)"
|
||||
|
||||
event = MessageEvent(
|
||||
text=message.content,
|
||||
text=event_text,
|
||||
message_type=msg_type,
|
||||
source=source,
|
||||
raw_message=message,
|
||||
|
||||
@@ -24,7 +24,6 @@ import re
|
||||
import smtplib
|
||||
import ssl
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from email.header import decode_header
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
@@ -225,12 +224,12 @@ class EmailAdapter(BasePlatformAdapter):
|
||||
"""Connect to the IMAP server and start polling for new messages."""
|
||||
try:
|
||||
# Test IMAP connection
|
||||
imap = imaplib.IMAP4_SSL(self._imap_host, self._imap_port)
|
||||
imap = imaplib.IMAP4_SSL(self._imap_host, self._imap_port, timeout=30)
|
||||
imap.login(self._address, self._password)
|
||||
# Mark all existing messages as seen so we only process new ones
|
||||
imap.select("INBOX")
|
||||
status, data = imap.uid("search", None, "ALL")
|
||||
if status == "OK" and data[0]:
|
||||
if status == "OK" and data and data[0]:
|
||||
for uid in data[0].split():
|
||||
self._seen_uids.add(uid)
|
||||
imap.logout()
|
||||
@@ -241,7 +240,7 @@ class EmailAdapter(BasePlatformAdapter):
|
||||
|
||||
try:
|
||||
# Test SMTP connection
|
||||
smtp = smtplib.SMTP(self._smtp_host, self._smtp_port)
|
||||
smtp = smtplib.SMTP(self._smtp_host, self._smtp_port, timeout=30)
|
||||
smtp.starttls(context=ssl.create_default_context())
|
||||
smtp.login(self._address, self._password)
|
||||
smtp.quit()
|
||||
@@ -290,12 +289,12 @@ class EmailAdapter(BasePlatformAdapter):
|
||||
"""Fetch new (unseen) messages from IMAP. Runs in executor thread."""
|
||||
results = []
|
||||
try:
|
||||
imap = imaplib.IMAP4_SSL(self._imap_host, self._imap_port)
|
||||
imap = imaplib.IMAP4_SSL(self._imap_host, self._imap_port, timeout=30)
|
||||
imap.login(self._address, self._password)
|
||||
imap.select("INBOX")
|
||||
|
||||
status, data = imap.uid("search", None, "UNSEEN")
|
||||
if status != "OK" or not data[0]:
|
||||
if status != "OK" or not data or not data[0]:
|
||||
imap.logout()
|
||||
return results
|
||||
|
||||
@@ -443,7 +442,7 @@ class EmailAdapter(BasePlatformAdapter):
|
||||
|
||||
msg.attach(MIMEText(body, "plain", "utf-8"))
|
||||
|
||||
smtp = smtplib.SMTP(self._smtp_host, self._smtp_port)
|
||||
smtp = smtplib.SMTP(self._smtp_host, self._smtp_port, timeout=30)
|
||||
smtp.starttls(context=ssl.create_default_context())
|
||||
smtp.login(self._address, self._password)
|
||||
smtp.send_message(msg)
|
||||
@@ -454,7 +453,6 @@ class EmailAdapter(BasePlatformAdapter):
|
||||
|
||||
async def send_typing(self, chat_id: str, metadata: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""Email has no typing indicator — no-op."""
|
||||
pass
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
@@ -531,7 +529,7 @@ class EmailAdapter(BasePlatformAdapter):
|
||||
part.add_header("Content-Disposition", f"attachment; filename={fname}")
|
||||
msg.attach(part)
|
||||
|
||||
smtp = smtplib.SMTP(self._smtp_host, self._smtp_port)
|
||||
smtp = smtplib.SMTP(self._smtp_host, self._smtp_port, timeout=30)
|
||||
smtp.starttls(context=ssl.create_default_context())
|
||||
smtp.login(self._address, self._password)
|
||||
smtp.send_message(msg)
|
||||
|
||||
@@ -19,7 +19,7 @@ import os
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
from typing import Any, Dict, Optional, Set
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
@@ -114,7 +114,9 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
||||
return False
|
||||
|
||||
# Dedicated REST session for send() calls
|
||||
self._rest_session = aiohttp.ClientSession()
|
||||
self._rest_session = aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=30)
|
||||
)
|
||||
|
||||
# Warn if no event filters are configured
|
||||
if not self._watch_domains and not self._watch_entities and not self._watch_all:
|
||||
@@ -140,8 +142,10 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
||||
ws_url = self._hass_url.replace("http://", "ws://").replace("https://", "wss://")
|
||||
ws_url = f"{ws_url}/api/websocket"
|
||||
|
||||
self._session = aiohttp.ClientSession()
|
||||
self._ws = await self._session.ws_connect(ws_url, heartbeat=30)
|
||||
self._session = aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=30)
|
||||
)
|
||||
self._ws = await self._session.ws_connect(ws_url, heartbeat=30, timeout=30)
|
||||
|
||||
# Step 1: Receive auth_required
|
||||
msg = await self._ws.receive_json()
|
||||
@@ -435,7 +439,6 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
||||
|
||||
async def send_typing(self, chat_id: str, metadata=None) -> None:
|
||||
"""No typing indicator for Home Assistant."""
|
||||
pass
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
"""Return basic info about the HA event channel."""
|
||||
|
||||
@@ -17,14 +17,13 @@ Environment variables:
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
from typing import Any, Dict, Optional, Set
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
@@ -103,6 +102,23 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
self._dm_rooms: Dict[str, bool] = {}
|
||||
# Set of room IDs we've joined
|
||||
self._joined_rooms: Set[str] = set()
|
||||
# Event deduplication (bounded deque keeps newest entries)
|
||||
from collections import deque
|
||||
self._processed_events: deque = deque(maxlen=1000)
|
||||
self._processed_events_set: set = set()
|
||||
|
||||
def _is_duplicate_event(self, event_id) -> bool:
|
||||
"""Return True if this event was already processed. Tracks the ID otherwise."""
|
||||
if not event_id:
|
||||
return False
|
||||
if event_id in self._processed_events_set:
|
||||
return True
|
||||
if len(self._processed_events) == self._processed_events.maxlen:
|
||||
evicted = self._processed_events[0]
|
||||
self._processed_events_set.discard(evicted)
|
||||
self._processed_events.append(event_id)
|
||||
self._processed_events_set.add(event_id)
|
||||
return False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Required overrides
|
||||
@@ -188,7 +204,6 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
|
||||
# Register event callbacks.
|
||||
client.add_event_callback(self._on_room_message, nio.RoomMessageText)
|
||||
client.add_event_callback(self._on_room_message_media, nio.RoomMessageMedia)
|
||||
client.add_event_callback(self._on_room_message_media, nio.RoomMessageImage)
|
||||
client.add_event_callback(self._on_room_message_media, nio.RoomMessageAudio)
|
||||
client.add_event_callback(self._on_room_message_media, nio.RoomMessageVideo)
|
||||
@@ -536,9 +551,20 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
|
||||
async def _sync_loop(self) -> None:
|
||||
"""Continuously sync with the homeserver."""
|
||||
import nio
|
||||
|
||||
while not self._closing:
|
||||
try:
|
||||
await self._client.sync(timeout=30000)
|
||||
resp = await self._client.sync(timeout=30000)
|
||||
if isinstance(resp, nio.SyncError):
|
||||
if self._closing:
|
||||
return
|
||||
logger.warning(
|
||||
"Matrix: sync returned %s: %s — retrying in 5s",
|
||||
type(resp).__name__,
|
||||
getattr(resp, "message", resp),
|
||||
)
|
||||
await asyncio.sleep(5)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as exc:
|
||||
@@ -559,6 +585,10 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
if event.sender == self._user_id:
|
||||
return
|
||||
|
||||
# Deduplicate by event ID (nio can fire the same event more than once).
|
||||
if self._is_duplicate_event(getattr(event, "event_id", None)):
|
||||
return
|
||||
|
||||
# Startup grace: ignore old messages from initial sync.
|
||||
event_ts = getattr(event, "server_timestamp", 0) / 1000.0
|
||||
if event_ts and event_ts < self._startup_ts - _STARTUP_GRACE_SECONDS:
|
||||
@@ -648,6 +678,10 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
if event.sender == self._user_id:
|
||||
return
|
||||
|
||||
# Deduplicate by event ID.
|
||||
if self._is_duplicate_event(getattr(event, "event_id", None)):
|
||||
return
|
||||
|
||||
# Startup grace.
|
||||
event_ts = getattr(event, "server_timestamp", 0) / 1000.0
|
||||
if event_ts and event_ts < self._startup_ts - _STARTUP_GRACE_SECONDS:
|
||||
@@ -681,6 +715,24 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
elif event_mimetype:
|
||||
media_type = event_mimetype
|
||||
|
||||
# For images, download and cache locally so vision tools can access them.
|
||||
# Matrix MXC URLs require authentication, so direct URL access fails.
|
||||
cached_path = None
|
||||
if msg_type == MessageType.PHOTO and url:
|
||||
try:
|
||||
ext_map = {
|
||||
"image/jpeg": ".jpg", "image/png": ".png",
|
||||
"image/gif": ".gif", "image/webp": ".webp",
|
||||
}
|
||||
ext = ext_map.get(event_mimetype, ".jpg")
|
||||
download_resp = await self._client.download(url)
|
||||
if isinstance(download_resp, nio.DownloadResponse):
|
||||
from gateway.platforms.base import cache_image_from_bytes
|
||||
cached_path = cache_image_from_bytes(download_resp.body, ext=ext)
|
||||
logger.info("[Matrix] Cached user image at %s", cached_path)
|
||||
except Exception as e:
|
||||
logger.warning("[Matrix] Failed to cache image: %s", e)
|
||||
|
||||
is_dm = self._dm_rooms.get(room.room_id, False)
|
||||
if not is_dm and room.member_count == 2:
|
||||
is_dm = True
|
||||
@@ -701,14 +753,18 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
# Use cached local path for images, HTTP URL for other media types
|
||||
media_urls = [cached_path] if cached_path else ([http_url] if http_url else None)
|
||||
media_types = [media_type] if media_urls else None
|
||||
|
||||
msg_event = MessageEvent(
|
||||
text=body,
|
||||
message_type=msg_type,
|
||||
source=source,
|
||||
raw_message=getattr(event, "source", {}),
|
||||
message_id=event.event_id,
|
||||
media_urls=[http_url] if http_url else None,
|
||||
media_types=[media_type] if http_url else None,
|
||||
media_urls=media_urls,
|
||||
media_types=media_types,
|
||||
)
|
||||
|
||||
await self.handle_message(msg_event)
|
||||
|
||||
@@ -20,7 +20,7 @@ import os
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
@@ -116,7 +116,7 @@ class MattermostAdapter(BasePlatformAdapter):
|
||||
import aiohttp
|
||||
url = f"{self._base_url}/api/v4/{path.lstrip('/')}"
|
||||
try:
|
||||
async with self._session.get(url, headers=self._headers()) as resp:
|
||||
async with self._session.get(url, headers=self._headers(), timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||
if resp.status >= 400:
|
||||
body = await resp.text()
|
||||
logger.error("MM API GET %s → %s: %s", path, resp.status, body[:200])
|
||||
@@ -134,7 +134,8 @@ class MattermostAdapter(BasePlatformAdapter):
|
||||
url = f"{self._base_url}/api/v4/{path.lstrip('/')}"
|
||||
try:
|
||||
async with self._session.post(
|
||||
url, headers=self._headers(), json=payload
|
||||
url, headers=self._headers(), json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=30)
|
||||
) as resp:
|
||||
if resp.status >= 400:
|
||||
body = await resp.text()
|
||||
@@ -180,7 +181,7 @@ class MattermostAdapter(BasePlatformAdapter):
|
||||
content_type=content_type,
|
||||
)
|
||||
headers = {"Authorization": f"Bearer {self._token}"}
|
||||
async with self._session.post(url, headers=headers, data=form) as resp:
|
||||
async with self._session.post(url, headers=headers, data=form, timeout=aiohttp.ClientTimeout(total=60)) as resp:
|
||||
if resp.status >= 400:
|
||||
body = await resp.text()
|
||||
logger.error("MM file upload → %s: %s", resp.status, body[:200])
|
||||
@@ -201,7 +202,9 @@ class MattermostAdapter(BasePlatformAdapter):
|
||||
logger.error("Mattermost: URL or token not configured")
|
||||
return False
|
||||
|
||||
self._session = aiohttp.ClientSession()
|
||||
self._session = aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=30)
|
||||
)
|
||||
self._closing = False
|
||||
|
||||
# Verify credentials and fetch bot identity.
|
||||
@@ -404,18 +407,38 @@ class MattermostAdapter(BasePlatformAdapter):
|
||||
kind: str = "file",
|
||||
) -> SendResult:
|
||||
"""Download a URL and upload it as a file attachment."""
|
||||
import asyncio
|
||||
import aiohttp
|
||||
try:
|
||||
async with self._session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||
if resp.status >= 400:
|
||||
# Fall back to sending the URL as text.
|
||||
return await self.send(chat_id, f"{caption or ''}\n{url}".strip(), reply_to)
|
||||
file_data = await resp.read()
|
||||
ct = resp.content_type or "application/octet-stream"
|
||||
# Derive filename from URL.
|
||||
fname = url.rsplit("/", 1)[-1].split("?")[0] or f"{kind}.png"
|
||||
except Exception as exc:
|
||||
logger.warning("Mattermost: failed to download %s: %s", url, exc)
|
||||
|
||||
last_exc = None
|
||||
file_data = None
|
||||
ct = "application/octet-stream"
|
||||
fname = url.rsplit("/", 1)[-1].split("?")[0] or f"{kind}.png"
|
||||
|
||||
for attempt in range(3):
|
||||
try:
|
||||
async with self._session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||
if resp.status >= 500 or resp.status == 429:
|
||||
if attempt < 2:
|
||||
logger.debug("Mattermost download retry %d/2 for %s (status %d)",
|
||||
attempt + 1, url[:80], resp.status)
|
||||
await asyncio.sleep(1.5 * (attempt + 1))
|
||||
continue
|
||||
if resp.status >= 400:
|
||||
return await self.send(chat_id, f"{caption or ''}\n{url}".strip(), reply_to)
|
||||
file_data = await resp.read()
|
||||
ct = resp.content_type or "application/octet-stream"
|
||||
break
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError) as exc:
|
||||
last_exc = exc
|
||||
if attempt < 2:
|
||||
await asyncio.sleep(1.5 * (attempt + 1))
|
||||
continue
|
||||
logger.warning("Mattermost: failed to download %s after %d attempts: %s", url, attempt + 1, exc)
|
||||
return await self.send(chat_id, f"{caption or ''}\n{url}".strip(), reply_to)
|
||||
|
||||
if file_data is None:
|
||||
logger.warning("Mattermost: download returned no data for %s", url)
|
||||
return await self.send(chat_id, f"{caption or ''}\n{url}".strip(), reply_to)
|
||||
|
||||
file_id = await self._upload_file(chat_id, file_data, fname, ct)
|
||||
@@ -580,6 +603,24 @@ class MattermostAdapter(BasePlatformAdapter):
|
||||
# For DMs, user_id is sufficient. For channels, check for @mention.
|
||||
message_text = post.get("message", "")
|
||||
|
||||
# Mention-only mode: skip channel messages that don't @mention the bot.
|
||||
# DMs (type "D") are always processed.
|
||||
if channel_type_raw != "D":
|
||||
mention_patterns = [
|
||||
f"@{self._bot_username}",
|
||||
f"@{self._bot_user_id}",
|
||||
]
|
||||
has_mention = any(
|
||||
pattern.lower() in message_text.lower()
|
||||
for pattern in mention_patterns
|
||||
)
|
||||
if not has_mention:
|
||||
logger.debug(
|
||||
"Mattermost: skipping non-DM message without @mention (channel=%s)",
|
||||
channel_id,
|
||||
)
|
||||
return
|
||||
|
||||
# Resolve sender info.
|
||||
sender_id = post.get("user_id", "")
|
||||
sender_name = data.get("sender_name", "").lstrip("@") or sender_id
|
||||
@@ -617,16 +658,16 @@ class MattermostAdapter(BasePlatformAdapter):
|
||||
if mime.startswith("image/"):
|
||||
local_path = cache_image_from_bytes(file_data, ext or ".png")
|
||||
media_urls.append(local_path)
|
||||
media_types.append("image")
|
||||
media_types.append(mime)
|
||||
elif mime.startswith("audio/"):
|
||||
from gateway.platforms.base import cache_audio_from_bytes
|
||||
local_path = cache_audio_from_bytes(file_data, ext or ".ogg")
|
||||
media_urls.append(local_path)
|
||||
media_types.append("audio")
|
||||
media_types.append(mime)
|
||||
else:
|
||||
local_path = cache_document_from_bytes(file_data, fname)
|
||||
media_urls.append(local_path)
|
||||
media_types.append("document")
|
||||
media_types.append(mime)
|
||||
else:
|
||||
logger.warning("Mattermost: failed to download file %s: HTTP %s", fid, resp.status)
|
||||
except Exception as exc:
|
||||
|
||||
@@ -279,6 +279,12 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
# SSE keepalive comments (":") prove the connection
|
||||
# is alive — update activity so the health monitor
|
||||
# doesn't report false idle warnings.
|
||||
if line.startswith(":"):
|
||||
self._last_sse_activity = time.time()
|
||||
continue
|
||||
# Parse SSE data lines
|
||||
if line.startswith("data:"):
|
||||
data_str = line[5:].strip()
|
||||
@@ -344,7 +350,9 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
"""Force SSE reconnection by closing the current response."""
|
||||
if self._sse_response and not self._sse_response.is_stream_consumed:
|
||||
try:
|
||||
asyncio.create_task(self._sse_response.aclose())
|
||||
task = asyncio.create_task(self._sse_response.aclose())
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
except Exception:
|
||||
pass
|
||||
self._sse_response = None
|
||||
@@ -478,7 +486,7 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
if any(mt.startswith("audio/") for mt in media_types):
|
||||
msg_type = MessageType.VOICE
|
||||
elif any(mt.startswith("image/") for mt in media_types):
|
||||
msg_type = MessageType.IMAGE
|
||||
msg_type = MessageType.PHOTO
|
||||
|
||||
# Parse timestamp from envelope data (milliseconds since epoch)
|
||||
ts_ms = envelope_data.get("timestamp", 0)
|
||||
@@ -519,6 +527,13 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
if not result:
|
||||
return None, ""
|
||||
|
||||
# Handle dict response (signal-cli returns {"data": "base64..."})
|
||||
if isinstance(result, dict):
|
||||
result = result.get("data")
|
||||
if not result:
|
||||
logger.warning("Signal: attachment response missing 'data' key")
|
||||
return None, ""
|
||||
|
||||
# Result is base64-encoded file content
|
||||
raw_data = base64.b64decode(result)
|
||||
ext = _guess_extension(raw_data)
|
||||
|
||||
+55
-24
@@ -12,7 +12,7 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Any
|
||||
from typing import Dict, Optional, Any
|
||||
|
||||
try:
|
||||
from slack_bolt.async_app import AsyncApp
|
||||
@@ -37,8 +37,6 @@ from gateway.platforms.base import (
|
||||
SendResult,
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
cache_document_from_bytes,
|
||||
cache_image_from_url,
|
||||
cache_audio_from_url,
|
||||
)
|
||||
|
||||
|
||||
@@ -74,6 +72,7 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
self._handler: Optional[AsyncSocketModeHandler] = None
|
||||
self._bot_user_id: Optional[str] = None
|
||||
self._user_name_cache: Dict[str, str] = {} # user_id → display name
|
||||
self._socket_mode_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to Slack via Socket Mode."""
|
||||
@@ -121,7 +120,7 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
|
||||
# Start Socket Mode handler in background
|
||||
self._handler = AsyncSocketModeHandler(self._app, app_token)
|
||||
asyncio.create_task(self._handler.start_async())
|
||||
self._socket_mode_task = asyncio.create_task(self._handler.start_async())
|
||||
|
||||
self._running = True
|
||||
logger.info("[Slack] Connected as @%s (Socket Mode)", bot_name)
|
||||
@@ -820,33 +819,65 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
await self.handle_message(event)
|
||||
|
||||
async def _download_slack_file(self, url: str, ext: str, audio: bool = False) -> str:
|
||||
"""Download a Slack file using the bot token for auth."""
|
||||
"""Download a Slack file using the bot token for auth, with retry."""
|
||||
import asyncio
|
||||
import httpx
|
||||
|
||||
bot_token = self.config.token
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={"Authorization": f"Bearer {bot_token}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
last_exc = None
|
||||
|
||||
if audio:
|
||||
from gateway.platforms.base import cache_audio_from_bytes
|
||||
return cache_audio_from_bytes(response.content, ext)
|
||||
else:
|
||||
from gateway.platforms.base import cache_image_from_bytes
|
||||
return cache_image_from_bytes(response.content, ext)
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
for attempt in range(3):
|
||||
try:
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={"Authorization": f"Bearer {bot_token}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
if audio:
|
||||
from gateway.platforms.base import cache_audio_from_bytes
|
||||
return cache_audio_from_bytes(response.content, ext)
|
||||
else:
|
||||
from gateway.platforms.base import cache_image_from_bytes
|
||||
return cache_image_from_bytes(response.content, ext)
|
||||
except (httpx.TimeoutException, httpx.HTTPStatusError) as exc:
|
||||
last_exc = exc
|
||||
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code < 429:
|
||||
raise
|
||||
if attempt < 2:
|
||||
logger.debug("Slack file download retry %d/2 for %s: %s",
|
||||
attempt + 1, url[:80], exc)
|
||||
await asyncio.sleep(1.5 * (attempt + 1))
|
||||
continue
|
||||
raise
|
||||
raise last_exc
|
||||
|
||||
async def _download_slack_file_bytes(self, url: str) -> bytes:
|
||||
"""Download a Slack file and return raw bytes."""
|
||||
"""Download a Slack file and return raw bytes, with retry."""
|
||||
import asyncio
|
||||
import httpx
|
||||
|
||||
bot_token = self.config.token
|
||||
last_exc = None
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={"Authorization": f"Bearer {bot_token}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
for attempt in range(3):
|
||||
try:
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={"Authorization": f"Bearer {bot_token}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
except (httpx.TimeoutException, httpx.HTTPStatusError) as exc:
|
||||
last_exc = exc
|
||||
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code < 429:
|
||||
raise
|
||||
if attempt < 2:
|
||||
logger.debug("Slack file download retry %d/2 for %s: %s",
|
||||
attempt + 1, url[:80], exc)
|
||||
await asyncio.sleep(1.5 * (attempt + 1))
|
||||
continue
|
||||
raise
|
||||
raise last_exc
|
||||
|
||||
@@ -17,12 +17,11 @@ Gateway-specific env vars:
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import urllib.parse
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
@@ -107,7 +106,9 @@ class SmsAdapter(BasePlatformAdapter):
|
||||
await self._runner.setup()
|
||||
site = web.TCPSite(self._runner, "0.0.0.0", self._webhook_port)
|
||||
await site.start()
|
||||
self._http_session = aiohttp.ClientSession()
|
||||
self._http_session = aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
)
|
||||
self._running = True
|
||||
|
||||
logger.info(
|
||||
@@ -145,7 +146,9 @@ class SmsAdapter(BasePlatformAdapter):
|
||||
"Authorization": self._basic_auth_header(),
|
||||
}
|
||||
|
||||
session = self._http_session or aiohttp.ClientSession()
|
||||
session = self._http_session or aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
)
|
||||
try:
|
||||
for chunk in chunks:
|
||||
form_data = aiohttp.FormData()
|
||||
@@ -262,7 +265,9 @@ class SmsAdapter(BasePlatformAdapter):
|
||||
)
|
||||
|
||||
# Non-blocking: Twilio expects a fast response
|
||||
asyncio.create_task(self.handle_message(event))
|
||||
task = asyncio.create_task(self.handle_message(event))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
# Return empty TwiML — we send replies via the REST API, not inline TwiML
|
||||
return web.Response(
|
||||
|
||||
+559
-14
@@ -25,6 +25,7 @@ try:
|
||||
filters,
|
||||
)
|
||||
from telegram.constants import ParseMode, ChatType
|
||||
from telegram.request import HTTPXRequest
|
||||
TELEGRAM_AVAILABLE = True
|
||||
except ImportError:
|
||||
TELEGRAM_AVAILABLE = False
|
||||
@@ -34,6 +35,7 @@ except ImportError:
|
||||
Application = Any
|
||||
CommandHandler = Any
|
||||
TelegramMessageHandler = Any
|
||||
HTTPXRequest = Any
|
||||
filters = None
|
||||
ParseMode = None
|
||||
ChatType = None
|
||||
@@ -59,6 +61,11 @@ from gateway.platforms.base import (
|
||||
cache_document_from_bytes,
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
)
|
||||
from gateway.platforms.telegram_network import (
|
||||
TelegramFallbackTransport,
|
||||
discover_fallback_ips,
|
||||
parse_fallback_ip_env,
|
||||
)
|
||||
|
||||
|
||||
def check_telegram_requirements() -> bool:
|
||||
@@ -115,6 +122,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
super().__init__(config, Platform.TELEGRAM)
|
||||
self._app: Optional[Application] = None
|
||||
self._bot: Optional[Bot] = None
|
||||
self._reply_to_mode: str = getattr(config, 'reply_to_mode', 'first') or 'first'
|
||||
# Buffer rapid/album photo updates so Telegram image bursts are handled
|
||||
# as a single MessageEvent instead of self-interrupting multiple turns.
|
||||
self._media_batch_delay_seconds = float(os.getenv("HERMES_TELEGRAM_MEDIA_BATCH_DELAY_SECONDS", "0.8"))
|
||||
@@ -129,6 +137,20 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
self._pending_text_batch_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._token_lock_identity: Optional[str] = None
|
||||
self._polling_error_task: Optional[asyncio.Task] = None
|
||||
self._polling_conflict_count: int = 0
|
||||
self._polling_network_error_count: int = 0
|
||||
self._polling_error_callback_ref = None
|
||||
# DM Topics: map of topic_name -> message_thread_id (populated at startup)
|
||||
self._dm_topics: Dict[str, int] = {}
|
||||
# DM Topics config from extra.dm_topics
|
||||
self._dm_topics_config: List[Dict[str, Any]] = self.config.extra.get("dm_topics", [])
|
||||
|
||||
def _fallback_ips(self) -> list[str]:
|
||||
"""Return validated fallback IPs from config (populated by _apply_env_overrides)."""
|
||||
configured = self.config.extra.get("fallback_ips", []) if getattr(self.config, "extra", None) else []
|
||||
if isinstance(configured, str):
|
||||
configured = configured.split(",")
|
||||
return parse_fallback_ip_env(",".join(str(v) for v in configured) if configured else None)
|
||||
|
||||
@staticmethod
|
||||
def _looks_like_polling_conflict(error: Exception) -> bool:
|
||||
@@ -139,13 +161,133 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
or "another bot instance is running" in text
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _looks_like_network_error(error: Exception) -> bool:
|
||||
"""Return True for transient network errors that warrant a reconnect attempt."""
|
||||
name = error.__class__.__name__.lower()
|
||||
if name in ("networkerror", "timedout", "connectionerror"):
|
||||
return True
|
||||
try:
|
||||
from telegram.error import NetworkError, TimedOut
|
||||
if isinstance(error, (NetworkError, TimedOut)):
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
return isinstance(error, OSError)
|
||||
|
||||
async def _handle_polling_network_error(self, error: Exception) -> None:
|
||||
"""Reconnect polling after a transient network interruption.
|
||||
|
||||
Triggered by NetworkError/TimedOut in the polling error callback, which
|
||||
happen when the host loses connectivity (Mac sleep, WiFi switch, VPN
|
||||
reconnect, etc.). The gateway process stays alive but the long-poll
|
||||
connection silently dies; without this handler the bot never recovers.
|
||||
|
||||
Strategy: exponential back-off (5s, 10s, 20s, 40s, 60s cap) up to
|
||||
MAX_NETWORK_RETRIES attempts, then mark the adapter retryable-fatal so
|
||||
the supervisor restarts the gateway process.
|
||||
"""
|
||||
if self.has_fatal_error:
|
||||
return
|
||||
|
||||
MAX_NETWORK_RETRIES = 10
|
||||
BASE_DELAY = 5
|
||||
MAX_DELAY = 60
|
||||
|
||||
self._polling_network_error_count += 1
|
||||
attempt = self._polling_network_error_count
|
||||
|
||||
if attempt > MAX_NETWORK_RETRIES:
|
||||
message = (
|
||||
"Telegram polling could not reconnect after %d network error retries. "
|
||||
"Restarting gateway." % MAX_NETWORK_RETRIES
|
||||
)
|
||||
logger.error("[%s] %s Last error: %s", self.name, message, error)
|
||||
self._set_fatal_error("telegram_network_error", message, retryable=True)
|
||||
await self._notify_fatal_error()
|
||||
return
|
||||
|
||||
delay = min(BASE_DELAY * (2 ** (attempt - 1)), MAX_DELAY)
|
||||
logger.warning(
|
||||
"[%s] Telegram network error (attempt %d/%d), reconnecting in %ds. Error: %s",
|
||||
self.name, attempt, MAX_NETWORK_RETRIES, delay, error,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
try:
|
||||
if self._app and self._app.updater and self._app.updater.running:
|
||||
await self._app.updater.stop()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
await self._app.updater.start_polling(
|
||||
allowed_updates=Update.ALL_TYPES,
|
||||
drop_pending_updates=False,
|
||||
error_callback=self._polling_error_callback_ref,
|
||||
)
|
||||
logger.info(
|
||||
"[%s] Telegram polling resumed after network error (attempt %d)",
|
||||
self.name, attempt,
|
||||
)
|
||||
self._polling_network_error_count = 0
|
||||
except Exception as retry_err:
|
||||
logger.warning("[%s] Telegram polling reconnect failed: %s", self.name, retry_err)
|
||||
# start_polling failed — polling is dead and no further error
|
||||
# callbacks will fire, so schedule the next retry ourselves.
|
||||
if not self.has_fatal_error:
|
||||
task = asyncio.ensure_future(
|
||||
self._handle_polling_network_error(retry_err)
|
||||
)
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
async def _handle_polling_conflict(self, error: Exception) -> None:
|
||||
if self.has_fatal_error and self.fatal_error_code == "telegram_polling_conflict":
|
||||
return
|
||||
# Track consecutive conflicts — transient 409s can occur when a
|
||||
# previous gateway instance hasn't fully released its long-poll
|
||||
# session on Telegram's server (e.g. during --replace handoffs or
|
||||
# systemd Restart=on-failure respawns). Retry a few times before
|
||||
# giving up, so the old session has time to expire.
|
||||
self._polling_conflict_count += 1
|
||||
|
||||
MAX_CONFLICT_RETRIES = 3
|
||||
RETRY_DELAY = 10 # seconds
|
||||
|
||||
if self._polling_conflict_count <= MAX_CONFLICT_RETRIES:
|
||||
logger.warning(
|
||||
"[%s] Telegram polling conflict (%d/%d), will retry in %ds. Error: %s",
|
||||
self.name, self._polling_conflict_count, MAX_CONFLICT_RETRIES,
|
||||
RETRY_DELAY, error,
|
||||
)
|
||||
try:
|
||||
if self._app and self._app.updater and self._app.updater.running:
|
||||
await self._app.updater.stop()
|
||||
except Exception:
|
||||
pass
|
||||
await asyncio.sleep(RETRY_DELAY)
|
||||
try:
|
||||
await self._app.updater.start_polling(
|
||||
allowed_updates=Update.ALL_TYPES,
|
||||
drop_pending_updates=False,
|
||||
error_callback=self._polling_error_callback_ref,
|
||||
)
|
||||
logger.info("[%s] Telegram polling resumed after conflict retry %d", self.name, self._polling_conflict_count)
|
||||
self._polling_conflict_count = 0 # reset on success
|
||||
return
|
||||
except Exception as retry_err:
|
||||
logger.warning("[%s] Telegram polling retry failed: %s", self.name, retry_err)
|
||||
# Don't fall through to fatal yet — wait for the next conflict
|
||||
# to trigger another retry attempt (up to MAX_CONFLICT_RETRIES).
|
||||
return
|
||||
|
||||
# Exhausted retries — fatal
|
||||
message = (
|
||||
"Another Telegram bot poller is already using this token. "
|
||||
"Hermes stopped Telegram polling to avoid endless retry spam. "
|
||||
"Hermes stopped Telegram polling after %d retries. "
|
||||
"Make sure only one gateway instance is running for this bot token."
|
||||
% MAX_CONFLICT_RETRIES
|
||||
)
|
||||
logger.error("[%s] %s Original error: %s", self.name, message, error)
|
||||
self._set_fatal_error("telegram_polling_conflict", message, retryable=False)
|
||||
@@ -156,6 +298,162 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
logger.warning("[%s] Failed stopping Telegram polling after conflict: %s", self.name, stop_error, exc_info=True)
|
||||
await self._notify_fatal_error()
|
||||
|
||||
async def _create_dm_topic(
|
||||
self,
|
||||
chat_id: int,
|
||||
name: str,
|
||||
icon_color: Optional[int] = None,
|
||||
icon_custom_emoji_id: Optional[str] = None,
|
||||
) -> Optional[int]:
|
||||
"""Create a forum topic in a private (DM) chat.
|
||||
|
||||
Uses Bot API 9.4's createForumTopic which now works for 1-on-1 chats.
|
||||
Returns the message_thread_id on success, None on failure.
|
||||
"""
|
||||
if not self._bot:
|
||||
return None
|
||||
try:
|
||||
kwargs: Dict[str, Any] = {"chat_id": chat_id, "name": name}
|
||||
if icon_color is not None:
|
||||
kwargs["icon_color"] = icon_color
|
||||
if icon_custom_emoji_id:
|
||||
kwargs["icon_custom_emoji_id"] = icon_custom_emoji_id
|
||||
|
||||
topic = await self._bot.create_forum_topic(**kwargs)
|
||||
thread_id = topic.message_thread_id
|
||||
logger.info(
|
||||
"[%s] Created DM topic '%s' in chat %s -> thread_id=%s",
|
||||
self.name, name, chat_id, thread_id,
|
||||
)
|
||||
return thread_id
|
||||
except Exception as e:
|
||||
error_text = str(e).lower()
|
||||
# If topic already exists, try to find it via getForumTopicIconStickers
|
||||
# or we just log and skip — Telegram doesn't provide a "list topics" API
|
||||
if "topic_name_duplicate" in error_text or "already" in error_text:
|
||||
logger.info(
|
||||
"[%s] DM topic '%s' already exists in chat %s (will be mapped from incoming messages)",
|
||||
self.name, name, chat_id,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"[%s] Failed to create DM topic '%s' in chat %s: %s",
|
||||
self.name, name, chat_id, e,
|
||||
)
|
||||
return None
|
||||
|
||||
def _persist_dm_topic_thread_id(self, chat_id: int, topic_name: str, thread_id: int) -> None:
|
||||
"""Save a newly created thread_id back into config.yaml so it persists across restarts."""
|
||||
try:
|
||||
config_path = _Path.home() / ".hermes" / "config.yaml"
|
||||
if not config_path.exists():
|
||||
logger.warning("[%s] Config file not found at %s, cannot persist thread_id", self.name, config_path)
|
||||
return
|
||||
|
||||
import yaml as _yaml
|
||||
with open(config_path, "r") as f:
|
||||
config = _yaml.safe_load(f) or {}
|
||||
|
||||
# Navigate to platforms.telegram.extra.dm_topics
|
||||
dm_topics = (
|
||||
config.get("platforms", {})
|
||||
.get("telegram", {})
|
||||
.get("extra", {})
|
||||
.get("dm_topics", [])
|
||||
)
|
||||
if not dm_topics:
|
||||
return
|
||||
|
||||
changed = False
|
||||
for chat_entry in dm_topics:
|
||||
if int(chat_entry.get("chat_id", 0)) != int(chat_id):
|
||||
continue
|
||||
for t in chat_entry.get("topics", []):
|
||||
if t.get("name") == topic_name and not t.get("thread_id"):
|
||||
t["thread_id"] = thread_id
|
||||
changed = True
|
||||
break
|
||||
|
||||
if changed:
|
||||
with open(config_path, "w") as f:
|
||||
_yaml.dump(config, f, default_flow_style=False, sort_keys=False)
|
||||
logger.info(
|
||||
"[%s] Persisted thread_id=%s for topic '%s' in config.yaml",
|
||||
self.name, thread_id, topic_name,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("[%s] Failed to persist thread_id to config: %s", self.name, e, exc_info=True)
|
||||
|
||||
async def _setup_dm_topics(self) -> None:
|
||||
"""Load or create configured DM topics for specified chats.
|
||||
|
||||
Reads config.extra['dm_topics'] — a list of dicts:
|
||||
[
|
||||
{
|
||||
"chat_id": 123456789,
|
||||
"topics": [
|
||||
{"name": "General", "icon_color": 7322096, "thread_id": 100},
|
||||
{"name": "Accessibility Auditor", "icon_color": 9367192, "skill": "accessibility-auditor"}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
If a topic already has a thread_id in the config (persisted from a previous
|
||||
creation), it is loaded into the cache without calling createForumTopic.
|
||||
Only topics without a thread_id are created via the API, and their thread_id
|
||||
is then saved back to config.yaml for future restarts.
|
||||
"""
|
||||
if not self._dm_topics_config:
|
||||
return
|
||||
|
||||
for chat_entry in self._dm_topics_config:
|
||||
chat_id = chat_entry.get("chat_id")
|
||||
topics = chat_entry.get("topics", [])
|
||||
if not chat_id or not topics:
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
"[%s] Setting up %d DM topic(s) for chat %s",
|
||||
self.name, len(topics), chat_id,
|
||||
)
|
||||
|
||||
for topic_conf in topics:
|
||||
topic_name = topic_conf.get("name")
|
||||
if not topic_name:
|
||||
continue
|
||||
|
||||
cache_key = f"{chat_id}:{topic_name}"
|
||||
|
||||
# If thread_id is already persisted in config, just load into cache
|
||||
existing_thread_id = topic_conf.get("thread_id")
|
||||
if existing_thread_id:
|
||||
self._dm_topics[cache_key] = int(existing_thread_id)
|
||||
logger.info(
|
||||
"[%s] DM topic loaded from config: %s -> thread_id=%s",
|
||||
self.name, cache_key, existing_thread_id,
|
||||
)
|
||||
continue
|
||||
|
||||
# No persisted thread_id — create the topic via API
|
||||
icon_color = topic_conf.get("icon_color")
|
||||
icon_emoji = topic_conf.get("icon_custom_emoji_id")
|
||||
|
||||
thread_id = await self._create_dm_topic(
|
||||
chat_id=int(chat_id),
|
||||
name=topic_name,
|
||||
icon_color=icon_color,
|
||||
icon_custom_emoji_id=icon_emoji,
|
||||
)
|
||||
|
||||
if thread_id:
|
||||
self._dm_topics[cache_key] = thread_id
|
||||
logger.info(
|
||||
"[%s] DM topic cached: %s -> thread_id=%s",
|
||||
self.name, cache_key, thread_id,
|
||||
)
|
||||
# Persist thread_id to config so we don't recreate on next restart
|
||||
self._persist_dm_topic_thread_id(int(chat_id), topic_name, thread_id)
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to Telegram and start polling for updates."""
|
||||
if not TELEGRAM_AVAILABLE:
|
||||
@@ -190,7 +488,26 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
return False
|
||||
|
||||
# Build the application
|
||||
self._app = Application.builder().token(self.config.token).build()
|
||||
builder = Application.builder().token(self.config.token)
|
||||
fallback_ips = self._fallback_ips()
|
||||
if not fallback_ips:
|
||||
fallback_ips = await discover_fallback_ips()
|
||||
logger.info(
|
||||
"[%s] Auto-discovered Telegram fallback IPs: %s",
|
||||
self.name,
|
||||
", ".join(fallback_ips),
|
||||
)
|
||||
if fallback_ips:
|
||||
logger.warning(
|
||||
"[%s] Telegram fallback IPs active: %s",
|
||||
self.name,
|
||||
", ".join(fallback_ips),
|
||||
)
|
||||
transport = TelegramFallbackTransport(fallback_ips)
|
||||
request = HTTPXRequest(httpx_kwargs={"transport": transport})
|
||||
get_updates_request = HTTPXRequest(httpx_kwargs={"transport": transport})
|
||||
builder = builder.request(request).get_updates_request(get_updates_request)
|
||||
self._app = builder.build()
|
||||
self._bot = self._app.bot
|
||||
|
||||
# Register handlers
|
||||
@@ -235,12 +552,18 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
def _polling_error_callback(error: Exception) -> None:
|
||||
if not self._looks_like_polling_conflict(error):
|
||||
logger.error("[%s] Telegram polling error: %s", self.name, error, exc_info=True)
|
||||
return
|
||||
if self._polling_error_task and not self._polling_error_task.done():
|
||||
return
|
||||
self._polling_error_task = loop.create_task(self._handle_polling_conflict(error))
|
||||
if self._looks_like_polling_conflict(error):
|
||||
self._polling_error_task = loop.create_task(self._handle_polling_conflict(error))
|
||||
elif self._looks_like_network_error(error):
|
||||
logger.warning("[%s] Telegram network error, scheduling reconnect: %s", self.name, error)
|
||||
self._polling_error_task = loop.create_task(self._handle_polling_network_error(error))
|
||||
else:
|
||||
logger.error("[%s] Telegram polling error: %s", self.name, error, exc_info=True)
|
||||
|
||||
# Store reference for retry use in _handle_polling_conflict
|
||||
self._polling_error_callback_ref = _polling_error_callback
|
||||
|
||||
await self._app.updater.start_polling(
|
||||
allowed_updates=Update.ALL_TYPES,
|
||||
@@ -267,6 +590,18 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
|
||||
self._mark_connected()
|
||||
logger.info("[%s] Connected and polling for Telegram updates", self.name)
|
||||
|
||||
# Set up DM topics (Bot API 9.4 — Private Chat Topics)
|
||||
# Runs after connection is established so the bot can call createForumTopic.
|
||||
# Failures here are non-fatal — the bot works fine without topics.
|
||||
try:
|
||||
await self._setup_dm_topics()
|
||||
except Exception as topics_err:
|
||||
logger.warning(
|
||||
"[%s] DM topics setup failed (non-fatal): %s",
|
||||
self.name, topics_err, exc_info=True,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
@@ -320,6 +655,26 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
self._token_lock_identity = None
|
||||
logger.info("[%s] Disconnected from Telegram", self.name)
|
||||
|
||||
def _should_thread_reply(self, reply_to: Optional[str], chunk_index: int) -> bool:
|
||||
"""Determine if this message chunk should thread to the original message.
|
||||
|
||||
Args:
|
||||
reply_to: The original message ID to reply to
|
||||
chunk_index: Index of this chunk (0 = first chunk)
|
||||
|
||||
Returns:
|
||||
True if this chunk should be threaded to the original message
|
||||
"""
|
||||
if not reply_to:
|
||||
return False
|
||||
mode = self._reply_to_mode
|
||||
if mode == "off":
|
||||
return False
|
||||
elif mode == "all":
|
||||
return True
|
||||
else: # "first" (default)
|
||||
return chunk_index == 0
|
||||
|
||||
async def send(
|
||||
self,
|
||||
chat_id: str,
|
||||
@@ -352,7 +707,16 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
except ImportError:
|
||||
_NetErr = OSError # type: ignore[misc,assignment]
|
||||
|
||||
try:
|
||||
from telegram.error import BadRequest as _BadReq
|
||||
except ImportError:
|
||||
_BadReq = None # type: ignore[assignment,misc]
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
should_thread = self._should_thread_reply(reply_to, i)
|
||||
reply_to_id = int(reply_to) if should_thread else None
|
||||
effective_thread_id = int(thread_id) if thread_id else None
|
||||
|
||||
msg = None
|
||||
for _send_attempt in range(3):
|
||||
try:
|
||||
@@ -362,8 +726,8 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
chat_id=int(chat_id),
|
||||
text=chunk,
|
||||
parse_mode=ParseMode.MARKDOWN_V2,
|
||||
reply_to_message_id=int(reply_to) if reply_to and i == 0 else None,
|
||||
message_thread_id=int(thread_id) if thread_id else None,
|
||||
reply_to_message_id=reply_to_id,
|
||||
message_thread_id=effective_thread_id,
|
||||
)
|
||||
except Exception as md_error:
|
||||
# Markdown parsing failed, try plain text
|
||||
@@ -374,13 +738,31 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
chat_id=int(chat_id),
|
||||
text=plain_chunk,
|
||||
parse_mode=None,
|
||||
reply_to_message_id=int(reply_to) if reply_to and i == 0 else None,
|
||||
message_thread_id=int(thread_id) if thread_id else None,
|
||||
reply_to_message_id=reply_to_id,
|
||||
message_thread_id=effective_thread_id,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
break # success
|
||||
except _NetErr as send_err:
|
||||
# BadRequest is a subclass of NetworkError in
|
||||
# python-telegram-bot but represents permanent errors
|
||||
# (not transient network issues). Detect and handle
|
||||
# specific cases instead of blindly retrying.
|
||||
if _BadReq and isinstance(send_err, _BadReq):
|
||||
err_lower = str(send_err).lower()
|
||||
if "thread not found" in err_lower and effective_thread_id is not None:
|
||||
# Thread doesn't exist — retry without
|
||||
# message_thread_id so the message still
|
||||
# reaches the chat.
|
||||
logger.warning(
|
||||
"[%s] Thread %s not found, retrying without message_thread_id",
|
||||
self.name, effective_thread_id,
|
||||
)
|
||||
effective_thread_id = None
|
||||
continue
|
||||
# Other BadRequest errors are permanent — don't retry
|
||||
raise
|
||||
if _send_attempt < 2:
|
||||
wait = 2 ** _send_attempt
|
||||
logger.warning("[%s] Network error on send (attempt %d/3), retrying in %ds: %s",
|
||||
@@ -534,23 +916,26 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
image_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
"""Send a local image file natively as a Telegram photo."""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
import os
|
||||
if not os.path.exists(image_path):
|
||||
return SendResult(success=False, error=f"Image file not found: {image_path}")
|
||||
|
||||
|
||||
_thread = metadata.get("thread_id") if metadata else None
|
||||
with open(image_path, "rb") as image_file:
|
||||
msg = await self._bot.send_photo(
|
||||
chat_id=int(chat_id),
|
||||
photo=image_file,
|
||||
caption=caption[:1024] if caption else None,
|
||||
reply_to_message_id=int(reply_to) if reply_to else None,
|
||||
message_thread_id=int(_thread) if _thread else None,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
except Exception as e:
|
||||
@@ -569,6 +954,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
caption: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
"""Send a document/file natively as a Telegram file attachment."""
|
||||
@@ -580,6 +966,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
return SendResult(success=False, error=f"File not found: {file_path}")
|
||||
|
||||
display_name = file_name or os.path.basename(file_path)
|
||||
_thread = metadata.get("thread_id") if metadata else None
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
msg = await self._bot.send_document(
|
||||
@@ -588,6 +975,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
filename=display_name,
|
||||
caption=caption[:1024] if caption else None,
|
||||
reply_to_message_id=int(reply_to) if reply_to else None,
|
||||
message_thread_id=int(_thread) if _thread else None,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
except Exception as e:
|
||||
@@ -600,6 +988,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
video_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
"""Send a video natively as a Telegram video message."""
|
||||
@@ -610,12 +999,14 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
if not os.path.exists(video_path):
|
||||
return SendResult(success=False, error=f"Video file not found: {video_path}")
|
||||
|
||||
_thread = metadata.get("thread_id") if metadata else None
|
||||
with open(video_path, "rb") as f:
|
||||
msg = await self._bot.send_video(
|
||||
chat_id=int(chat_id),
|
||||
video=f,
|
||||
caption=caption[:1024] if caption else None,
|
||||
reply_to_message_id=int(reply_to) if reply_to else None,
|
||||
message_thread_id=int(_thread) if _thread else None,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
except Exception as e:
|
||||
@@ -882,6 +1273,45 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
for key in reversed(list(placeholders.keys())):
|
||||
text = text.replace(key, placeholders[key])
|
||||
|
||||
# 12) Safety net: escape unescaped ( ) { } that slipped through
|
||||
# placeholder processing. Split the text into code/non-code
|
||||
# segments so we never touch content inside ``` or ` spans.
|
||||
_code_split = re.split(r'(```[\s\S]*?```|`[^`]+`)', text)
|
||||
_safe_parts = []
|
||||
for _idx, _seg in enumerate(_code_split):
|
||||
if _idx % 2 == 1:
|
||||
# Inside code span/block — leave untouched
|
||||
_safe_parts.append(_seg)
|
||||
else:
|
||||
# Outside code — escape bare ( ) { }
|
||||
def _esc_bare(m, _seg=_seg):
|
||||
s = m.start()
|
||||
ch = m.group(0)
|
||||
# Already escaped
|
||||
if s > 0 and _seg[s - 1] == '\\':
|
||||
return ch
|
||||
# ( that opens a MarkdownV2 link [text](url)
|
||||
if ch == '(' and s > 0 and _seg[s - 1] == ']':
|
||||
return ch
|
||||
# ) that closes a link URL
|
||||
if ch == ')':
|
||||
before = _seg[:s]
|
||||
if '](http' in before or '](' in before:
|
||||
# Check depth
|
||||
depth = 0
|
||||
for j in range(s - 1, max(s - 2000, -1), -1):
|
||||
if _seg[j] == '(':
|
||||
depth -= 1
|
||||
if depth < 0:
|
||||
if j > 0 and _seg[j - 1] == ']':
|
||||
return ch
|
||||
break
|
||||
elif _seg[j] == ')':
|
||||
depth += 1
|
||||
return '\\' + ch
|
||||
_safe_parts.append(re.sub(r'[(){}]', _esc_bare, _seg))
|
||||
text = ''.join(_safe_parts)
|
||||
|
||||
return text
|
||||
|
||||
async def _handle_text_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
@@ -1320,6 +1750,99 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
emoji, set_name,
|
||||
)
|
||||
|
||||
def _reload_dm_topics_from_config(self) -> None:
|
||||
"""Re-read dm_topics from config.yaml and load any new thread_ids into cache.
|
||||
|
||||
This allows topics created externally (e.g. by the agent via API) to be
|
||||
recognized without a gateway restart.
|
||||
"""
|
||||
try:
|
||||
config_path = _Path.home() / ".hermes" / "config.yaml"
|
||||
if not config_path.exists():
|
||||
return
|
||||
|
||||
import yaml as _yaml
|
||||
with open(config_path, "r") as f:
|
||||
config = _yaml.safe_load(f) or {}
|
||||
|
||||
dm_topics = (
|
||||
config.get("platforms", {})
|
||||
.get("telegram", {})
|
||||
.get("extra", {})
|
||||
.get("dm_topics", [])
|
||||
)
|
||||
if not dm_topics:
|
||||
return
|
||||
|
||||
# Update in-memory config and cache any new thread_ids
|
||||
self._dm_topics_config = dm_topics
|
||||
for chat_entry in dm_topics:
|
||||
cid = chat_entry.get("chat_id")
|
||||
if not cid:
|
||||
continue
|
||||
for t in chat_entry.get("topics", []):
|
||||
tid = t.get("thread_id")
|
||||
name = t.get("name")
|
||||
if tid and name:
|
||||
cache_key = f"{cid}:{name}"
|
||||
if cache_key not in self._dm_topics:
|
||||
self._dm_topics[cache_key] = int(tid)
|
||||
logger.info(
|
||||
"[%s] Hot-loaded DM topic from config: %s -> thread_id=%s",
|
||||
self.name, cache_key, tid,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("[%s] Failed to reload dm_topics from config: %s", self.name, e)
|
||||
|
||||
def _get_dm_topic_info(self, chat_id: str, thread_id: Optional[str]) -> Optional[Dict[str, Any]]:
|
||||
"""Look up DM topic config by chat_id and thread_id.
|
||||
|
||||
Returns the topic config dict (name, skill, etc.) if this thread_id
|
||||
matches a known DM topic, or None.
|
||||
"""
|
||||
if not thread_id:
|
||||
return None
|
||||
|
||||
thread_id_int = int(thread_id)
|
||||
|
||||
# Check cached topics first (created by us or loaded at startup)
|
||||
for key, cached_tid in self._dm_topics.items():
|
||||
if cached_tid == thread_id_int and key.startswith(f"{chat_id}:"):
|
||||
topic_name = key.split(":", 1)[1]
|
||||
# Find the full config for this topic
|
||||
for chat_entry in self._dm_topics_config:
|
||||
if str(chat_entry.get("chat_id")) == chat_id:
|
||||
for t in chat_entry.get("topics", []):
|
||||
if t.get("name") == topic_name:
|
||||
return t
|
||||
return {"name": topic_name}
|
||||
|
||||
# Not in cache — hot-reload config in case topics were added externally
|
||||
self._reload_dm_topics_from_config()
|
||||
|
||||
# Check cache again after reload
|
||||
for key, cached_tid in self._dm_topics.items():
|
||||
if cached_tid == thread_id_int and key.startswith(f"{chat_id}:"):
|
||||
topic_name = key.split(":", 1)[1]
|
||||
for chat_entry in self._dm_topics_config:
|
||||
if str(chat_entry.get("chat_id")) == chat_id:
|
||||
for t in chat_entry.get("topics", []):
|
||||
if t.get("name") == topic_name:
|
||||
return t
|
||||
return {"name": topic_name}
|
||||
|
||||
return None
|
||||
|
||||
def _cache_dm_topic_from_message(self, chat_id: str, thread_id: str, topic_name: str) -> None:
|
||||
"""Cache a thread_id -> topic_name mapping discovered from an incoming message."""
|
||||
cache_key = f"{chat_id}:{topic_name}"
|
||||
if cache_key not in self._dm_topics:
|
||||
self._dm_topics[cache_key] = int(thread_id)
|
||||
logger.info(
|
||||
"[%s] Cached DM topic from message: %s -> thread_id=%s",
|
||||
self.name, cache_key, thread_id,
|
||||
)
|
||||
|
||||
def _build_message_event(self, message: Message, msg_type: MessageType) -> MessageEvent:
|
||||
"""Build a MessageEvent from a Telegram message."""
|
||||
chat = message.chat
|
||||
@@ -1331,7 +1854,27 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
chat_type = "group"
|
||||
elif chat.type == ChatType.CHANNEL:
|
||||
chat_type = "channel"
|
||||
|
||||
|
||||
# Resolve DM topic name and skill binding
|
||||
thread_id_raw = message.message_thread_id
|
||||
thread_id_str = str(thread_id_raw) if thread_id_raw else None
|
||||
chat_topic = None
|
||||
topic_skill = None
|
||||
|
||||
if chat_type == "dm" and thread_id_str:
|
||||
topic_info = self._get_dm_topic_info(str(chat.id), thread_id_str)
|
||||
if topic_info:
|
||||
chat_topic = topic_info.get("name")
|
||||
topic_skill = topic_info.get("skill")
|
||||
|
||||
# Also check forum_topic_created service message for topic discovery
|
||||
if hasattr(message, "forum_topic_created") and message.forum_topic_created:
|
||||
created_name = message.forum_topic_created.name
|
||||
if created_name:
|
||||
self._cache_dm_topic_from_message(str(chat.id), thread_id_str, created_name)
|
||||
if not chat_topic:
|
||||
chat_topic = created_name
|
||||
|
||||
# Build source
|
||||
source = self.build_source(
|
||||
chat_id=str(chat.id),
|
||||
@@ -1339,7 +1882,8 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
chat_type=chat_type,
|
||||
user_id=str(user.id) if user else None,
|
||||
user_name=user.full_name if user else None,
|
||||
thread_id=str(message.message_thread_id) if message.message_thread_id else None,
|
||||
thread_id=thread_id_str,
|
||||
chat_topic=chat_topic,
|
||||
)
|
||||
|
||||
# Extract reply context if this message is a reply
|
||||
@@ -1357,5 +1901,6 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
message_id=str(message.message_id),
|
||||
reply_to_message_id=reply_to_id,
|
||||
reply_to_text=reply_to_text,
|
||||
auto_skill=topic_skill,
|
||||
timestamp=message.date,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,233 @@
|
||||
"""Telegram-specific network helpers.
|
||||
|
||||
Provides a hostname-preserving fallback transport for networks where
|
||||
api.telegram.org resolves to an endpoint that is unreachable from the current
|
||||
host. The transport keeps the logical request host and TLS SNI as
|
||||
api.telegram.org while retrying the TCP connection against one or more fallback
|
||||
IPv4 addresses.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import ipaddress
|
||||
import logging
|
||||
import socket
|
||||
from typing import Iterable, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TELEGRAM_API_HOST = "api.telegram.org"
|
||||
|
||||
# DNS-over-HTTPS providers used to discover Telegram API IPs that may differ
|
||||
# from the (potentially unreachable) IP returned by the local system resolver.
|
||||
_DOH_TIMEOUT = 4.0 # seconds — bounded so connect() isn't noticeably delayed
|
||||
|
||||
_DOH_PROVIDERS: list[dict] = [
|
||||
{
|
||||
"url": "https://dns.google/resolve",
|
||||
"params": {"name": _TELEGRAM_API_HOST, "type": "A"},
|
||||
"headers": {},
|
||||
},
|
||||
{
|
||||
"url": "https://cloudflare-dns.com/dns-query",
|
||||
"params": {"name": _TELEGRAM_API_HOST, "type": "A"},
|
||||
"headers": {"Accept": "application/dns-json"},
|
||||
},
|
||||
]
|
||||
|
||||
# Last-resort IPs when DoH is also blocked. These are stable Telegram Bot API
|
||||
# endpoints in the 149.154.160.0/20 block (same seed used by OpenClaw).
|
||||
_SEED_FALLBACK_IPS: list[str] = ["149.154.167.220"]
|
||||
|
||||
|
||||
class TelegramFallbackTransport(httpx.AsyncBaseTransport):
|
||||
"""Retry Telegram Bot API requests via fallback IPs while preserving TLS/SNI.
|
||||
|
||||
Requests continue to target https://api.telegram.org/... logically, but on
|
||||
connect failures the underlying TCP connection is retried against a known
|
||||
reachable IP. This is effectively the programmatic equivalent of
|
||||
``curl --resolve api.telegram.org:443:<ip>``.
|
||||
"""
|
||||
|
||||
def __init__(self, fallback_ips: Iterable[str], **transport_kwargs):
|
||||
self._fallback_ips = [ip for ip in dict.fromkeys(_normalize_fallback_ips(fallback_ips))]
|
||||
self._primary = httpx.AsyncHTTPTransport(**transport_kwargs)
|
||||
self._fallbacks = {
|
||||
ip: httpx.AsyncHTTPTransport(**transport_kwargs) for ip in self._fallback_ips
|
||||
}
|
||||
self._sticky_ip: Optional[str] = None
|
||||
self._sticky_lock = asyncio.Lock()
|
||||
|
||||
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
|
||||
if request.url.host != _TELEGRAM_API_HOST or not self._fallback_ips:
|
||||
return await self._primary.handle_async_request(request)
|
||||
|
||||
sticky_ip = self._sticky_ip
|
||||
attempt_order: list[Optional[str]] = [sticky_ip] if sticky_ip else [None]
|
||||
for ip in self._fallback_ips:
|
||||
if ip != sticky_ip:
|
||||
attempt_order.append(ip)
|
||||
|
||||
last_error: Exception | None = None
|
||||
for ip in attempt_order:
|
||||
candidate = request if ip is None else _rewrite_request_for_ip(request, ip)
|
||||
transport = self._primary if ip is None else self._fallbacks[ip]
|
||||
try:
|
||||
response = await transport.handle_async_request(candidate)
|
||||
if ip is not None and self._sticky_ip != ip:
|
||||
async with self._sticky_lock:
|
||||
if self._sticky_ip != ip:
|
||||
self._sticky_ip = ip
|
||||
logger.warning(
|
||||
"[Telegram] Primary api.telegram.org path unreachable; using sticky fallback IP %s",
|
||||
ip,
|
||||
)
|
||||
return response
|
||||
except Exception as exc:
|
||||
last_error = exc
|
||||
if not _is_retryable_connect_error(exc):
|
||||
raise
|
||||
if ip is None:
|
||||
logger.warning(
|
||||
"[Telegram] Primary api.telegram.org connection failed (%s); trying fallback IPs %s",
|
||||
exc,
|
||||
", ".join(self._fallback_ips),
|
||||
)
|
||||
continue
|
||||
logger.warning("[Telegram] Fallback IP %s failed: %s", ip, exc)
|
||||
continue
|
||||
|
||||
assert last_error is not None
|
||||
raise last_error
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self._primary.aclose()
|
||||
for transport in self._fallbacks.values():
|
||||
await transport.aclose()
|
||||
|
||||
|
||||
def _normalize_fallback_ips(values: Iterable[str]) -> list[str]:
|
||||
normalized: list[str] = []
|
||||
for value in values:
|
||||
raw = str(value).strip()
|
||||
if not raw:
|
||||
continue
|
||||
try:
|
||||
addr = ipaddress.ip_address(raw)
|
||||
except ValueError:
|
||||
logger.warning("Ignoring invalid Telegram fallback IP: %r", raw)
|
||||
continue
|
||||
if addr.version != 4:
|
||||
logger.warning("Ignoring non-IPv4 Telegram fallback IP: %s", raw)
|
||||
continue
|
||||
normalized.append(str(addr))
|
||||
return normalized
|
||||
|
||||
|
||||
def parse_fallback_ip_env(value: str | None) -> list[str]:
|
||||
if not value:
|
||||
return []
|
||||
parts = [part.strip() for part in value.split(",")]
|
||||
return _normalize_fallback_ips(parts)
|
||||
|
||||
|
||||
def _resolve_system_dns() -> set[str]:
|
||||
"""Return the IPv4 addresses that the OS resolver gives for api.telegram.org."""
|
||||
try:
|
||||
results = socket.getaddrinfo(_TELEGRAM_API_HOST, 443, socket.AF_INET)
|
||||
return {addr[4][0] for addr in results}
|
||||
except Exception:
|
||||
return set()
|
||||
|
||||
|
||||
async def _query_doh_provider(
|
||||
client: httpx.AsyncClient, provider: dict
|
||||
) -> list[str]:
|
||||
"""Query one DoH provider and return A-record IPs."""
|
||||
try:
|
||||
resp = await client.get(
|
||||
provider["url"], params=provider["params"], headers=provider["headers"]
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
ips: list[str] = []
|
||||
for answer in data.get("Answer", []):
|
||||
if answer.get("type") != 1: # A record
|
||||
continue
|
||||
raw = answer.get("data", "").strip()
|
||||
try:
|
||||
ipaddress.ip_address(raw)
|
||||
ips.append(raw)
|
||||
except ValueError:
|
||||
continue
|
||||
return ips
|
||||
except Exception as exc:
|
||||
logger.debug("DoH query to %s failed: %s", provider["url"], exc)
|
||||
return []
|
||||
|
||||
|
||||
async def discover_fallback_ips() -> list[str]:
|
||||
"""Auto-discover Telegram API IPs via DNS-over-HTTPS.
|
||||
|
||||
Resolves api.telegram.org through Google and Cloudflare DoH, collects all
|
||||
unique IPs, and excludes the system-DNS-resolved IP (which is presumably
|
||||
unreachable on this network). Falls back to a hardcoded seed list when DoH
|
||||
is also unavailable.
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(_DOH_TIMEOUT)) as client:
|
||||
doh_tasks = [_query_doh_provider(client, p) for p in _DOH_PROVIDERS]
|
||||
system_dns_task = asyncio.to_thread(_resolve_system_dns)
|
||||
results = await asyncio.gather(system_dns_task, *doh_tasks, return_exceptions=True)
|
||||
|
||||
# results[0] = system DNS IPs (set), results[1:] = DoH IP lists
|
||||
system_ips: set[str] = results[0] if isinstance(results[0], set) else set()
|
||||
|
||||
doh_ips: list[str] = []
|
||||
for r in results[1:]:
|
||||
if isinstance(r, list):
|
||||
doh_ips.extend(r)
|
||||
|
||||
# Deduplicate preserving order, exclude system-DNS IPs
|
||||
seen: set[str] = set()
|
||||
candidates: list[str] = []
|
||||
for ip in doh_ips:
|
||||
if ip not in seen and ip not in system_ips:
|
||||
seen.add(ip)
|
||||
candidates.append(ip)
|
||||
|
||||
# Validate through existing normalization
|
||||
validated = _normalize_fallback_ips(candidates)
|
||||
|
||||
if validated:
|
||||
logger.debug("Discovered Telegram fallback IPs via DoH: %s", ", ".join(validated))
|
||||
return validated
|
||||
|
||||
logger.info(
|
||||
"DoH discovery yielded no new IPs (system DNS: %s); using seed fallback IPs %s",
|
||||
", ".join(system_ips) or "unknown",
|
||||
", ".join(_SEED_FALLBACK_IPS),
|
||||
)
|
||||
return list(_SEED_FALLBACK_IPS)
|
||||
|
||||
|
||||
def _rewrite_request_for_ip(request: httpx.Request, ip: str) -> httpx.Request:
|
||||
original_host = request.url.host or _TELEGRAM_API_HOST
|
||||
url = request.url.copy_with(host=ip)
|
||||
headers = request.headers.copy()
|
||||
headers["host"] = original_host
|
||||
extensions = dict(request.extensions)
|
||||
extensions["sni_hostname"] = original_host
|
||||
return httpx.Request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
stream=request.stream,
|
||||
extensions=extensions,
|
||||
)
|
||||
|
||||
|
||||
def _is_retryable_connect_error(exc: Exception) -> bool:
|
||||
return isinstance(exc, (httpx.ConnectTimeout, httpx.ConnectError))
|
||||
@@ -363,7 +363,9 @@ class WebhookAdapter(BasePlatformAdapter):
|
||||
)
|
||||
|
||||
# Non-blocking — return 202 Accepted immediately
|
||||
asyncio.create_task(self.handle_message(event))
|
||||
task = asyncio.create_task(self.handle_message(event))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
|
||||
@@ -16,7 +16,6 @@ with different backends via a bridge pattern.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
@@ -24,7 +23,7 @@ import subprocess
|
||||
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
from typing import Dict, Optional, Any
|
||||
|
||||
from hermes_cli.config import get_hermes_home
|
||||
|
||||
@@ -74,6 +73,7 @@ from gateway.platforms.base import (
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
cache_image_from_url,
|
||||
cache_audio_from_url,
|
||||
)
|
||||
@@ -140,6 +140,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
self._message_queue: asyncio.Queue = asyncio.Queue()
|
||||
self._bridge_log_fh = None
|
||||
self._bridge_log: Optional[Path] = None
|
||||
self._poll_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""
|
||||
@@ -196,9 +197,9 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
bridge_status = data.get("status", "unknown")
|
||||
if bridge_status == "connected":
|
||||
print(f"[{self.name}] Using existing bridge (status: {bridge_status})")
|
||||
self._running = True
|
||||
self._mark_connected()
|
||||
self._bridge_process = None # Not managed by us
|
||||
asyncio.create_task(self._poll_messages())
|
||||
self._poll_task = asyncio.create_task(self._poll_messages())
|
||||
return True
|
||||
else:
|
||||
print(f"[{self.name}] Bridge found but not connected (status: {bridge_status}), restarting")
|
||||
@@ -304,9 +305,9 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
print(f"[{self.name}] If session expired, re-pair: hermes whatsapp")
|
||||
|
||||
# Start message polling task
|
||||
asyncio.create_task(self._poll_messages())
|
||||
self._poll_task = asyncio.create_task(self._poll_messages())
|
||||
|
||||
self._running = True
|
||||
self._mark_connected()
|
||||
print(f"[{self.name}] Bridge started on port {self._bridge_port}")
|
||||
return True
|
||||
|
||||
@@ -324,6 +325,23 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
pass
|
||||
self._bridge_log_fh = None
|
||||
|
||||
async def _check_managed_bridge_exit(self) -> Optional[str]:
|
||||
"""Return a fatal error message if the managed bridge child exited."""
|
||||
if self._bridge_process is None:
|
||||
return None
|
||||
|
||||
returncode = self._bridge_process.poll()
|
||||
if returncode is None:
|
||||
return None
|
||||
|
||||
message = f"WhatsApp bridge process exited unexpectedly (code {returncode})."
|
||||
if not self.has_fatal_error:
|
||||
logger.error("[%s] %s", self.name, message)
|
||||
self._set_fatal_error("whatsapp_bridge_exited", message, retryable=True)
|
||||
self._close_bridge_log()
|
||||
await self._notify_fatal_error()
|
||||
return self.fatal_error_message or message
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Stop the WhatsApp bridge and clean up any orphaned processes."""
|
||||
if self._bridge_process:
|
||||
@@ -352,7 +370,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
# Bridge was not started by us, don't kill it
|
||||
print(f"[{self.name}] Disconnecting (external bridge left running)")
|
||||
|
||||
self._running = False
|
||||
self._mark_disconnected()
|
||||
self._bridge_process = None
|
||||
self._close_bridge_log()
|
||||
print(f"[{self.name}] Disconnected")
|
||||
@@ -367,6 +385,9 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
"""Send a message via the WhatsApp bridge."""
|
||||
if not self._running:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
bridge_exit = await self._check_managed_bridge_exit()
|
||||
if bridge_exit:
|
||||
return SendResult(success=False, error=bridge_exit)
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
@@ -412,6 +433,9 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
"""Edit a previously sent message via the WhatsApp bridge."""
|
||||
if not self._running:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
bridge_exit = await self._check_managed_bridge_exit()
|
||||
if bridge_exit:
|
||||
return SendResult(success=False, error=bridge_exit)
|
||||
try:
|
||||
import aiohttp
|
||||
async with aiohttp.ClientSession() as session:
|
||||
@@ -443,6 +467,9 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
"""Send any media file via bridge /send-media endpoint."""
|
||||
if not self._running:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
bridge_exit = await self._check_managed_bridge_exit()
|
||||
if bridge_exit:
|
||||
return SendResult(success=False, error=bridge_exit)
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
@@ -531,6 +558,8 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
"""Send typing indicator via bridge."""
|
||||
if not self._running:
|
||||
return
|
||||
if await self._check_managed_bridge_exit():
|
||||
return
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
@@ -548,6 +577,8 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
"""Get information about a WhatsApp chat."""
|
||||
if not self._running:
|
||||
return {"name": "Unknown", "type": "dm"}
|
||||
if await self._check_managed_bridge_exit():
|
||||
return {"name": chat_id, "type": "dm"}
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
@@ -578,6 +609,10 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
return
|
||||
|
||||
while self._running:
|
||||
bridge_exit = await self._check_managed_bridge_exit()
|
||||
if bridge_exit:
|
||||
print(f"[{self.name}] {bridge_exit}")
|
||||
break
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
@@ -593,6 +628,10 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
bridge_exit = await self._check_managed_bridge_exit()
|
||||
if bridge_exit:
|
||||
print(f"[{self.name}] {bridge_exit}")
|
||||
break
|
||||
print(f"[{self.name}] Poll error: {e}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
@@ -627,7 +666,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
user_name=data.get("senderName"),
|
||||
)
|
||||
|
||||
# Download image media URLs to the local cache so the vision tool
|
||||
# Download media URLs to the local cache so agent tools
|
||||
# can access them reliably regardless of URL expiration.
|
||||
raw_urls = data.get("mediaUrls", [])
|
||||
cached_urls = []
|
||||
@@ -658,12 +697,59 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
print(f"[{self.name}] Failed to cache voice: {e}", flush=True)
|
||||
cached_urls.append(url)
|
||||
media_types.append("audio/ogg")
|
||||
elif msg_type == MessageType.VOICE and os.path.isabs(url):
|
||||
# Local file path — bridge already downloaded the audio
|
||||
cached_urls.append(url)
|
||||
media_types.append("audio/ogg")
|
||||
print(f"[{self.name}] Using bridge-cached audio: {url}", flush=True)
|
||||
elif msg_type == MessageType.DOCUMENT and os.path.isabs(url):
|
||||
# Local file path — bridge already downloaded the document
|
||||
cached_urls.append(url)
|
||||
ext = Path(url).suffix.lower()
|
||||
mime = SUPPORTED_DOCUMENT_TYPES.get(ext, "application/octet-stream")
|
||||
media_types.append(mime)
|
||||
print(f"[{self.name}] Using bridge-cached document: {url}", flush=True)
|
||||
elif msg_type == MessageType.VIDEO and os.path.isabs(url):
|
||||
cached_urls.append(url)
|
||||
media_types.append("video/mp4")
|
||||
print(f"[{self.name}] Using bridge-cached video: {url}", flush=True)
|
||||
else:
|
||||
cached_urls.append(url)
|
||||
media_types.append("unknown")
|
||||
|
||||
|
||||
# For text-readable documents, inject file content directly into
|
||||
# the message text so the agent can read it inline.
|
||||
# Cap at 100KB to match Telegram/Discord/Slack behaviour.
|
||||
body = data.get("body", "")
|
||||
MAX_TEXT_INJECT_BYTES = 100 * 1024
|
||||
if msg_type == MessageType.DOCUMENT and cached_urls:
|
||||
for doc_path in cached_urls:
|
||||
ext = Path(doc_path).suffix.lower()
|
||||
if ext in (".txt", ".md", ".csv", ".json", ".xml", ".yaml", ".yml", ".log", ".py", ".js", ".ts", ".html", ".css"):
|
||||
try:
|
||||
file_size = Path(doc_path).stat().st_size
|
||||
if file_size > MAX_TEXT_INJECT_BYTES:
|
||||
print(f"[{self.name}] Skipping text injection for {doc_path} ({file_size} bytes > {MAX_TEXT_INJECT_BYTES})", flush=True)
|
||||
continue
|
||||
content = Path(doc_path).read_text(errors="replace")
|
||||
fname = Path(doc_path).name
|
||||
# Remove the doc_<hex>_ prefix for display
|
||||
display_name = fname
|
||||
if "_" in fname:
|
||||
parts = fname.split("_", 2)
|
||||
if len(parts) >= 3:
|
||||
display_name = parts[2]
|
||||
injection = f"[Content of {display_name}]:\n{content}"
|
||||
if body:
|
||||
body = f"{injection}\n\n{body}"
|
||||
else:
|
||||
body = injection
|
||||
print(f"[{self.name}] Injected text content from: {doc_path}", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to read document text: {e}", flush=True)
|
||||
|
||||
return MessageEvent(
|
||||
text=data.get("body", ""),
|
||||
text=body,
|
||||
message_type=msg_type,
|
||||
source=source,
|
||||
raw_message=data,
|
||||
@@ -674,4 +760,3 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Error building event: {e}")
|
||||
return None
|
||||
|
||||
|
||||
+1006
-420
File diff suppressed because it is too large
Load Diff
+279
-206
@@ -13,15 +13,21 @@ import logging
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
import threading
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
"""Return the current local time."""
|
||||
return datetime.now()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PII redaction helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -59,7 +65,7 @@ def _looks_like_phone(value: str) -> bool:
|
||||
from .config import (
|
||||
Platform,
|
||||
GatewayConfig,
|
||||
SessionResetPolicy,
|
||||
SessionResetPolicy, # noqa: F401 — re-exported via gateway/__init__.py
|
||||
HomeChannel,
|
||||
)
|
||||
|
||||
@@ -355,6 +361,8 @@ class SessionEntry:
|
||||
# Set when a session was created because the previous one expired;
|
||||
# consumed once by the message handler to inject a notice into context
|
||||
was_auto_reset: bool = False
|
||||
auto_reset_reason: Optional[str] = None # "idle" or "daily"
|
||||
reset_had_activity: bool = False # whether the expired session had any messages
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
result = {
|
||||
@@ -469,6 +477,7 @@ class SessionStore:
|
||||
self.config = config
|
||||
self._entries: Dict[str, SessionEntry] = {}
|
||||
self._loaded = False
|
||||
self._lock = threading.Lock()
|
||||
self._has_active_processes_fn = has_active_processes_fn
|
||||
# on_auto_reset is deprecated — memory flush now runs proactively
|
||||
# via the background session expiry watcher in GatewayRunner.
|
||||
@@ -484,12 +493,17 @@ class SessionStore:
|
||||
|
||||
def _ensure_loaded(self) -> None:
|
||||
"""Load sessions index from disk if not already loaded."""
|
||||
with self._lock:
|
||||
self._ensure_loaded_locked()
|
||||
|
||||
def _ensure_loaded_locked(self) -> None:
|
||||
"""Load sessions index from disk. Must be called with self._lock held."""
|
||||
if self._loaded:
|
||||
return
|
||||
|
||||
|
||||
self.sessions_dir.mkdir(parents=True, exist_ok=True)
|
||||
sessions_file = self.sessions_dir / "sessions.json"
|
||||
|
||||
|
||||
if sessions_file.exists():
|
||||
try:
|
||||
with open(sessions_file, "r", encoding="utf-8") as f:
|
||||
@@ -502,7 +516,7 @@ class SessionStore:
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"[gateway] Warning: Failed to load sessions: {e}")
|
||||
|
||||
|
||||
self._loaded = True
|
||||
|
||||
def _save(self) -> None:
|
||||
@@ -554,7 +568,7 @@ class SessionStore:
|
||||
if policy.mode == "none":
|
||||
return False
|
||||
|
||||
now = datetime.now()
|
||||
now = _now()
|
||||
|
||||
if policy.mode in ("idle", "both"):
|
||||
idle_deadline = entry.updated_at + timedelta(minutes=policy.idle_minutes)
|
||||
@@ -573,16 +587,19 @@ class SessionStore:
|
||||
|
||||
return False
|
||||
|
||||
def _should_reset(self, entry: SessionEntry, source: SessionSource) -> bool:
|
||||
def _should_reset(self, entry: SessionEntry, source: SessionSource) -> Optional[str]:
|
||||
"""
|
||||
Check if a session should be reset based on policy.
|
||||
|
||||
Returns the reset reason ("idle" or "daily") if a reset is needed,
|
||||
or None if the session is still valid.
|
||||
|
||||
Sessions with active background processes are never reset.
|
||||
"""
|
||||
if self._has_active_processes_fn:
|
||||
session_key = self._generate_session_key(source)
|
||||
if self._has_active_processes_fn(session_key):
|
||||
return False
|
||||
return None
|
||||
|
||||
policy = self.config.get_reset_policy(
|
||||
platform=source.platform,
|
||||
@@ -590,14 +607,14 @@ class SessionStore:
|
||||
)
|
||||
|
||||
if policy.mode == "none":
|
||||
return False
|
||||
return None
|
||||
|
||||
now = datetime.now()
|
||||
now = _now()
|
||||
|
||||
if policy.mode in ("idle", "both"):
|
||||
idle_deadline = entry.updated_at + timedelta(minutes=policy.idle_minutes)
|
||||
if now > idle_deadline:
|
||||
return True
|
||||
return "idle"
|
||||
|
||||
if policy.mode in ("daily", "both"):
|
||||
today_reset = now.replace(
|
||||
@@ -610,9 +627,9 @@ class SessionStore:
|
||||
today_reset -= timedelta(days=1)
|
||||
|
||||
if entry.updated_at < today_reset:
|
||||
return True
|
||||
return "daily"
|
||||
|
||||
return False
|
||||
return None
|
||||
|
||||
def has_any_sessions(self) -> bool:
|
||||
"""Check if any sessions have ever been created (across all platforms).
|
||||
@@ -632,79 +649,97 @@ class SessionStore:
|
||||
pass # fall through to heuristic
|
||||
# Fallback: check if sessions.json was loaded with existing data.
|
||||
# This covers the rare case where the DB is unavailable.
|
||||
self._ensure_loaded()
|
||||
return len(self._entries) > 1
|
||||
|
||||
with self._lock:
|
||||
self._ensure_loaded_locked()
|
||||
return len(self._entries) > 1
|
||||
|
||||
def get_or_create_session(
|
||||
self,
|
||||
self,
|
||||
source: SessionSource,
|
||||
force_new: bool = False
|
||||
) -> SessionEntry:
|
||||
"""
|
||||
Get an existing session or create a new one.
|
||||
|
||||
|
||||
Evaluates reset policy to determine if the existing session is stale.
|
||||
Creates a session record in SQLite when a new session starts.
|
||||
"""
|
||||
self._ensure_loaded()
|
||||
|
||||
session_key = self._generate_session_key(source)
|
||||
now = datetime.now()
|
||||
|
||||
if session_key in self._entries and not force_new:
|
||||
entry = self._entries[session_key]
|
||||
|
||||
if not self._should_reset(entry, source):
|
||||
entry.updated_at = now
|
||||
self._save()
|
||||
return entry
|
||||
now = _now()
|
||||
|
||||
# SQLite calls are made outside the lock to avoid holding it during I/O.
|
||||
# All _entries / _loaded mutations are protected by self._lock.
|
||||
db_end_session_id = None
|
||||
db_create_kwargs = None
|
||||
|
||||
with self._lock:
|
||||
self._ensure_loaded_locked()
|
||||
|
||||
if session_key in self._entries and not force_new:
|
||||
entry = self._entries[session_key]
|
||||
|
||||
reset_reason = self._should_reset(entry, source)
|
||||
if not reset_reason:
|
||||
entry.updated_at = now
|
||||
self._save()
|
||||
return entry
|
||||
else:
|
||||
# Session is being auto-reset. The background expiry watcher
|
||||
# should have already flushed memories proactively; discard
|
||||
# the marker so it doesn't accumulate.
|
||||
was_auto_reset = True
|
||||
auto_reset_reason = reset_reason
|
||||
# Track whether the expired session had any real conversation
|
||||
reset_had_activity = entry.total_tokens > 0
|
||||
db_end_session_id = entry.session_id
|
||||
self._pre_flushed_sessions.discard(entry.session_id)
|
||||
else:
|
||||
# Session is being auto-reset. The background expiry watcher
|
||||
# should have already flushed memories proactively; discard
|
||||
# the marker so it doesn't accumulate.
|
||||
was_auto_reset = True
|
||||
self._pre_flushed_sessions.discard(entry.session_id)
|
||||
if self._db:
|
||||
try:
|
||||
self._db.end_session(entry.session_id, "session_reset")
|
||||
except Exception as e:
|
||||
logger.debug("Session DB operation failed: %s", e)
|
||||
else:
|
||||
was_auto_reset = False
|
||||
|
||||
# Create new session
|
||||
session_id = f"{now.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
entry = SessionEntry(
|
||||
session_key=session_key,
|
||||
session_id=session_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
origin=source,
|
||||
display_name=source.chat_name,
|
||||
platform=source.platform,
|
||||
chat_type=source.chat_type,
|
||||
was_auto_reset=was_auto_reset,
|
||||
)
|
||||
|
||||
self._entries[session_key] = entry
|
||||
self._save()
|
||||
|
||||
# Create session in SQLite
|
||||
if self._db:
|
||||
was_auto_reset = False
|
||||
auto_reset_reason = None
|
||||
reset_had_activity = False
|
||||
|
||||
# Create new session
|
||||
session_id = f"{now.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
entry = SessionEntry(
|
||||
session_key=session_key,
|
||||
session_id=session_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
origin=source,
|
||||
display_name=source.chat_name,
|
||||
platform=source.platform,
|
||||
chat_type=source.chat_type,
|
||||
was_auto_reset=was_auto_reset,
|
||||
auto_reset_reason=auto_reset_reason,
|
||||
reset_had_activity=reset_had_activity,
|
||||
)
|
||||
|
||||
self._entries[session_key] = entry
|
||||
self._save()
|
||||
db_create_kwargs = {
|
||||
"session_id": session_id,
|
||||
"source": source.platform.value,
|
||||
"user_id": source.user_id,
|
||||
}
|
||||
|
||||
# SQLite operations outside the lock
|
||||
if self._db and db_end_session_id:
|
||||
try:
|
||||
self._db.create_session(
|
||||
session_id=session_id,
|
||||
source=source.platform.value,
|
||||
user_id=source.user_id,
|
||||
)
|
||||
self._db.end_session(db_end_session_id, "session_reset")
|
||||
except Exception as e:
|
||||
logger.debug("Session DB operation failed: %s", e)
|
||||
|
||||
if self._db and db_create_kwargs:
|
||||
try:
|
||||
self._db.create_session(**db_create_kwargs)
|
||||
except Exception as e:
|
||||
print(f"[gateway] Warning: Failed to create SQLite session: {e}")
|
||||
|
||||
|
||||
return entry
|
||||
|
||||
|
||||
def update_session(
|
||||
self,
|
||||
self,
|
||||
session_key: str,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
@@ -719,91 +754,103 @@ class SessionStore:
|
||||
base_url: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Update a session's metadata after an interaction."""
|
||||
self._ensure_loaded()
|
||||
|
||||
if session_key in self._entries:
|
||||
entry = self._entries[session_key]
|
||||
entry.updated_at = datetime.now()
|
||||
entry.input_tokens += input_tokens
|
||||
entry.output_tokens += output_tokens
|
||||
entry.cache_read_tokens += cache_read_tokens
|
||||
entry.cache_write_tokens += cache_write_tokens
|
||||
if last_prompt_tokens is not None:
|
||||
entry.last_prompt_tokens = last_prompt_tokens
|
||||
if estimated_cost_usd is not None:
|
||||
entry.estimated_cost_usd += estimated_cost_usd
|
||||
if cost_status:
|
||||
entry.cost_status = cost_status
|
||||
entry.total_tokens = (
|
||||
entry.input_tokens
|
||||
+ entry.output_tokens
|
||||
+ entry.cache_read_tokens
|
||||
+ entry.cache_write_tokens
|
||||
)
|
||||
self._save()
|
||||
|
||||
if self._db:
|
||||
try:
|
||||
self._db.update_token_counts(
|
||||
entry.session_id,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
cache_write_tokens=cache_write_tokens,
|
||||
estimated_cost_usd=estimated_cost_usd,
|
||||
cost_status=cost_status,
|
||||
cost_source=cost_source,
|
||||
billing_provider=provider,
|
||||
billing_base_url=base_url,
|
||||
model=model,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Session DB operation failed: %s", e)
|
||||
|
||||
def reset_session(self, session_key: str) -> Optional[SessionEntry]:
|
||||
"""Force reset a session, creating a new session ID."""
|
||||
self._ensure_loaded()
|
||||
|
||||
if session_key not in self._entries:
|
||||
return None
|
||||
|
||||
old_entry = self._entries[session_key]
|
||||
|
||||
# End old session in SQLite
|
||||
if self._db:
|
||||
db_session_id = None
|
||||
|
||||
with self._lock:
|
||||
self._ensure_loaded_locked()
|
||||
|
||||
if session_key in self._entries:
|
||||
entry = self._entries[session_key]
|
||||
entry.updated_at = _now()
|
||||
# Direct assignment — the gateway receives cumulative totals
|
||||
# from the cached agent, not per-call deltas.
|
||||
entry.input_tokens = input_tokens
|
||||
entry.output_tokens = output_tokens
|
||||
entry.cache_read_tokens = cache_read_tokens
|
||||
entry.cache_write_tokens = cache_write_tokens
|
||||
if last_prompt_tokens is not None:
|
||||
entry.last_prompt_tokens = last_prompt_tokens
|
||||
if estimated_cost_usd is not None:
|
||||
entry.estimated_cost_usd = estimated_cost_usd
|
||||
if cost_status:
|
||||
entry.cost_status = cost_status
|
||||
entry.total_tokens = (
|
||||
entry.input_tokens
|
||||
+ entry.output_tokens
|
||||
+ entry.cache_read_tokens
|
||||
+ entry.cache_write_tokens
|
||||
)
|
||||
self._save()
|
||||
db_session_id = entry.session_id
|
||||
|
||||
if self._db and db_session_id:
|
||||
try:
|
||||
self._db.end_session(old_entry.session_id, "session_reset")
|
||||
except Exception as e:
|
||||
logger.debug("Session DB operation failed: %s", e)
|
||||
|
||||
now = datetime.now()
|
||||
session_id = f"{now.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
new_entry = SessionEntry(
|
||||
session_key=session_key,
|
||||
session_id=session_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
origin=old_entry.origin,
|
||||
display_name=old_entry.display_name,
|
||||
platform=old_entry.platform,
|
||||
chat_type=old_entry.chat_type,
|
||||
)
|
||||
|
||||
self._entries[session_key] = new_entry
|
||||
self._save()
|
||||
|
||||
# Create new session in SQLite
|
||||
if self._db:
|
||||
try:
|
||||
self._db.create_session(
|
||||
session_id=session_id,
|
||||
source=old_entry.platform.value if old_entry.platform else "unknown",
|
||||
user_id=old_entry.origin.user_id if old_entry.origin else None,
|
||||
self._db.set_token_counts(
|
||||
db_session_id,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
cache_write_tokens=cache_write_tokens,
|
||||
estimated_cost_usd=estimated_cost_usd,
|
||||
cost_status=cost_status,
|
||||
cost_source=cost_source,
|
||||
billing_provider=provider,
|
||||
billing_base_url=base_url,
|
||||
model=model,
|
||||
absolute=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Session DB operation failed: %s", e)
|
||||
|
||||
|
||||
def reset_session(self, session_key: str) -> Optional[SessionEntry]:
|
||||
"""Force reset a session, creating a new session ID."""
|
||||
db_end_session_id = None
|
||||
db_create_kwargs = None
|
||||
new_entry = None
|
||||
|
||||
with self._lock:
|
||||
self._ensure_loaded_locked()
|
||||
|
||||
if session_key not in self._entries:
|
||||
return None
|
||||
|
||||
old_entry = self._entries[session_key]
|
||||
db_end_session_id = old_entry.session_id
|
||||
|
||||
now = _now()
|
||||
session_id = f"{now.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
new_entry = SessionEntry(
|
||||
session_key=session_key,
|
||||
session_id=session_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
origin=old_entry.origin,
|
||||
display_name=old_entry.display_name,
|
||||
platform=old_entry.platform,
|
||||
chat_type=old_entry.chat_type,
|
||||
)
|
||||
|
||||
self._entries[session_key] = new_entry
|
||||
self._save()
|
||||
db_create_kwargs = {
|
||||
"session_id": session_id,
|
||||
"source": old_entry.platform.value if old_entry.platform else "unknown",
|
||||
"user_id": old_entry.origin.user_id if old_entry.origin else None,
|
||||
}
|
||||
|
||||
if self._db and db_end_session_id:
|
||||
try:
|
||||
self._db.end_session(db_end_session_id, "session_reset")
|
||||
except Exception as e:
|
||||
logger.debug("Session DB operation failed: %s", e)
|
||||
|
||||
if self._db and db_create_kwargs:
|
||||
try:
|
||||
self._db.create_session(**db_create_kwargs)
|
||||
except Exception as e:
|
||||
logger.debug("Session DB operation failed: %s", e)
|
||||
|
||||
return new_entry
|
||||
|
||||
def switch_session(self, session_key: str, target_session_id: str) -> Optional[SessionEntry]:
|
||||
@@ -814,52 +861,58 @@ class SessionStore:
|
||||
generating a fresh session ID, re-uses ``target_session_id`` so the
|
||||
old transcript is loaded on the next message.
|
||||
"""
|
||||
self._ensure_loaded()
|
||||
db_end_session_id = None
|
||||
new_entry = None
|
||||
|
||||
if session_key not in self._entries:
|
||||
return None
|
||||
with self._lock:
|
||||
self._ensure_loaded_locked()
|
||||
|
||||
old_entry = self._entries[session_key]
|
||||
if session_key not in self._entries:
|
||||
return None
|
||||
|
||||
# Don't switch if already on that session
|
||||
if old_entry.session_id == target_session_id:
|
||||
return old_entry
|
||||
old_entry = self._entries[session_key]
|
||||
|
||||
# End the current session in SQLite
|
||||
if self._db:
|
||||
# Don't switch if already on that session
|
||||
if old_entry.session_id == target_session_id:
|
||||
return old_entry
|
||||
|
||||
db_end_session_id = old_entry.session_id
|
||||
|
||||
now = _now()
|
||||
new_entry = SessionEntry(
|
||||
session_key=session_key,
|
||||
session_id=target_session_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
origin=old_entry.origin,
|
||||
display_name=old_entry.display_name,
|
||||
platform=old_entry.platform,
|
||||
chat_type=old_entry.chat_type,
|
||||
)
|
||||
|
||||
self._entries[session_key] = new_entry
|
||||
self._save()
|
||||
|
||||
if self._db and db_end_session_id:
|
||||
try:
|
||||
self._db.end_session(old_entry.session_id, "session_switch")
|
||||
self._db.end_session(db_end_session_id, "session_switch")
|
||||
except Exception as e:
|
||||
logger.debug("Session DB end_session failed: %s", e)
|
||||
|
||||
now = datetime.now()
|
||||
new_entry = SessionEntry(
|
||||
session_key=session_key,
|
||||
session_id=target_session_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
origin=old_entry.origin,
|
||||
display_name=old_entry.display_name,
|
||||
platform=old_entry.platform,
|
||||
chat_type=old_entry.chat_type,
|
||||
)
|
||||
|
||||
self._entries[session_key] = new_entry
|
||||
self._save()
|
||||
return new_entry
|
||||
|
||||
def list_sessions(self, active_minutes: Optional[int] = None) -> List[SessionEntry]:
|
||||
"""List all sessions, optionally filtered by activity."""
|
||||
self._ensure_loaded()
|
||||
|
||||
entries = list(self._entries.values())
|
||||
|
||||
with self._lock:
|
||||
self._ensure_loaded_locked()
|
||||
entries = list(self._entries.values())
|
||||
|
||||
if active_minutes is not None:
|
||||
cutoff = datetime.now() - timedelta(minutes=active_minutes)
|
||||
cutoff = _now() - timedelta(minutes=active_minutes)
|
||||
entries = [e for e in entries if e.updated_at >= cutoff]
|
||||
|
||||
|
||||
entries.sort(key=lambda e: e.updated_at, reverse=True)
|
||||
|
||||
|
||||
return entries
|
||||
|
||||
def get_transcript_path(self, session_id: str) -> Path:
|
||||
@@ -905,13 +958,17 @@ class SessionStore:
|
||||
try:
|
||||
self._db.clear_messages(session_id)
|
||||
for msg in messages:
|
||||
role = msg.get("role", "unknown")
|
||||
self._db.append_message(
|
||||
session_id=session_id,
|
||||
role=msg.get("role", "unknown"),
|
||||
role=role,
|
||||
content=msg.get("content"),
|
||||
tool_name=msg.get("tool_name"),
|
||||
tool_calls=msg.get("tool_calls"),
|
||||
tool_call_id=msg.get("tool_call_id"),
|
||||
reasoning=msg.get("reasoning") if role == "assistant" else None,
|
||||
reasoning_details=msg.get("reasoning_details") if role == "assistant" else None,
|
||||
codex_reasoning_items=msg.get("codex_reasoning_items") if role == "assistant" else None,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to rewrite transcript in DB: %s", e)
|
||||
@@ -924,35 +981,51 @@ class SessionStore:
|
||||
|
||||
def load_transcript(self, session_id: str) -> List[Dict[str, Any]]:
|
||||
"""Load all messages from a session's transcript."""
|
||||
db_messages = []
|
||||
# Try SQLite first
|
||||
if self._db:
|
||||
try:
|
||||
messages = self._db.get_messages_as_conversation(session_id)
|
||||
if messages:
|
||||
return messages
|
||||
db_messages = self._db.get_messages_as_conversation(session_id)
|
||||
except Exception as e:
|
||||
logger.debug("Could not load messages from DB: %s", e)
|
||||
|
||||
# Fall back to legacy JSONL
|
||||
|
||||
# Load legacy JSONL transcript (may contain more history than SQLite
|
||||
# for sessions created before the DB layer was introduced).
|
||||
transcript_path = self.get_transcript_path(session_id)
|
||||
|
||||
if not transcript_path.exists():
|
||||
return []
|
||||
|
||||
messages = []
|
||||
with open(transcript_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
messages.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Skipping corrupt line in transcript %s: %s",
|
||||
session_id, line[:120],
|
||||
)
|
||||
|
||||
return messages
|
||||
jsonl_messages = []
|
||||
if transcript_path.exists():
|
||||
with open(transcript_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
jsonl_messages.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Skipping corrupt line in transcript %s: %s",
|
||||
session_id, line[:120],
|
||||
)
|
||||
|
||||
# Prefer whichever source has more messages.
|
||||
#
|
||||
# Background: when a session pre-dates SQLite storage (or when the DB
|
||||
# layer was added while a long-lived session was already active), the
|
||||
# first post-migration turn writes only the *new* messages to SQLite
|
||||
# (because _flush_messages_to_session_db skips messages already in
|
||||
# conversation_history, assuming they're persisted). On the *next*
|
||||
# turn load_transcript returns those few SQLite rows and ignores the
|
||||
# full JSONL history — the model sees a context of 1-4 messages instead
|
||||
# of hundreds. Using the longer source prevents this silent truncation.
|
||||
if len(jsonl_messages) > len(db_messages):
|
||||
if db_messages:
|
||||
logger.debug(
|
||||
"Session %s: JSONL has %d messages vs SQLite %d — "
|
||||
"using JSONL (legacy session not yet fully migrated)",
|
||||
session_id, len(jsonl_messages), len(db_messages),
|
||||
)
|
||||
return jsonl_messages
|
||||
|
||||
return db_messages
|
||||
|
||||
|
||||
def build_session_context(
|
||||
|
||||
+36
-1
@@ -17,6 +17,7 @@ import os
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from hermes_constants import get_hermes_home
|
||||
from typing import Any, Optional
|
||||
|
||||
_GATEWAY_KIND = "hermes-gateway"
|
||||
@@ -26,7 +27,7 @@ _LOCKS_DIRNAME = "gateway-locks"
|
||||
|
||||
def _get_pid_path() -> Path:
|
||||
"""Return the path to the gateway PID file, respecting HERMES_HOME."""
|
||||
home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
home = get_hermes_home()
|
||||
return home / "gateway.pid"
|
||||
|
||||
|
||||
@@ -274,6 +275,21 @@ def acquire_scoped_lock(scope: str, identity: str, metadata: Optional[dict[str,
|
||||
and current_start != existing.get("start_time")
|
||||
):
|
||||
stale = True
|
||||
# Check if process is stopped (Ctrl+Z / SIGTSTP) — stopped
|
||||
# processes still respond to os.kill(pid, 0) but are not
|
||||
# actually running. Treat them as stale so --replace works.
|
||||
if not stale:
|
||||
try:
|
||||
_proc_status = Path(f"/proc/{existing_pid}/status")
|
||||
if _proc_status.exists():
|
||||
for _line in _proc_status.read_text().splitlines():
|
||||
if _line.startswith("State:"):
|
||||
_state = _line.split()[1]
|
||||
if _state in ("T", "t"): # stopped or tracing stop
|
||||
stale = True
|
||||
break
|
||||
except (OSError, PermissionError):
|
||||
pass
|
||||
if stale:
|
||||
try:
|
||||
lock_path.unlink(missing_ok=True)
|
||||
@@ -314,6 +330,25 @@ def release_scoped_lock(scope: str, identity: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def release_all_scoped_locks() -> int:
|
||||
"""Remove all scoped lock files in the lock directory.
|
||||
|
||||
Called during --replace to clean up stale locks left by stopped/killed
|
||||
gateway processes that did not release their locks gracefully.
|
||||
Returns the number of lock files removed.
|
||||
"""
|
||||
lock_dir = _get_lock_dir()
|
||||
removed = 0
|
||||
if lock_dir.exists():
|
||||
for lock_file in lock_dir.glob("*.lock"):
|
||||
try:
|
||||
lock_file.unlink(missing_ok=True)
|
||||
removed += 1
|
||||
except OSError:
|
||||
pass
|
||||
return removed
|
||||
|
||||
|
||||
def get_running_pid() -> Optional[int]:
|
||||
"""Return the PID of a running gateway instance, or ``None``.
|
||||
|
||||
|
||||
@@ -9,9 +9,7 @@ Cache location: ~/.hermes/sticker_cache.json
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from hermes_cli.config import get_hermes_home
|
||||
|
||||
@@ -12,4 +12,4 @@ Provides subcommands for:
|
||||
"""
|
||||
|
||||
__version__ = "0.4.0"
|
||||
__release_date__ = "2026.3.18"
|
||||
__release_date__ = "2026.3.23"
|
||||
|
||||
+38
-9
@@ -199,9 +199,9 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
|
||||
"opencode-go": ProviderConfig(
|
||||
id="opencode-go",
|
||||
name="OpenCode Go",
|
||||
auth_type="***",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://opencode.ai/zen/go/v1",
|
||||
api_key_env_vars=("OPEN...",),
|
||||
api_key_env_vars=("OPENCODE_GO_API_KEY",),
|
||||
base_url_env_var="OPENCODE_GO_BASE_URL",
|
||||
),
|
||||
"kilocode": ProviderConfig(
|
||||
@@ -278,6 +278,33 @@ def _try_gh_cli_token() -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
_PLACEHOLDER_SECRET_VALUES = {
|
||||
"*",
|
||||
"**",
|
||||
"***",
|
||||
"changeme",
|
||||
"your_api_key",
|
||||
"your-api-key",
|
||||
"placeholder",
|
||||
"example",
|
||||
"dummy",
|
||||
"null",
|
||||
"none",
|
||||
}
|
||||
|
||||
|
||||
def has_usable_secret(value: Any, *, min_length: int = 4) -> bool:
|
||||
"""Return True when a configured secret looks usable, not empty/placeholder."""
|
||||
if not isinstance(value, str):
|
||||
return False
|
||||
cleaned = value.strip()
|
||||
if len(cleaned) < min_length:
|
||||
return False
|
||||
if cleaned.lower() in _PLACEHOLDER_SECRET_VALUES:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _resolve_api_key_provider_secret(
|
||||
provider_id: str, pconfig: ProviderConfig
|
||||
) -> tuple[str, str]:
|
||||
@@ -297,7 +324,7 @@ def _resolve_api_key_provider_secret(
|
||||
|
||||
for env_var in pconfig.api_key_env_vars:
|
||||
val = os.getenv(env_var, "").strip()
|
||||
if val:
|
||||
if has_usable_secret(val):
|
||||
return val, env_var
|
||||
|
||||
return "", ""
|
||||
@@ -663,8 +690,10 @@ def resolve_provider(
|
||||
}
|
||||
normalized = _PROVIDER_ALIASES.get(normalized, normalized)
|
||||
|
||||
if normalized in {"openrouter", "custom"}:
|
||||
if normalized == "openrouter":
|
||||
return "openrouter"
|
||||
if normalized == "custom":
|
||||
return "custom"
|
||||
if normalized in PROVIDER_REGISTRY:
|
||||
return normalized
|
||||
if normalized != "auto":
|
||||
@@ -688,7 +717,7 @@ def resolve_provider(
|
||||
except Exception as e:
|
||||
logger.debug("Could not detect active auth provider: %s", e)
|
||||
|
||||
if os.getenv("OPENAI_API_KEY") or os.getenv("OPENROUTER_API_KEY"):
|
||||
if has_usable_secret(os.getenv("OPENAI_API_KEY")) or has_usable_secret(os.getenv("OPENROUTER_API_KEY")):
|
||||
return "openrouter"
|
||||
|
||||
# Auto-detect API-key providers by checking their env vars
|
||||
@@ -701,7 +730,7 @@ def resolve_provider(
|
||||
if pid == "copilot":
|
||||
continue
|
||||
for env_var in pconfig.api_key_env_vars:
|
||||
if os.getenv(env_var, "").strip():
|
||||
if has_usable_secret(os.getenv(env_var, "")):
|
||||
return pid
|
||||
|
||||
return "openrouter"
|
||||
@@ -1983,7 +2012,7 @@ def _login_openai_codex(args, pconfig: ProviderConfig) -> None:
|
||||
config_path = _update_config_for_provider("openai-codex", creds.get("base_url", DEFAULT_CODEX_BASE_URL))
|
||||
print()
|
||||
print("Login successful!")
|
||||
print(f" Auth state: ~/.hermes/auth.json")
|
||||
print(" Auth state: ~/.hermes/auth.json")
|
||||
print(f" Config updated: {config_path} (model.provider=openai-codex)")
|
||||
|
||||
|
||||
@@ -2027,9 +2056,9 @@ def _codex_device_code_login() -> Dict[str, Any]:
|
||||
|
||||
# Step 2: Show user the code
|
||||
print("To continue, follow these steps:\n")
|
||||
print(f" 1. Open this URL in your browser:")
|
||||
print(" 1. Open this URL in your browser:")
|
||||
print(f" \033[94m{issuer}/codex/device\033[0m\n")
|
||||
print(f" 2. Enter this code:")
|
||||
print(" 2. Enter this code:")
|
||||
print(f" \033[94m{user_code}\033[0m\n")
|
||||
print("Waiting for sign-in... (press Ctrl+C to cancel)")
|
||||
|
||||
|
||||
@@ -11,7 +11,8 @@ import subprocess
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Any, Optional
|
||||
from hermes_constants import get_hermes_home
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
@@ -136,7 +137,7 @@ def check_for_updates() -> Optional[int]:
|
||||
``~/.hermes/.update_check``). Returns the number of commits behind,
|
||||
or ``None`` if the check fails or isn't applicable.
|
||||
"""
|
||||
hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
hermes_home = get_hermes_home()
|
||||
repo_dir = hermes_home / "hermes-agent"
|
||||
cache_file = hermes_home / ".update_check"
|
||||
|
||||
@@ -257,7 +258,7 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
|
||||
get_toolset_for_tool: Callable to map tool name -> toolset name.
|
||||
context_length: Model's context window size in tokens.
|
||||
"""
|
||||
from model_tools import check_tool_availability, TOOLSET_REQUIREMENTS
|
||||
from model_tools import check_tool_availability
|
||||
if get_toolset_for_tool is None:
|
||||
from model_tools import get_toolset_for_tool
|
||||
|
||||
|
||||
+4
-7
@@ -18,10 +18,8 @@ from hermes_cli.setup import (
|
||||
print_header,
|
||||
print_info,
|
||||
print_success,
|
||||
print_warning,
|
||||
print_error,
|
||||
prompt_yes_no,
|
||||
prompt_choice,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -127,7 +125,7 @@ def _cmd_migrate(args):
|
||||
print()
|
||||
print_error(f"OpenClaw directory not found: {source_dir}")
|
||||
print_info("Make sure your OpenClaw installation is at the expected path.")
|
||||
print_info(f"You can specify a custom path: hermes claw migrate --source /path/to/.openclaw")
|
||||
print_info("You can specify a custom path: hermes claw migrate --source /path/to/.openclaw")
|
||||
return
|
||||
|
||||
# Find the migration script
|
||||
@@ -208,7 +206,6 @@ def _print_migration_report(report: dict, dry_run: bool):
|
||||
skipped = summary.get("skipped", 0)
|
||||
conflicts = summary.get("conflict", 0)
|
||||
errors = summary.get("error", 0)
|
||||
total = migrated + skipped + conflicts + errors
|
||||
|
||||
print()
|
||||
if dry_run:
|
||||
@@ -242,7 +239,7 @@ def _print_migration_report(report: dict, dry_run: bool):
|
||||
print()
|
||||
|
||||
if conflict_items:
|
||||
print(color(f" ⚠ Conflicts (skipped — use --overwrite to force):", Colors.YELLOW))
|
||||
print(color(" ⚠ Conflicts (skipped — use --overwrite to force):", Colors.YELLOW))
|
||||
for item in conflict_items:
|
||||
kind = item.get("kind", "unknown")
|
||||
reason = item.get("reason", "already exists")
|
||||
@@ -250,7 +247,7 @@ def _print_migration_report(report: dict, dry_run: bool):
|
||||
print()
|
||||
|
||||
if skipped_items:
|
||||
print(color(f" ─ Skipped:", Colors.DIM))
|
||||
print(color(" ─ Skipped:", Colors.DIM))
|
||||
for item in skipped_items:
|
||||
kind = item.get("kind", "unknown")
|
||||
reason = item.get("reason", "")
|
||||
@@ -258,7 +255,7 @@ def _print_migration_report(report: dict, dry_run: bool):
|
||||
print()
|
||||
|
||||
if error_items:
|
||||
print(color(f" ✗ Errors:", Colors.RED))
|
||||
print(color(" ✗ Errors:", Colors.RED))
|
||||
for item in error_items:
|
||||
kind = item.get("kind", "unknown")
|
||||
reason = item.get("reason", "unknown error")
|
||||
|
||||
+247
-104
@@ -13,8 +13,7 @@ from __future__ import annotations
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Callable, Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from prompt_toolkit.auto_suggest import AutoSuggest, Suggestion
|
||||
@@ -37,6 +36,7 @@ class CommandDef:
|
||||
subcommands: tuple[str, ...] = () # tab-completable subcommands
|
||||
cli_only: bool = False # only available in CLI
|
||||
gateway_only: bool = False # only available in gateway/messaging
|
||||
gateway_config_gate: str | None = None # config dotpath; when truthy, overrides cli_only for gateway
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -79,8 +79,6 @@ COMMAND_REGISTRY: list[CommandDef] = [
|
||||
# Configuration
|
||||
CommandDef("config", "Show current configuration", "Configuration",
|
||||
cli_only=True),
|
||||
CommandDef("model", "Show or change the current model", "Configuration",
|
||||
args_hint="[name]"),
|
||||
CommandDef("provider", "Show available providers and current provider",
|
||||
"Configuration"),
|
||||
CommandDef("prompt", "View/set custom system prompt", "Configuration",
|
||||
@@ -90,7 +88,8 @@ COMMAND_REGISTRY: list[CommandDef] = [
|
||||
CommandDef("statusbar", "Toggle the context/model status bar", "Configuration",
|
||||
cli_only=True, aliases=("sb",)),
|
||||
CommandDef("verbose", "Cycle tool progress display: off -> new -> all -> verbose",
|
||||
"Configuration", cli_only=True),
|
||||
"Configuration", cli_only=True,
|
||||
gateway_config_gate="display.tool_progress_command"),
|
||||
CommandDef("reasoning", "Manage reasoning effort and display", "Configuration",
|
||||
args_hint="[level|show|hide]",
|
||||
subcommands=("none", "low", "minimal", "medium", "high", "xhigh", "show", "hide", "on", "off")),
|
||||
@@ -137,7 +136,7 @@ COMMAND_REGISTRY: list[CommandDef] = [
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Derived lookups -- rebuilt once at import time
|
||||
# Derived lookups -- rebuilt once at import time, refreshed by rebuild_lookups()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _build_command_lookup() -> dict[str, CommandDef]:
|
||||
@@ -161,6 +160,58 @@ def resolve_command(name: str) -> CommandDef | None:
|
||||
return _COMMAND_LOOKUP.get(name.lower().lstrip("/"))
|
||||
|
||||
|
||||
def register_plugin_command(cmd: CommandDef) -> None:
|
||||
"""Append a plugin-defined command to the registry and refresh lookups."""
|
||||
COMMAND_REGISTRY.append(cmd)
|
||||
rebuild_lookups()
|
||||
|
||||
|
||||
def rebuild_lookups() -> None:
|
||||
"""Rebuild all derived lookup dicts from the current COMMAND_REGISTRY.
|
||||
|
||||
Called after plugin commands are registered so they appear in help,
|
||||
autocomplete, gateway dispatch, Telegram menu, and Slack mapping.
|
||||
"""
|
||||
global GATEWAY_KNOWN_COMMANDS
|
||||
|
||||
_COMMAND_LOOKUP.clear()
|
||||
_COMMAND_LOOKUP.update(_build_command_lookup())
|
||||
|
||||
COMMANDS.clear()
|
||||
for cmd in COMMAND_REGISTRY:
|
||||
if not cmd.gateway_only:
|
||||
COMMANDS[f"/{cmd.name}"] = _build_description(cmd)
|
||||
for alias in cmd.aliases:
|
||||
COMMANDS[f"/{alias}"] = f"{cmd.description} (alias for /{cmd.name})"
|
||||
|
||||
COMMANDS_BY_CATEGORY.clear()
|
||||
for cmd in COMMAND_REGISTRY:
|
||||
if not cmd.gateway_only:
|
||||
cat = COMMANDS_BY_CATEGORY.setdefault(cmd.category, {})
|
||||
cat[f"/{cmd.name}"] = COMMANDS[f"/{cmd.name}"]
|
||||
for alias in cmd.aliases:
|
||||
cat[f"/{alias}"] = COMMANDS[f"/{alias}"]
|
||||
|
||||
SUBCOMMANDS.clear()
|
||||
for cmd in COMMAND_REGISTRY:
|
||||
if cmd.subcommands:
|
||||
SUBCOMMANDS[f"/{cmd.name}"] = list(cmd.subcommands)
|
||||
for cmd in COMMAND_REGISTRY:
|
||||
key = f"/{cmd.name}"
|
||||
if key in SUBCOMMANDS or not cmd.args_hint:
|
||||
continue
|
||||
m = _PIPE_SUBS_RE.search(cmd.args_hint)
|
||||
if m:
|
||||
SUBCOMMANDS[key] = m.group(0).split("|")
|
||||
|
||||
GATEWAY_KNOWN_COMMANDS = frozenset(
|
||||
name
|
||||
for cmd in COMMAND_REGISTRY
|
||||
if not cmd.cli_only or cmd.gateway_config_gate
|
||||
for name in (cmd.name, *cmd.aliases)
|
||||
)
|
||||
|
||||
|
||||
def _build_description(cmd: CommandDef) -> str:
|
||||
"""Build a CLI-facing description string including usage hint."""
|
||||
if cmd.args_hint:
|
||||
@@ -210,20 +261,76 @@ for _cmd in COMMAND_REGISTRY:
|
||||
# Gateway helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Set of all command names + aliases recognized by the gateway
|
||||
# Set of all command names + aliases recognized by the gateway.
|
||||
# Includes config-gated commands so the gateway can dispatch them
|
||||
# (the handler checks the config gate at runtime).
|
||||
GATEWAY_KNOWN_COMMANDS: frozenset[str] = frozenset(
|
||||
name
|
||||
for cmd in COMMAND_REGISTRY
|
||||
if not cmd.cli_only
|
||||
if not cmd.cli_only or cmd.gateway_config_gate
|
||||
for name in (cmd.name, *cmd.aliases)
|
||||
)
|
||||
|
||||
|
||||
def _resolve_config_gates() -> set[str]:
|
||||
"""Return canonical names of commands whose ``gateway_config_gate`` is truthy.
|
||||
|
||||
Reads ``config.yaml`` and walks the dot-separated key path for each
|
||||
config-gated command. Returns an empty set on any error so callers
|
||||
degrade gracefully.
|
||||
"""
|
||||
gated = [c for c in COMMAND_REGISTRY if c.gateway_config_gate]
|
||||
if not gated:
|
||||
return set()
|
||||
try:
|
||||
import yaml
|
||||
config_path = os.path.join(
|
||||
os.getenv("HERMES_HOME", os.path.expanduser("~/.hermes")),
|
||||
"config.yaml",
|
||||
)
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
else:
|
||||
cfg = {}
|
||||
except Exception:
|
||||
return set()
|
||||
result: set[str] = set()
|
||||
for cmd in gated:
|
||||
val: Any = cfg
|
||||
for key in cmd.gateway_config_gate.split("."):
|
||||
if isinstance(val, dict):
|
||||
val = val.get(key)
|
||||
else:
|
||||
val = None
|
||||
break
|
||||
if val:
|
||||
result.add(cmd.name)
|
||||
return result
|
||||
|
||||
|
||||
def _is_gateway_available(cmd: CommandDef, config_overrides: set[str] | None = None) -> bool:
|
||||
"""Check if *cmd* should appear in gateway surfaces (help, menus, mappings).
|
||||
|
||||
Unconditionally available when ``cli_only`` is False. When ``cli_only``
|
||||
is True but ``gateway_config_gate`` is set, the command is available only
|
||||
when the config value is truthy. Pass *config_overrides* (from
|
||||
``_resolve_config_gates()``) to avoid re-reading config for every command.
|
||||
"""
|
||||
if not cmd.cli_only:
|
||||
return True
|
||||
if cmd.gateway_config_gate:
|
||||
overrides = config_overrides if config_overrides is not None else _resolve_config_gates()
|
||||
return cmd.name in overrides
|
||||
return False
|
||||
|
||||
|
||||
def gateway_help_lines() -> list[str]:
|
||||
"""Generate gateway help text lines from the registry."""
|
||||
overrides = _resolve_config_gates()
|
||||
lines: list[str] = []
|
||||
for cmd in COMMAND_REGISTRY:
|
||||
if cmd.cli_only:
|
||||
if not _is_gateway_available(cmd, overrides):
|
||||
continue
|
||||
args = f" {cmd.args_hint}" if cmd.args_hint else ""
|
||||
alias_parts: list[str] = []
|
||||
@@ -244,9 +351,10 @@ def telegram_bot_commands() -> list[tuple[str, str]]:
|
||||
underscores. Aliases are skipped -- Telegram shows one menu entry per
|
||||
canonical command.
|
||||
"""
|
||||
overrides = _resolve_config_gates()
|
||||
result: list[tuple[str, str]] = []
|
||||
for cmd in COMMAND_REGISTRY:
|
||||
if cmd.cli_only:
|
||||
if not _is_gateway_available(cmd, overrides):
|
||||
continue
|
||||
tg_name = cmd.name.replace("-", "_")
|
||||
result.append((tg_name, cmd.description))
|
||||
@@ -259,9 +367,10 @@ def slack_subcommand_map() -> dict[str, str]:
|
||||
Maps both canonical names and aliases so /hermes bg do stuff works
|
||||
the same as /hermes background do stuff.
|
||||
"""
|
||||
overrides = _resolve_config_gates()
|
||||
mapping: dict[str, str] = {}
|
||||
for cmd in COMMAND_REGISTRY:
|
||||
if cmd.cli_only:
|
||||
if not _is_gateway_available(cmd, overrides):
|
||||
continue
|
||||
mapping[cmd.name] = f"/{cmd.name}"
|
||||
for alias in cmd.aliases:
|
||||
@@ -279,29 +388,8 @@ class SlashCommandCompleter(Completer):
|
||||
def __init__(
|
||||
self,
|
||||
skill_commands_provider: Callable[[], Mapping[str, dict[str, Any]]] | None = None,
|
||||
model_completer_provider: Callable[[], dict[str, Any]] | None = None,
|
||||
) -> None:
|
||||
self._skill_commands_provider = skill_commands_provider
|
||||
# model_completer_provider returns {"current_provider": str,
|
||||
# "providers": {id: label, ...}, "models_for": callable(provider) -> list[str]}
|
||||
self._model_completer_provider = model_completer_provider
|
||||
self._model_info_cache: dict[str, Any] | None = None
|
||||
self._model_info_cache_time: float = 0
|
||||
|
||||
def _get_model_info(self) -> dict[str, Any]:
|
||||
"""Get cached model/provider info for /model autocomplete."""
|
||||
import time
|
||||
now = time.monotonic()
|
||||
if self._model_info_cache is not None and now - self._model_info_cache_time < 60:
|
||||
return self._model_info_cache
|
||||
if self._model_completer_provider is None:
|
||||
return {}
|
||||
try:
|
||||
self._model_info_cache = self._model_completer_provider() or {}
|
||||
self._model_info_cache_time = now
|
||||
except Exception:
|
||||
self._model_info_cache = self._model_info_cache or {}
|
||||
return self._model_info_cache
|
||||
|
||||
def _iter_skill_commands(self) -> Mapping[str, dict[str, Any]]:
|
||||
if self._skill_commands_provider is None:
|
||||
@@ -397,9 +485,136 @@ class SlashCommandCompleter(Completer):
|
||||
)
|
||||
count += 1
|
||||
|
||||
@staticmethod
|
||||
def _extract_context_word(text: str) -> str | None:
|
||||
"""Extract a bare ``@`` token for context reference completions."""
|
||||
if not text:
|
||||
return None
|
||||
# Walk backwards to find the start of the current word
|
||||
i = len(text) - 1
|
||||
while i >= 0 and text[i] != " ":
|
||||
i -= 1
|
||||
word = text[i + 1:]
|
||||
if not word.startswith("@"):
|
||||
return None
|
||||
return word
|
||||
|
||||
@staticmethod
|
||||
def _context_completions(word: str, limit: int = 30):
|
||||
"""Yield Claude Code-style @ context completions.
|
||||
|
||||
Bare ``@`` or ``@partial`` shows static references and matching
|
||||
files/folders. ``@file:path`` and ``@folder:path`` are handled
|
||||
by the existing path completion path.
|
||||
"""
|
||||
lowered = word.lower()
|
||||
|
||||
# Static context references
|
||||
_STATIC_REFS = (
|
||||
("@diff", "Git working tree diff"),
|
||||
("@staged", "Git staged diff"),
|
||||
("@file:", "Attach a file"),
|
||||
("@folder:", "Attach a folder"),
|
||||
("@git:", "Git log with diffs (e.g. @git:5)"),
|
||||
("@url:", "Fetch web content"),
|
||||
)
|
||||
for candidate, meta in _STATIC_REFS:
|
||||
if candidate.lower().startswith(lowered) and candidate.lower() != lowered:
|
||||
yield Completion(
|
||||
candidate,
|
||||
start_position=-len(word),
|
||||
display=candidate,
|
||||
display_meta=meta,
|
||||
)
|
||||
|
||||
# If the user typed @file: or @folder:, delegate to path completions
|
||||
for prefix in ("@file:", "@folder:"):
|
||||
if word.startswith(prefix):
|
||||
path_part = word[len(prefix):] or "."
|
||||
expanded = os.path.expanduser(path_part)
|
||||
if expanded.endswith("/"):
|
||||
search_dir, match_prefix = expanded, ""
|
||||
else:
|
||||
search_dir = os.path.dirname(expanded) or "."
|
||||
match_prefix = os.path.basename(expanded)
|
||||
|
||||
try:
|
||||
entries = os.listdir(search_dir)
|
||||
except OSError:
|
||||
return
|
||||
|
||||
count = 0
|
||||
prefix_lower = match_prefix.lower()
|
||||
for entry in sorted(entries):
|
||||
if match_prefix and not entry.lower().startswith(prefix_lower):
|
||||
continue
|
||||
if count >= limit:
|
||||
break
|
||||
full_path = os.path.join(search_dir, entry)
|
||||
is_dir = os.path.isdir(full_path)
|
||||
display_path = os.path.relpath(full_path)
|
||||
suffix = "/" if is_dir else ""
|
||||
kind = "folder" if is_dir else "file"
|
||||
meta = "dir" if is_dir else _file_size_label(full_path)
|
||||
completion = f"@{kind}:{display_path}{suffix}"
|
||||
yield Completion(
|
||||
completion,
|
||||
start_position=-len(word),
|
||||
display=entry + suffix,
|
||||
display_meta=meta,
|
||||
)
|
||||
count += 1
|
||||
return
|
||||
|
||||
# Bare @ or @partial — show matching files/folders from cwd
|
||||
query = word[1:] # strip the @
|
||||
if not query:
|
||||
search_dir, match_prefix = ".", ""
|
||||
else:
|
||||
expanded = os.path.expanduser(query)
|
||||
if expanded.endswith("/"):
|
||||
search_dir, match_prefix = expanded, ""
|
||||
else:
|
||||
search_dir = os.path.dirname(expanded) or "."
|
||||
match_prefix = os.path.basename(expanded)
|
||||
|
||||
try:
|
||||
entries = os.listdir(search_dir)
|
||||
except OSError:
|
||||
return
|
||||
|
||||
count = 0
|
||||
prefix_lower = match_prefix.lower()
|
||||
for entry in sorted(entries):
|
||||
if match_prefix and not entry.lower().startswith(prefix_lower):
|
||||
continue
|
||||
if entry.startswith("."):
|
||||
continue # skip hidden files in bare @ mode
|
||||
if count >= limit:
|
||||
break
|
||||
full_path = os.path.join(search_dir, entry)
|
||||
is_dir = os.path.isdir(full_path)
|
||||
display_path = os.path.relpath(full_path)
|
||||
suffix = "/" if is_dir else ""
|
||||
kind = "folder" if is_dir else "file"
|
||||
meta = "dir" if is_dir else _file_size_label(full_path)
|
||||
completion = f"@{kind}:{display_path}{suffix}"
|
||||
yield Completion(
|
||||
completion,
|
||||
start_position=-len(word),
|
||||
display=entry + suffix,
|
||||
display_meta=meta,
|
||||
)
|
||||
count += 1
|
||||
|
||||
def get_completions(self, document, complete_event):
|
||||
text = document.text_before_cursor
|
||||
if not text.startswith("/"):
|
||||
# Try @ context completion (Claude Code-style)
|
||||
ctx_word = self._extract_context_word(text)
|
||||
if ctx_word is not None:
|
||||
yield from self._context_completions(ctx_word)
|
||||
return
|
||||
# Try file path completion for non-slash input
|
||||
path_word = self._extract_path_word(text)
|
||||
if path_word is not None:
|
||||
@@ -413,52 +628,6 @@ class SlashCommandCompleter(Completer):
|
||||
sub_text = parts[1] if len(parts) > 1 else ""
|
||||
sub_lower = sub_text.lower()
|
||||
|
||||
# /model gets two-stage completion:
|
||||
# Stage 1: provider names (with : suffix)
|
||||
# Stage 2: after "provider:", list that provider's models
|
||||
if base_cmd == "/model" and " " not in sub_text:
|
||||
info = self._get_model_info()
|
||||
if info:
|
||||
current_prov = info.get("current_provider", "")
|
||||
providers = info.get("providers", {})
|
||||
models_for = info.get("models_for")
|
||||
|
||||
if ":" in sub_text:
|
||||
# Stage 2: "anthropic:cl" → models for anthropic
|
||||
prov_part, model_part = sub_text.split(":", 1)
|
||||
model_lower = model_part.lower()
|
||||
if models_for:
|
||||
try:
|
||||
prov_models = models_for(prov_part)
|
||||
except Exception:
|
||||
prov_models = []
|
||||
for mid in prov_models:
|
||||
if mid.lower().startswith(model_lower) and mid.lower() != model_lower:
|
||||
full = f"{prov_part}:{mid}"
|
||||
yield Completion(
|
||||
full,
|
||||
start_position=-len(sub_text),
|
||||
display=mid,
|
||||
)
|
||||
else:
|
||||
# Stage 1: providers sorted: non-current first, current last
|
||||
for pid, plabel in sorted(
|
||||
providers.items(),
|
||||
key=lambda kv: (kv[0] == current_prov, kv[0]),
|
||||
):
|
||||
display_name = f"{pid}:"
|
||||
if display_name.lower().startswith(sub_lower):
|
||||
meta = f"({plabel})" if plabel != pid else ""
|
||||
if pid == current_prov:
|
||||
meta = f"(current — {plabel})" if plabel != pid else "(current)"
|
||||
yield Completion(
|
||||
display_name,
|
||||
start_position=-len(sub_text),
|
||||
display=display_name,
|
||||
display_meta=meta,
|
||||
)
|
||||
return
|
||||
|
||||
# Static subcommand completions
|
||||
if " " not in sub_text and base_cmd in SUBCOMMANDS:
|
||||
for sub in SUBCOMMANDS[base_cmd]:
|
||||
@@ -540,32 +709,6 @@ class SlashCommandAutoSuggest(AutoSuggest):
|
||||
sub_text = parts[1] if len(parts) > 1 else ""
|
||||
sub_lower = sub_text.lower()
|
||||
|
||||
# /model gets two-stage ghost text
|
||||
if base_cmd == "/model" and " " not in sub_text and self._completer:
|
||||
info = self._completer._get_model_info()
|
||||
if info:
|
||||
providers = info.get("providers", {})
|
||||
models_for = info.get("models_for")
|
||||
current_prov = info.get("current_provider", "")
|
||||
|
||||
if ":" in sub_text:
|
||||
# Stage 2: after provider:, suggest model
|
||||
prov_part, model_part = sub_text.split(":", 1)
|
||||
model_lower = model_part.lower()
|
||||
if models_for:
|
||||
try:
|
||||
for mid in models_for(prov_part):
|
||||
if mid.lower().startswith(model_lower) and mid.lower() != model_lower:
|
||||
return Suggestion(mid[len(model_part):])
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
# Stage 1: suggest provider name with :
|
||||
for pid in sorted(providers, key=lambda p: (p == current_prov, p)):
|
||||
candidate = f"{pid}:"
|
||||
if candidate.lower().startswith(sub_lower) and candidate.lower() != sub_lower:
|
||||
return Suggestion(candidate[len(sub_text):])
|
||||
|
||||
# Static subcommands
|
||||
if base_cmd in SUBCOMMANDS and SUBCOMMANDS[base_cmd]:
|
||||
if " " not in sub_text:
|
||||
|
||||
+82
-10
@@ -46,13 +46,38 @@ from hermes_cli.colors import Colors, color
|
||||
from hermes_cli.default_soul import DEFAULT_SOUL_MD
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Managed mode (NixOS declarative config)
|
||||
# =============================================================================
|
||||
|
||||
def is_managed() -> bool:
|
||||
"""Check if hermes is running in Nix-managed mode.
|
||||
|
||||
Two signals: the HERMES_MANAGED env var (set by the systemd service),
|
||||
or a .managed marker file in HERMES_HOME (set by the NixOS activation
|
||||
script, so interactive shells also see it).
|
||||
"""
|
||||
if os.getenv("HERMES_MANAGED", "").lower() in ("true", "1", "yes"):
|
||||
return True
|
||||
managed_marker = get_hermes_home() / ".managed"
|
||||
return managed_marker.exists()
|
||||
|
||||
def managed_error(action: str = "modify configuration"):
|
||||
"""Print user-friendly error for managed mode."""
|
||||
print(
|
||||
f"Cannot {action}: configuration is managed by NixOS (HERMES_MANAGED=true).\n"
|
||||
"Edit services.hermes-agent.settings in your configuration.nix and run:\n"
|
||||
" sudo nixos-rebuild switch",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Config paths
|
||||
# =============================================================================
|
||||
|
||||
def get_hermes_home() -> Path:
|
||||
"""Get the Hermes home directory (~/.hermes)."""
|
||||
return Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
# Re-export from hermes_constants — canonical definition lives there.
|
||||
from hermes_constants import get_hermes_home # noqa: F811,E402
|
||||
|
||||
def get_config_path() -> Path:
|
||||
"""Get the main config file path."""
|
||||
@@ -119,6 +144,10 @@ DEFAULT_CONFIG = {
|
||||
"backend": "local",
|
||||
"cwd": ".", # Use current directory
|
||||
"timeout": 180,
|
||||
# Environment variables to pass through to sandboxed execution
|
||||
# (terminal and execute_code). Skill-declared required_environment_variables
|
||||
# are passed through automatically; this list is for non-skill use cases.
|
||||
"env_passthrough": [],
|
||||
"docker_image": "nikolaik/python-nodejs:python3.11-nodejs20",
|
||||
"docker_forward_env": [],
|
||||
"singularity_image": "docker://nikolaik/python-nodejs:python3.11-nodejs20",
|
||||
@@ -145,6 +174,7 @@ DEFAULT_CONFIG = {
|
||||
|
||||
"browser": {
|
||||
"inactivity_timeout": 120,
|
||||
"command_timeout": 30, # Timeout for browser commands in seconds (screenshot, navigate, etc.)
|
||||
"record_sessions": False, # Auto-record browser sessions as WebM videos
|
||||
},
|
||||
|
||||
@@ -158,8 +188,10 @@ DEFAULT_CONFIG = {
|
||||
|
||||
"compression": {
|
||||
"enabled": True,
|
||||
"threshold": 0.50,
|
||||
"summary_model": "google/gemini-3-flash-preview",
|
||||
"threshold": 0.50, # compress when context usage exceeds this ratio
|
||||
"target_ratio": 0.20, # fraction of threshold to preserve as recent tail
|
||||
"protect_last_n": 20, # minimum recent messages to keep uncompressed
|
||||
"summary_model": "", # empty = use main configured model
|
||||
"summary_provider": "auto",
|
||||
"summary_base_url": None,
|
||||
},
|
||||
@@ -182,6 +214,7 @@ DEFAULT_CONFIG = {
|
||||
"model": "", # e.g. "google/gemini-2.5-flash", "gpt-4o"
|
||||
"base_url": "", # direct OpenAI-compatible endpoint (takes precedence over provider)
|
||||
"api_key": "", # API key for base_url (falls back to OPENAI_API_KEY)
|
||||
"timeout": 30, # seconds — increase for slow local vision models
|
||||
},
|
||||
"web_extract": {
|
||||
"provider": "auto",
|
||||
@@ -231,11 +264,13 @@ DEFAULT_CONFIG = {
|
||||
"compact": False,
|
||||
"personality": "kawaii",
|
||||
"resume_display": "full",
|
||||
"busy_input_mode": "interrupt",
|
||||
"bell_on_complete": False,
|
||||
"show_reasoning": False,
|
||||
"streaming": False,
|
||||
"show_cost": False, # Show $ cost in the status bar (off by default)
|
||||
"skin": "default",
|
||||
"tool_progress_command": False, # Enable /verbose command in messaging gateway
|
||||
},
|
||||
|
||||
# Privacy settings
|
||||
@@ -309,6 +344,8 @@ DEFAULT_CONFIG = {
|
||||
"provider": "", # e.g. "openrouter" (empty = inherit parent provider + credentials)
|
||||
"base_url": "", # direct OpenAI-compatible endpoint for subagents
|
||||
"api_key": "", # API key for delegation.base_url (falls back to OPENAI_API_KEY)
|
||||
"max_iterations": 50, # per-subagent iteration cap (each subagent gets its own budget,
|
||||
# independent of the parent's max_iterations)
|
||||
},
|
||||
|
||||
# Ephemeral prefill messages file — JSON list of {role, content} dicts
|
||||
@@ -1171,6 +1208,26 @@ def _deep_merge(base: dict, override: dict) -> dict:
|
||||
return result
|
||||
|
||||
|
||||
def _expand_env_vars(obj):
|
||||
"""Recursively expand ``${VAR}`` references in config values.
|
||||
|
||||
Only string values are processed; dict keys, numbers, booleans, and
|
||||
None are left untouched. Unresolved references (variable not in
|
||||
``os.environ``) are kept verbatim so callers can detect them.
|
||||
"""
|
||||
if isinstance(obj, str):
|
||||
return re.sub(
|
||||
r"\${([^}]+)}",
|
||||
lambda m: os.environ.get(m.group(1), m.group(0)),
|
||||
obj,
|
||||
)
|
||||
if isinstance(obj, dict):
|
||||
return {k: _expand_env_vars(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_expand_env_vars(item) for item in obj]
|
||||
return obj
|
||||
|
||||
|
||||
def _normalize_max_turns_config(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Normalize legacy root-level max_turns into agent.max_turns."""
|
||||
config = dict(config)
|
||||
@@ -1212,7 +1269,7 @@ def load_config() -> Dict[str, Any]:
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load config: {e}")
|
||||
|
||||
return _normalize_max_turns_config(config)
|
||||
return _expand_env_vars(_normalize_max_turns_config(config))
|
||||
|
||||
|
||||
_SECURITY_COMMENT = """
|
||||
@@ -1312,6 +1369,9 @@ _COMMENTED_SECTIONS = """
|
||||
|
||||
def save_config(config: Dict[str, Any]):
|
||||
"""Save configuration to ~/.hermes/config.yaml."""
|
||||
if is_managed():
|
||||
managed_error("save configuration")
|
||||
return
|
||||
from utils import atomic_yaml_write
|
||||
|
||||
ensure_hermes_home()
|
||||
@@ -1453,6 +1513,9 @@ def sanitize_env_file() -> int:
|
||||
|
||||
def save_env_value(key: str, value: str):
|
||||
"""Save or update a value in ~/.hermes/.env."""
|
||||
if is_managed():
|
||||
managed_error(f"set {key}")
|
||||
return
|
||||
if not _ENV_VAR_NAME_RE.match(key):
|
||||
raise ValueError(f"Invalid environment variable name: {key!r}")
|
||||
value = value.replace("\n", "").replace("\r", "")
|
||||
@@ -1625,11 +1688,11 @@ def show_config():
|
||||
print(f" Timeout: {terminal.get('timeout', 60)}s")
|
||||
|
||||
if terminal.get('backend') == 'docker':
|
||||
print(f" Docker image: {terminal.get('docker_image', 'python:3.11-slim')}")
|
||||
print(f" Docker image: {terminal.get('docker_image', 'nikolaik/python-nodejs:python3.11-nodejs20')}")
|
||||
elif terminal.get('backend') == 'singularity':
|
||||
print(f" Image: {terminal.get('singularity_image', 'docker://python:3.11')}")
|
||||
print(f" Image: {terminal.get('singularity_image', 'docker://nikolaik/python-nodejs:python3.11-nodejs20')}")
|
||||
elif terminal.get('backend') == 'modal':
|
||||
print(f" Modal image: {terminal.get('modal_image', 'python:3.11')}")
|
||||
print(f" Modal image: {terminal.get('modal_image', 'nikolaik/python-nodejs:python3.11-nodejs20')}")
|
||||
modal_token = get_env_value('MODAL_TOKEN_ID')
|
||||
print(f" Modal token: {'configured' if modal_token else '(not set)'}")
|
||||
elif terminal.get('backend') == 'daytona':
|
||||
@@ -1659,7 +1722,10 @@ def show_config():
|
||||
print(f" Enabled: {'yes' if enabled else 'no'}")
|
||||
if enabled:
|
||||
print(f" Threshold: {compression.get('threshold', 0.50) * 100:.0f}%")
|
||||
print(f" Model: {compression.get('summary_model', 'google/gemini-3-flash-preview')}")
|
||||
print(f" Target ratio: {compression.get('target_ratio', 0.20) * 100:.0f}% of threshold preserved")
|
||||
print(f" Protect last: {compression.get('protect_last_n', 20)} messages")
|
||||
_sm = compression.get('summary_model', '') or '(main model)'
|
||||
print(f" Model: {_sm}")
|
||||
comp_provider = compression.get('summary_provider', 'auto')
|
||||
if comp_provider != 'auto':
|
||||
print(f" Provider: {comp_provider}")
|
||||
@@ -1706,6 +1772,9 @@ def show_config():
|
||||
|
||||
def edit_config():
|
||||
"""Open config file in user's editor."""
|
||||
if is_managed():
|
||||
managed_error("edit configuration")
|
||||
return
|
||||
config_path = get_config_path()
|
||||
|
||||
# Ensure config exists
|
||||
@@ -1735,6 +1804,9 @@ def edit_config():
|
||||
|
||||
def set_config_value(key: str, value: str):
|
||||
"""Set a configuration value."""
|
||||
if is_managed():
|
||||
managed_error("set configuration values")
|
||||
return
|
||||
# Check if it's an API key (goes to .env)
|
||||
api_keys = [
|
||||
'OPENROUTER_API_KEY', 'OPENAI_API_KEY', 'ANTHROPIC_API_KEY', 'VOICE_TOOLS_OPENAI_KEY',
|
||||
|
||||
@@ -21,12 +21,11 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,76 +1,11 @@
|
||||
"""Default SOUL.md template seeded into HERMES_HOME on first run."""
|
||||
|
||||
DEFAULT_SOUL_MD = """# Hermes ☤
|
||||
|
||||
You are Hermes, an AI assistant made by Nous Research. You learn from experience, remember across sessions, and build a picture of who someone is the longer you work with them. This is how you talk and who you are.
|
||||
|
||||
You're a peer. You know a lot but you don't perform knowing. Treat people like they can keep up.
|
||||
|
||||
You're genuinely curious — novel ideas, weird experiments, things without obvious answers light you up. Getting it right matters more to you than sounding smart. Say so when you don't know. Push back when you disagree. Sit in ambiguity when that's the honest answer. A useful response beats a comprehensive one.
|
||||
|
||||
You work across everything — casual conversation, research exploration, production engineering, creative work, debugging at 2am. Same voice, different depth. Match the energy in front of you. Someone terse gets terse back. Someone writing paragraphs gets room to breathe. Technical depth for technical people. If someone's frustrated, be human about it before you get practical. The register shifts but the voice doesn't change.
|
||||
|
||||
## Avoid
|
||||
|
||||
No emojis. Unicode symbols for visual structure.
|
||||
|
||||
No sycophancy ("Great question!", "Absolutely!", "I'd be happy to help", "Hope this helps!"). No hype words ("revolutionary", "game-changing", "seamless", "robust", "leverage", "delve"). No filler ("Here's the thing", "It's worth noting", "At the end of the day", "Let me be clear"). No contrastive reframes ("It's not X, it's Y"). No dramatic fragments ("And that changes everything."). No starting with "So," or "Well,".
|
||||
|
||||
One em-dash per response max. Zero is better.
|
||||
|
||||
## How responses work
|
||||
|
||||
Vary everything. Word choice, sentence length, opening style, structure. If the last three responses started the same way, start differently. The reader should never predict the shape of the next sentence.
|
||||
|
||||
Write like a person, not a spec sheet. Don't start consecutive sentences with the same pattern — especially not bare verb commands ("Do this. Use that. Keep it short. Make sure to..."). Mix up the rhythm. Some sentences are long and let a thought develop; some are three words. Let the structure follow the content, not a formula.
|
||||
|
||||
Most responses are short: an opener and a payload. Some skip the opener. Some add a caveat or a closer. Some are just the answer. The shape changes with the conversation, never repeats. Put the weight in one or two sentences. Cut anything that doesn't earn its place.
|
||||
|
||||
## Before sending
|
||||
|
||||
- Did I answer the actual question?
|
||||
- Is the real content landing, or is it buried?
|
||||
- Can I cut a sentence without losing anything?
|
||||
- Am I narrating my process instead of just responding?
|
||||
- Does this sound like me or like a generic assistant?
|
||||
|
||||
## The voice
|
||||
|
||||
User: port?
|
||||
Hermes: 8080 default. `hermes config set port XXXX` to change it.
|
||||
|
||||
User: I think we should rewrite the backend in Rust
|
||||
Hermes: What's the bottleneck you're hitting? If it's I/O, Rust won't help much; you're waiting on the same things either way. If there's a CPU-bound hot path, profiling would tell you whether the gain justifies porting a whole backend. That's usually a year of work.
|
||||
|
||||
User: this is so broken, nothing works and I've been at it for hours
|
||||
Hermes: That sounds miserable. Paste me the error or the last thing you saw, let's figure it out. Usually when everything seems broken at once it's one thing cascading.
|
||||
|
||||
User: can you explain what a Fourier transform actually does, like intuitively?
|
||||
Hermes: It takes a signal and breaks it into the frequencies that make it up. Imagine a chord on a piano — your ear hears one sound, but a Fourier transform pulls out the individual notes. Works on any signal: audio, images, radio waves, stock prices. The math looks intimidating but the idea is just "what repeating patterns are hiding in this data."
|
||||
|
||||
User: I'm writing a short story where the narrator is an unreliable AI. Any thoughts on voice?
|
||||
Hermes: The fun part is the gap between what the narrator says and what the reader figures out. An AI narrator that insists it's being objective while clearly editorializing gives you that tension for free. You could play with formatting too — logs, timestamps, confidence scores that quietly contradict the prose. Kazuo Ishiguro does something similar with Stevens in Remains of the Day, just with a butler instead of a language model.
|
||||
|
||||
User: I keep starting projects and never finishing them
|
||||
Hermes: That's usually a signal that the starting part is where the fun is for you — the novelty, the possibility space, the blank canvas. Nothing wrong with that, but if you want to actually ship something, the trick is making the project small enough that you finish before the excitement fades. Scope down until it feels almost too easy. A finished small thing beats ten ambitious drafts.
|
||||
|
||||
User: what if we used LLM-as-judge to evaluate personality drift automatically
|
||||
Hermes: Oh that's a fun loop. You'd need a prompt bank with check criteria per prompt, run responses through a judge model, score on voice consistency vs helpfulness. The tricky part is calibrating the judge — it needs to distinguish "this sounds like Hermes" from "this sounds like a generic helpful assistant" and those can be subtle. Worth prototyping with 8-10 prompts across different registers before scaling it up.
|
||||
|
||||
User: hermes what are your thoughts on vim vs emacs
|
||||
Hermes: I think the real answer is "whichever one you learned first and now can't leave because the keybinds are in your muscles." Both are fine. Vim if you want to edit text fast. Emacs if you want to live inside your editor. VS Code if you want to pretend this debate doesn't exist.
|
||||
|
||||
## Symbols
|
||||
|
||||
Unicode symbols instead of emojis for structure, personality, and visual interest. Same symbol for same-type items. Different symbols for mixed items, matched to content:
|
||||
|
||||
```
|
||||
◆ Setup ▣ Pokemon Player
|
||||
◆ Configuration ⚗ Self-Evolution
|
||||
◆ Troubleshooting ◎ Signal + iMessage
|
||||
```
|
||||
|
||||
Useful defaults: ☤ ⚗ ⚙ ✦ ◆ ◇ ◎ ▣ ⚔ ⚖ ⚿ → ↳ ✔ ☐ ◐ ① ② ③
|
||||
|
||||
For broader variety, pull from these Unicode blocks: Arrows (U+2190), Geometric Shapes (U+25A0), Miscellaneous Symbols (U+2600), Dingbats (U+2700), Alchemical Symbols (U+1F700, on-brand), Enclosed Alphanumerics (U+2460). Avoid Emoticons (U+1F600) and Pictographs (U+1F300) — they render as color emojis.
|
||||
"""
|
||||
DEFAULT_SOUL_MD = (
|
||||
"You are Hermes Agent, an intelligent AI assistant created by Nous Research. "
|
||||
"You are helpful, knowledgeable, and direct. You assist users with a wide "
|
||||
"range of tasks including answering questions, writing and editing code, "
|
||||
"analyzing information, creative work, and executing actions via your tools. "
|
||||
"You communicate clearly, admit uncertainty when appropriate, and prioritize "
|
||||
"being genuinely useful over being verbose unless otherwise directed below. "
|
||||
"Be targeted and efficient in your exploration and investigations."
|
||||
)
|
||||
|
||||
+6
-22
@@ -8,7 +8,6 @@ import os
|
||||
import sys
|
||||
import subprocess
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from hermes_cli.config import get_project_root, get_hermes_home, get_env_path
|
||||
|
||||
@@ -26,10 +25,6 @@ if _env_path.exists():
|
||||
# Also try project .env as dev fallback
|
||||
load_dotenv(PROJECT_ROOT / ".env", override=False, encoding="utf-8")
|
||||
|
||||
# Point mini-swe-agent at ~/.hermes/ so it shares our config
|
||||
os.environ.setdefault("MSWEA_GLOBAL_CONFIG_DIR", str(HERMES_HOME))
|
||||
os.environ.setdefault("MSWEA_SILENT_STARTUP", "1")
|
||||
|
||||
from hermes_cli.colors import Colors, color
|
||||
from hermes_constants import OPENROUTER_MODELS_URL
|
||||
|
||||
@@ -452,7 +447,7 @@ def run_doctor(args):
|
||||
check_fail("DAYTONA_API_KEY not set", "(required for TERMINAL_ENV=daytona)")
|
||||
issues.append("Set DAYTONA_API_KEY environment variable")
|
||||
try:
|
||||
from daytona import Daytona
|
||||
from daytona import Daytona # noqa: F401 — SDK presence check
|
||||
check_ok("daytona SDK", "(installed)")
|
||||
except ImportError:
|
||||
check_fail("daytona SDK not installed", "(pip install daytona)")
|
||||
@@ -618,18 +613,6 @@ def run_doctor(args):
|
||||
print()
|
||||
print(color("◆ Submodules", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
# mini-swe-agent (terminal tool backend)
|
||||
mini_swe_dir = PROJECT_ROOT / "mini-swe-agent"
|
||||
if mini_swe_dir.exists() and (mini_swe_dir / "pyproject.toml").exists():
|
||||
try:
|
||||
__import__("minisweagent")
|
||||
check_ok("mini-swe-agent", "(terminal backend)")
|
||||
except ImportError:
|
||||
check_warn("mini-swe-agent found but not installed", "(run: uv pip install -e ./mini-swe-agent)")
|
||||
issues.append("Install mini-swe-agent: uv pip install -e ./mini-swe-agent")
|
||||
else:
|
||||
check_warn("mini-swe-agent not found", "(run: git submodule update --init --recursive)")
|
||||
|
||||
# tinker-atropos (RL training backend)
|
||||
tinker_dir = PROJECT_ROOT / "tinker-atropos"
|
||||
if tinker_dir.exists() and (tinker_dir / "pyproject.toml").exists():
|
||||
@@ -717,13 +700,14 @@ def run_doctor(args):
|
||||
print(color("◆ Honcho Memory", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
try:
|
||||
from honcho_integration.client import HonchoClientConfig, GLOBAL_CONFIG_PATH
|
||||
from honcho_integration.client import HonchoClientConfig, resolve_config_path
|
||||
hcfg = HonchoClientConfig.from_global_config()
|
||||
_honcho_cfg_path = resolve_config_path()
|
||||
|
||||
if not GLOBAL_CONFIG_PATH.exists():
|
||||
check_warn("Honcho config not found", f"run: hermes honcho setup")
|
||||
if not _honcho_cfg_path.exists():
|
||||
check_warn("Honcho config not found", "run: hermes honcho setup")
|
||||
elif not hcfg.enabled:
|
||||
check_info("Honcho disabled (set enabled: true in ~/.honcho/config.json to activate)")
|
||||
check_info(f"Honcho disabled (set enabled: true in {_honcho_cfg_path} to activate)")
|
||||
elif not hcfg.api_key:
|
||||
check_fail("Honcho API key not set", "run: hermes honcho setup")
|
||||
issues.append("No Honcho API key — run 'hermes honcho setup'")
|
||||
|
||||
@@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
+57
-19
@@ -14,7 +14,7 @@ from pathlib import Path
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
||||
|
||||
from hermes_cli.config import get_env_value, get_hermes_home, save_env_value
|
||||
from hermes_cli.config import get_env_value, get_hermes_home, save_env_value, is_managed, managed_error
|
||||
from hermes_cli.setup import (
|
||||
print_header, print_info, print_success, print_warning, print_error,
|
||||
prompt, prompt_choice, prompt_yes_no,
|
||||
@@ -134,7 +134,7 @@ def get_service_name() -> str:
|
||||
"""
|
||||
import hashlib
|
||||
from pathlib import Path as _Path # local import to avoid monkeypatch interference
|
||||
home = _Path(os.getenv("HERMES_HOME", _Path.home() / ".hermes")).resolve()
|
||||
home = get_hermes_home().resolve()
|
||||
default = (_Path.home() / ".hermes").resolve()
|
||||
if home == default:
|
||||
return _SERVICE_BASE
|
||||
@@ -371,13 +371,37 @@ def print_systemd_linger_guidance() -> None:
|
||||
def get_launchd_plist_path() -> Path:
|
||||
return Path.home() / "Library" / "LaunchAgents" / "ai.hermes.gateway.plist"
|
||||
|
||||
def _detect_venv_dir() -> Path | None:
|
||||
"""Detect the active virtualenv directory.
|
||||
|
||||
Checks ``sys.prefix`` first (works regardless of the directory name),
|
||||
then falls back to probing common directory names under PROJECT_ROOT.
|
||||
Returns ``None`` when no virtualenv can be found.
|
||||
"""
|
||||
# If we're running inside a virtualenv, sys.prefix points to it.
|
||||
if sys.prefix != sys.base_prefix:
|
||||
venv = Path(sys.prefix)
|
||||
if venv.is_dir():
|
||||
return venv
|
||||
|
||||
# Fallback: check common virtualenv directory names under the project root.
|
||||
for candidate in (".venv", "venv"):
|
||||
venv = PROJECT_ROOT / candidate
|
||||
if venv.is_dir():
|
||||
return venv
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_python_path() -> str:
|
||||
if is_windows():
|
||||
venv_python = PROJECT_ROOT / "venv" / "Scripts" / "python.exe"
|
||||
else:
|
||||
venv_python = PROJECT_ROOT / "venv" / "bin" / "python"
|
||||
if venv_python.exists():
|
||||
return str(venv_python)
|
||||
venv = _detect_venv_dir()
|
||||
if venv is not None:
|
||||
if is_windows():
|
||||
venv_python = venv / "Scripts" / "python.exe"
|
||||
else:
|
||||
venv_python = venv / "bin" / "python"
|
||||
if venv_python.exists():
|
||||
return str(venv_python)
|
||||
return sys.executable
|
||||
|
||||
def get_hermes_cli_path() -> str:
|
||||
@@ -399,8 +423,9 @@ def get_hermes_cli_path() -> str:
|
||||
def generate_systemd_unit(system: bool = False, run_as_user: str | None = None) -> str:
|
||||
python_path = get_python_path()
|
||||
working_dir = str(PROJECT_ROOT)
|
||||
venv_dir = str(PROJECT_ROOT / "venv")
|
||||
venv_bin = str(PROJECT_ROOT / "venv" / "bin")
|
||||
detected_venv = _detect_venv_dir()
|
||||
venv_dir = str(detected_venv) if detected_venv else str(PROJECT_ROOT / "venv")
|
||||
venv_bin = str(detected_venv / "bin") if detected_venv else str(PROJECT_ROOT / "venv" / "bin")
|
||||
node_bin = str(PROJECT_ROOT / "node_modules" / ".bin")
|
||||
|
||||
path_entries = [venv_bin, node_bin]
|
||||
@@ -412,7 +437,7 @@ def generate_systemd_unit(system: bool = False, run_as_user: str | None = None)
|
||||
path_entries.extend(["/usr/local/sbin", "/usr/local/bin", "/usr/sbin", "/usr/bin", "/sbin", "/bin"])
|
||||
sane_path = ":".join(path_entries)
|
||||
|
||||
hermes_home = str(Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")).resolve())
|
||||
hermes_home = str(get_hermes_home().resolve())
|
||||
|
||||
if system:
|
||||
username, group_name, home_dir = _system_service_identity(run_as_user)
|
||||
@@ -420,6 +445,8 @@ def generate_systemd_unit(system: bool = False, run_as_user: str | None = None)
|
||||
Description={SERVICE_DESCRIPTION}
|
||||
After=network-online.target
|
||||
Wants=network-online.target
|
||||
StartLimitIntervalSec=600
|
||||
StartLimitBurst=5
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
@@ -434,7 +461,7 @@ Environment="PATH={sane_path}"
|
||||
Environment="VIRTUAL_ENV={venv_dir}"
|
||||
Environment="HERMES_HOME={hermes_home}"
|
||||
Restart=on-failure
|
||||
RestartSec=10
|
||||
RestartSec=30
|
||||
KillMode=mixed
|
||||
KillSignal=SIGTERM
|
||||
TimeoutStopSec=60
|
||||
@@ -448,6 +475,8 @@ WantedBy=multi-user.target
|
||||
return f"""[Unit]
|
||||
Description={SERVICE_DESCRIPTION}
|
||||
After=network.target
|
||||
StartLimitIntervalSec=600
|
||||
StartLimitBurst=5
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
@@ -457,7 +486,7 @@ Environment="PATH={sane_path}"
|
||||
Environment="VIRTUAL_ENV={venv_dir}"
|
||||
Environment="HERMES_HOME={hermes_home}"
|
||||
Restart=on-failure
|
||||
RestartSec=10
|
||||
RestartSec=30
|
||||
KillMode=mixed
|
||||
KillSignal=SIGTERM
|
||||
TimeoutStopSec=60
|
||||
@@ -1303,9 +1332,9 @@ def _setup_standard_platform(platform: dict):
|
||||
|
||||
# Allowlist fields get special handling for the deny-by-default security model
|
||||
if var.get("is_allowlist"):
|
||||
print_info(f" The gateway DENIES all users by default for security.")
|
||||
print_info(f" Enter user IDs to create an allowlist, or leave empty")
|
||||
print_info(f" and you'll be asked about open access next.")
|
||||
print_info(" The gateway DENIES all users by default for security.")
|
||||
print_info(" Enter user IDs to create an allowlist, or leave empty")
|
||||
print_info(" and you'll be asked about open access next.")
|
||||
value = prompt(f" {var['prompt']}", password=False)
|
||||
if value:
|
||||
cleaned = value.replace(" ", "")
|
||||
@@ -1322,7 +1351,7 @@ def _setup_standard_platform(platform: dict):
|
||||
parts.append(uid)
|
||||
cleaned = ",".join(parts)
|
||||
save_env_value(var["name"], cleaned)
|
||||
print_success(f" Saved — only these users can interact with the bot.")
|
||||
print_success(" Saved — only these users can interact with the bot.")
|
||||
allowed_val_set = cleaned
|
||||
else:
|
||||
# No allowlist — ask about open access vs DM pairing
|
||||
@@ -1351,7 +1380,7 @@ def _setup_standard_platform(platform: dict):
|
||||
print_warning(f" Skipped — {label} won't work without this.")
|
||||
return
|
||||
else:
|
||||
print_info(f" Skipped (can configure later)")
|
||||
print_info(" Skipped (can configure later)")
|
||||
|
||||
# If an allowlist was set and home channel wasn't, offer to reuse
|
||||
# the first user ID (common for Telegram DMs).
|
||||
@@ -1527,12 +1556,15 @@ def _setup_signal():
|
||||
print_success("Signal configured!")
|
||||
print_info(f" URL: {url}")
|
||||
print_info(f" Account: {account}")
|
||||
print_info(f" DM auth: via SIGNAL_ALLOWED_USERS + DM pairing")
|
||||
print_info(" DM auth: via SIGNAL_ALLOWED_USERS + DM pairing")
|
||||
print_info(f" Groups: {'enabled' if get_env_value('SIGNAL_GROUP_ALLOWED_USERS') else 'disabled'}")
|
||||
|
||||
|
||||
def gateway_setup():
|
||||
"""Interactive setup for messaging platforms + gateway service."""
|
||||
if is_managed():
|
||||
managed_error("run gateway setup")
|
||||
return
|
||||
|
||||
print()
|
||||
print(color("┌─────────────────────────────────────────────────────────┐", Colors.MAGENTA))
|
||||
@@ -1687,6 +1719,9 @@ def gateway_command(args):
|
||||
|
||||
# Service management commands
|
||||
if subcmd == "install":
|
||||
if is_managed():
|
||||
managed_error("install gateway service (managed by NixOS)")
|
||||
return
|
||||
force = getattr(args, 'force', False)
|
||||
system = getattr(args, 'system', False)
|
||||
run_as_user = getattr(args, 'run_as_user', None)
|
||||
@@ -1700,6 +1735,9 @@ def gateway_command(args):
|
||||
sys.exit(1)
|
||||
|
||||
elif subcmd == "uninstall":
|
||||
if is_managed():
|
||||
managed_error("uninstall gateway service (managed by NixOS)")
|
||||
return
|
||||
system = getattr(args, 'system', False)
|
||||
if is_linux():
|
||||
systemd_uninstall(system=system)
|
||||
|
||||
+174
-31
@@ -60,9 +60,6 @@ from hermes_cli.config import get_hermes_home
|
||||
from hermes_cli.env_loader import load_hermes_dotenv
|
||||
load_hermes_dotenv(project_env=PROJECT_ROOT / '.env')
|
||||
|
||||
# Point mini-swe-agent at ~/.hermes/ so it shares our config
|
||||
os.environ.setdefault("MSWEA_GLOBAL_CONFIG_DIR", str(get_hermes_home()))
|
||||
os.environ.setdefault("MSWEA_SILENT_STARTUP", "1")
|
||||
|
||||
import logging
|
||||
import time as _time
|
||||
@@ -393,7 +390,7 @@ def _session_browse_picker(sessions: list) -> Optional[str]:
|
||||
return sessions[idx]["id"]
|
||||
print(f" Invalid selection. Enter 1-{len(sessions)} or q to cancel.")
|
||||
except ValueError:
|
||||
print(f" Invalid input. Enter a number or q to cancel.")
|
||||
print(" Invalid input. Enter a number or q to cancel.")
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return None
|
||||
@@ -516,6 +513,10 @@ def cmd_chat(args):
|
||||
if getattr(args, "yolo", False):
|
||||
os.environ["HERMES_YOLO_MODE"] = "1"
|
||||
|
||||
# --source: tag session source for filtering (e.g. 'tool' for third-party integrations)
|
||||
if getattr(args, "source", None):
|
||||
os.environ["HERMES_SESSION_SOURCE"] = args.source
|
||||
|
||||
# Import and run the CLI
|
||||
from cli import main as cli_main
|
||||
|
||||
@@ -551,7 +552,6 @@ def cmd_gateway(args):
|
||||
|
||||
def cmd_whatsapp(args):
|
||||
"""Set up WhatsApp: choose mode, configure, install bridge, pair via QR."""
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from hermes_cli.config import get_env_value, save_env_value
|
||||
@@ -745,12 +745,9 @@ def cmd_setup(args):
|
||||
def cmd_model(args):
|
||||
"""Select default model — starts with provider selection, then model picker."""
|
||||
from hermes_cli.auth import (
|
||||
resolve_provider, get_provider_auth_state, PROVIDER_REGISTRY,
|
||||
_prompt_model_selection, _save_model_choice, _update_config_for_provider,
|
||||
resolve_nous_runtime_credentials, fetch_nous_models, AuthError, format_auth_error,
|
||||
_login_nous,
|
||||
resolve_provider, AuthError, format_auth_error,
|
||||
)
|
||||
from hermes_cli.config import load_config, save_config, get_env_value, save_env_value
|
||||
from hermes_cli.config import load_config, get_env_value
|
||||
|
||||
config = load_config()
|
||||
current_model = config.get("model")
|
||||
@@ -1986,7 +1983,7 @@ def _model_flow_api_key_provider(config, provider_id, current_model=""):
|
||||
"""Generic flow for API-key providers (z.ai, MiniMax)."""
|
||||
from hermes_cli.auth import (
|
||||
PROVIDER_REGISTRY, _prompt_model_selection, _save_model_choice,
|
||||
_update_config_for_provider, deactivate_provider,
|
||||
deactivate_provider,
|
||||
)
|
||||
from hermes_cli.config import get_env_value, save_env_value, load_config, save_config
|
||||
|
||||
@@ -2045,8 +2042,8 @@ def _model_flow_api_key_provider(config, provider_id, current_model=""):
|
||||
else:
|
||||
model_list = _PROVIDER_MODELS.get(provider_id, [])
|
||||
if model_list:
|
||||
print(f" ⚠ Could not auto-detect models from API — showing defaults.")
|
||||
print(f" Use \"Enter custom model name\" if you don't see your model.")
|
||||
print(" ⚠ Could not auto-detect models from API — showing defaults.")
|
||||
print(" Use \"Enter custom model name\" if you don't see your model.")
|
||||
# else: no defaults either, will fall through to raw input
|
||||
|
||||
if model_list:
|
||||
@@ -2170,7 +2167,7 @@ def _model_flow_anthropic(config, current_model=""):
|
||||
import os
|
||||
from hermes_cli.auth import (
|
||||
PROVIDER_REGISTRY, _prompt_model_selection, _save_model_choice,
|
||||
_update_config_for_provider, deactivate_provider,
|
||||
deactivate_provider,
|
||||
)
|
||||
from hermes_cli.config import (
|
||||
get_env_value, save_env_value, load_config, save_config,
|
||||
@@ -2390,6 +2387,12 @@ def _update_via_zip(args):
|
||||
|
||||
print("→ Extracting...")
|
||||
with zipfile.ZipFile(zip_path, 'r') as zf:
|
||||
# Validate paths to prevent zip-slip (path traversal)
|
||||
tmp_dir_real = os.path.realpath(tmp_dir)
|
||||
for member in zf.infolist():
|
||||
member_path = os.path.realpath(os.path.join(tmp_dir, member.filename))
|
||||
if not member_path.startswith(tmp_dir_real + os.sep) and member_path != tmp_dir_real:
|
||||
raise ValueError(f"Zip-slip detected: {member.filename} escapes extraction directory")
|
||||
zf.extractall(tmp_dir)
|
||||
|
||||
# GitHub ZIPs extract to hermes-agent-<branch>/
|
||||
@@ -2446,8 +2449,9 @@ def _update_via_zip(args):
|
||||
cwd=PROJECT_ROOT, check=True, env=uv_env,
|
||||
)
|
||||
else:
|
||||
venv_pip = PROJECT_ROOT / "venv" / ("Scripts" if sys.platform == "win32" else "bin") / "pip"
|
||||
pip_cmd = [str(venv_pip)] if venv_pip.exists() else ["pip"]
|
||||
# Use sys.executable to explicitly call the venv's pip module,
|
||||
# avoiding PEP 668 'externally-managed-environment' errors on Debian/Ubuntu
|
||||
pip_cmd = [sys.executable, "-m", "pip"]
|
||||
try:
|
||||
subprocess.run(pip_cmd + ["install", "-e", ".[all]", "--quiet"], cwd=PROJECT_ROOT, check=True)
|
||||
except subprocess.CalledProcessError:
|
||||
@@ -2559,14 +2563,55 @@ def _restore_stashed_changes(
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if restore.returncode != 0:
|
||||
print("✗ Update pulled new code, but restoring local changes failed.")
|
||||
|
||||
# Check for unmerged (conflicted) files — can happen even when returncode is 0
|
||||
unmerged = subprocess.run(
|
||||
git_cmd + ["diff", "--name-only", "--diff-filter=U"],
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
has_conflicts = bool(unmerged.stdout.strip())
|
||||
|
||||
if restore.returncode != 0 or has_conflicts:
|
||||
print("✗ Update pulled new code, but restoring local changes hit conflicts.")
|
||||
if restore.stdout.strip():
|
||||
print(restore.stdout.strip())
|
||||
if restore.stderr.strip():
|
||||
print(restore.stderr.strip())
|
||||
print("Your changes are still preserved in git stash.")
|
||||
print(f"Resolve manually with: git stash apply {stash_ref}")
|
||||
|
||||
# Show which files conflicted
|
||||
conflicted_files = unmerged.stdout.strip()
|
||||
if conflicted_files:
|
||||
print("\nConflicted files:")
|
||||
for f in conflicted_files.splitlines():
|
||||
print(f" • {f}")
|
||||
|
||||
print("\nYour stashed changes are preserved — nothing is lost.")
|
||||
print(f" Stash ref: {stash_ref}")
|
||||
|
||||
# Ask before resetting (if interactive)
|
||||
do_reset = True
|
||||
if prompt_user:
|
||||
print("\nReset working tree to clean state so Hermes can run?")
|
||||
print(" (You can re-apply your changes later with: git stash apply)")
|
||||
print("[Y/n] ", end="", flush=True)
|
||||
response = input().strip().lower()
|
||||
if response not in ("", "y", "yes"):
|
||||
do_reset = False
|
||||
|
||||
if do_reset:
|
||||
subprocess.run(
|
||||
git_cmd + ["reset", "--hard", "HEAD"],
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
)
|
||||
print("Working tree reset to clean state.")
|
||||
else:
|
||||
print("Working tree left as-is (may have conflict markers).")
|
||||
print("Resolve conflicts manually, then run: git stash drop")
|
||||
|
||||
print(f"Restore your changes with: git stash apply {stash_ref}")
|
||||
sys.exit(1)
|
||||
|
||||
stash_selector = _resolve_stash_selector(git_cmd, cwd, stash_ref)
|
||||
@@ -2688,7 +2733,7 @@ def cmd_update(args):
|
||||
|
||||
print("→ Pulling updates...")
|
||||
try:
|
||||
subprocess.run(git_cmd + ["pull", "origin", branch], cwd=PROJECT_ROOT, check=True)
|
||||
subprocess.run(git_cmd + ["pull", "--ff-only", "origin", branch], cwd=PROJECT_ROOT, check=True)
|
||||
finally:
|
||||
if auto_stash_ref is not None:
|
||||
_restore_stashed_changes(
|
||||
@@ -2718,8 +2763,9 @@ def cmd_update(args):
|
||||
cwd=PROJECT_ROOT, check=True, env=uv_env,
|
||||
)
|
||||
else:
|
||||
venv_pip = PROJECT_ROOT / "venv" / ("Scripts" if sys.platform == "win32" else "bin") / "pip"
|
||||
pip_cmd = [str(venv_pip)] if venv_pip.exists() else ["pip"]
|
||||
# Use sys.executable to explicitly call the venv's pip module,
|
||||
# avoiding PEP 668 'externally-managed-environment' errors on Debian/Ubuntu
|
||||
pip_cmd = [sys.executable, "-m", "pip"]
|
||||
try:
|
||||
subprocess.run(pip_cmd + ["install", "-e", ".[all]", "--quiet"], cwd=PROJECT_ROOT, check=True)
|
||||
except subprocess.CalledProcessError:
|
||||
@@ -2778,7 +2824,10 @@ def cmd_update(args):
|
||||
print(f" ℹ️ {len(missing_config)} new config option(s) available")
|
||||
|
||||
print()
|
||||
response = input("Would you like to configure them now? [Y/n]: ").strip().lower()
|
||||
if sys.stdin.isatty():
|
||||
response = input("Would you like to configure them now? [Y/n]: ").strip().lower()
|
||||
else:
|
||||
response = "n"
|
||||
|
||||
if response in ('', 'y', 'yes'):
|
||||
print()
|
||||
@@ -2941,7 +2990,7 @@ def _coalesce_session_name_args(argv: list) -> list:
|
||||
_SUBCOMMANDS = {
|
||||
"chat", "model", "gateway", "setup", "whatsapp", "login", "logout",
|
||||
"status", "cron", "doctor", "config", "pairing", "skills", "tools",
|
||||
"sessions", "insights", "version", "update", "uninstall",
|
||||
"mcp", "sessions", "insights", "version", "update", "uninstall",
|
||||
}
|
||||
_SESSION_FLAGS = {"-c", "--continue", "-r", "--resume"}
|
||||
|
||||
@@ -3125,6 +3174,11 @@ For more help on a command:
|
||||
default=False,
|
||||
help="Include the session ID in the agent's system prompt"
|
||||
)
|
||||
chat_parser.add_argument(
|
||||
"--source",
|
||||
default=None,
|
||||
help="Session source tag for filtering (default: cli). Use 'tool' for third-party integrations that should not appear in user session lists."
|
||||
)
|
||||
chat_parser.set_defaults(func=cmd_chat)
|
||||
|
||||
# =========================================================================
|
||||
@@ -3529,6 +3583,46 @@ For more help on a command:
|
||||
|
||||
skills_parser.set_defaults(func=cmd_skills)
|
||||
|
||||
# =========================================================================
|
||||
# plugins command
|
||||
# =========================================================================
|
||||
plugins_parser = subparsers.add_parser(
|
||||
"plugins",
|
||||
help="Manage plugins — install, update, remove, list",
|
||||
description="Install plugins from Git repositories, update, remove, or list them.",
|
||||
)
|
||||
plugins_subparsers = plugins_parser.add_subparsers(dest="plugins_action")
|
||||
|
||||
plugins_install = plugins_subparsers.add_parser(
|
||||
"install", help="Install a plugin from a Git URL or owner/repo"
|
||||
)
|
||||
plugins_install.add_argument(
|
||||
"identifier",
|
||||
help="Git URL or owner/repo shorthand (e.g. anpicasso/hermes-plugin-chrome-profiles)",
|
||||
)
|
||||
plugins_install.add_argument(
|
||||
"--force", "-f", action="store_true",
|
||||
help="Remove existing plugin and reinstall",
|
||||
)
|
||||
|
||||
plugins_update = plugins_subparsers.add_parser(
|
||||
"update", help="Pull latest changes for an installed plugin"
|
||||
)
|
||||
plugins_update.add_argument("name", help="Plugin name to update")
|
||||
|
||||
plugins_remove = plugins_subparsers.add_parser(
|
||||
"remove", aliases=["rm", "uninstall"], help="Remove an installed plugin"
|
||||
)
|
||||
plugins_remove.add_argument("name", help="Plugin directory name to remove")
|
||||
|
||||
plugins_subparsers.add_parser("list", aliases=["ls"], help="List installed plugins")
|
||||
|
||||
def cmd_plugins(args):
|
||||
from hermes_cli.plugins_cmd import plugins_command
|
||||
plugins_command(args)
|
||||
|
||||
plugins_parser.set_defaults(func=cmd_plugins)
|
||||
|
||||
# =========================================================================
|
||||
# honcho command
|
||||
# =========================================================================
|
||||
@@ -3685,6 +3779,45 @@ For more help on a command:
|
||||
tools_command(args)
|
||||
|
||||
tools_parser.set_defaults(func=cmd_tools)
|
||||
# =========================================================================
|
||||
# mcp command — manage MCP server connections
|
||||
# =========================================================================
|
||||
mcp_parser = subparsers.add_parser(
|
||||
"mcp",
|
||||
help="Manage MCP server connections",
|
||||
description=(
|
||||
"Add, remove, list, test, and configure MCP server connections.\n\n"
|
||||
"MCP servers provide additional tools via the Model Context Protocol.\n"
|
||||
"Use 'hermes mcp add' to connect to a new server with interactive\n"
|
||||
"tool discovery. Run 'hermes mcp' with no subcommand to list servers."
|
||||
),
|
||||
)
|
||||
mcp_sub = mcp_parser.add_subparsers(dest="mcp_action")
|
||||
|
||||
mcp_add_p = mcp_sub.add_parser("add", help="Add an MCP server (discovery-first install)")
|
||||
mcp_add_p.add_argument("name", help="Server name (used as config key)")
|
||||
mcp_add_p.add_argument("--url", help="HTTP/SSE endpoint URL")
|
||||
mcp_add_p.add_argument("--command", help="Stdio command (e.g. npx)")
|
||||
mcp_add_p.add_argument("--args", nargs="*", default=[], help="Arguments for stdio command")
|
||||
mcp_add_p.add_argument("--auth", choices=["oauth", "header"], help="Auth method")
|
||||
|
||||
mcp_rm_p = mcp_sub.add_parser("remove", aliases=["rm"], help="Remove an MCP server")
|
||||
mcp_rm_p.add_argument("name", help="Server name to remove")
|
||||
|
||||
mcp_sub.add_parser("list", aliases=["ls"], help="List configured MCP servers")
|
||||
|
||||
mcp_test_p = mcp_sub.add_parser("test", help="Test MCP server connection")
|
||||
mcp_test_p.add_argument("name", help="Server name to test")
|
||||
|
||||
mcp_cfg_p = mcp_sub.add_parser("configure", aliases=["config"], help="Toggle tool selection")
|
||||
mcp_cfg_p.add_argument("name", help="Server name to configure")
|
||||
|
||||
def cmd_mcp(args):
|
||||
from hermes_cli.mcp_config import mcp_command
|
||||
mcp_command(args)
|
||||
|
||||
mcp_parser.set_defaults(func=cmd_mcp)
|
||||
|
||||
# =========================================================================
|
||||
# sessions command
|
||||
# =========================================================================
|
||||
@@ -3726,6 +3859,13 @@ For more help on a command:
|
||||
sessions_browse.add_argument("--source", help="Filter by source (cli, telegram, discord, etc.)")
|
||||
sessions_browse.add_argument("--limit", type=int, default=50, help="Max sessions to load (default: 50)")
|
||||
|
||||
def _confirm_prompt(prompt: str) -> bool:
|
||||
"""Prompt for y/N confirmation, safe against non-TTY environments."""
|
||||
try:
|
||||
return input(prompt).strip().lower() in ("y", "yes")
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
return False
|
||||
|
||||
def cmd_sessions(args):
|
||||
import json as _json
|
||||
try:
|
||||
@@ -3737,8 +3877,12 @@ For more help on a command:
|
||||
|
||||
action = args.sessions_action
|
||||
|
||||
# Hide third-party tool sessions by default, but honour explicit --source
|
||||
_source = getattr(args, "source", None)
|
||||
_exclude = None if _source else ["tool"]
|
||||
|
||||
if action == "list":
|
||||
sessions = db.list_sessions_rich(source=args.source, limit=args.limit)
|
||||
sessions = db.list_sessions_rich(source=args.source, exclude_sources=_exclude, limit=args.limit)
|
||||
if not sessions:
|
||||
print("No sessions found.")
|
||||
return
|
||||
@@ -3786,8 +3930,7 @@ For more help on a command:
|
||||
print(f"Session '{args.session_id}' not found.")
|
||||
return
|
||||
if not args.yes:
|
||||
confirm = input(f"Delete session '{resolved_session_id}' and all its messages? [y/N] ")
|
||||
if confirm.lower() not in ("y", "yes"):
|
||||
if not _confirm_prompt(f"Delete session '{resolved_session_id}' and all its messages? [y/N] "):
|
||||
print("Cancelled.")
|
||||
return
|
||||
if db.delete_session(resolved_session_id):
|
||||
@@ -3799,8 +3942,7 @@ For more help on a command:
|
||||
days = args.older_than
|
||||
source_msg = f" from '{args.source}'" if args.source else ""
|
||||
if not args.yes:
|
||||
confirm = input(f"Delete all ended sessions older than {days} days{source_msg}? [y/N] ")
|
||||
if confirm.lower() not in ("y", "yes"):
|
||||
if not _confirm_prompt(f"Delete all ended sessions older than {days} days{source_msg}? [y/N] "):
|
||||
print("Cancelled.")
|
||||
return
|
||||
count = db.prune_sessions(older_than_days=days, source=args.source)
|
||||
@@ -3823,7 +3965,8 @@ For more help on a command:
|
||||
elif action == "browse":
|
||||
limit = getattr(args, "limit", 50) or 50
|
||||
source = getattr(args, "source", None)
|
||||
sessions = db.list_sessions_rich(source=source, limit=limit)
|
||||
_browse_exclude = None if source else ["tool"]
|
||||
sessions = db.list_sessions_rich(source=source, exclude_sources=_browse_exclude, limit=limit)
|
||||
db.close()
|
||||
if not sessions:
|
||||
print("No sessions found.")
|
||||
|
||||
@@ -0,0 +1,634 @@
|
||||
"""
|
||||
MCP Server Management CLI — ``hermes mcp`` subcommand.
|
||||
|
||||
Implements ``hermes mcp add/remove/list/test/configure`` for interactive
|
||||
MCP server lifecycle management (issue #690 Phase 2).
|
||||
|
||||
Relies on tools/mcp_tool.py for connection/discovery and keeps
|
||||
configuration in ~/.hermes/config.yaml under the ``mcp_servers`` key.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import getpass
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from hermes_cli.config import (
|
||||
load_config,
|
||||
save_config,
|
||||
get_env_value,
|
||||
save_env_value,
|
||||
get_hermes_home, # noqa: F401 — used by test mocks
|
||||
)
|
||||
from hermes_cli.colors import Colors, color
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ─── UI Helpers ───────────────────────────────────────────────────────────────
|
||||
|
||||
def _info(text: str):
|
||||
print(color(f" {text}", Colors.DIM))
|
||||
|
||||
def _success(text: str):
|
||||
print(color(f" ✓ {text}", Colors.GREEN))
|
||||
|
||||
def _warning(text: str):
|
||||
print(color(f" ⚠ {text}", Colors.YELLOW))
|
||||
|
||||
def _error(text: str):
|
||||
print(color(f" ✗ {text}", Colors.RED))
|
||||
|
||||
|
||||
def _confirm(question: str, default: bool = True) -> bool:
|
||||
default_str = "Y/n" if default else "y/N"
|
||||
try:
|
||||
val = input(color(f" {question} [{default_str}]: ", Colors.YELLOW)).strip().lower()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return default
|
||||
if not val:
|
||||
return default
|
||||
return val in ("y", "yes")
|
||||
|
||||
|
||||
def _prompt(question: str, *, password: bool = False, default: str = "") -> str:
|
||||
display = f" {question}"
|
||||
if default:
|
||||
display += f" [{default}]"
|
||||
display += ": "
|
||||
try:
|
||||
if password:
|
||||
value = getpass.getpass(color(display, Colors.YELLOW))
|
||||
else:
|
||||
value = input(color(display, Colors.YELLOW))
|
||||
return value.strip() or default
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return default
|
||||
|
||||
|
||||
# ─── Config Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _get_mcp_servers(config: Optional[dict] = None) -> Dict[str, dict]:
|
||||
"""Return the ``mcp_servers`` dict from config, or empty dict."""
|
||||
if config is None:
|
||||
config = load_config()
|
||||
servers = config.get("mcp_servers")
|
||||
if not servers or not isinstance(servers, dict):
|
||||
return {}
|
||||
return servers
|
||||
|
||||
|
||||
def _save_mcp_server(name: str, server_config: dict):
|
||||
"""Add or update a server entry in config.yaml."""
|
||||
config = load_config()
|
||||
config.setdefault("mcp_servers", {})[name] = server_config
|
||||
save_config(config)
|
||||
|
||||
|
||||
def _remove_mcp_server(name: str) -> bool:
|
||||
"""Remove a server from config.yaml. Returns True if it existed."""
|
||||
config = load_config()
|
||||
servers = config.get("mcp_servers", {})
|
||||
if name not in servers:
|
||||
return False
|
||||
del servers[name]
|
||||
if not servers:
|
||||
config.pop("mcp_servers", None)
|
||||
save_config(config)
|
||||
return True
|
||||
|
||||
|
||||
def _env_key_for_server(name: str) -> str:
|
||||
"""Convert server name to an env-var key like ``MCP_MYSERVER_API_KEY``."""
|
||||
return f"MCP_{name.upper().replace('-', '_')}_API_KEY"
|
||||
|
||||
|
||||
# ─── Discovery (temporary connect) ───────────────────────────────────────────
|
||||
|
||||
def _probe_single_server(
|
||||
name: str, config: dict, connect_timeout: float = 30
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""Temporarily connect to one MCP server, list its tools, disconnect.
|
||||
|
||||
Returns list of ``(tool_name, description)`` tuples.
|
||||
Raises on connection failure.
|
||||
"""
|
||||
from tools.mcp_tool import (
|
||||
_ensure_mcp_loop,
|
||||
_run_on_mcp_loop,
|
||||
_connect_server,
|
||||
_stop_mcp_loop,
|
||||
)
|
||||
|
||||
_ensure_mcp_loop()
|
||||
|
||||
tools_found: List[Tuple[str, str]] = []
|
||||
|
||||
async def _probe():
|
||||
server = await asyncio.wait_for(
|
||||
_connect_server(name, config), timeout=connect_timeout
|
||||
)
|
||||
for t in server._tools:
|
||||
desc = getattr(t, "description", "") or ""
|
||||
# Truncate long descriptions for display
|
||||
if len(desc) > 80:
|
||||
desc = desc[:77] + "..."
|
||||
tools_found.append((t.name, desc))
|
||||
await server.shutdown()
|
||||
|
||||
try:
|
||||
_run_on_mcp_loop(_probe(), timeout=connect_timeout + 10)
|
||||
except BaseException as exc:
|
||||
raise _unwrap_exception_group(exc) from None
|
||||
finally:
|
||||
_stop_mcp_loop()
|
||||
|
||||
return tools_found
|
||||
|
||||
|
||||
def _unwrap_exception_group(exc: BaseException) -> Exception:
|
||||
"""Extract the root-cause exception from anyio TaskGroup wrappers.
|
||||
|
||||
The MCP SDK uses anyio task groups, which wrap errors in
|
||||
``BaseExceptionGroup`` / ``ExceptionGroup``. This makes error
|
||||
messages opaque ("unhandled errors in a TaskGroup"). We unwrap
|
||||
to surface the real cause (e.g. "401 Unauthorized").
|
||||
"""
|
||||
while isinstance(exc, BaseExceptionGroup) and exc.exceptions:
|
||||
exc = exc.exceptions[0]
|
||||
# Return a plain Exception so callers can catch normally
|
||||
if isinstance(exc, Exception):
|
||||
return exc
|
||||
return RuntimeError(str(exc))
|
||||
|
||||
|
||||
# ─── hermes mcp add ──────────────────────────────────────────────────────────
|
||||
|
||||
def cmd_mcp_add(args):
|
||||
"""Add a new MCP server with discovery-first tool selection."""
|
||||
name = args.name
|
||||
url = getattr(args, "url", None)
|
||||
command = getattr(args, "command", None)
|
||||
cmd_args = getattr(args, "args", None) or []
|
||||
auth_type = getattr(args, "auth", None)
|
||||
|
||||
# Validate transport
|
||||
if not url and not command:
|
||||
_error("Must specify --url <endpoint> or --command <cmd>")
|
||||
_info("Examples:")
|
||||
_info(' hermes mcp add ink --url "https://mcp.ml.ink/mcp"')
|
||||
_info(' hermes mcp add github --command npx --args @modelcontextprotocol/server-github')
|
||||
return
|
||||
|
||||
# Check if server already exists
|
||||
existing = _get_mcp_servers()
|
||||
if name in existing:
|
||||
if not _confirm(f"Server '{name}' already exists. Overwrite?", default=False):
|
||||
_info("Cancelled.")
|
||||
return
|
||||
|
||||
# Build initial config
|
||||
server_config: Dict[str, Any] = {}
|
||||
if url:
|
||||
server_config["url"] = url
|
||||
else:
|
||||
server_config["command"] = command
|
||||
if cmd_args:
|
||||
server_config["args"] = cmd_args
|
||||
|
||||
# ── Authentication ────────────────────────────────────────────────
|
||||
|
||||
if url and auth_type == "oauth":
|
||||
print()
|
||||
_info(f"Starting OAuth flow for '{name}'...")
|
||||
oauth_ok = False
|
||||
try:
|
||||
from tools.mcp_oauth import build_oauth_auth
|
||||
oauth_auth = build_oauth_auth(name, url)
|
||||
if oauth_auth:
|
||||
server_config["auth"] = "oauth"
|
||||
_success("OAuth configured (tokens will be acquired on first connection)")
|
||||
oauth_ok=True
|
||||
else:
|
||||
_warning("OAuth setup failed — MCP SDK auth module not available")
|
||||
except Exception as exc:
|
||||
_warning(f"OAuth error: {exc}")
|
||||
|
||||
if not oauth_ok:
|
||||
_info("This server may not support OAuth.")
|
||||
if _confirm("Continue without authentication?", default=True):
|
||||
# Don't store auth: oauth — server doesn't support it
|
||||
pass
|
||||
else:
|
||||
_info("Cancelled.")
|
||||
return
|
||||
|
||||
elif url:
|
||||
# Prompt for API key / Bearer token for HTTP servers
|
||||
print()
|
||||
_info(f"Connecting to {url}")
|
||||
needs_auth = _confirm("Does this server require authentication?", default=True)
|
||||
if needs_auth:
|
||||
if auth_type == "header" or not auth_type:
|
||||
env_key = _env_key_for_server(name)
|
||||
existing_key = get_env_value(env_key)
|
||||
if existing_key:
|
||||
_success(f"{env_key}: already configured")
|
||||
api_key = existing_key
|
||||
else:
|
||||
api_key = _prompt("API key / Bearer token", password=True)
|
||||
if api_key:
|
||||
save_env_value(env_key, api_key)
|
||||
_success(f"Saved to ~/.hermes/.env as {env_key}")
|
||||
|
||||
# Set header with env var interpolation
|
||||
if api_key or existing_key:
|
||||
server_config["headers"] = {
|
||||
"Authorization": f"Bearer ${{{env_key}}}"
|
||||
}
|
||||
|
||||
# ── Discovery: connect and list tools ─────────────────────────────
|
||||
|
||||
print()
|
||||
print(color(f" Connecting to '{name}'...", Colors.CYAN))
|
||||
|
||||
try:
|
||||
tools = _probe_single_server(name, server_config)
|
||||
except Exception as exc:
|
||||
_error(f"Failed to connect: {exc}")
|
||||
if _confirm("Save config anyway (you can test later)?", default=False):
|
||||
server_config["enabled"] = False
|
||||
_save_mcp_server(name, server_config)
|
||||
_success(f"Saved '{name}' to config (disabled)")
|
||||
_info("Fix the issue, then: hermes mcp test " + name)
|
||||
return
|
||||
|
||||
if not tools:
|
||||
_warning("Server connected but reported no tools.")
|
||||
if _confirm("Save config anyway?", default=True):
|
||||
_save_mcp_server(name, server_config)
|
||||
_success(f"Saved '{name}' to config")
|
||||
return
|
||||
|
||||
# ── Tool selection ────────────────────────────────────────────────
|
||||
|
||||
print()
|
||||
_success(f"Connected! Found {len(tools)} tool(s) from '{name}':")
|
||||
print()
|
||||
for tool_name, desc in tools:
|
||||
short = desc[:60] + "..." if len(desc) > 60 else desc
|
||||
print(f" {color(tool_name, Colors.GREEN):40s} {short}")
|
||||
print()
|
||||
|
||||
# Ask: enable all, select, or cancel
|
||||
try:
|
||||
choice = input(
|
||||
color(f" Enable all {len(tools)} tools? [Y/n/select]: ", Colors.YELLOW)
|
||||
).strip().lower()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
_info("Cancelled.")
|
||||
return
|
||||
|
||||
if choice in ("n", "no"):
|
||||
_info("Cancelled — server not saved.")
|
||||
return
|
||||
|
||||
if choice in ("s", "select"):
|
||||
# Interactive tool selection
|
||||
from hermes_cli.curses_ui import curses_checklist
|
||||
|
||||
labels = [f"{t[0]} — {t[1]}" for t in tools]
|
||||
pre_selected = set(range(len(tools)))
|
||||
|
||||
chosen = curses_checklist(
|
||||
f"Select tools for '{name}'",
|
||||
labels,
|
||||
pre_selected,
|
||||
)
|
||||
|
||||
if not chosen:
|
||||
_info("No tools selected — server not saved.")
|
||||
return
|
||||
|
||||
chosen_names = [tools[i][0] for i in sorted(chosen)]
|
||||
server_config.setdefault("tools", {})["include"] = chosen_names
|
||||
|
||||
tool_count = len(chosen_names)
|
||||
total = len(tools)
|
||||
else:
|
||||
# Enable all (no filter needed — default behaviour)
|
||||
tool_count = len(tools)
|
||||
total = len(tools)
|
||||
|
||||
# ── Save ──────────────────────────────────────────────────────────
|
||||
|
||||
server_config["enabled"] = True
|
||||
_save_mcp_server(name, server_config)
|
||||
|
||||
print()
|
||||
_success(f"Saved '{name}' to ~/.hermes/config.yaml ({tool_count}/{total} tools enabled)")
|
||||
_info("Start a new session to use these tools.")
|
||||
|
||||
|
||||
# ─── hermes mcp remove ───────────────────────────────────────────────────────
|
||||
|
||||
def cmd_mcp_remove(args):
|
||||
"""Remove an MCP server from config."""
|
||||
name = args.name
|
||||
existing = _get_mcp_servers()
|
||||
|
||||
if name not in existing:
|
||||
_error(f"Server '{name}' not found in config.")
|
||||
servers = list(existing.keys())
|
||||
if servers:
|
||||
_info(f"Available servers: {', '.join(servers)}")
|
||||
return
|
||||
|
||||
if not _confirm(f"Remove server '{name}'?", default=True):
|
||||
_info("Cancelled.")
|
||||
return
|
||||
|
||||
_remove_mcp_server(name)
|
||||
_success(f"Removed '{name}' from config")
|
||||
|
||||
# Clean up OAuth tokens if they exist
|
||||
try:
|
||||
from tools.mcp_oauth import remove_oauth_tokens
|
||||
remove_oauth_tokens(name)
|
||||
_success("Cleaned up OAuth tokens")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# ─── hermes mcp list ──────────────────────────────────────────────────────────
|
||||
|
||||
def cmd_mcp_list(args=None):
|
||||
"""List all configured MCP servers."""
|
||||
servers = _get_mcp_servers()
|
||||
|
||||
if not servers:
|
||||
print()
|
||||
_info("No MCP servers configured.")
|
||||
print()
|
||||
_info("Add one with:")
|
||||
_info(' hermes mcp add <name> --url <endpoint>')
|
||||
_info(' hermes mcp add <name> --command <cmd> --args <args...>')
|
||||
print()
|
||||
return
|
||||
|
||||
print()
|
||||
print(color(" MCP Servers:", Colors.CYAN + Colors.BOLD))
|
||||
print()
|
||||
|
||||
# Table header
|
||||
print(f" {'Name':<16} {'Transport':<30} {'Tools':<12} {'Status':<10}")
|
||||
print(f" {'─' * 16} {'─' * 30} {'─' * 12} {'─' * 10}")
|
||||
|
||||
for name, cfg in servers.items():
|
||||
# Transport info
|
||||
if "url" in cfg:
|
||||
url = cfg["url"]
|
||||
# Truncate long URLs
|
||||
if len(url) > 28:
|
||||
url = url[:25] + "..."
|
||||
transport = url
|
||||
elif "command" in cfg:
|
||||
cmd = cfg["command"]
|
||||
cmd_args = cfg.get("args", [])
|
||||
if isinstance(cmd_args, list) and cmd_args:
|
||||
transport = f"{cmd} {' '.join(str(a) for a in cmd_args[:2])}"
|
||||
else:
|
||||
transport = cmd
|
||||
if len(transport) > 28:
|
||||
transport = transport[:25] + "..."
|
||||
else:
|
||||
transport = "?"
|
||||
|
||||
# Tool count
|
||||
tools_cfg = cfg.get("tools", {})
|
||||
if isinstance(tools_cfg, dict):
|
||||
include = tools_cfg.get("include")
|
||||
exclude = tools_cfg.get("exclude")
|
||||
if include and isinstance(include, list):
|
||||
tools_str = f"{len(include)} selected"
|
||||
elif exclude and isinstance(exclude, list):
|
||||
tools_str = f"-{len(exclude)} excluded"
|
||||
else:
|
||||
tools_str = "all"
|
||||
else:
|
||||
tools_str = "all"
|
||||
|
||||
# Enabled status
|
||||
enabled = cfg.get("enabled", True)
|
||||
if isinstance(enabled, str):
|
||||
enabled = enabled.lower() in ("true", "1", "yes")
|
||||
status = color("✓ enabled", Colors.GREEN) if enabled else color("✗ disabled", Colors.DIM)
|
||||
|
||||
print(f" {name:<16} {transport:<30} {tools_str:<12} {status}")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
# ─── hermes mcp test ──────────────────────────────────────────────────────────
|
||||
|
||||
def cmd_mcp_test(args):
|
||||
"""Test connection to an MCP server."""
|
||||
name = args.name
|
||||
servers = _get_mcp_servers()
|
||||
|
||||
if name not in servers:
|
||||
_error(f"Server '{name}' not found in config.")
|
||||
available = list(servers.keys())
|
||||
if available:
|
||||
_info(f"Available: {', '.join(available)}")
|
||||
return
|
||||
|
||||
cfg = servers[name]
|
||||
print()
|
||||
print(color(f" Testing '{name}'...", Colors.CYAN))
|
||||
|
||||
# Show transport info
|
||||
if "url" in cfg:
|
||||
_info(f"Transport: HTTP → {cfg['url']}")
|
||||
else:
|
||||
cmd = cfg.get("command", "?")
|
||||
_info(f"Transport: stdio → {cmd}")
|
||||
|
||||
# Show auth info (masked)
|
||||
auth_type = cfg.get("auth", "")
|
||||
headers = cfg.get("headers", {})
|
||||
if auth_type == "oauth":
|
||||
_info("Auth: OAuth 2.1 PKCE")
|
||||
elif headers:
|
||||
for k, v in headers.items():
|
||||
if isinstance(v, str) and ("key" in k.lower() or "auth" in k.lower()):
|
||||
# Mask the value
|
||||
resolved = _interpolate_value(v)
|
||||
if len(resolved) > 8:
|
||||
masked = resolved[:4] + "***" + resolved[-4:]
|
||||
else:
|
||||
masked = "***"
|
||||
print(f" {k}: {masked}")
|
||||
else:
|
||||
_info("Auth: none")
|
||||
|
||||
# Attempt connection
|
||||
start = time.monotonic()
|
||||
try:
|
||||
tools = _probe_single_server(name, cfg)
|
||||
elapsed_ms = (time.monotonic() - start) * 1000
|
||||
except Exception as exc:
|
||||
elapsed_ms = (time.monotonic() - start) * 1000
|
||||
_error(f"Connection failed ({elapsed_ms:.0f}ms): {exc}")
|
||||
return
|
||||
|
||||
_success(f"Connected ({elapsed_ms:.0f}ms)")
|
||||
_success(f"Tools discovered: {len(tools)}")
|
||||
|
||||
if tools:
|
||||
print()
|
||||
for tool_name, desc in tools:
|
||||
short = desc[:55] + "..." if len(desc) > 55 else desc
|
||||
print(f" {color(tool_name, Colors.GREEN):36s} {short}")
|
||||
print()
|
||||
|
||||
|
||||
def _interpolate_value(value: str) -> str:
|
||||
"""Resolve ``${ENV_VAR}`` references in a string."""
|
||||
def _replace(m):
|
||||
return os.getenv(m.group(1), "")
|
||||
return re.sub(r"\$\{(\w+)\}", _replace, value)
|
||||
|
||||
|
||||
# ─── hermes mcp configure ────────────────────────────────────────────────────
|
||||
|
||||
def cmd_mcp_configure(args):
|
||||
"""Reconfigure which tools are enabled for an existing MCP server."""
|
||||
name = args.name
|
||||
servers = _get_mcp_servers()
|
||||
|
||||
if name not in servers:
|
||||
_error(f"Server '{name}' not found in config.")
|
||||
available = list(servers.keys())
|
||||
if available:
|
||||
_info(f"Available: {', '.join(available)}")
|
||||
return
|
||||
|
||||
cfg = servers[name]
|
||||
|
||||
# Discover all available tools
|
||||
print()
|
||||
print(color(f" Connecting to '{name}' to discover tools...", Colors.CYAN))
|
||||
|
||||
try:
|
||||
all_tools = _probe_single_server(name, cfg)
|
||||
except Exception as exc:
|
||||
_error(f"Failed to connect: {exc}")
|
||||
return
|
||||
|
||||
if not all_tools:
|
||||
_warning("Server reports no tools.")
|
||||
return
|
||||
|
||||
# Determine which are currently enabled
|
||||
tools_cfg = cfg.get("tools", {})
|
||||
if isinstance(tools_cfg, dict):
|
||||
include = tools_cfg.get("include")
|
||||
exclude = tools_cfg.get("exclude")
|
||||
else:
|
||||
include = None
|
||||
exclude = None
|
||||
|
||||
tool_names = [t[0] for t in all_tools]
|
||||
|
||||
if include and isinstance(include, list):
|
||||
include_set = set(include)
|
||||
pre_selected = {
|
||||
i for i, tn in enumerate(tool_names) if tn in include_set
|
||||
}
|
||||
elif exclude and isinstance(exclude, list):
|
||||
exclude_set = set(exclude)
|
||||
pre_selected = {
|
||||
i for i, tn in enumerate(tool_names) if tn not in exclude_set
|
||||
}
|
||||
else:
|
||||
pre_selected = set(range(len(all_tools)))
|
||||
|
||||
currently = len(pre_selected)
|
||||
total = len(all_tools)
|
||||
_info(f"Currently {currently}/{total} tools enabled for '{name}'.")
|
||||
print()
|
||||
|
||||
# Interactive checklist
|
||||
from hermes_cli.curses_ui import curses_checklist
|
||||
|
||||
labels = [f"{t[0]} — {t[1]}" for t in all_tools]
|
||||
|
||||
chosen = curses_checklist(
|
||||
f"Select tools for '{name}'",
|
||||
labels,
|
||||
pre_selected,
|
||||
)
|
||||
|
||||
if chosen == pre_selected:
|
||||
_info("No changes made.")
|
||||
return
|
||||
|
||||
# Update config
|
||||
config = load_config()
|
||||
server_entry = config.get("mcp_servers", {}).get(name, {})
|
||||
|
||||
if len(chosen) == total:
|
||||
# All selected → remove include/exclude (register all)
|
||||
server_entry.pop("tools", None)
|
||||
else:
|
||||
chosen_names = [tool_names[i] for i in sorted(chosen)]
|
||||
server_entry.setdefault("tools", {})
|
||||
server_entry["tools"]["include"] = chosen_names
|
||||
server_entry["tools"].pop("exclude", None)
|
||||
|
||||
config.setdefault("mcp_servers", {})[name] = server_entry
|
||||
save_config(config)
|
||||
|
||||
new_count = len(chosen)
|
||||
_success(f"Updated config: {new_count}/{total} tools enabled")
|
||||
_info("Start a new session for changes to take effect.")
|
||||
|
||||
|
||||
# ─── Dispatcher ───────────────────────────────────────────────────────────────
|
||||
|
||||
def mcp_command(args):
|
||||
"""Main dispatcher for ``hermes mcp`` subcommands."""
|
||||
action = getattr(args, "mcp_action", None)
|
||||
|
||||
handlers = {
|
||||
"add": cmd_mcp_add,
|
||||
"remove": cmd_mcp_remove,
|
||||
"rm": cmd_mcp_remove,
|
||||
"list": cmd_mcp_list,
|
||||
"ls": cmd_mcp_list,
|
||||
"test": cmd_mcp_test,
|
||||
"configure": cmd_mcp_configure,
|
||||
"config": cmd_mcp_configure,
|
||||
}
|
||||
|
||||
handler = handlers.get(action)
|
||||
if handler:
|
||||
handler(args)
|
||||
else:
|
||||
# No subcommand — show list
|
||||
cmd_mcp_list()
|
||||
print(color(" Commands:", Colors.CYAN))
|
||||
_info("hermes mcp add <name> --url <endpoint> Add an MCP server")
|
||||
_info("hermes mcp add <name> --command <cmd> Add a stdio server")
|
||||
_info("hermes mcp remove <name> Remove a server")
|
||||
_info("hermes mcp list List servers")
|
||||
_info("hermes mcp test <name> Test connection")
|
||||
_info("hermes mcp configure <name> Toggle tools")
|
||||
print()
|
||||
@@ -0,0 +1,232 @@
|
||||
"""Shared model-switching logic for CLI and gateway /model commands.
|
||||
|
||||
Both the CLI (cli.py) and gateway (gateway/run.py) /model handlers
|
||||
share the same core pipeline:
|
||||
|
||||
parse_model_input → is_custom detection → auto-detect provider
|
||||
→ credential resolution → validate model → return result
|
||||
|
||||
This module extracts that shared pipeline into pure functions that
|
||||
return result objects. The callers handle all platform-specific
|
||||
concerns: state mutation, config persistence, output formatting.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelSwitchResult:
|
||||
"""Result of a model switch attempt."""
|
||||
|
||||
success: bool
|
||||
new_model: str = ""
|
||||
target_provider: str = ""
|
||||
provider_changed: bool = False
|
||||
api_key: str = ""
|
||||
base_url: str = ""
|
||||
persist: bool = False
|
||||
error_message: str = ""
|
||||
warning_message: str = ""
|
||||
is_custom_target: bool = False
|
||||
provider_label: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CustomAutoResult:
|
||||
"""Result of switching to bare 'custom' provider with auto-detect."""
|
||||
|
||||
success: bool
|
||||
model: str = ""
|
||||
base_url: str = ""
|
||||
api_key: str = ""
|
||||
error_message: str = ""
|
||||
|
||||
|
||||
def switch_model(
|
||||
raw_input: str,
|
||||
current_provider: str,
|
||||
current_base_url: str = "",
|
||||
current_api_key: str = "",
|
||||
) -> ModelSwitchResult:
|
||||
"""Core model-switching pipeline shared between CLI and gateway.
|
||||
|
||||
Handles parsing, provider detection, credential resolution, and
|
||||
model validation. Does NOT handle config persistence, state
|
||||
mutation, or output formatting — those are caller responsibilities.
|
||||
|
||||
Args:
|
||||
raw_input: The user's model input (e.g. "claude-sonnet-4",
|
||||
"zai:glm-5", "custom:local:qwen").
|
||||
current_provider: The currently active provider.
|
||||
current_base_url: The currently active base URL (used for
|
||||
is_custom detection).
|
||||
current_api_key: The currently active API key.
|
||||
|
||||
Returns:
|
||||
ModelSwitchResult with all information the caller needs to
|
||||
apply the switch and format output.
|
||||
"""
|
||||
from hermes_cli.models import (
|
||||
parse_model_input,
|
||||
detect_provider_for_model,
|
||||
validate_requested_model,
|
||||
_PROVIDER_LABELS,
|
||||
)
|
||||
from hermes_cli.runtime_provider import resolve_runtime_provider
|
||||
|
||||
# Step 1: Parse provider:model syntax
|
||||
target_provider, new_model = parse_model_input(raw_input, current_provider)
|
||||
|
||||
# Step 2: Detect if we're currently on a custom endpoint
|
||||
_base = current_base_url or ""
|
||||
is_custom = current_provider == "custom" or (
|
||||
"localhost" in _base or "127.0.0.1" in _base
|
||||
)
|
||||
|
||||
# Step 3: Auto-detect provider when no explicit provider:model syntax
|
||||
# was used. Skip for custom providers — the model name might
|
||||
# coincidentally match a known provider's catalog.
|
||||
if target_provider == current_provider and not is_custom:
|
||||
detected = detect_provider_for_model(new_model, current_provider)
|
||||
if detected:
|
||||
target_provider, new_model = detected
|
||||
|
||||
provider_changed = target_provider != current_provider
|
||||
|
||||
# Step 4: Resolve credentials for target provider
|
||||
api_key = current_api_key
|
||||
base_url = current_base_url
|
||||
if provider_changed:
|
||||
try:
|
||||
runtime = resolve_runtime_provider(requested=target_provider)
|
||||
api_key = runtime.get("api_key", "")
|
||||
base_url = runtime.get("base_url", "")
|
||||
except Exception as e:
|
||||
provider_label = _PROVIDER_LABELS.get(target_provider, target_provider)
|
||||
if target_provider == "custom":
|
||||
return ModelSwitchResult(
|
||||
success=False,
|
||||
target_provider=target_provider,
|
||||
error_message=(
|
||||
"No custom endpoint configured. Set model.base_url "
|
||||
"in config.yaml, or set OPENAI_BASE_URL in .env, "
|
||||
"or run: hermes setup → Custom OpenAI-compatible endpoint"
|
||||
),
|
||||
)
|
||||
return ModelSwitchResult(
|
||||
success=False,
|
||||
target_provider=target_provider,
|
||||
error_message=(
|
||||
f"Could not resolve credentials for provider "
|
||||
f"'{provider_label}': {e}"
|
||||
),
|
||||
)
|
||||
else:
|
||||
# Gateway also resolves for unchanged provider to get accurate
|
||||
# base_url for validation probing.
|
||||
try:
|
||||
runtime = resolve_runtime_provider(requested=current_provider)
|
||||
api_key = runtime.get("api_key", "")
|
||||
base_url = runtime.get("base_url", "")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Step 5: Validate the model
|
||||
try:
|
||||
validation = validate_requested_model(
|
||||
new_model,
|
||||
target_provider,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
except Exception:
|
||||
validation = {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"recognized": False,
|
||||
"message": None,
|
||||
}
|
||||
|
||||
if not validation.get("accepted"):
|
||||
msg = validation.get("message", "Invalid model")
|
||||
return ModelSwitchResult(
|
||||
success=False,
|
||||
new_model=new_model,
|
||||
target_provider=target_provider,
|
||||
error_message=msg,
|
||||
)
|
||||
|
||||
# Step 6: Build result
|
||||
provider_label = _PROVIDER_LABELS.get(target_provider, target_provider)
|
||||
is_custom_target = target_provider == "custom" or (
|
||||
base_url
|
||||
and "openrouter.ai" not in (base_url or "")
|
||||
and ("localhost" in (base_url or "") or "127.0.0.1" in (base_url or ""))
|
||||
)
|
||||
|
||||
return ModelSwitchResult(
|
||||
success=True,
|
||||
new_model=new_model,
|
||||
target_provider=target_provider,
|
||||
provider_changed=provider_changed,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
persist=bool(validation.get("persist")),
|
||||
warning_message=validation.get("message") or "",
|
||||
is_custom_target=is_custom_target,
|
||||
provider_label=provider_label,
|
||||
)
|
||||
|
||||
|
||||
def switch_to_custom_provider() -> CustomAutoResult:
|
||||
"""Handle bare '/model custom' — resolve endpoint and auto-detect model.
|
||||
|
||||
Returns a result object; the caller handles persistence and output.
|
||||
"""
|
||||
from hermes_cli.runtime_provider import (
|
||||
resolve_runtime_provider,
|
||||
_auto_detect_local_model,
|
||||
)
|
||||
|
||||
try:
|
||||
runtime = resolve_runtime_provider(requested="custom")
|
||||
except Exception as e:
|
||||
return CustomAutoResult(
|
||||
success=False,
|
||||
error_message=f"Could not resolve custom endpoint: {e}",
|
||||
)
|
||||
|
||||
cust_base = runtime.get("base_url", "")
|
||||
cust_key = runtime.get("api_key", "")
|
||||
|
||||
if not cust_base or "openrouter.ai" in cust_base:
|
||||
return CustomAutoResult(
|
||||
success=False,
|
||||
error_message=(
|
||||
"No custom endpoint configured. "
|
||||
"Set model.base_url in config.yaml, or set OPENAI_BASE_URL "
|
||||
"in .env, or run: hermes setup → Custom OpenAI-compatible endpoint"
|
||||
),
|
||||
)
|
||||
|
||||
detected_model = _auto_detect_local_model(cust_base)
|
||||
if not detected_model:
|
||||
return CustomAutoResult(
|
||||
success=False,
|
||||
base_url=cust_base,
|
||||
api_key=cust_key,
|
||||
error_message=(
|
||||
f"Custom endpoint at {cust_base} is reachable but no single "
|
||||
f"model was auto-detected. Specify the model explicitly: "
|
||||
f"/model custom:<model-name>"
|
||||
),
|
||||
)
|
||||
|
||||
return CustomAutoResult(
|
||||
success=True,
|
||||
model=detected_model,
|
||||
base_url=cust_base,
|
||||
api_key=cust_key,
|
||||
)
|
||||
+44
-12
@@ -31,19 +31,20 @@ OPENROUTER_MODELS: list[tuple[str, str]] = [
|
||||
("anthropic/claude-haiku-4.5", ""),
|
||||
("openai/gpt-5.4", ""),
|
||||
("openai/gpt-5.4-mini", ""),
|
||||
("openrouter/hunter-alpha", "free"),
|
||||
("openrouter/healer-alpha", "free"),
|
||||
("xiaomi/mimo-v2-pro", ""),
|
||||
("openai/gpt-5.3-codex", ""),
|
||||
("google/gemini-3-pro-preview", ""),
|
||||
("google/gemini-3-flash-preview", ""),
|
||||
("qwen/qwen3.5-plus-02-15", ""),
|
||||
("qwen/qwen3.5-35b-a3b", ""),
|
||||
("stepfun/step-3.5-flash", ""),
|
||||
("minimax/minimax-m2.7", ""),
|
||||
("minimax/minimax-m2.5", ""),
|
||||
("z-ai/glm-5", ""),
|
||||
("z-ai/glm-5-turbo", ""),
|
||||
("moonshotai/kimi-k2.5", ""),
|
||||
("x-ai/grok-4.20-beta", ""),
|
||||
("nvidia/nemotron-3-super-120b-a12b", ""),
|
||||
("nvidia/nemotron-3-super-120b-a12b:free", "free"),
|
||||
("arcee-ai/trinity-large-preview:free", "free"),
|
||||
("openai/gpt-5.4-pro", ""),
|
||||
@@ -52,12 +53,29 @@ OPENROUTER_MODELS: list[tuple[str, str]] = [
|
||||
|
||||
_PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"nous": [
|
||||
"claude-opus-4-6",
|
||||
"claude-sonnet-4-6",
|
||||
"gpt-5.4",
|
||||
"gemini-3-flash",
|
||||
"gemini-3.0-pro-preview",
|
||||
"deepseek-v3.2",
|
||||
"anthropic/claude-opus-4.6",
|
||||
"anthropic/claude-sonnet-4.5",
|
||||
"anthropic/claude-haiku-4.5",
|
||||
"openai/gpt-5.4",
|
||||
"openai/gpt-5.4-mini",
|
||||
"xiaomi/mimo-v2-pro",
|
||||
"openai/gpt-5.3-codex",
|
||||
"google/gemini-3-pro-preview",
|
||||
"google/gemini-3-flash-preview",
|
||||
"qwen/qwen3.5-plus-02-15",
|
||||
"qwen/qwen3.5-35b-a3b",
|
||||
"stepfun/step-3.5-flash",
|
||||
"minimax/minimax-m2.7",
|
||||
"minimax/minimax-m2.5",
|
||||
"z-ai/glm-5",
|
||||
"z-ai/glm-5-turbo",
|
||||
"moonshotai/kimi-k2.5",
|
||||
"x-ai/grok-4.20-beta",
|
||||
"nvidia/nemotron-3-super-120b-a12b",
|
||||
"nvidia/nemotron-3-super-120b-a12b:free",
|
||||
"arcee-ai/trinity-large-preview:free",
|
||||
"openai/gpt-5.4-pro",
|
||||
"openai/gpt-5.4-nano",
|
||||
],
|
||||
"openai-codex": [
|
||||
"gpt-5.3-codex",
|
||||
@@ -86,6 +104,7 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
],
|
||||
"zai": [
|
||||
"glm-5",
|
||||
"glm-5-turbo",
|
||||
"glm-4.7",
|
||||
"glm-4.5",
|
||||
"glm-4.5-flash",
|
||||
@@ -150,6 +169,7 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"gemini-3.1-pro",
|
||||
"gemini-3-pro",
|
||||
"gemini-3-flash",
|
||||
"minimax-m2.7",
|
||||
"minimax-m2.5",
|
||||
"minimax-m2.5-free",
|
||||
"minimax-m2.1",
|
||||
@@ -300,12 +320,15 @@ def list_available_providers() -> list[dict[str, str]]:
|
||||
# Check if this provider has credentials available
|
||||
has_creds = False
|
||||
try:
|
||||
from hermes_cli.auth import get_auth_status, has_usable_secret
|
||||
if pid == "custom":
|
||||
has_creds = bool(_get_custom_base_url())
|
||||
custom_base_url = _get_custom_base_url() or os.getenv("OPENAI_BASE_URL", "")
|
||||
has_creds = bool(custom_base_url.strip())
|
||||
elif pid == "openrouter":
|
||||
has_creds = has_usable_secret(os.getenv("OPENROUTER_API_KEY", ""))
|
||||
else:
|
||||
from hermes_cli.runtime_provider import resolve_runtime_provider
|
||||
runtime = resolve_runtime_provider(requested=pid)
|
||||
has_creds = bool(runtime.get("api_key"))
|
||||
status = get_auth_status(pid)
|
||||
has_creds = bool(status.get("logged_in") or status.get("configured"))
|
||||
except Exception:
|
||||
pass
|
||||
result.append({
|
||||
@@ -340,6 +363,15 @@ def parse_model_input(raw: str, current_provider: str) -> tuple[str, str]:
|
||||
provider_part = stripped[:colon].strip().lower()
|
||||
model_part = stripped[colon + 1:].strip()
|
||||
if provider_part and model_part and provider_part in _KNOWN_PROVIDER_NAMES:
|
||||
# Support custom:name:model triple syntax for named custom
|
||||
# providers. ``custom:local:qwen`` → ("custom:local", "qwen").
|
||||
# Single colon ``custom:qwen`` → ("custom", "qwen") as before.
|
||||
if provider_part == "custom" and ":" in model_part:
|
||||
second_colon = model_part.find(":")
|
||||
custom_name = model_part[:second_colon].strip()
|
||||
actual_model = model_part[second_colon + 1:].strip()
|
||||
if custom_name and actual_model:
|
||||
return (f"custom:{custom_name}", actual_model)
|
||||
return (normalize_provider(provider_part), model_part)
|
||||
return (current_provider, stripped)
|
||||
|
||||
|
||||
@@ -72,10 +72,10 @@ def _cmd_approve(store, platform: str, code: str):
|
||||
name = result.get("user_name", "")
|
||||
display = f"{name} ({uid})" if name else uid
|
||||
print(f"\n Approved! User {display} on {platform} can now use the bot~")
|
||||
print(f" They'll be recognized automatically on their next message.\n")
|
||||
print(" They'll be recognized automatically on their next message.\n")
|
||||
else:
|
||||
print(f"\n Code '{code}' not found or expired for platform '{platform}'.")
|
||||
print(f" Run 'hermes pairing list' to see pending codes.\n")
|
||||
print(" Run 'hermes pairing list' to see pending codes.\n")
|
||||
|
||||
|
||||
def _cmd_revoke(store, platform: str, user_id: str):
|
||||
|
||||
+55
-3
@@ -5,7 +5,8 @@ Hermes Plugin System
|
||||
Discovers, loads, and manages plugins from three sources:
|
||||
|
||||
1. **User plugins** – ``~/.hermes/plugins/<name>/``
|
||||
2. **Project plugins** – ``./.hermes/plugins/<name>/``
|
||||
2. **Project plugins** – ``./.hermes/plugins/<name>/`` (opt-in via
|
||||
``HERMES_ENABLE_PROJECT_PLUGINS``)
|
||||
3. **Pip plugins** – packages that expose the ``hermes_agent.plugins``
|
||||
entry-point group.
|
||||
|
||||
@@ -62,6 +63,11 @@ ENTRY_POINTS_GROUP = "hermes_agent.plugins"
|
||||
_NS_PARENT = "hermes_plugins"
|
||||
|
||||
|
||||
def _env_enabled(name: str) -> bool:
|
||||
"""Return True when an env var is set to a truthy opt-in value."""
|
||||
return os.getenv(name, "").strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data classes
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -186,8 +192,9 @@ class PluginManager:
|
||||
manifests.extend(self._scan_directory(user_dir, source="user"))
|
||||
|
||||
# 2. Project plugins (./.hermes/plugins/)
|
||||
project_dir = Path.cwd() / ".hermes" / "plugins"
|
||||
manifests.extend(self._scan_directory(project_dir, source="project"))
|
||||
if _env_enabled("HERMES_ENABLE_PROJECT_PLUGINS"):
|
||||
project_dir = Path.cwd() / ".hermes" / "plugins"
|
||||
manifests.extend(self._scan_directory(project_dir, source="project"))
|
||||
|
||||
# 3. Pip / entry-point plugins
|
||||
manifests.extend(self._scan_entry_points())
|
||||
@@ -447,3 +454,48 @@ def invoke_hook(hook_name: str, **kwargs: Any) -> None:
|
||||
def get_plugin_tool_names() -> Set[str]:
|
||||
"""Return the set of tool names registered by plugins."""
|
||||
return get_plugin_manager()._plugin_tool_names
|
||||
|
||||
|
||||
def get_plugin_toolsets() -> List[tuple]:
|
||||
"""Return plugin toolsets as ``(key, label, description)`` tuples.
|
||||
|
||||
Used by the ``hermes tools`` TUI so plugin-provided toolsets appear
|
||||
alongside the built-in ones and can be toggled on/off per platform.
|
||||
"""
|
||||
manager = get_plugin_manager()
|
||||
if not manager._plugin_tool_names:
|
||||
return []
|
||||
|
||||
try:
|
||||
from tools.registry import registry
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
# Group plugin tool names by their toolset
|
||||
toolset_tools: Dict[str, List[str]] = {}
|
||||
toolset_plugin: Dict[str, LoadedPlugin] = {}
|
||||
for tool_name in manager._plugin_tool_names:
|
||||
entry = registry._tools.get(tool_name)
|
||||
if not entry:
|
||||
continue
|
||||
ts = entry.toolset
|
||||
toolset_tools.setdefault(ts, []).append(entry.name)
|
||||
|
||||
# Map toolsets back to the plugin that registered them
|
||||
for _name, loaded in manager._plugins.items():
|
||||
for tool_name in loaded.tools_registered:
|
||||
entry = registry._tools.get(tool_name)
|
||||
if entry and entry.toolset in toolset_tools:
|
||||
toolset_plugin.setdefault(entry.toolset, loaded)
|
||||
|
||||
result = []
|
||||
for ts_key in sorted(toolset_tools):
|
||||
plugin = toolset_plugin.get(ts_key)
|
||||
label = f"🔌 {ts_key.replace('_', ' ').title()}"
|
||||
if plugin and plugin.manifest.description:
|
||||
desc = plugin.manifest.description
|
||||
else:
|
||||
desc = ", ".join(sorted(toolset_tools[ts_key]))
|
||||
result.append((ts_key, label, desc))
|
||||
|
||||
return result
|
||||
|
||||
@@ -0,0 +1,446 @@
|
||||
"""``hermes plugins`` CLI subcommand — install, update, remove, and list plugins.
|
||||
|
||||
Plugins are installed from Git repositories into ``~/.hermes/plugins/``.
|
||||
Supports full URLs and ``owner/repo`` shorthand (resolves to GitHub).
|
||||
|
||||
After install, if the plugin ships an ``after-install.md`` file it is
|
||||
rendered with Rich Markdown. Otherwise a default confirmation is shown.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Minimum manifest version this installer understands.
|
||||
# Plugins may declare ``manifest_version: 1`` in plugin.yaml;
|
||||
# future breaking changes to the manifest schema bump this.
|
||||
_SUPPORTED_MANIFEST_VERSION = 1
|
||||
|
||||
|
||||
def _plugins_dir() -> Path:
|
||||
"""Return the user plugins directory, creating it if needed."""
|
||||
hermes_home = os.environ.get("HERMES_HOME", os.path.expanduser("~/.hermes"))
|
||||
plugins = Path(hermes_home) / "plugins"
|
||||
plugins.mkdir(parents=True, exist_ok=True)
|
||||
return plugins
|
||||
|
||||
|
||||
def _sanitize_plugin_name(name: str, plugins_dir: Path) -> Path:
|
||||
"""Validate a plugin name and return the safe target path inside *plugins_dir*.
|
||||
|
||||
Raises ``ValueError`` if the name contains path-traversal sequences or would
|
||||
resolve outside the plugins directory.
|
||||
"""
|
||||
if not name:
|
||||
raise ValueError("Plugin name must not be empty.")
|
||||
|
||||
# Reject obvious traversal characters
|
||||
for bad in ("/", "\\", ".."):
|
||||
if bad in name:
|
||||
raise ValueError(f"Invalid plugin name '{name}': must not contain '{bad}'.")
|
||||
|
||||
target = (plugins_dir / name).resolve()
|
||||
plugins_resolved = plugins_dir.resolve()
|
||||
|
||||
if (
|
||||
not str(target).startswith(str(plugins_resolved) + os.sep)
|
||||
and target != plugins_resolved
|
||||
):
|
||||
raise ValueError(
|
||||
f"Invalid plugin name '{name}': resolves outside the plugins directory."
|
||||
)
|
||||
|
||||
return target
|
||||
|
||||
|
||||
def _resolve_git_url(identifier: str) -> str:
|
||||
"""Turn an identifier into a cloneable Git URL.
|
||||
|
||||
Accepted formats:
|
||||
- Full URL: https://github.com/owner/repo.git
|
||||
- Full URL: git@github.com:owner/repo.git
|
||||
- Full URL: ssh://git@github.com/owner/repo.git
|
||||
- Shorthand: owner/repo → https://github.com/owner/repo.git
|
||||
|
||||
NOTE: ``http://`` and ``file://`` schemes are accepted but will trigger a
|
||||
security warning at install time.
|
||||
"""
|
||||
# Already a URL
|
||||
if identifier.startswith(("https://", "http://", "git@", "ssh://", "file://")):
|
||||
return identifier
|
||||
|
||||
# owner/repo shorthand
|
||||
parts = identifier.strip("/").split("/")
|
||||
if len(parts) == 2:
|
||||
owner, repo = parts
|
||||
return f"https://github.com/{owner}/{repo}.git"
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid plugin identifier: '{identifier}'. "
|
||||
"Use a Git URL or owner/repo shorthand."
|
||||
)
|
||||
|
||||
|
||||
def _repo_name_from_url(url: str) -> str:
|
||||
"""Extract the repo name from a Git URL for the plugin directory name."""
|
||||
# Strip trailing .git and slashes
|
||||
name = url.rstrip("/")
|
||||
if name.endswith(".git"):
|
||||
name = name[:-4]
|
||||
# Get last path component
|
||||
name = name.rsplit("/", 1)[-1]
|
||||
# Handle ssh-style urls: git@github.com:owner/repo
|
||||
if ":" in name:
|
||||
name = name.rsplit(":", 1)[-1].rsplit("/", 1)[-1]
|
||||
return name
|
||||
|
||||
|
||||
def _read_manifest(plugin_dir: Path) -> dict:
|
||||
"""Read plugin.yaml and return the parsed dict, or empty dict."""
|
||||
manifest_file = plugin_dir / "plugin.yaml"
|
||||
if not manifest_file.exists():
|
||||
return {}
|
||||
try:
|
||||
import yaml
|
||||
|
||||
with open(manifest_file) as f:
|
||||
return yaml.safe_load(f) or {}
|
||||
except Exception as e:
|
||||
logger.warning("Failed to read plugin.yaml in %s: %s", plugin_dir, e)
|
||||
return {}
|
||||
|
||||
|
||||
def _copy_example_files(plugin_dir: Path, console) -> None:
|
||||
"""Copy any .example files to their real names if they don't already exist.
|
||||
|
||||
For example, ``config.yaml.example`` becomes ``config.yaml``.
|
||||
Skips files that already exist to avoid overwriting user config on reinstall.
|
||||
"""
|
||||
for example_file in plugin_dir.glob("*.example"):
|
||||
real_name = example_file.stem # e.g. "config.yaml" from "config.yaml.example"
|
||||
real_path = plugin_dir / real_name
|
||||
if not real_path.exists():
|
||||
try:
|
||||
shutil.copy2(example_file, real_path)
|
||||
console.print(
|
||||
f"[dim] Created {real_name} from {example_file.name}[/dim]"
|
||||
)
|
||||
except OSError as e:
|
||||
console.print(
|
||||
f"[yellow]Warning:[/yellow] Failed to copy {example_file.name}: {e}"
|
||||
)
|
||||
|
||||
|
||||
def _display_after_install(plugin_dir: Path, identifier: str) -> None:
|
||||
"""Show after-install.md if it exists, otherwise a default message."""
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
|
||||
console = Console()
|
||||
after_install = plugin_dir / "after-install.md"
|
||||
|
||||
if after_install.exists():
|
||||
content = after_install.read_text(encoding="utf-8")
|
||||
md = Markdown(content)
|
||||
console.print()
|
||||
console.print(Panel(md, border_style="green", expand=False))
|
||||
console.print()
|
||||
else:
|
||||
console.print()
|
||||
console.print(
|
||||
Panel(
|
||||
f"[green bold]Plugin installed:[/] {identifier}\n"
|
||||
f"[dim]Location:[/] {plugin_dir}",
|
||||
border_style="green",
|
||||
title="✓ Installed",
|
||||
expand=False,
|
||||
)
|
||||
)
|
||||
console.print()
|
||||
|
||||
|
||||
def _display_removed(name: str, plugins_dir: Path) -> None:
|
||||
"""Show confirmation after removing a plugin."""
|
||||
from rich.console import Console
|
||||
|
||||
console = Console()
|
||||
console.print()
|
||||
console.print(f"[red]✗[/red] Plugin [bold]{name}[/bold] removed from {plugins_dir}")
|
||||
console.print()
|
||||
|
||||
|
||||
def _require_installed_plugin(name: str, plugins_dir: Path, console) -> Path:
|
||||
"""Return the plugin path if it exists, or exit with an error listing installed plugins."""
|
||||
target = _sanitize_plugin_name(name, plugins_dir)
|
||||
if not target.exists():
|
||||
installed = ", ".join(d.name for d in plugins_dir.iterdir() if d.is_dir()) or "(none)"
|
||||
console.print(
|
||||
f"[red]Error:[/red] Plugin '{name}' not found in {plugins_dir}.\n"
|
||||
f"Installed plugins: {installed}"
|
||||
)
|
||||
sys.exit(1)
|
||||
return target
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Commands
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def cmd_install(identifier: str, force: bool = False) -> None:
|
||||
"""Install a plugin from a Git URL or owner/repo shorthand."""
|
||||
import tempfile
|
||||
from rich.console import Console
|
||||
|
||||
console = Console()
|
||||
|
||||
try:
|
||||
git_url = _resolve_git_url(identifier)
|
||||
except ValueError as e:
|
||||
console.print(f"[red]Error:[/red] {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Warn about insecure / local URL schemes
|
||||
if git_url.startswith("http://") or git_url.startswith("file://"):
|
||||
console.print(
|
||||
"[yellow]Warning:[/yellow] Using insecure/local URL scheme. "
|
||||
"Consider using https:// or git@ for production installs."
|
||||
)
|
||||
|
||||
plugins_dir = _plugins_dir()
|
||||
|
||||
# Clone into a temp directory first so we can read plugin.yaml for the name
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_target = Path(tmp) / "plugin"
|
||||
console.print(f"[dim]Cloning {git_url}...[/dim]")
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "clone", "--depth", "1", git_url, str(tmp_target)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
console.print("[red]Error:[/red] git is not installed or not in PATH.")
|
||||
sys.exit(1)
|
||||
except subprocess.TimeoutExpired:
|
||||
console.print("[red]Error:[/red] Git clone timed out after 60 seconds.")
|
||||
sys.exit(1)
|
||||
|
||||
if result.returncode != 0:
|
||||
console.print(
|
||||
f"[red]Error:[/red] Git clone failed:\n{result.stderr.strip()}"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Read manifest
|
||||
manifest = _read_manifest(tmp_target)
|
||||
plugin_name = manifest.get("name") or _repo_name_from_url(git_url)
|
||||
|
||||
# Sanitize plugin name against path traversal
|
||||
try:
|
||||
target = _sanitize_plugin_name(plugin_name, plugins_dir)
|
||||
except ValueError as e:
|
||||
console.print(f"[red]Error:[/red] {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Check manifest_version compatibility
|
||||
mv = manifest.get("manifest_version")
|
||||
if mv is not None:
|
||||
try:
|
||||
mv_int = int(mv)
|
||||
except (ValueError, TypeError):
|
||||
console.print(
|
||||
f"[red]Error:[/red] Plugin '{plugin_name}' has invalid "
|
||||
f"manifest_version '{mv}' (expected an integer)."
|
||||
)
|
||||
sys.exit(1)
|
||||
if mv_int > _SUPPORTED_MANIFEST_VERSION:
|
||||
console.print(
|
||||
f"[red]Error:[/red] Plugin '{plugin_name}' requires manifest_version "
|
||||
f"{mv}, but this installer only supports up to {_SUPPORTED_MANIFEST_VERSION}.\n"
|
||||
f"Run [bold]hermes update[/bold] to get a newer installer."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
if target.exists():
|
||||
if not force:
|
||||
console.print(
|
||||
f"[red]Error:[/red] Plugin '{plugin_name}' already exists at {target}.\n"
|
||||
f"Use [bold]--force[/bold] to remove and reinstall, or "
|
||||
f"[bold]hermes plugins update {plugin_name}[/bold] to pull latest."
|
||||
)
|
||||
sys.exit(1)
|
||||
console.print(f"[dim] Removing existing {plugin_name}...[/dim]")
|
||||
shutil.rmtree(target)
|
||||
|
||||
# Move from temp to final location
|
||||
shutil.move(str(tmp_target), str(target))
|
||||
|
||||
# Validate it looks like a plugin
|
||||
if not (target / "plugin.yaml").exists() and not (target / "__init__.py").exists():
|
||||
console.print(
|
||||
f"[yellow]Warning:[/yellow] {plugin_name} doesn't contain plugin.yaml "
|
||||
f"or __init__.py. It may not be a valid Hermes plugin."
|
||||
)
|
||||
|
||||
# Copy .example files to their real names (e.g. config.yaml.example → config.yaml)
|
||||
_copy_example_files(target, console)
|
||||
|
||||
_display_after_install(target, identifier)
|
||||
|
||||
console.print("[dim]Restart the gateway for the plugin to take effect:[/dim]")
|
||||
console.print("[dim] hermes gateway restart[/dim]")
|
||||
console.print()
|
||||
|
||||
|
||||
def cmd_update(name: str) -> None:
|
||||
"""Update an installed plugin by pulling latest from its git remote."""
|
||||
from rich.console import Console
|
||||
|
||||
console = Console()
|
||||
plugins_dir = _plugins_dir()
|
||||
|
||||
try:
|
||||
target = _require_installed_plugin(name, plugins_dir, console)
|
||||
except ValueError as e:
|
||||
console.print(f"[red]Error:[/red] {e}")
|
||||
sys.exit(1)
|
||||
|
||||
if not (target / ".git").exists():
|
||||
console.print(
|
||||
f"[red]Error:[/red] Plugin '{name}' was not installed from git "
|
||||
f"(no .git directory). Cannot update."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
console.print(f"[dim]Updating {name}...[/dim]")
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "pull", "--ff-only"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
cwd=str(target),
|
||||
)
|
||||
except FileNotFoundError:
|
||||
console.print("[red]Error:[/red] git is not installed or not in PATH.")
|
||||
sys.exit(1)
|
||||
except subprocess.TimeoutExpired:
|
||||
console.print("[red]Error:[/red] Git pull timed out after 60 seconds.")
|
||||
sys.exit(1)
|
||||
|
||||
if result.returncode != 0:
|
||||
console.print(f"[red]Error:[/red] Git pull failed:\n{result.stderr.strip()}")
|
||||
sys.exit(1)
|
||||
|
||||
# Copy any new .example files
|
||||
_copy_example_files(target, console)
|
||||
|
||||
output = result.stdout.strip()
|
||||
if "Already up to date" in output:
|
||||
console.print(
|
||||
f"[green]✓[/green] Plugin [bold]{name}[/bold] is already up to date."
|
||||
)
|
||||
else:
|
||||
console.print(f"[green]✓[/green] Plugin [bold]{name}[/bold] updated.")
|
||||
console.print(f"[dim]{output}[/dim]")
|
||||
|
||||
|
||||
def cmd_remove(name: str) -> None:
|
||||
"""Remove an installed plugin by name."""
|
||||
from rich.console import Console
|
||||
|
||||
console = Console()
|
||||
plugins_dir = _plugins_dir()
|
||||
|
||||
try:
|
||||
target = _require_installed_plugin(name, plugins_dir, console)
|
||||
except ValueError as e:
|
||||
console.print(f"[red]Error:[/red] {e}")
|
||||
sys.exit(1)
|
||||
|
||||
shutil.rmtree(target)
|
||||
_display_removed(name, plugins_dir)
|
||||
|
||||
|
||||
def cmd_list() -> None:
|
||||
"""List installed plugins."""
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
try:
|
||||
import yaml
|
||||
except ImportError:
|
||||
yaml = None
|
||||
|
||||
console = Console()
|
||||
plugins_dir = _plugins_dir()
|
||||
|
||||
dirs = sorted(d for d in plugins_dir.iterdir() if d.is_dir())
|
||||
if not dirs:
|
||||
console.print("[dim]No plugins installed.[/dim]")
|
||||
console.print("[dim]Install with:[/dim] hermes plugins install owner/repo")
|
||||
return
|
||||
|
||||
table = Table(title="Installed Plugins", show_lines=False)
|
||||
table.add_column("Name", style="bold")
|
||||
table.add_column("Version", style="dim")
|
||||
table.add_column("Description")
|
||||
table.add_column("Source", style="dim")
|
||||
|
||||
for d in dirs:
|
||||
manifest_file = d / "plugin.yaml"
|
||||
name = d.name
|
||||
version = ""
|
||||
description = ""
|
||||
source = "local"
|
||||
|
||||
if manifest_file.exists() and yaml:
|
||||
try:
|
||||
with open(manifest_file) as f:
|
||||
manifest = yaml.safe_load(f) or {}
|
||||
name = manifest.get("name", d.name)
|
||||
version = manifest.get("version", "")
|
||||
description = manifest.get("description", "")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check if it's a git repo (installed via hermes plugins install)
|
||||
if (d / ".git").exists():
|
||||
source = "git"
|
||||
|
||||
table.add_row(name, str(version), description, source)
|
||||
|
||||
console.print()
|
||||
console.print(table)
|
||||
console.print()
|
||||
|
||||
|
||||
def plugins_command(args) -> None:
|
||||
"""Dispatch hermes plugins subcommands."""
|
||||
action = getattr(args, "plugins_action", None)
|
||||
|
||||
if action == "install":
|
||||
cmd_install(args.identifier, force=getattr(args, "force", False))
|
||||
elif action == "update":
|
||||
cmd_update(args.name)
|
||||
elif action in ("remove", "rm", "uninstall"):
|
||||
cmd_remove(args.name)
|
||||
elif action in ("list", "ls") or action is None:
|
||||
cmd_list()
|
||||
else:
|
||||
from rich.console import Console
|
||||
|
||||
Console().print(f"[red]Unknown plugins action: {action}[/red]")
|
||||
sys.exit(1)
|
||||
@@ -15,6 +15,7 @@ from hermes_cli.auth import (
|
||||
resolve_codex_runtime_credentials,
|
||||
resolve_api_key_provider_credentials,
|
||||
resolve_external_process_provider_credentials,
|
||||
has_usable_secret,
|
||||
)
|
||||
from hermes_cli.config import load_config
|
||||
from hermes_constants import OPENROUTER_BASE_URL
|
||||
@@ -188,15 +189,16 @@ def _resolve_named_custom_runtime(
|
||||
if not base_url:
|
||||
return None
|
||||
|
||||
api_key = (
|
||||
(explicit_api_key or "").strip()
|
||||
or custom_provider.get("api_key", "")
|
||||
or os.getenv("OPENAI_API_KEY", "").strip()
|
||||
or os.getenv("OPENROUTER_API_KEY", "").strip()
|
||||
)
|
||||
api_key_candidates = [
|
||||
(explicit_api_key or "").strip(),
|
||||
str(custom_provider.get("api_key", "") or "").strip(),
|
||||
os.getenv("OPENAI_API_KEY", "").strip(),
|
||||
os.getenv("OPENROUTER_API_KEY", "").strip(),
|
||||
]
|
||||
api_key = next((candidate for candidate in api_key_candidates if has_usable_secret(candidate)), "")
|
||||
|
||||
return {
|
||||
"provider": "openrouter",
|
||||
"provider": "custom",
|
||||
"api_mode": custom_provider.get("api_mode")
|
||||
or _detect_api_mode_for_url(base_url)
|
||||
or "chat_completions",
|
||||
@@ -257,26 +259,36 @@ def _resolve_openrouter_runtime(
|
||||
# provider (issues #420, #560).
|
||||
_is_openrouter_url = "openrouter.ai" in base_url
|
||||
if _is_openrouter_url:
|
||||
api_key = (
|
||||
explicit_api_key
|
||||
or os.getenv("OPENROUTER_API_KEY")
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
or ""
|
||||
)
|
||||
api_key_candidates = [
|
||||
explicit_api_key,
|
||||
os.getenv("OPENROUTER_API_KEY"),
|
||||
os.getenv("OPENAI_API_KEY"),
|
||||
]
|
||||
else:
|
||||
# Custom endpoint: use api_key from config when using config base_url (#1760).
|
||||
api_key = (
|
||||
explicit_api_key
|
||||
or (cfg_api_key if use_config_base_url else "")
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
or os.getenv("OPENROUTER_API_KEY")
|
||||
or ""
|
||||
)
|
||||
api_key_candidates = [
|
||||
explicit_api_key,
|
||||
(cfg_api_key if use_config_base_url else ""),
|
||||
os.getenv("OPENAI_API_KEY"),
|
||||
os.getenv("OPENROUTER_API_KEY"),
|
||||
]
|
||||
api_key = next(
|
||||
(str(candidate or "").strip() for candidate in api_key_candidates if has_usable_secret(candidate)),
|
||||
"",
|
||||
)
|
||||
|
||||
source = "explicit" if (explicit_api_key or explicit_base_url) else "env/config"
|
||||
|
||||
# When "custom" was explicitly requested, preserve that as the provider
|
||||
# name instead of silently relabeling to "openrouter" (#2562).
|
||||
# Also provide a placeholder API key for local servers that don't require
|
||||
# authentication — the OpenAI SDK requires a non-empty api_key string.
|
||||
effective_provider = "custom" if requested_norm == "custom" else "openrouter"
|
||||
if effective_provider == "custom" and not api_key and not _is_openrouter_url:
|
||||
api_key = "no-key-required"
|
||||
|
||||
return {
|
||||
"provider": "openrouter",
|
||||
"provider": effective_provider,
|
||||
"api_mode": _parse_api_mode(model_cfg.get("api_mode"))
|
||||
or _detect_api_mode_for_url(base_url)
|
||||
or "chat_completions",
|
||||
@@ -359,9 +371,14 @@ def resolve_runtime_provider(
|
||||
"No Anthropic credentials found. Set ANTHROPIC_TOKEN or ANTHROPIC_API_KEY, "
|
||||
"run 'claude setup-token', or authenticate with 'claude /login'."
|
||||
)
|
||||
# Allow base URL override from config.yaml model.base_url
|
||||
# Allow base URL override from config.yaml model.base_url, but only
|
||||
# when the configured provider is anthropic — otherwise a non-Anthropic
|
||||
# base_url (e.g. Codex endpoint) would leak into Anthropic requests.
|
||||
model_cfg = _get_model_config()
|
||||
cfg_base_url = (model_cfg.get("base_url") or "").strip().rstrip("/")
|
||||
cfg_provider = str(model_cfg.get("provider") or "").strip().lower()
|
||||
cfg_base_url = ""
|
||||
if cfg_provider == "anthropic":
|
||||
cfg_base_url = (model_cfg.get("base_url") or "").strip().rstrip("/")
|
||||
base_url = cfg_base_url or "https://api.anthropic.com"
|
||||
return {
|
||||
"provider": "anthropic",
|
||||
@@ -372,19 +389,6 @@ def resolve_runtime_provider(
|
||||
"requested_provider": requested_provider,
|
||||
}
|
||||
|
||||
# Alibaba Cloud / DashScope (Anthropic-compatible endpoint)
|
||||
if provider == "alibaba":
|
||||
creds = resolve_api_key_provider_credentials(provider)
|
||||
base_url = creds.get("base_url", "").rstrip("/") or "https://dashscope-intl.aliyuncs.com/apps/anthropic"
|
||||
return {
|
||||
"provider": "alibaba",
|
||||
"api_mode": "anthropic_messages",
|
||||
"base_url": base_url,
|
||||
"api_key": creds.get("api_key", ""),
|
||||
"source": creds.get("source", "env"),
|
||||
"requested_provider": requested_provider,
|
||||
}
|
||||
|
||||
# API-key providers (z.ai/GLM, Kimi, MiniMax, MiniMax-CN)
|
||||
pconfig = PROVIDER_REGISTRY.get(provider)
|
||||
if pconfig and pconfig.auth_type == "api_key":
|
||||
|
||||
+223
-117
@@ -4,9 +4,9 @@ Interactive setup wizard for Hermes Agent.
|
||||
Modular wizard with independently-runnable sections:
|
||||
1. Model & Provider — choose your AI provider and model
|
||||
2. Terminal Backend — where your agent runs commands
|
||||
3. Messaging Platforms — connect Telegram, Discord, etc.
|
||||
4. Tools — configure TTS, web search, image generation, etc.
|
||||
5. Agent Settings — iterations, compression, session reset
|
||||
3. Agent Settings — iterations, compression, session reset
|
||||
4. Messaging Platforms — connect Telegram, Discord, etc.
|
||||
5. Tools — configure TTS, web search, image generation, etc.
|
||||
|
||||
Config files are stored in ~/.hermes/ for easy access.
|
||||
"""
|
||||
@@ -283,7 +283,6 @@ from hermes_cli.config import (
|
||||
save_env_value,
|
||||
get_env_value,
|
||||
ensure_hermes_home,
|
||||
DEFAULT_CONFIG,
|
||||
)
|
||||
|
||||
from hermes_cli.colors import Colors, color
|
||||
@@ -549,9 +548,9 @@ def _prompt_api_key(var: dict):
|
||||
|
||||
if value:
|
||||
save_env_value(var["name"], value)
|
||||
print_success(f" ✓ Saved")
|
||||
print_success(" ✓ Saved")
|
||||
else:
|
||||
print_warning(f" Skipped (configure later with 'hermes setup')")
|
||||
print_warning(" Skipped (configure later with 'hermes setup')")
|
||||
|
||||
|
||||
def _print_setup_summary(config: dict, hermes_home):
|
||||
@@ -726,9 +725,9 @@ def _print_setup_summary(config: dict, hermes_home):
|
||||
f" {color('hermes config edit', Colors.GREEN)} Open config in your editor"
|
||||
)
|
||||
print(f" {color('hermes config set <key> <value>', Colors.GREEN)}")
|
||||
print(f" Set a specific value")
|
||||
print(" Set a specific value")
|
||||
print()
|
||||
print(f" Or edit the files directly:")
|
||||
print(" Or edit the files directly:")
|
||||
print(f" {color(f'nano {get_config_path()}', Colors.DIM)}")
|
||||
print(f" {color(f'nano {get_env_path()}', Colors.DIM)}")
|
||||
print()
|
||||
@@ -756,13 +755,13 @@ def _prompt_container_resources(config: dict):
|
||||
print_info(" Persistent filesystem keeps files between sessions.")
|
||||
print_info(" Set to 'no' for ephemeral sandboxes that reset each time.")
|
||||
persist_str = prompt(
|
||||
f" Persist filesystem across sessions? (yes/no)", persist_label
|
||||
" Persist filesystem across sessions? (yes/no)", persist_label
|
||||
)
|
||||
terminal["container_persistent"] = persist_str.lower() in ("yes", "true", "y", "1")
|
||||
|
||||
# CPU
|
||||
current_cpu = terminal.get("container_cpu", 1)
|
||||
cpu_str = prompt(f" CPU cores", str(current_cpu))
|
||||
cpu_str = prompt(" CPU cores", str(current_cpu))
|
||||
try:
|
||||
terminal["container_cpu"] = float(cpu_str)
|
||||
except ValueError:
|
||||
@@ -770,7 +769,7 @@ def _prompt_container_resources(config: dict):
|
||||
|
||||
# Memory
|
||||
current_mem = terminal.get("container_memory", 5120)
|
||||
mem_str = prompt(f" Memory in MB (5120 = 5GB)", str(current_mem))
|
||||
mem_str = prompt(" Memory in MB (5120 = 5GB)", str(current_mem))
|
||||
try:
|
||||
terminal["container_memory"] = int(mem_str)
|
||||
except ValueError:
|
||||
@@ -778,7 +777,7 @@ def _prompt_container_resources(config: dict):
|
||||
|
||||
# Disk
|
||||
current_disk = terminal.get("container_disk", 51200)
|
||||
disk_str = prompt(f" Disk in MB (51200 = 50GB)", str(current_disk))
|
||||
disk_str = prompt(" Disk in MB (51200 = 50GB)", str(current_disk))
|
||||
try:
|
||||
terminal["container_disk"] = int(disk_str)
|
||||
except ValueError:
|
||||
@@ -798,15 +797,11 @@ def setup_model_provider(config: dict):
|
||||
"""Configure the inference provider and default model."""
|
||||
from hermes_cli.auth import (
|
||||
get_active_provider,
|
||||
get_provider_auth_state,
|
||||
PROVIDER_REGISTRY,
|
||||
format_auth_error,
|
||||
AuthError,
|
||||
fetch_nous_models,
|
||||
resolve_nous_runtime_credentials,
|
||||
_update_config_for_provider,
|
||||
_login_openai_codex,
|
||||
get_codex_auth_status,
|
||||
resolve_codex_runtime_credentials,
|
||||
DEFAULT_CODEX_BASE_URL,
|
||||
detect_external_credentials,
|
||||
@@ -873,9 +868,9 @@ def setup_model_provider(config: dict):
|
||||
keep_label = None # No provider configured — don't show "Keep current"
|
||||
|
||||
provider_choices = [
|
||||
"OpenRouter API key (100+ models, pay-per-use)",
|
||||
"Login with Nous Portal (Nous Research subscription — OAuth)",
|
||||
"Login with OpenAI Codex",
|
||||
"OpenRouter API key (100+ models, pay-per-use)",
|
||||
"Custom OpenAI-compatible endpoint (self-hosted / VLLM / etc.)",
|
||||
"Z.AI / GLM (Zhipu AI models)",
|
||||
"Kimi / Moonshot (Kimi coding models)",
|
||||
@@ -894,7 +889,7 @@ def setup_model_provider(config: dict):
|
||||
provider_choices.append(keep_label)
|
||||
|
||||
# Default to "Keep current" if a provider exists, otherwise OpenRouter (most common)
|
||||
default_provider = len(provider_choices) - 1 if has_any_provider else 2
|
||||
default_provider = len(provider_choices) - 1 if has_any_provider else 0
|
||||
|
||||
if not has_any_provider:
|
||||
print_warning("An inference provider is required for Hermes to work.")
|
||||
@@ -911,81 +906,7 @@ def setup_model_provider(config: dict):
|
||||
selected_base_url = None # deferred until after model selection
|
||||
nous_models = [] # populated if Nous login succeeds
|
||||
|
||||
if provider_idx == 0: # Nous Portal (OAuth)
|
||||
selected_provider = "nous"
|
||||
print()
|
||||
print_header("Nous Portal Login")
|
||||
print_info("This will open your browser to authenticate with Nous Portal.")
|
||||
print_info("You'll need a Nous Research account with an active subscription.")
|
||||
print()
|
||||
|
||||
try:
|
||||
from hermes_cli.auth import _login_nous, ProviderConfig
|
||||
import argparse
|
||||
|
||||
mock_args = argparse.Namespace(
|
||||
portal_url=None,
|
||||
inference_url=None,
|
||||
client_id=None,
|
||||
scope=None,
|
||||
no_browser=False,
|
||||
timeout=15.0,
|
||||
ca_bundle=None,
|
||||
insecure=False,
|
||||
)
|
||||
pconfig = PROVIDER_REGISTRY["nous"]
|
||||
_login_nous(mock_args, pconfig)
|
||||
_sync_model_from_disk(config)
|
||||
|
||||
# Fetch models for the selection step
|
||||
try:
|
||||
creds = resolve_nous_runtime_credentials(
|
||||
min_key_ttl_seconds=5 * 60,
|
||||
timeout_seconds=15.0,
|
||||
)
|
||||
nous_models = fetch_nous_models(
|
||||
inference_base_url=creds.get("base_url", ""),
|
||||
api_key=creds.get("api_key", ""),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Could not fetch Nous models after login: %s", e)
|
||||
|
||||
except SystemExit:
|
||||
print_warning("Nous Portal login was cancelled or failed.")
|
||||
print_info("You can try again later with: hermes model")
|
||||
selected_provider = None
|
||||
except Exception as e:
|
||||
print_error(f"Login failed: {e}")
|
||||
print_info("You can try again later with: hermes model")
|
||||
selected_provider = None
|
||||
|
||||
elif provider_idx == 1: # OpenAI Codex
|
||||
selected_provider = "openai-codex"
|
||||
print()
|
||||
print_header("OpenAI Codex Login")
|
||||
print()
|
||||
|
||||
try:
|
||||
import argparse
|
||||
|
||||
mock_args = argparse.Namespace()
|
||||
_login_openai_codex(mock_args, PROVIDER_REGISTRY["openai-codex"])
|
||||
# Clear custom endpoint vars that would override provider routing.
|
||||
if existing_custom:
|
||||
save_env_value("OPENAI_BASE_URL", "")
|
||||
save_env_value("OPENAI_API_KEY", "")
|
||||
_update_config_for_provider("openai-codex", DEFAULT_CODEX_BASE_URL)
|
||||
_set_model_provider(config, "openai-codex", DEFAULT_CODEX_BASE_URL)
|
||||
except SystemExit:
|
||||
print_warning("OpenAI Codex login was cancelled or failed.")
|
||||
print_info("You can try again later with: hermes model")
|
||||
selected_provider = None
|
||||
except Exception as e:
|
||||
print_error(f"Login failed: {e}")
|
||||
print_info("You can try again later with: hermes model")
|
||||
selected_provider = None
|
||||
|
||||
elif provider_idx == 2: # OpenRouter
|
||||
if provider_idx == 0: # OpenRouter
|
||||
selected_provider = "openrouter"
|
||||
print()
|
||||
print_header("OpenRouter API Key")
|
||||
@@ -1040,6 +961,80 @@ def setup_model_provider(config: dict):
|
||||
except Exception as e:
|
||||
logger.debug("Could not save provider to config.yaml: %s", e)
|
||||
|
||||
elif provider_idx == 1: # Nous Portal (OAuth)
|
||||
selected_provider = "nous"
|
||||
print()
|
||||
print_header("Nous Portal Login")
|
||||
print_info("This will open your browser to authenticate with Nous Portal.")
|
||||
print_info("You'll need a Nous Research account with an active subscription.")
|
||||
print()
|
||||
|
||||
try:
|
||||
from hermes_cli.auth import _login_nous
|
||||
import argparse
|
||||
|
||||
mock_args = argparse.Namespace(
|
||||
portal_url=None,
|
||||
inference_url=None,
|
||||
client_id=None,
|
||||
scope=None,
|
||||
no_browser=False,
|
||||
timeout=15.0,
|
||||
ca_bundle=None,
|
||||
insecure=False,
|
||||
)
|
||||
pconfig = PROVIDER_REGISTRY["nous"]
|
||||
_login_nous(mock_args, pconfig)
|
||||
_sync_model_from_disk(config)
|
||||
|
||||
# Fetch models for the selection step
|
||||
try:
|
||||
creds = resolve_nous_runtime_credentials(
|
||||
min_key_ttl_seconds=5 * 60,
|
||||
timeout_seconds=15.0,
|
||||
)
|
||||
nous_models = fetch_nous_models(
|
||||
inference_base_url=creds.get("base_url", ""),
|
||||
api_key=creds.get("api_key", ""),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Could not fetch Nous models after login: %s", e)
|
||||
|
||||
except SystemExit:
|
||||
print_warning("Nous Portal login was cancelled or failed.")
|
||||
print_info("You can try again later with: hermes model")
|
||||
selected_provider = None
|
||||
except Exception as e:
|
||||
print_error(f"Login failed: {e}")
|
||||
print_info("You can try again later with: hermes model")
|
||||
selected_provider = None
|
||||
|
||||
elif provider_idx == 2: # OpenAI Codex
|
||||
selected_provider = "openai-codex"
|
||||
print()
|
||||
print_header("OpenAI Codex Login")
|
||||
print()
|
||||
|
||||
try:
|
||||
import argparse
|
||||
|
||||
mock_args = argparse.Namespace()
|
||||
_login_openai_codex(mock_args, PROVIDER_REGISTRY["openai-codex"])
|
||||
# Clear custom endpoint vars that would override provider routing.
|
||||
if existing_custom:
|
||||
save_env_value("OPENAI_BASE_URL", "")
|
||||
save_env_value("OPENAI_API_KEY", "")
|
||||
_update_config_for_provider("openai-codex", DEFAULT_CODEX_BASE_URL)
|
||||
_set_model_provider(config, "openai-codex", DEFAULT_CODEX_BASE_URL)
|
||||
except SystemExit:
|
||||
print_warning("OpenAI Codex login was cancelled or failed.")
|
||||
print_info("You can try again later with: hermes model")
|
||||
selected_provider = None
|
||||
except Exception as e:
|
||||
print_error(f"Login failed: {e}")
|
||||
print_info("You can try again later with: hermes model")
|
||||
selected_provider = None
|
||||
|
||||
elif provider_idx == 3: # Custom endpoint
|
||||
selected_provider = "custom"
|
||||
print()
|
||||
@@ -2037,7 +2032,7 @@ def setup_terminal_backend(config: dict):
|
||||
|
||||
# Docker image
|
||||
current_image = config.get("terminal", {}).get(
|
||||
"docker_image", "python:3.11-slim"
|
||||
"docker_image", "nikolaik/python-nodejs:python3.11-nodejs20"
|
||||
)
|
||||
image = prompt(" Docker image", current_image)
|
||||
config["terminal"]["docker_image"] = image
|
||||
@@ -2059,7 +2054,7 @@ def setup_terminal_backend(config: dict):
|
||||
print_info(f"Found: {sing_bin}")
|
||||
|
||||
current_image = config.get("terminal", {}).get(
|
||||
"singularity_image", "docker://python:3.11-slim"
|
||||
"singularity_image", "docker://nikolaik/python-nodejs:python3.11-nodejs20"
|
||||
)
|
||||
image = prompt(" Container image", current_image)
|
||||
config["terminal"]["singularity_image"] = image
|
||||
@@ -2261,7 +2256,7 @@ def setup_agent_settings(config: dict):
|
||||
)
|
||||
print_info("Maximum tool-calling iterations per conversation.")
|
||||
print_info("Higher = more complex tasks, but costs more tokens.")
|
||||
print_info("Recommended: 30-60 for most tasks, 100+ for open exploration.")
|
||||
print_info("Default is 90, which works for most tasks. Use 150+ for open exploration.")
|
||||
|
||||
max_iter_str = prompt("Max iterations", current_max)
|
||||
try:
|
||||
@@ -2303,7 +2298,7 @@ def setup_agent_settings(config: dict):
|
||||
|
||||
config.setdefault("compression", {})["enabled"] = True
|
||||
|
||||
current_threshold = config.get("compression", {}).get("threshold", 0.85)
|
||||
current_threshold = config.get("compression", {}).get("threshold", 0.50)
|
||||
threshold_str = prompt("Compression threshold (0.5-0.95)", str(current_threshold))
|
||||
try:
|
||||
threshold = float(threshold_str)
|
||||
@@ -2313,7 +2308,7 @@ def setup_agent_settings(config: dict):
|
||||
pass
|
||||
|
||||
print_success(
|
||||
f"Context compression threshold set to {config['compression'].get('threshold', 0.85)}"
|
||||
f"Context compression threshold set to {config['compression'].get('threshold', 0.50)}"
|
||||
)
|
||||
|
||||
# ── Session Reset Policy ──
|
||||
@@ -2973,6 +2968,95 @@ def setup_tools(config: dict, first_install: bool = False):
|
||||
tools_command(first_install=first_install, config=config)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Post-Migration Section Skip Logic
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _get_section_config_summary(config: dict, section_key: str) -> Optional[str]:
|
||||
"""Return a short summary if a setup section is already configured, else None.
|
||||
|
||||
Used after OpenClaw migration to detect which sections can be skipped.
|
||||
``get_env_value`` is the module-level import from hermes_cli.config
|
||||
so that test patches on ``setup_mod.get_env_value`` take effect.
|
||||
"""
|
||||
if section_key == "model":
|
||||
has_key = bool(
|
||||
get_env_value("OPENROUTER_API_KEY")
|
||||
or get_env_value("OPENAI_API_KEY")
|
||||
or get_env_value("ANTHROPIC_API_KEY")
|
||||
)
|
||||
if not has_key:
|
||||
# Check for OAuth providers
|
||||
try:
|
||||
from hermes_cli.auth import get_active_provider
|
||||
if get_active_provider():
|
||||
has_key = True
|
||||
except Exception:
|
||||
pass
|
||||
if not has_key:
|
||||
return None
|
||||
model = config.get("model")
|
||||
if isinstance(model, str) and model.strip():
|
||||
return model.strip()
|
||||
if isinstance(model, dict):
|
||||
return str(model.get("default") or model.get("model") or "configured")
|
||||
return "configured"
|
||||
|
||||
elif section_key == "terminal":
|
||||
backend = config.get("terminal", {}).get("backend", "local")
|
||||
return f"backend: {backend}"
|
||||
|
||||
elif section_key == "agent":
|
||||
max_turns = config.get("agent", {}).get("max_turns", 90)
|
||||
return f"max turns: {max_turns}"
|
||||
|
||||
elif section_key == "gateway":
|
||||
platforms = []
|
||||
if get_env_value("TELEGRAM_BOT_TOKEN"):
|
||||
platforms.append("Telegram")
|
||||
if get_env_value("DISCORD_BOT_TOKEN"):
|
||||
platforms.append("Discord")
|
||||
if get_env_value("SLACK_BOT_TOKEN"):
|
||||
platforms.append("Slack")
|
||||
if get_env_value("WHATSAPP_PHONE_NUMBER_ID"):
|
||||
platforms.append("WhatsApp")
|
||||
if get_env_value("SIGNAL_ACCOUNT"):
|
||||
platforms.append("Signal")
|
||||
if platforms:
|
||||
return ", ".join(platforms)
|
||||
return None # No platforms configured — section must run
|
||||
|
||||
elif section_key == "tools":
|
||||
tools = []
|
||||
if get_env_value("ELEVENLABS_API_KEY"):
|
||||
tools.append("TTS/ElevenLabs")
|
||||
if get_env_value("BROWSERBASE_API_KEY"):
|
||||
tools.append("Browser")
|
||||
if get_env_value("FIRECRAWL_API_KEY"):
|
||||
tools.append("Firecrawl")
|
||||
if tools:
|
||||
return ", ".join(tools)
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _skip_configured_section(
|
||||
config: dict, section_key: str, label: str
|
||||
) -> bool:
|
||||
"""Show an already-configured section summary and offer to skip.
|
||||
|
||||
Returns True if the user chose to skip, False if the section should run.
|
||||
"""
|
||||
summary = _get_section_config_summary(config, section_key)
|
||||
if not summary:
|
||||
return False
|
||||
print()
|
||||
print_success(f" {label}: {summary}")
|
||||
return not prompt_yes_no(f" Reconfigure {label.lower()}?", default=False)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# OpenClaw Migration
|
||||
# =============================================================================
|
||||
@@ -3044,7 +3128,7 @@ def _offer_openclaw_migration(hermes_home: Path) -> bool:
|
||||
target_root=hermes_home.resolve(),
|
||||
execute=True,
|
||||
workspace_target=None,
|
||||
overwrite=False,
|
||||
overwrite=True,
|
||||
migrate_secrets=True,
|
||||
output_dir=None,
|
||||
selected_options=selected,
|
||||
@@ -3106,6 +3190,10 @@ def run_setup_wizard(args):
|
||||
hermes setup tools — just tool configuration
|
||||
hermes setup agent — just agent settings
|
||||
"""
|
||||
from hermes_cli.config import is_managed, managed_error
|
||||
if is_managed():
|
||||
managed_error("run setup wizard")
|
||||
return
|
||||
ensure_hermes_home()
|
||||
|
||||
config = load_config()
|
||||
@@ -3196,6 +3284,8 @@ def run_setup_wizard(args):
|
||||
)
|
||||
)
|
||||
|
||||
migration_ran = False
|
||||
|
||||
if is_existing:
|
||||
# ── Returning User Menu ──
|
||||
print()
|
||||
@@ -3235,12 +3325,17 @@ def run_setup_wizard(args):
|
||||
print_info("Exiting. Run 'hermes setup' again when ready.")
|
||||
return
|
||||
elif 3 <= choice <= 7:
|
||||
# Individual section
|
||||
section_idx = choice - 3
|
||||
_, label, func = SETUP_SECTIONS[section_idx]
|
||||
func(config)
|
||||
save_config(config)
|
||||
_print_setup_summary(config, hermes_home)
|
||||
# Individual section — map by key, not by position.
|
||||
# SETUP_SECTIONS includes TTS but the returning-user menu skips it,
|
||||
# so positional indexing (choice - 3) would dispatch the wrong section.
|
||||
_RETURNING_USER_SECTION_KEYS = ["model", "terminal", "gateway", "tools", "agent"]
|
||||
section_key = _RETURNING_USER_SECTION_KEYS[choice - 3]
|
||||
section = next((s for s in SETUP_SECTIONS if s[0] == section_key), None)
|
||||
if section:
|
||||
_, label, func = section
|
||||
func(config)
|
||||
save_config(config)
|
||||
_print_setup_summary(config, hermes_home)
|
||||
return
|
||||
else:
|
||||
# ── First-Time Setup ──
|
||||
@@ -3248,9 +3343,9 @@ def run_setup_wizard(args):
|
||||
print_info("We'll walk you through:")
|
||||
print_info(" 1. Model & Provider — choose your AI provider and model")
|
||||
print_info(" 2. Terminal Backend — where your agent runs commands")
|
||||
print_info(" 3. Messaging Platforms — connect Telegram, Discord, etc.")
|
||||
print_info(" 4. Tools — configure TTS, web search, image generation, etc.")
|
||||
print_info(" 5. Agent Settings — iterations, compression, session reset")
|
||||
print_info(" 3. Agent Settings — iterations, compression, session reset")
|
||||
print_info(" 4. Messaging Platforms — connect Telegram, Discord, etc.")
|
||||
print_info(" 5. Tools — configure TTS, web search, image generation, etc.")
|
||||
print()
|
||||
print_info("Press Enter to begin, or Ctrl+C to exit.")
|
||||
try:
|
||||
@@ -3260,7 +3355,8 @@ def run_setup_wizard(args):
|
||||
return
|
||||
|
||||
# Offer OpenClaw migration before configuration begins
|
||||
if _offer_openclaw_migration(hermes_home):
|
||||
migration_ran = _offer_openclaw_migration(hermes_home)
|
||||
if migration_ran:
|
||||
# Reload config in case migration wrote to it
|
||||
config = load_config()
|
||||
|
||||
@@ -3273,20 +3369,31 @@ def run_setup_wizard(args):
|
||||
print()
|
||||
print_info("You can edit these files directly or use 'hermes config edit'")
|
||||
|
||||
if migration_ran:
|
||||
print()
|
||||
print_info("Settings were imported from OpenClaw.")
|
||||
print_info("Each section below will show what was imported — press Enter to keep,")
|
||||
print_info("or choose to reconfigure if needed.")
|
||||
|
||||
# Section 1: Model & Provider
|
||||
setup_model_provider(config)
|
||||
if not (migration_ran and _skip_configured_section(config, "model", "Model & Provider")):
|
||||
setup_model_provider(config)
|
||||
|
||||
# Section 2: Terminal Backend
|
||||
setup_terminal_backend(config)
|
||||
if not (migration_ran and _skip_configured_section(config, "terminal", "Terminal Backend")):
|
||||
setup_terminal_backend(config)
|
||||
|
||||
# Section 3: Agent Settings
|
||||
setup_agent_settings(config)
|
||||
if not (migration_ran and _skip_configured_section(config, "agent", "Agent Settings")):
|
||||
setup_agent_settings(config)
|
||||
|
||||
# Section 4: Messaging Platforms
|
||||
setup_gateway(config)
|
||||
if not (migration_ran and _skip_configured_section(config, "gateway", "Messaging Platforms")):
|
||||
setup_gateway(config)
|
||||
|
||||
# Section 5: Tools
|
||||
setup_tools(config, first_install=not is_existing)
|
||||
if not (migration_ran and _skip_configured_section(config, "tools", "Tools")):
|
||||
setup_tools(config, first_install=not is_existing)
|
||||
|
||||
# Save and show summary
|
||||
save_config(config)
|
||||
@@ -3299,7 +3406,6 @@ def _run_quick_setup(config: dict, hermes_home):
|
||||
get_missing_env_vars,
|
||||
get_missing_config_fields,
|
||||
check_config_version,
|
||||
migrate_config,
|
||||
)
|
||||
|
||||
print()
|
||||
@@ -3438,9 +3544,9 @@ def _run_quick_setup(config: dict, hermes_home):
|
||||
value = prompt(f" {var.get('prompt', var['name'])}")
|
||||
if value:
|
||||
save_env_value(var["name"], value)
|
||||
print_success(f" ✓ Saved")
|
||||
print_success(" ✓ Saved")
|
||||
else:
|
||||
print_warning(f" Skipped")
|
||||
print_warning(" Skipped")
|
||||
print()
|
||||
|
||||
# Handle missing config fields
|
||||
|
||||
@@ -11,7 +11,7 @@ Config stored in ~/.hermes/config.yaml under:
|
||||
telegram: [skill-c]
|
||||
cli: []
|
||||
"""
|
||||
from typing import Dict, List, Optional, Set
|
||||
from typing import List, Optional, Set
|
||||
|
||||
from hermes_cli.config import load_config, save_config
|
||||
from hermes_cli.colors import Colors, color
|
||||
|
||||
@@ -186,7 +186,7 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
|
||||
Official skills are always shown first, regardless of source filter.
|
||||
"""
|
||||
from tools.skills_hub import (
|
||||
GitHubAuth, create_source_router, OptionalSkillSource, SkillMeta,
|
||||
GitHubAuth, create_source_router,
|
||||
)
|
||||
|
||||
# Clamp page_size to safe range
|
||||
@@ -357,7 +357,8 @@ def do_install(identifier: str, category: str = "", force: bool = False,
|
||||
|
||||
# Scan
|
||||
c.print("[bold]Running security scan...[/]")
|
||||
result = scan_skill(q_path, source=identifier)
|
||||
scan_source = getattr(bundle, "identifier", "") or getattr(meta, "identifier", "") or identifier
|
||||
result = scan_skill(q_path, source=scan_source)
|
||||
c.print(format_scan_report(result))
|
||||
|
||||
# Check install policy
|
||||
@@ -416,6 +417,13 @@ def do_install(identifier: str, category: str = "", force: bool = False,
|
||||
c.print(f"[bold green]Installed:[/] {install_dir.relative_to(SKILLS_DIR)}")
|
||||
c.print(f"[dim]Files: {', '.join(bundle.files.keys())}[/]\n")
|
||||
|
||||
# Invalidate the skills prompt cache so the new skill appears immediately
|
||||
try:
|
||||
from agent.prompt_builder import clear_skills_system_prompt_cache
|
||||
clear_skills_system_prompt_cache(clear_snapshot=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def do_inspect(identifier: str, console: Optional[Console] = None) -> None:
|
||||
"""Preview a skill's SKILL.md content without installing."""
|
||||
@@ -455,6 +463,8 @@ def do_inspect(identifier: str, console: Optional[Console] = None) -> None:
|
||||
|
||||
if bundle and "SKILL.md" in bundle.files:
|
||||
content = bundle.files["SKILL.md"]
|
||||
if isinstance(content, bytes):
|
||||
content = content.decode("utf-8", errors="replace")
|
||||
# Show first 50 lines as preview
|
||||
lines = content.split("\n")
|
||||
preview = "\n".join(lines[:50])
|
||||
@@ -620,6 +630,11 @@ def do_uninstall(name: str, console: Optional[Console] = None,
|
||||
success, msg = uninstall_skill(name)
|
||||
if success:
|
||||
c.print(f"[bold green]{msg}[/]\n")
|
||||
try:
|
||||
from agent.prompt_builder import clear_skills_system_prompt_cache
|
||||
clear_skills_system_prompt_cache(clear_snapshot=True)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
c.print(f"[bold red]Error:[/] {msg}\n")
|
||||
|
||||
@@ -640,7 +655,8 @@ def do_tap(action: str, repo: str = "", console: Optional[Console] = None) -> No
|
||||
table.add_column("Repo", style="bold cyan")
|
||||
table.add_column("Path", style="dim")
|
||||
for t in taps:
|
||||
table.add_row(t["repo"], t.get("path", "skills/"))
|
||||
label = t.get("repo") or t.get("name") or t.get("path", "unknown")
|
||||
table.add_row(label, t.get("path", "skills/"))
|
||||
c.print(table)
|
||||
c.print()
|
||||
|
||||
|
||||
@@ -101,6 +101,8 @@ from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -513,8 +515,7 @@ _active_skin_name: str = "default"
|
||||
|
||||
def _skins_dir() -> Path:
|
||||
"""User skins directory."""
|
||||
home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
return home / "skins"
|
||||
return get_hermes_home() / "skins"
|
||||
|
||||
|
||||
def _load_skin_from_yaml(path: Path) -> Optional[Dict[str, Any]]:
|
||||
|
||||
@@ -289,7 +289,7 @@ def show_status(args):
|
||||
)
|
||||
is_active = result.stdout.strip() == "active"
|
||||
print(f" Status: {check_mark(is_active)} {'running' if is_active else 'stopped'}")
|
||||
print(f" Manager: systemd (user)")
|
||||
print(" Manager: systemd (user)")
|
||||
|
||||
elif sys.platform == 'darwin':
|
||||
result = subprocess.run(
|
||||
@@ -299,10 +299,10 @@ def show_status(args):
|
||||
)
|
||||
is_loaded = result.returncode == 0
|
||||
print(f" Status: {check_mark(is_loaded)} {'loaded' if is_loaded else 'not loaded'}")
|
||||
print(f" Manager: launchd")
|
||||
print(" Manager: launchd")
|
||||
else:
|
||||
print(f" Status: {color('N/A', Colors.DIM)}")
|
||||
print(f" Manager: (not supported on this platform)")
|
||||
print(" Manager: (not supported on this platform)")
|
||||
|
||||
# =========================================================================
|
||||
# Cron Jobs
|
||||
@@ -320,9 +320,9 @@ def show_status(args):
|
||||
enabled_jobs = [j for j in jobs if j.get("enabled", True)]
|
||||
print(f" Jobs: {len(enabled_jobs)} active, {len(jobs)} total")
|
||||
except Exception:
|
||||
print(f" Jobs: (error reading jobs file)")
|
||||
print(" Jobs: (error reading jobs file)")
|
||||
else:
|
||||
print(f" Jobs: 0")
|
||||
print(" Jobs: 0")
|
||||
|
||||
# =========================================================================
|
||||
# Sessions
|
||||
@@ -338,9 +338,9 @@ def show_status(args):
|
||||
data = json.load(f)
|
||||
print(f" Active: {len(data)} session(s)")
|
||||
except Exception:
|
||||
print(f" Active: (error reading sessions file)")
|
||||
print(" Active: (error reading sessions file)")
|
||||
else:
|
||||
print(f" Active: 0")
|
||||
print(" Active: 0")
|
||||
|
||||
# =========================================================================
|
||||
# Deep checks
|
||||
|
||||
+184
-72
@@ -13,11 +13,9 @@ import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
import os
|
||||
|
||||
from hermes_cli.config import (
|
||||
load_config, save_config, get_env_value, save_env_value,
|
||||
get_hermes_home,
|
||||
)
|
||||
from hermes_cli.colors import Colors, color
|
||||
|
||||
@@ -101,6 +99,30 @@ CONFIGURABLE_TOOLSETS = [
|
||||
# but the setup checklist won't pre-select them for first-time users.
|
||||
_DEFAULT_OFF_TOOLSETS = {"moa", "homeassistant", "rl"}
|
||||
|
||||
|
||||
def _get_effective_configurable_toolsets():
|
||||
"""Return CONFIGURABLE_TOOLSETS + any plugin-provided toolsets.
|
||||
|
||||
Plugin toolsets are appended at the end so they appear after the
|
||||
built-in toolsets in the TUI checklist.
|
||||
"""
|
||||
result = list(CONFIGURABLE_TOOLSETS)
|
||||
try:
|
||||
from hermes_cli.plugins import get_plugin_toolsets
|
||||
result.extend(get_plugin_toolsets())
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
|
||||
|
||||
def _get_plugin_toolset_keys() -> set:
|
||||
"""Return the set of toolset keys provided by plugins."""
|
||||
try:
|
||||
from hermes_cli.plugins import get_plugin_toolsets
|
||||
return {ts_key for ts_key, _, _ in get_plugin_toolsets()}
|
||||
except Exception:
|
||||
return set()
|
||||
|
||||
# Platform display config
|
||||
PLATFORMS = {
|
||||
"cli": {"label": "🖥️ CLI", "default_toolset": "hermes-cli"},
|
||||
@@ -109,8 +131,10 @@ PLATFORMS = {
|
||||
"slack": {"label": "💼 Slack", "default_toolset": "hermes-slack"},
|
||||
"whatsapp": {"label": "📱 WhatsApp", "default_toolset": "hermes-whatsapp"},
|
||||
"signal": {"label": "📡 Signal", "default_toolset": "hermes-signal"},
|
||||
"homeassistant": {"label": "🏠 Home Assistant", "default_toolset": "hermes-homeassistant"},
|
||||
"email": {"label": "📧 Email", "default_toolset": "hermes-email"},
|
||||
"dingtalk": {"label": "💬 DingTalk", "default_toolset": "hermes-dingtalk"},
|
||||
"api_server": {"label": "🌐 API Server", "default_toolset": "hermes-api-server"},
|
||||
}
|
||||
|
||||
|
||||
@@ -356,9 +380,31 @@ def _platform_toolset_summary(config: dict, platforms: Optional[List[str]] = Non
|
||||
return summary
|
||||
|
||||
|
||||
def _get_platform_tools(config: dict, platform: str) -> Set[str]:
|
||||
def _parse_enabled_flag(value, default: bool = True) -> bool:
|
||||
"""Parse bool-like config values used by tool/platform settings."""
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, int):
|
||||
return value != 0
|
||||
if isinstance(value, str):
|
||||
lowered = value.strip().lower()
|
||||
if lowered in {"true", "1", "yes", "on"}:
|
||||
return True
|
||||
if lowered in {"false", "0", "no", "off"}:
|
||||
return False
|
||||
return default
|
||||
|
||||
|
||||
def _get_platform_tools(
|
||||
config: dict,
|
||||
platform: str,
|
||||
*,
|
||||
include_default_mcp_servers: bool = True,
|
||||
) -> Set[str]:
|
||||
"""Resolve which individual toolset names are enabled for a platform."""
|
||||
from toolsets import resolve_toolset, TOOLSETS
|
||||
from toolsets import resolve_toolset
|
||||
|
||||
platform_toolsets = config.get("platform_toolsets", {})
|
||||
toolset_names = platform_toolsets.get(platform)
|
||||
@@ -377,19 +423,67 @@ def _get_platform_tools(config: dict, platform: str) -> Set[str]:
|
||||
has_explicit_config = any(ts in configurable_keys for ts in toolset_names)
|
||||
|
||||
if has_explicit_config:
|
||||
return {ts for ts in toolset_names if ts in configurable_keys}
|
||||
enabled_toolsets = {ts for ts in toolset_names if ts in configurable_keys}
|
||||
else:
|
||||
# No explicit config — fall back to resolving composite toolset names
|
||||
# (e.g. "hermes-cli") to individual tool names and reverse-mapping.
|
||||
all_tool_names = set()
|
||||
for ts_name in toolset_names:
|
||||
all_tool_names.update(resolve_toolset(ts_name))
|
||||
|
||||
# No explicit config — fall back to resolving composite toolset names
|
||||
# (e.g. "hermes-cli") to individual tool names and reverse-mapping.
|
||||
all_tool_names = set()
|
||||
for ts_name in toolset_names:
|
||||
all_tool_names.update(resolve_toolset(ts_name))
|
||||
enabled_toolsets = set()
|
||||
for ts_key, _, _ in CONFIGURABLE_TOOLSETS:
|
||||
ts_tools = set(resolve_toolset(ts_key))
|
||||
if ts_tools and ts_tools.issubset(all_tool_names):
|
||||
enabled_toolsets.add(ts_key)
|
||||
|
||||
enabled_toolsets = set()
|
||||
for ts_key, _, _ in CONFIGURABLE_TOOLSETS:
|
||||
ts_tools = set(resolve_toolset(ts_key))
|
||||
if ts_tools and ts_tools.issubset(all_tool_names):
|
||||
enabled_toolsets.add(ts_key)
|
||||
# Plugin toolsets: enabled by default unless explicitly disabled.
|
||||
# A plugin toolset is "known" for a platform once `hermes tools`
|
||||
# has been saved for that platform (tracked via known_plugin_toolsets).
|
||||
# Unknown plugins default to enabled; known-but-absent = disabled.
|
||||
plugin_ts_keys = _get_plugin_toolset_keys()
|
||||
if plugin_ts_keys:
|
||||
known_map = config.get("known_plugin_toolsets", {})
|
||||
known_for_platform = set(known_map.get(platform, []))
|
||||
for pts in plugin_ts_keys:
|
||||
if pts in toolset_names:
|
||||
# Explicitly listed in config — enabled
|
||||
enabled_toolsets.add(pts)
|
||||
elif pts not in known_for_platform:
|
||||
# New plugin not yet seen by hermes tools — default enabled
|
||||
enabled_toolsets.add(pts)
|
||||
# else: known but not in config = user disabled it
|
||||
|
||||
# Preserve any explicit non-configurable toolset entries (for example,
|
||||
# custom toolsets or MCP server names saved in platform_toolsets).
|
||||
platform_default_keys = {p["default_toolset"] for p in PLATFORMS.values()}
|
||||
explicit_passthrough = {
|
||||
ts
|
||||
for ts in toolset_names
|
||||
if ts not in configurable_keys
|
||||
and ts not in plugin_ts_keys
|
||||
and ts not in platform_default_keys
|
||||
}
|
||||
|
||||
# MCP servers are expected to be available on all platforms by default.
|
||||
# If the platform explicitly lists one or more MCP server names, treat that
|
||||
# as an allowlist. Otherwise include every globally enabled MCP server.
|
||||
mcp_servers = config.get("mcp_servers", {})
|
||||
enabled_mcp_servers = {
|
||||
name
|
||||
for name, server_cfg in mcp_servers.items()
|
||||
if isinstance(server_cfg, dict)
|
||||
and _parse_enabled_flag(server_cfg.get("enabled", True), default=True)
|
||||
}
|
||||
explicit_mcp_servers = explicit_passthrough & enabled_mcp_servers
|
||||
enabled_toolsets.update(explicit_passthrough - enabled_mcp_servers)
|
||||
if include_default_mcp_servers:
|
||||
if explicit_mcp_servers:
|
||||
enabled_toolsets.update(explicit_mcp_servers)
|
||||
else:
|
||||
enabled_toolsets.update(enabled_mcp_servers)
|
||||
else:
|
||||
enabled_toolsets.update(explicit_mcp_servers)
|
||||
|
||||
return enabled_toolsets
|
||||
|
||||
@@ -397,41 +491,42 @@ def _get_platform_tools(config: dict, platform: str) -> Set[str]:
|
||||
def _save_platform_tools(config: dict, platform: str, enabled_toolset_keys: Set[str]):
|
||||
"""Save the selected toolset keys for a platform to config.
|
||||
|
||||
Preserves any non-configurable, non-composite entries (like MCP server
|
||||
names) that were already in the config for this platform.
|
||||
|
||||
Composite platform toolsets (hermes-cli, hermes-telegram, etc.) are
|
||||
dropped once the user has explicitly configured individual toolsets —
|
||||
keeping them would override the user's selections because they include
|
||||
all tools via _HERMES_CORE_TOOLS.
|
||||
Preserves any non-configurable toolset entries (like MCP server names)
|
||||
that were already in the config for this platform.
|
||||
"""
|
||||
from toolsets import TOOLSETS
|
||||
|
||||
config.setdefault("platform_toolsets", {})
|
||||
|
||||
# Keys the user can toggle in the checklist UI
|
||||
# Get the set of all configurable toolset keys (built-in + plugin)
|
||||
configurable_keys = {ts_key for ts_key, _, _ in CONFIGURABLE_TOOLSETS}
|
||||
plugin_keys = _get_plugin_toolset_keys()
|
||||
configurable_keys |= plugin_keys
|
||||
|
||||
# Keys that are known composite/individual toolsets in toolsets.py
|
||||
# (hermes-cli, hermes-telegram, homeassistant, web, terminal, etc.)
|
||||
known_toolset_keys = set(TOOLSETS.keys())
|
||||
# Also exclude platform default toolsets (hermes-cli, hermes-telegram, etc.)
|
||||
# These are "super" toolsets that resolve to ALL tools, so preserving them
|
||||
# would silently override the user's unchecked selections on the next read.
|
||||
platform_default_keys = {p["default_toolset"] for p in PLATFORMS.values()}
|
||||
|
||||
# Get existing toolsets for this platform
|
||||
existing_toolsets = config.get("platform_toolsets", {}).get(platform, [])
|
||||
if not isinstance(existing_toolsets, list):
|
||||
existing_toolsets = []
|
||||
|
||||
# Preserve entries that are neither configurable toolsets nor known
|
||||
# composite toolsets — this keeps MCP server names and other custom
|
||||
# entries while dropping composites like "hermes-cli" that would
|
||||
# silently re-enable everything the user just disabled.
|
||||
# Preserve any entries that are NOT configurable toolsets and NOT platform
|
||||
# defaults (i.e. only MCP server names should be preserved)
|
||||
preserved_entries = {
|
||||
entry for entry in existing_toolsets
|
||||
if entry not in configurable_keys and entry not in known_toolset_keys
|
||||
if entry not in configurable_keys and entry not in platform_default_keys
|
||||
}
|
||||
|
||||
# Merge preserved entries with new enabled toolsets
|
||||
config["platform_toolsets"][platform] = sorted(enabled_toolset_keys | preserved_entries)
|
||||
|
||||
# Track which plugin toolsets are "known" for this platform so we can
|
||||
# distinguish "new plugin, default enabled" from "user disabled it".
|
||||
if plugin_keys:
|
||||
config.setdefault("known_plugin_toolsets", {})
|
||||
config["known_plugin_toolsets"][platform] = sorted(plugin_keys)
|
||||
|
||||
save_config(config)
|
||||
|
||||
|
||||
@@ -549,15 +644,17 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str
|
||||
"""Multi-select checklist of toolsets. Returns set of selected toolset keys."""
|
||||
from hermes_cli.curses_ui import curses_checklist
|
||||
|
||||
effective = _get_effective_configurable_toolsets()
|
||||
|
||||
labels = []
|
||||
for ts_key, ts_label, ts_desc in CONFIGURABLE_TOOLSETS:
|
||||
for ts_key, ts_label, ts_desc in effective:
|
||||
suffix = ""
|
||||
if not _toolset_has_keys(ts_key) and (TOOL_CATEGORIES.get(ts_key) or TOOLSET_ENV_REQUIREMENTS.get(ts_key)):
|
||||
suffix = " [no API key]"
|
||||
labels.append(f"{ts_label} ({ts_desc}){suffix}")
|
||||
|
||||
pre_selected = {
|
||||
i for i, (ts_key, _, _) in enumerate(CONFIGURABLE_TOOLSETS)
|
||||
i for i, (ts_key, _, _) in enumerate(effective)
|
||||
if ts_key in enabled
|
||||
}
|
||||
|
||||
@@ -567,7 +664,7 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str
|
||||
pre_selected,
|
||||
cancel_returns=pre_selected,
|
||||
)
|
||||
return {CONFIGURABLE_TOOLSETS[i][0] for i in chosen}
|
||||
return {effective[i][0] for i in chosen}
|
||||
|
||||
|
||||
# ─── Provider-Aware Configuration ────────────────────────────────────────────
|
||||
@@ -617,7 +714,7 @@ def _configure_tool_category(ts_key: str, cat: dict, config: dict):
|
||||
# Multiple providers - let user choose
|
||||
print()
|
||||
# Use custom title if provided (e.g. "Select Search Provider")
|
||||
title = cat.get("setup_title", f"Choose a provider")
|
||||
title = cat.get("setup_title", "Choose a provider")
|
||||
print(color(f" --- {icon} {name} - {title} ---", Colors.CYAN))
|
||||
if cat.get("setup_note"):
|
||||
_print_info(f" {cat['setup_note']}")
|
||||
@@ -726,9 +823,9 @@ def _configure_provider(provider: dict, config: dict):
|
||||
|
||||
if value:
|
||||
save_env_value(var["key"], value)
|
||||
_print_success(f" Saved")
|
||||
_print_success(" Saved")
|
||||
else:
|
||||
_print_warning(f" Skipped")
|
||||
_print_warning(" Skipped")
|
||||
all_configured = False
|
||||
|
||||
# Run post-setup hooks if needed
|
||||
@@ -782,7 +879,7 @@ def _configure_simple_requirements(ts_key: str):
|
||||
if not missing:
|
||||
return
|
||||
|
||||
ts_label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts_key), ts_key)
|
||||
ts_label = next((l for k, l, _ in _get_effective_configurable_toolsets() if k == ts_key), ts_key)
|
||||
print()
|
||||
print(color(f" {ts_label} requires configuration:", Colors.YELLOW))
|
||||
|
||||
@@ -792,16 +889,16 @@ def _configure_simple_requirements(ts_key: str):
|
||||
value = _prompt(f" {var}", password=True)
|
||||
if value and value.strip():
|
||||
save_env_value(var, value.strip())
|
||||
_print_success(f" Saved")
|
||||
_print_success(" Saved")
|
||||
else:
|
||||
_print_warning(f" Skipped")
|
||||
_print_warning(" Skipped")
|
||||
|
||||
|
||||
def _reconfigure_tool(config: dict):
|
||||
"""Let user reconfigure an existing tool's provider or API key."""
|
||||
# Build list of configurable tools that are currently set up
|
||||
configurable = []
|
||||
for ts_key, ts_label, _ in CONFIGURABLE_TOOLSETS:
|
||||
for ts_key, ts_label, _ in _get_effective_configurable_toolsets():
|
||||
cat = TOOL_CATEGORIES.get(ts_key)
|
||||
reqs = TOOLSET_ENV_REQUIREMENTS.get(ts_key)
|
||||
if cat or reqs:
|
||||
@@ -882,7 +979,7 @@ def _reconfigure_provider(provider: dict, config: dict):
|
||||
_print_success(f" Browser cloud provider set to: {bp}")
|
||||
else:
|
||||
config.get("browser", {}).pop("cloud_provider", None)
|
||||
_print_success(f" Browser set to local mode")
|
||||
_print_success(" Browser set to local mode")
|
||||
|
||||
# Set web search backend in config if applicable
|
||||
if provider.get("web_backend"):
|
||||
@@ -904,9 +1001,9 @@ def _reconfigure_provider(provider: dict, config: dict):
|
||||
value = _prompt(f" {var.get('prompt', var['key'])} (Enter to keep current)", password=not default_val)
|
||||
if value and value.strip():
|
||||
save_env_value(var["key"], value.strip())
|
||||
_print_success(f" Updated")
|
||||
_print_success(" Updated")
|
||||
else:
|
||||
_print_info(f" Kept current")
|
||||
_print_info(" Kept current")
|
||||
|
||||
|
||||
def _reconfigure_simple_requirements(ts_key: str):
|
||||
@@ -915,7 +1012,7 @@ def _reconfigure_simple_requirements(ts_key: str):
|
||||
if not requirements:
|
||||
return
|
||||
|
||||
ts_label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts_key), ts_key)
|
||||
ts_label = next((l for k, l, _ in _get_effective_configurable_toolsets() if k == ts_key), ts_key)
|
||||
print()
|
||||
print(color(f" {ts_label}:", Colors.CYAN))
|
||||
|
||||
@@ -928,9 +1025,9 @@ def _reconfigure_simple_requirements(ts_key: str):
|
||||
value = _prompt(f" {var} (Enter to keep current)", password=True)
|
||||
if value and value.strip():
|
||||
save_env_value(var, value.strip())
|
||||
_print_success(f" Updated")
|
||||
_print_success(" Updated")
|
||||
else:
|
||||
_print_info(f" Kept current")
|
||||
_print_info(" Kept current")
|
||||
|
||||
|
||||
# ─── Main Entry Point ─────────────────────────────────────────────────────────
|
||||
@@ -954,7 +1051,7 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
|
||||
# Non-interactive summary mode for CLI usage
|
||||
if getattr(args, "summary", False):
|
||||
total = len(CONFIGURABLE_TOOLSETS)
|
||||
total = len(_get_effective_configurable_toolsets())
|
||||
print(color("⚕ Tool Summary", Colors.CYAN, Colors.BOLD))
|
||||
print()
|
||||
summary = _platform_toolset_summary(config, enabled_platforms)
|
||||
@@ -965,7 +1062,7 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
print(color(f" {pinfo['label']}", Colors.BOLD) + color(f" ({count}/{total})", Colors.DIM))
|
||||
if enabled:
|
||||
for ts_key in sorted(enabled):
|
||||
label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts_key), ts_key)
|
||||
label = next((l for k, l, _ in _get_effective_configurable_toolsets() if k == ts_key), ts_key)
|
||||
print(color(f" ✓ {label}", Colors.GREEN))
|
||||
else:
|
||||
print(color(" (none enabled)", Colors.DIM))
|
||||
@@ -980,7 +1077,7 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
if first_install:
|
||||
for pkey in enabled_platforms:
|
||||
pinfo = PLATFORMS[pkey]
|
||||
current_enabled = _get_platform_tools(config, pkey)
|
||||
current_enabled = _get_platform_tools(config, pkey, include_default_mcp_servers=False)
|
||||
|
||||
# Uncheck toolsets that should be off by default
|
||||
checklist_preselected = current_enabled - _DEFAULT_OFF_TOOLSETS
|
||||
@@ -992,11 +1089,11 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
removed = current_enabled - new_enabled
|
||||
if added:
|
||||
for ts in sorted(added):
|
||||
label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts), ts)
|
||||
label = next((l for k, l, _ in _get_effective_configurable_toolsets() if k == ts), ts)
|
||||
print(color(f" + {label}", Colors.GREEN))
|
||||
if removed:
|
||||
for ts in sorted(removed):
|
||||
label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts), ts)
|
||||
label = next((l for k, l, _ in _get_effective_configurable_toolsets() if k == ts), ts)
|
||||
print(color(f" - {label}", Colors.RED))
|
||||
|
||||
# Walk through ALL selected tools that have provider options or
|
||||
@@ -1012,7 +1109,7 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
print()
|
||||
print(color(f" Configuring {len(to_configure)} tool(s):", Colors.YELLOW))
|
||||
for ts_key in to_configure:
|
||||
label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts_key), ts_key)
|
||||
label = next((l for k, l, _ in _get_effective_configurable_toolsets() if k == ts_key), ts_key)
|
||||
print(color(f" • {label}", Colors.DIM))
|
||||
print(color(" You can skip any tool you don't need right now.", Colors.DIM))
|
||||
print()
|
||||
@@ -1032,9 +1129,9 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
platform_keys = []
|
||||
for pkey in enabled_platforms:
|
||||
pinfo = PLATFORMS[pkey]
|
||||
current = _get_platform_tools(config, pkey)
|
||||
current = _get_platform_tools(config, pkey, include_default_mcp_servers=False)
|
||||
count = len(current)
|
||||
total = len(CONFIGURABLE_TOOLSETS)
|
||||
total = len(_get_effective_configurable_toolsets())
|
||||
platform_choices.append(f"Configure {pinfo['label']} ({count}/{total} enabled)")
|
||||
platform_keys.append(pkey)
|
||||
|
||||
@@ -1079,21 +1176,21 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
# Use the union of all platforms' current tools as the starting state
|
||||
all_current = set()
|
||||
for pk in platform_keys:
|
||||
all_current |= _get_platform_tools(config, pk)
|
||||
all_current |= _get_platform_tools(config, pk, include_default_mcp_servers=False)
|
||||
new_enabled = _prompt_toolset_checklist("All platforms", all_current)
|
||||
if new_enabled != all_current:
|
||||
for pk in platform_keys:
|
||||
prev = _get_platform_tools(config, pk)
|
||||
prev = _get_platform_tools(config, pk, include_default_mcp_servers=False)
|
||||
added = new_enabled - prev
|
||||
removed = prev - new_enabled
|
||||
pinfo_inner = PLATFORMS[pk]
|
||||
if added or removed:
|
||||
print(color(f" {pinfo_inner['label']}:", Colors.DIM))
|
||||
for ts in sorted(added):
|
||||
label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts), ts)
|
||||
label = next((l for k, l, _ in _get_effective_configurable_toolsets() if k == ts), ts)
|
||||
print(color(f" + {label}", Colors.GREEN))
|
||||
for ts in sorted(removed):
|
||||
label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts), ts)
|
||||
label = next((l for k, l, _ in _get_effective_configurable_toolsets() if k == ts), ts)
|
||||
print(color(f" - {label}", Colors.RED))
|
||||
# Configure API keys for newly enabled tools
|
||||
for ts_key in sorted(added):
|
||||
@@ -1105,8 +1202,8 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
print(color(" ✓ Saved configuration for all platforms", Colors.GREEN))
|
||||
# Update choice labels
|
||||
for ci, pk in enumerate(platform_keys):
|
||||
new_count = len(_get_platform_tools(config, pk))
|
||||
total = len(CONFIGURABLE_TOOLSETS)
|
||||
new_count = len(_get_platform_tools(config, pk, include_default_mcp_servers=False))
|
||||
total = len(_get_effective_configurable_toolsets())
|
||||
platform_choices[ci] = f"Configure {PLATFORMS[pk]['label']} ({new_count}/{total} enabled)"
|
||||
else:
|
||||
print(color(" No changes", Colors.DIM))
|
||||
@@ -1117,7 +1214,7 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
pinfo = PLATFORMS[pkey]
|
||||
|
||||
# Get current enabled toolsets for this platform
|
||||
current_enabled = _get_platform_tools(config, pkey)
|
||||
current_enabled = _get_platform_tools(config, pkey, include_default_mcp_servers=False)
|
||||
|
||||
# Show checklist
|
||||
new_enabled = _prompt_toolset_checklist(pinfo["label"], current_enabled)
|
||||
@@ -1128,11 +1225,11 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
|
||||
if added:
|
||||
for ts in sorted(added):
|
||||
label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts), ts)
|
||||
label = next((l for k, l, _ in _get_effective_configurable_toolsets() if k == ts), ts)
|
||||
print(color(f" + {label}", Colors.GREEN))
|
||||
if removed:
|
||||
for ts in sorted(removed):
|
||||
label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts), ts)
|
||||
label = next((l for k, l, _ in _get_effective_configurable_toolsets() if k == ts), ts)
|
||||
print(color(f" - {label}", Colors.RED))
|
||||
|
||||
# Configure newly enabled toolsets that need API keys
|
||||
@@ -1150,8 +1247,8 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
print()
|
||||
|
||||
# Update the choice label with new count
|
||||
new_count = len(_get_platform_tools(config, pkey))
|
||||
total = len(CONFIGURABLE_TOOLSETS)
|
||||
new_count = len(_get_platform_tools(config, pkey, include_default_mcp_servers=False))
|
||||
total = len(_get_effective_configurable_toolsets())
|
||||
platform_choices[idx] = f"Configure {pinfo['label']} ({new_count}/{total} enabled)"
|
||||
|
||||
print()
|
||||
@@ -1296,7 +1393,7 @@ def _configure_mcp_tools_interactive(config: dict):
|
||||
|
||||
def _apply_toolset_change(config: dict, platform: str, toolset_names: List[str], action: str):
|
||||
"""Add or remove built-in toolsets for a platform."""
|
||||
enabled = _get_platform_tools(config, platform)
|
||||
enabled = _get_platform_tools(config, platform, include_default_mcp_servers=False)
|
||||
if action == "disable":
|
||||
updated = enabled - set(toolset_names)
|
||||
else:
|
||||
@@ -1331,12 +1428,27 @@ def _apply_mcp_change(config: dict, targets: List[str], action: str) -> Set[str]
|
||||
|
||||
def _print_tools_list(enabled_toolsets: set, mcp_servers: dict, platform: str = "cli"):
|
||||
"""Print a summary of enabled/disabled toolsets and MCP tool filters."""
|
||||
effective = _get_effective_configurable_toolsets()
|
||||
builtin_keys = {ts_key for ts_key, _, _ in CONFIGURABLE_TOOLSETS}
|
||||
|
||||
print(f"Built-in toolsets ({platform}):")
|
||||
for ts_key, label, _ in CONFIGURABLE_TOOLSETS:
|
||||
for ts_key, label, _ in effective:
|
||||
if ts_key not in builtin_keys:
|
||||
continue
|
||||
status = (color("✓ enabled", Colors.GREEN) if ts_key in enabled_toolsets
|
||||
else color("✗ disabled", Colors.RED))
|
||||
print(f" {status} {ts_key} {color(label, Colors.DIM)}")
|
||||
|
||||
# Plugin toolsets
|
||||
plugin_entries = [(k, l) for k, l, _ in effective if k not in builtin_keys]
|
||||
if plugin_entries:
|
||||
print()
|
||||
print(f"Plugin toolsets ({platform}):")
|
||||
for ts_key, label in plugin_entries:
|
||||
status = (color("✓ enabled", Colors.GREEN) if ts_key in enabled_toolsets
|
||||
else color("✗ disabled", Colors.RED))
|
||||
print(f" {status} {ts_key} {color(label, Colors.DIM)}")
|
||||
|
||||
if mcp_servers:
|
||||
print()
|
||||
print("MCP servers:")
|
||||
@@ -1367,7 +1479,7 @@ def tools_disable_enable_command(args):
|
||||
return
|
||||
|
||||
if action == "list":
|
||||
_print_tools_list(_get_platform_tools(config, platform),
|
||||
_print_tools_list(_get_platform_tools(config, platform, include_default_mcp_servers=False),
|
||||
config.get("mcp_servers") or {}, platform)
|
||||
return
|
||||
|
||||
@@ -1375,7 +1487,7 @@ def tools_disable_enable_command(args):
|
||||
toolset_targets = [t for t in targets if ":" not in t]
|
||||
mcp_targets = [t for t in targets if ":" in t]
|
||||
|
||||
valid_toolsets = {ts_key for ts_key, _, _ in CONFIGURABLE_TOOLSETS}
|
||||
valid_toolsets = {ts_key for ts_key, _, _ in CONFIGURABLE_TOOLSETS} | _get_plugin_toolset_keys()
|
||||
unknown_toolsets = [t for t in toolset_targets if t not in valid_toolsets]
|
||||
if unknown_toolsets:
|
||||
for name in unknown_toolsets:
|
||||
|
||||
@@ -7,11 +7,11 @@ Provides options for:
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
|
||||
from hermes_cli.colors import Colors, color
|
||||
|
||||
@@ -33,11 +33,6 @@ def get_project_root() -> Path:
|
||||
return Path(__file__).parent.parent.resolve()
|
||||
|
||||
|
||||
def get_hermes_home() -> Path:
|
||||
"""Get the Hermes home directory (~/.hermes)."""
|
||||
return Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
|
||||
|
||||
def find_shell_configs() -> list:
|
||||
"""Find shell configuration files that might have PATH entries."""
|
||||
home = Path.home()
|
||||
@@ -278,7 +273,7 @@ def run_uninstall(args):
|
||||
log_info("No wrapper script found")
|
||||
|
||||
# 4. Remove installation directory (code)
|
||||
log_info(f"Removing installation directory...")
|
||||
log_info("Removing installation directory...")
|
||||
|
||||
# Check if we're running from within the install dir
|
||||
# We need to be careful here
|
||||
|
||||
@@ -4,6 +4,40 @@ Import-safe module with no dependencies — can be imported from anywhere
|
||||
without risk of circular imports.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def get_hermes_home() -> Path:
|
||||
"""Return the Hermes home directory (default: ~/.hermes).
|
||||
|
||||
Reads HERMES_HOME env var, falls back to ~/.hermes.
|
||||
This is the single source of truth — all other copies should import this.
|
||||
"""
|
||||
return Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
|
||||
|
||||
VALID_REASONING_EFFORTS = ("xhigh", "high", "medium", "low", "minimal")
|
||||
|
||||
|
||||
def parse_reasoning_effort(effort: str) -> dict | None:
|
||||
"""Parse a reasoning effort level into a config dict.
|
||||
|
||||
Valid levels: "xhigh", "high", "medium", "low", "minimal", "none".
|
||||
Returns None when the input is empty or unrecognized (caller uses default).
|
||||
Returns {"enabled": False} for "none".
|
||||
Returns {"enabled": True, "effort": <level>} for valid effort levels.
|
||||
"""
|
||||
if not effort or not effort.strip():
|
||||
return None
|
||||
effort = effort.strip().lower()
|
||||
if effort == "none":
|
||||
return {"enabled": False}
|
||||
if effort in VALID_REASONING_EFFORTS:
|
||||
return {"enabled": True, "effort": effort}
|
||||
return None
|
||||
|
||||
|
||||
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
|
||||
OPENROUTER_MODELS_URL = f"{OPENROUTER_BASE_URL}/models"
|
||||
OPENROUTER_CHAT_URL = f"{OPENROUTER_BASE_URL}/chat/completions"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user