Compare commits
306 Commits
lsp-plugin
...
feat/sessi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ce0f4838b0 | ||
|
|
2ecad49113 | ||
|
|
8245173d61 | ||
|
|
327e577acf | ||
|
|
b5996b6451 | ||
|
|
ef10d2e7c9 | ||
|
|
d5416284f1 | ||
|
|
af1ea1f4ed | ||
|
|
db84a78e61 | ||
|
|
f199cd9f84 | ||
|
|
77276070f5 | ||
|
|
274217316e | ||
|
|
13c72fb486 | ||
|
|
6af9942327 | ||
|
|
8373956850 | ||
|
|
94bdc63ff5 | ||
|
|
eacb398f75 | ||
|
|
5301cc212b | ||
|
|
c4a21d7831 | ||
|
|
59c7cc64f0 | ||
|
|
55f3262e78 | ||
|
|
5360b54244 | ||
|
|
647cc0bb0d | ||
|
|
4f8aaf1046 | ||
|
|
b6e07417c5 | ||
|
|
47614dbfca | ||
|
|
09d9724a09 | ||
|
|
85782a4ed7 | ||
|
|
9f57f2286d | ||
|
|
6682f91b80 | ||
|
|
05d9f641c0 | ||
|
|
9329e06696 | ||
|
|
04b1fdaecf | ||
|
|
681778a0b7 | ||
|
|
0161d4bb6c | ||
|
|
814c60092b | ||
|
|
23ac522d37 | ||
|
|
e0e7397c32 | ||
|
|
e0e4856d46 | ||
|
|
0086cdaf93 | ||
|
|
fc2754dbdf | ||
|
|
3df26b925c | ||
|
|
80efe664ce | ||
|
|
d57a4b3eb5 | ||
|
|
6bdad1f3b2 | ||
|
|
f9ad7400e3 | ||
|
|
965ae7fa97 | ||
|
|
cbd1f8e4be | ||
|
|
f8745f59c2 | ||
|
|
bcca5ed34d | ||
|
|
c8c6ce1731 | ||
|
|
5af672c753 | ||
|
|
d364132114 | ||
|
|
4c94396206 | ||
|
|
e8b9f5ff9a | ||
|
|
d3d5916089 | ||
|
|
eabd8c1fd1 | ||
|
|
4695d2716f | ||
|
|
8ed2ef6f46 | ||
|
|
1702a94c88 | ||
|
|
55622b5525 | ||
|
|
74e47c081f | ||
|
|
d6c488f2dc | ||
|
|
09d970160b | ||
|
|
db82c453b9 | ||
|
|
38ea2a57a5 | ||
|
|
0854640537 | ||
|
|
19071529f6 | ||
|
|
ed84637d11 | ||
|
|
4abfb6bc24 | ||
|
|
e84fe483bc | ||
|
|
ccb5aae0d2 | ||
|
|
34fc94d1f4 | ||
|
|
4813aaf0ba | ||
|
|
2844c888f1 | ||
|
|
f491b07cb2 | ||
|
|
ac64d0c2ca | ||
|
|
6244535682 | ||
|
|
7bf66a07bd | ||
|
|
06c6c1f0f2 | ||
|
|
fe83c4001b | ||
|
|
a28add199d | ||
|
|
bc42e62b17 | ||
|
|
b4b8509fe8 | ||
|
|
d44dafdb4e | ||
|
|
5ce0067c08 | ||
|
|
cd64bed55e | ||
|
|
9ed751b967 | ||
|
|
b08f53a758 | ||
|
|
78b842c995 | ||
|
|
26933c2f59 | ||
|
|
72b5dd8658 | ||
|
|
436a0a271e | ||
|
|
529ec85c77 | ||
|
|
364ddd45e8 | ||
|
|
3adde245b7 | ||
|
|
a952ca3ff6 | ||
|
|
f26098e22f | ||
|
|
1247ff2dca | ||
|
|
1dd33988e2 | ||
|
|
c03acca508 | ||
|
|
8ae65d5c8c | ||
|
|
d8fdec16d5 | ||
|
|
12f755c9eb | ||
|
|
63991bbd97 | ||
|
|
26deeea830 | ||
|
|
a694040520 | ||
|
|
524490a409 | ||
|
|
17e0e9d174 | ||
|
|
1dca6a6960 | ||
|
|
c75e1a03f9 | ||
|
|
29575b3712 | ||
|
|
71558e753d | ||
|
|
4f7e64c845 | ||
|
|
ddb8d8fa84 | ||
|
|
2cbf0631a5 | ||
|
|
0f0e20ef81 | ||
|
|
1551ce46a4 | ||
|
|
c76e879574 | ||
|
|
55ba02befb | ||
|
|
7becb19ea0 | ||
|
|
8199ec3803 | ||
|
|
f0e46c5e9e | ||
|
|
71191b7e8e | ||
|
|
00ad3d3c9c | ||
|
|
bd33a48a58 | ||
|
|
fd9c1504da | ||
|
|
057f5a31d1 | ||
|
|
b59ed9c6bc | ||
|
|
efa97af7e2 | ||
|
|
8de26e280e | ||
|
|
796c8a2d63 | ||
|
|
2ff744ae2c | ||
|
|
16796acc84 | ||
|
|
31b4721791 | ||
|
|
35ce94a2f8 | ||
|
|
5f234d4057 | ||
|
|
8f19078c6a | ||
|
|
d110ce4493 | ||
|
|
8db544b4d0 | ||
|
|
c872f07c47 | ||
|
|
d18618f48f | ||
|
|
4ca5e72444 | ||
|
|
657e6d87cc | ||
|
|
21e3a863bb | ||
|
|
e8cee87e85 | ||
|
|
39b4ebfcea | ||
|
|
24fe60faa2 | ||
|
|
748f3e016b | ||
|
|
5e54330e27 | ||
|
|
b05253ceed | ||
|
|
143184e943 | ||
|
|
31fcde876c | ||
|
|
4816646109 | ||
|
|
ec8449e9c6 | ||
|
|
e3f0a88891 | ||
|
|
0a7cbd3342 | ||
|
|
6b219f5af6 | ||
|
|
714630110b | ||
|
|
6bd16a645b | ||
|
|
0d085d9454 | ||
|
|
5c7d098bee | ||
|
|
d403cf018c | ||
|
|
f29f02a73f | ||
|
|
007a630b16 | ||
|
|
2cea98e143 | ||
|
|
563077a47a | ||
|
|
efc32ab639 | ||
|
|
4ceab16893 | ||
|
|
dee71a31e5 | ||
|
|
ffbc21100d | ||
|
|
d863773c81 | ||
|
|
d557544560 | ||
|
|
3633c8690b | ||
|
|
d5775fe988 | ||
|
|
f7ad2f1115 | ||
|
|
e90508103c | ||
|
|
8c6b0c9ecd | ||
|
|
07349ce4df | ||
|
|
95d074cdb2 | ||
|
|
5fe0672260 | ||
|
|
3a30c605b3 | ||
|
|
d898e0eb7f | ||
|
|
52521c937a | ||
|
|
7f08cb5941 | ||
|
|
c875c0dc11 | ||
|
|
6122a79aab | ||
|
|
3f13d78088 | ||
|
|
3c106c89a1 | ||
|
|
dd5a9502e3 | ||
|
|
ef98e3f9e6 | ||
|
|
66c70966cd | ||
|
|
e3fc081499 | ||
|
|
aa1e2edd35 | ||
|
|
091d8e1030 | ||
|
|
9d42c2c286 | ||
|
|
b833d85019 | ||
|
|
cc64a04f61 | ||
|
|
9a815b6c8c | ||
|
|
08671d8771 | ||
|
|
e2b2d48610 | ||
|
|
59da8ec4ec | ||
|
|
256bedb632 | ||
|
|
6f2d1c88b7 | ||
|
|
1979ef5802 | ||
|
|
d6c9711ba8 | ||
|
|
a9b8254e5f | ||
|
|
a43d7e67b4 | ||
|
|
6d30b4a7e3 | ||
|
|
8c4bec6155 | ||
|
|
659af123c3 | ||
|
|
f4c43f0886 | ||
|
|
b54b246071 | ||
|
|
1a00d730eb | ||
|
|
76f40e6449 | ||
|
|
4fdfdf6749 | ||
|
|
2bed2124a4 | ||
|
|
1149e75db2 | ||
|
|
5d90386baa | ||
|
|
c3094b46e9 | ||
|
|
da0ddbf88a | ||
|
|
71c6dd0dcf | ||
|
|
8709e1ebec | ||
|
|
54d817f882 | ||
|
|
74fdfe6b50 | ||
|
|
02a54e01ce | ||
|
|
942adf6179 | ||
|
|
1e01b25e76 | ||
|
|
486b692ddd | ||
|
|
b06e999302 | ||
|
|
80374d4dd9 | ||
|
|
8ac351407e | ||
|
|
a4289d74ac | ||
|
|
1a4e8f7041 | ||
|
|
420762f867 | ||
|
|
e77fd75c44 | ||
|
|
7c67097325 | ||
|
|
afa5b81918 | ||
|
|
e474130c48 | ||
|
|
327b8cee9e | ||
|
|
dd1d4e9c5d | ||
|
|
80c4b27437 | ||
|
|
557deece6f | ||
|
|
081f9368bc | ||
|
|
e71393237e | ||
|
|
4c825554c1 | ||
|
|
2a18b6283b | ||
|
|
d8c4460fe3 | ||
|
|
6f92a21926 | ||
|
|
0c233e70f8 | ||
|
|
a54d4b0e46 | ||
|
|
0bc5f7b235 | ||
|
|
8d553056c0 | ||
|
|
1beb578fde | ||
|
|
a694a26330 | ||
|
|
29c9ff9ba5 | ||
|
|
6f285efb80 | ||
|
|
413990c945 | ||
|
|
a33ec10874 | ||
|
|
c7cfad5d96 | ||
|
|
7a4ad5ccb4 | ||
|
|
b7bd0f77f3 | ||
|
|
d33deb7cbe | ||
|
|
2a3140a814 | ||
|
|
6ec89d885d | ||
|
|
80375cbe2c | ||
|
|
782e3f5164 | ||
|
|
e3858772d0 | ||
|
|
b3ca6362a8 | ||
|
|
d68a0ec383 | ||
|
|
389c707e42 | ||
|
|
9b2488af2a | ||
|
|
29d7c244c5 | ||
|
|
76bbb94be4 | ||
|
|
f9559c39c4 | ||
|
|
24e2151cd6 | ||
|
|
88ede807c4 | ||
|
|
83b93898c2 | ||
|
|
d89553c2d6 | ||
|
|
38441a7d77 | ||
|
|
f63d520496 | ||
|
|
62fd905340 | ||
|
|
3955aefced | ||
|
|
4bb0a82a2b | ||
|
|
4fa5f7b765 | ||
|
|
1189ed7855 | ||
|
|
71198b9e19 | ||
|
|
954e854ccc | ||
|
|
629c33c633 | ||
|
|
653d304290 | ||
|
|
642768c5c7 | ||
|
|
a34998ee2f | ||
|
|
c23a87bc16 | ||
|
|
d186186e1a | ||
|
|
2863e9484a | ||
|
|
c594a23047 | ||
|
|
8a31985e4f | ||
|
|
fc3fd6bb6b | ||
|
|
41c13ba71d | ||
|
|
36c5b188b5 | ||
|
|
1e29fa8865 | ||
|
|
e74a682b0f | ||
|
|
2b606d20e2 | ||
|
|
3ac750ec07 | ||
|
|
aa2d3e2ee1 | ||
|
|
7d628eaa3d |
47
.env.example
47
.env.example
@@ -14,6 +14,14 @@
|
||||
# LLM_MODEL is no longer read from .env — this line is kept for reference only.
|
||||
# LLM_MODEL=anthropic/claude-opus-4.6
|
||||
|
||||
# =============================================================================
|
||||
# LLM PROVIDER (NovitaAI)
|
||||
# =============================================================================
|
||||
# NovitaAI — 90+ models, pay-per-use
|
||||
# Get your key at: https://novita.ai/settings/key-management
|
||||
# NOVITA_API_KEY=
|
||||
# NOVITA_BASE_URL=https://api.novita.ai/openai/v1 # Override default base URL
|
||||
|
||||
# =============================================================================
|
||||
# LLM PROVIDER (Google AI Studio / Gemini)
|
||||
# =============================================================================
|
||||
@@ -273,6 +281,27 @@ BROWSER_SESSION_TIMEOUT=300
|
||||
# Browser sessions are automatically closed after this period of no activity
|
||||
BROWSER_INACTIVITY_TIMEOUT=120
|
||||
|
||||
# Extra Chromium launch flags passed to agent-browser, comma- or newline-separated.
|
||||
# Hermes auto-injects "--no-sandbox,--disable-dev-shm-usage" when it detects root
|
||||
# or AppArmor-restricted unprivileged user namespaces (Ubuntu 23.10+, DGX Spark,
|
||||
# many container images), so leave this unset unless you need extra flags.
|
||||
# Setting this disables the auto-injection.
|
||||
# AGENT_BROWSER_ARGS=--no-sandbox
|
||||
|
||||
# Camofox local anti-detection browser (Camoufox-based Firefox).
|
||||
# Set CAMOFOX_URL to route the browser tools through a local Camofox server
|
||||
# instead of agent-browser/Browserbase. See docs/user-guide/features/browser.md.
|
||||
# CAMOFOX_URL=http://localhost:9377
|
||||
|
||||
# Externally managed Camofox sessions — when another app owns the visible
|
||||
# Camofox browser, set these so Hermes shares the same userId/profile instead
|
||||
# of creating its own isolated session.
|
||||
# CAMOFOX_USER_ID=
|
||||
# CAMOFOX_SESSION_KEY=
|
||||
# Set to true to reuse an already-open Camofox tab for this identity before
|
||||
# creating a new one (useful for gateway restarts).
|
||||
# CAMOFOX_ADOPT_EXISTING_TAB=false
|
||||
|
||||
# =============================================================================
|
||||
# SESSION LOGGING
|
||||
# =============================================================================
|
||||
@@ -365,24 +394,6 @@ IMAGE_TOOLS_DEBUG=false
|
||||
# CONTEXT_COMPRESSION_THRESHOLD=0.85 # Compress at 85% of context limit
|
||||
# Model is set via compression.summary_model in config.yaml (default: google/gemini-3-flash-preview)
|
||||
|
||||
# =============================================================================
|
||||
# RL TRAINING (Tinker + Atropos)
|
||||
# =============================================================================
|
||||
# Run reinforcement learning training on language models using the Tinker API.
|
||||
# Requires the rl-server to be running (from tinker-atropos package).
|
||||
|
||||
# Tinker API Key - RL training service
|
||||
# Get at: https://tinker-console.thinkingmachines.ai/keys
|
||||
# TINKER_API_KEY=
|
||||
|
||||
# Weights & Biases API Key - Experiment tracking and metrics
|
||||
# Get at: https://wandb.ai/authorize
|
||||
# WANDB_API_KEY=
|
||||
|
||||
# RL API Server URL (default: http://localhost:8080)
|
||||
# Change if running the rl-server on a different host/port
|
||||
# RL_API_URL=http://localhost:8080
|
||||
|
||||
# =============================================================================
|
||||
# SKILLS HUB (GitHub integration for skill search/install/publish)
|
||||
# =============================================================================
|
||||
|
||||
195
.github/workflows/docker-publish.yml
vendored
195
.github/workflows/docker-publish.yml
vendored
@@ -28,9 +28,10 @@ permissions:
|
||||
contents: read
|
||||
|
||||
# Concurrency: push/release runs are NEVER cancelled so every merge gets its
|
||||
# own SHA-tagged image; :latest is guarded separately by the move-latest job.
|
||||
# PR runs reuse a PR-scoped group with cancel-in-progress: true so rapid
|
||||
# pushes to the same PR collapse to the latest commit.
|
||||
# own SHA-tagged image; :main and :latest are guarded separately by the
|
||||
# move-main and move-latest jobs. PR runs reuse a PR-scoped group with
|
||||
# cancel-in-progress: true so rapid pushes to the same PR collapse to the
|
||||
# latest commit.
|
||||
concurrency:
|
||||
group: docker-${{ github.event.pull_request.number || github.ref }}
|
||||
cancel-in-progress: ${{ github.event_name == 'pull_request' }}
|
||||
@@ -91,10 +92,10 @@ jobs:
|
||||
# pattern for multi-runner multi-platform builds.
|
||||
#
|
||||
# We apply the OCI revision label here (and again on arm64) because
|
||||
# the move-latest job reads it off the linux/amd64 sub-manifest config
|
||||
# of `:latest` to decide whether it's safe to advance. The label must
|
||||
# be on each per-arch image — manifest lists themselves don't carry
|
||||
# image config labels.
|
||||
# the move-main / move-latest jobs read it off the linux/amd64
|
||||
# sub-manifest config of the floating tag to decide whether it's safe
|
||||
# to advance. The label must be on each per-arch image — manifest
|
||||
# lists themselves don't carry image config labels.
|
||||
- name: Push amd64 by digest
|
||||
id: push
|
||||
if: github.event_name == 'push' && github.ref == 'refs/heads/main' || github.event_name == 'release'
|
||||
@@ -217,6 +218,8 @@ jobs:
|
||||
timeout-minutes: 10
|
||||
outputs:
|
||||
pushed_sha_tag: ${{ steps.mark_pushed.outputs.pushed }}
|
||||
pushed_release_tag: ${{ steps.mark_release_pushed.outputs.pushed }}
|
||||
release_tag: ${{ steps.tag.outputs.tag }}
|
||||
steps:
|
||||
- name: Download digests
|
||||
uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4
|
||||
@@ -271,33 +274,43 @@ jobs:
|
||||
IMAGE_NAME: ${{ env.IMAGE_NAME }}
|
||||
TAG: ${{ steps.tag.outputs.tag }}
|
||||
|
||||
# Signal to move-latest that the SHA tag is live. Only on main pushes;
|
||||
# releases don't trigger move-latest (they use their own release tag).
|
||||
# Signal to move-main that the SHA tag is live. Only on main pushes;
|
||||
# releases set pushed_release_tag instead.
|
||||
- name: Mark SHA tag pushed
|
||||
id: mark_pushed
|
||||
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
||||
run: echo "pushed=true" >> "$GITHUB_OUTPUT"
|
||||
|
||||
# Signal to move-latest that the release tag is live.
|
||||
- name: Mark release tag pushed
|
||||
id: mark_release_pushed
|
||||
if: github.event_name == 'release'
|
||||
run: echo "pushed=true" >> "$GITHUB_OUTPUT"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Move :latest to point at the SHA tag the merge job pushed.
|
||||
# Move :main to point at the SHA tag the merge job pushed.
|
||||
#
|
||||
# :main is the floating tag that tracks the tip of the main branch. Every
|
||||
# merge to main retags :main forward. Users who want "latest dev build"
|
||||
# pull :main; users who want stable releases pull :latest.
|
||||
#
|
||||
# The real serialization guarantee comes from the top-level concurrency
|
||||
# group (`docker-${{ github.ref }}` with `cancel-in-progress: false`),
|
||||
# which ensures at most one workflow run for this ref executes at a time.
|
||||
# That means two move-latest steps for the same ref cannot overlap.
|
||||
# That means two move-main steps for the same ref cannot overlap.
|
||||
#
|
||||
# This job has its own concurrency group as defense-in-depth: if the
|
||||
# top-level group is ever loosened, queued move-latests will run serially
|
||||
# top-level group is ever loosened, queued move-mains will run serially
|
||||
# in arrival order, each one running the ancestor check below and either
|
||||
# advancing :latest or skipping. `cancel-in-progress: false` matches the
|
||||
# advancing :main or skipping. `cancel-in-progress: false` matches the
|
||||
# top-level setting — we don't want rapid pushes to cancel a queued
|
||||
# move-latest, because the ancestor check is the real safety mechanism
|
||||
# and queueing is cheap (move-latest is a ~30s registry op).
|
||||
# move-main, because the ancestor check is the real safety mechanism
|
||||
# and queueing is cheap (move-main is a ~30s registry op).
|
||||
#
|
||||
# Combined with the ancestor check, this means :latest only ever moves
|
||||
# Combined with the ancestor check, this means :main only ever moves
|
||||
# forward in git history.
|
||||
# ---------------------------------------------------------------------------
|
||||
move-latest:
|
||||
move-main:
|
||||
if: |
|
||||
github.repository == 'NousResearch/hermes-agent'
|
||||
&& github.event_name == 'push'
|
||||
@@ -307,7 +320,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
concurrency:
|
||||
group: docker-move-latest-${{ github.ref }}
|
||||
group: docker-move-main-${{ github.ref }}
|
||||
cancel-in-progress: false
|
||||
steps:
|
||||
- name: Checkout code
|
||||
@@ -324,13 +337,13 @@ jobs:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
# Read the git revision label off the current :latest manifest, then
|
||||
# Read the git revision label off the current :main manifest, then
|
||||
# use `git merge-base --is-ancestor` to check whether our commit is a
|
||||
# descendant of it. If :latest doesn't exist yet, or its label is
|
||||
# descendant of it. If :main doesn't exist yet, or its label is
|
||||
# missing, we treat that as "safe to publish". If another run already
|
||||
# advanced :latest past us (or diverged), we skip and leave it alone.
|
||||
- name: Decide whether to move :latest
|
||||
id: latest_check
|
||||
# advanced :main past us (or diverged), we skip and leave it alone.
|
||||
- name: Decide whether to move :main
|
||||
id: main_check
|
||||
run: |
|
||||
set -euo pipefail
|
||||
image=nousresearch/hermes-agent
|
||||
@@ -338,6 +351,119 @@ jobs:
|
||||
# Pull the JSON for the linux/amd64 sub-manifest's config and extract
|
||||
# the OCI revision label with jq — Go template field access can't
|
||||
# handle dots in map keys, so using json+jq is the robust route.
|
||||
image_json=$(
|
||||
docker buildx imagetools inspect "${image}:main" \
|
||||
--format '{{ json (index .Image "linux/amd64") }}' \
|
||||
2>/dev/null || true
|
||||
)
|
||||
|
||||
if [ -z "${image_json}" ]; then
|
||||
echo "No existing :main (or inspect failed) — safe to publish."
|
||||
echo "push_main=true" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
current_sha=$(
|
||||
printf '%s' "${image_json}" \
|
||||
| jq -r '.config.Labels."org.opencontainers.image.revision" // ""'
|
||||
)
|
||||
|
||||
if [ -z "${current_sha}" ]; then
|
||||
echo "Registry :main has no revision label — safe to publish."
|
||||
echo "push_main=true" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Registry :main is at ${current_sha}"
|
||||
echo "This run is at ${GITHUB_SHA}"
|
||||
|
||||
if [ "${current_sha}" = "${GITHUB_SHA}" ]; then
|
||||
echo ":main already points at our SHA — nothing to do."
|
||||
echo "push_main=false" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Make sure we have the :main commit locally for merge-base.
|
||||
if ! git cat-file -e "${current_sha}^{commit}" 2>/dev/null; then
|
||||
git fetch --no-tags --prune origin \
|
||||
"+refs/heads/main:refs/remotes/origin/main" \
|
||||
|| true
|
||||
fi
|
||||
|
||||
if ! git cat-file -e "${current_sha}^{commit}" 2>/dev/null; then
|
||||
echo "Registry :main points at an unknown commit (${current_sha}); refusing to overwrite."
|
||||
echo "push_main=false" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Our SHA must be a descendant of the current :main to be safe.
|
||||
if git merge-base --is-ancestor "${current_sha}" "${GITHUB_SHA}"; then
|
||||
echo "Our commit is a descendant of :main — safe to advance."
|
||||
echo "push_main=true" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "Another run advanced :main past us (or diverged) — leaving it alone."
|
||||
echo "push_main=false" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
# Retag the already-pushed SHA manifest as :main. This is a registry-
|
||||
# side operation — no rebuild, no layer re-push — so it's quick and
|
||||
# atomic per-tag. The ancestor check above plus the cancel-in-progress
|
||||
# concurrency on this job together guarantee we only ever move :main
|
||||
# forward in git history.
|
||||
- name: Move :main to this SHA
|
||||
if: steps.main_check.outputs.push_main == 'true'
|
||||
run: |
|
||||
set -euo pipefail
|
||||
image=nousresearch/hermes-agent
|
||||
docker buildx imagetools create \
|
||||
--tag "${image}:main" \
|
||||
"${image}:sha-${GITHUB_SHA}"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Move :latest to point at the release tag the merge job pushed.
|
||||
#
|
||||
# :latest is the floating tag that tracks the most recent stable release.
|
||||
# Only `release: published` events advance it — never main pushes.
|
||||
#
|
||||
# We still run an ancestor check against the existing :latest so that a
|
||||
# backport release on an older branch (e.g. patching v1.1.5 after v1.2.3
|
||||
# is out) doesn't drag :latest backwards. The check is the same shape as
|
||||
# move-main: read the OCI revision label off the current :latest, look up
|
||||
# that commit in git, and only advance if our release commit is a strict
|
||||
# descendant.
|
||||
# ---------------------------------------------------------------------------
|
||||
move-latest:
|
||||
if: |
|
||||
github.repository == 'NousResearch/hermes-agent'
|
||||
&& github.event_name == 'release'
|
||||
&& needs.merge.outputs.pushed_release_tag == 'true'
|
||||
needs: merge
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
concurrency:
|
||||
group: docker-move-latest
|
||||
cancel-in-progress: false
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
with:
|
||||
fetch-depth: 1000
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Decide whether to move :latest
|
||||
id: latest_check
|
||||
run: |
|
||||
set -euo pipefail
|
||||
image=nousresearch/hermes-agent
|
||||
|
||||
image_json=$(
|
||||
docker buildx imagetools inspect "${image}:latest" \
|
||||
--format '{{ json (index .Image "linux/amd64") }}' \
|
||||
@@ -362,7 +488,7 @@ jobs:
|
||||
fi
|
||||
|
||||
echo "Registry :latest is at ${current_sha}"
|
||||
echo "This run is at ${GITHUB_SHA}"
|
||||
echo "This release is at ${GITHUB_SHA}"
|
||||
|
||||
if [ "${current_sha}" = "${GITHUB_SHA}" ]; then
|
||||
echo ":latest already points at our SHA — nothing to do."
|
||||
@@ -371,6 +497,7 @@ jobs:
|
||||
fi
|
||||
|
||||
# Make sure we have the :latest commit locally for merge-base.
|
||||
# Releases can be cut from any branch, so fetch broadly.
|
||||
if ! git cat-file -e "${current_sha}^{commit}" 2>/dev/null; then
|
||||
git fetch --no-tags --prune origin \
|
||||
"+refs/heads/main:refs/remotes/origin/main" \
|
||||
@@ -383,25 +510,25 @@ jobs:
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Our SHA must be a descendant of the current :latest to be safe.
|
||||
# Our release SHA must be a descendant of the current :latest.
|
||||
# Backport releases on older branches won't satisfy this and will
|
||||
# be left alone — :latest stays on the newer release.
|
||||
if git merge-base --is-ancestor "${current_sha}" "${GITHUB_SHA}"; then
|
||||
echo "Our commit is a descendant of :latest — safe to advance."
|
||||
echo "Our release commit is a descendant of :latest — safe to advance."
|
||||
echo "push_latest=true" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "Another run advanced :latest past us (or diverged) — leaving it alone."
|
||||
echo "Existing :latest is newer than this release (likely a backport) — leaving it alone."
|
||||
echo "push_latest=false" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
# Retag the already-pushed SHA manifest as :latest. This is a registry-
|
||||
# side operation — no rebuild, no layer re-push — so it's quick and
|
||||
# atomic per-tag. The ancestor check above plus the cancel-in-progress
|
||||
# concurrency on this job together guarantee we only ever move :latest
|
||||
# forward in git history.
|
||||
- name: Move :latest to this SHA
|
||||
# Retag the already-pushed release manifest as :latest.
|
||||
- name: Move :latest to this release tag
|
||||
if: steps.latest_check.outputs.push_latest == 'true'
|
||||
env:
|
||||
RELEASE_TAG: ${{ needs.merge.outputs.release_tag }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
image=nousresearch/hermes-agent
|
||||
docker buildx imagetools create \
|
||||
--tag "${image}:latest" \
|
||||
"${image}:sha-${GITHUB_SHA}"
|
||||
"${image}:${RELEASE_TAG}"
|
||||
|
||||
66
.github/workflows/supply-chain-audit.yml
vendored
66
.github/workflows/supply-chain-audit.yml
vendored
@@ -11,6 +11,7 @@ on:
|
||||
- '**/sitecustomize.py'
|
||||
- '**/usercustomize.py'
|
||||
- '**/__init__.pth'
|
||||
- 'pyproject.toml'
|
||||
|
||||
permissions:
|
||||
pull-requests: write
|
||||
@@ -137,3 +138,68 @@ jobs:
|
||||
run: |
|
||||
echo "::error::CRITICAL supply chain risk patterns detected in this PR. See the PR comment for details."
|
||||
exit 1
|
||||
|
||||
dep-bounds:
|
||||
name: Check PyPI dependency upper bounds
|
||||
runs-on: ubuntu-latest
|
||||
if: contains(github.event.pull_request.changed_files_url, 'pyproject.toml') || true
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Check for unbounded PyPI deps
|
||||
id: bounds
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
BASE="${{ github.event.pull_request.base.sha }}"
|
||||
HEAD="${{ github.event.pull_request.head.sha }}"
|
||||
|
||||
# Only check added lines in pyproject.toml
|
||||
ADDED=$(git diff "$BASE".."$HEAD" -- pyproject.toml | grep '^+' | grep -v '^+++' || true)
|
||||
|
||||
if [ -z "$ADDED" ]; then
|
||||
echo "found=false" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Match PyPI dep specs that have >= but no < ceiling.
|
||||
# Pattern: "package>=version" without a following ",<" bound.
|
||||
# Excludes git+ URLs (which use commit SHAs) and comments.
|
||||
UNBOUNDED=$(echo "$ADDED" | grep -oE '"[a-zA-Z0-9_-]+(\[[^\]]*\])?>=[ 0-9.]+"' | grep -v ',<' || true)
|
||||
|
||||
if [ -n "$UNBOUNDED" ]; then
|
||||
echo "found=true" >> "$GITHUB_OUTPUT"
|
||||
echo "$UNBOUNDED" > /tmp/unbounded.txt
|
||||
else
|
||||
echo "found=false" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: Post unbounded dep warning
|
||||
if: steps.bounds.outputs.found == 'true'
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
BODY="## ⚠️ Unbounded PyPI Dependency Detected
|
||||
|
||||
This PR adds PyPI dependencies without a \`<next_major\` upper bound. Per our [supply chain policy](../blob/main/CONTRIBUTING.md#dependency-pinning-policy-supply-chain-hardening), all PyPI deps must be pinned as \`>=floor,<next_major\`.
|
||||
|
||||
**Unbounded specs found:**
|
||||
\`\`\`
|
||||
$(cat /tmp/unbounded.txt)
|
||||
\`\`\`
|
||||
|
||||
**Fix:** Add an upper bound, e.g. \`\"package>=1.2.0,<2\"\`
|
||||
|
||||
---
|
||||
*See PR #2810 and CONTRIBUTING.md for the full policy rationale.*"
|
||||
|
||||
gh pr comment "${{ github.event.pull_request.number }}" --body "$BODY" || echo "::warning::Could not post PR comment (expected for fork PRs)"
|
||||
|
||||
- name: Fail on unbounded deps
|
||||
if: steps.bounds.outputs.found == 'true'
|
||||
run: |
|
||||
echo "::error::PyPI dependencies without upper bounds detected. Add <next_major ceiling per CONTRIBUTING.md policy."
|
||||
exit 1
|
||||
|
||||
5
.github/workflows/tests.yml
vendored
5
.github/workflows/tests.yml
vendored
@@ -55,11 +55,14 @@ jobs:
|
||||
|
||||
e2e:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
|
||||
- name: Install system dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y ripgrep
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5
|
||||
|
||||
|
||||
137
.github/workflows/upload_to_pypi.yml
vendored
Normal file
137
.github/workflows/upload_to_pypi.yml
vendored
Normal file
@@ -0,0 +1,137 @@
|
||||
name: Publish to PyPI
|
||||
|
||||
# Triggered by CalVer tag pushes from scripts/release.py (e.g. v2026.5.15)
|
||||
# Can also be triggered manually from the Actions tab as an escape hatch.
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v20*' # CalVer tags: v2026.5.15, v2026.5.15.2, etc.
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
confirm_tag:
|
||||
description: 'Tag to publish (e.g. v2026.5.15). Must already exist.'
|
||||
required: true
|
||||
type: string
|
||||
|
||||
# Restrict default token to read-only; each job escalates as needed.
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
# Prevent overlapping publishes (e.g. two same-day tags pushed quickly).
|
||||
concurrency:
|
||||
group: pypi-publish
|
||||
cancel-in-progress: false
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Build distribution 📦
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
# On workflow_dispatch, check out the confirmed tag.
|
||||
ref: ${{ inputs.confirm_tag || github.ref }}
|
||||
fetch-tags: true
|
||||
|
||||
- name: Validate tag exists
|
||||
if: github.event_name == 'workflow_dispatch'
|
||||
run: |
|
||||
if ! git tag -l "${{ inputs.confirm_tag }}" | grep -q .; then
|
||||
echo "::error::Tag '${{ inputs.confirm_tag }}' does not exist in the repo"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
|
||||
with:
|
||||
python-version: '3.13'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@d0cc045d04ccac9d8b7881df0226f9e82c39688e # v6
|
||||
|
||||
- name: Build wheel and sdist
|
||||
run: uv build --sdist --wheel
|
||||
|
||||
- name: Upload distribution artifacts
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4
|
||||
with:
|
||||
name: python-package-distributions
|
||||
path: dist/
|
||||
|
||||
publish:
|
||||
name: Publish to PyPI
|
||||
needs: build
|
||||
runs-on: ubuntu-latest
|
||||
environment:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/hermes-agent
|
||||
permissions:
|
||||
id-token: write # OIDC trusted publishing
|
||||
|
||||
steps:
|
||||
- name: Download distribution artifacts
|
||||
uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4
|
||||
with:
|
||||
name: python-package-distributions
|
||||
path: dist/
|
||||
|
||||
- name: Publish to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@cef221092ed1bacb1cc03d23a2d87d1d172e277b # v1.14.0
|
||||
with:
|
||||
skip-existing: true
|
||||
|
||||
sign:
|
||||
name: Sign and attach to GitHub Release
|
||||
# Only runs on tag pushes — release.py creates the GitHub Release,
|
||||
# and workflow_dispatch won't have a matching release to attach to.
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
needs: publish
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write # attach assets to the existing release
|
||||
id-token: write # sigstore signing
|
||||
|
||||
steps:
|
||||
- name: Download distribution artifacts
|
||||
uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4
|
||||
with:
|
||||
name: python-package-distributions
|
||||
path: dist/
|
||||
|
||||
- name: Wait for GitHub Release to exist
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ github.token }}
|
||||
# release.py creates the GitHub Release after pushing the tag,
|
||||
# but this workflow starts from the tag push — wait for it.
|
||||
run: |
|
||||
for i in $(seq 1 30); do
|
||||
if gh release view "$GITHUB_REF_NAME" --repo "$GITHUB_REPOSITORY" >/dev/null 2>&1; then
|
||||
echo "Release $GITHUB_REF_NAME found"
|
||||
exit 0
|
||||
fi
|
||||
echo "Waiting for release... ($i/30)"
|
||||
sleep 10
|
||||
done
|
||||
echo "::warning::Release $GITHUB_REF_NAME not found after 5 minutes — skipping signature upload"
|
||||
echo "skip_sign=true" >> "$GITHUB_ENV"
|
||||
|
||||
- name: Sign with Sigstore
|
||||
if: env.skip_sign != 'true'
|
||||
uses: sigstore/gh-action-sigstore-python@f514d46b907ebcd5bedc05145c03b69c1edd8b46 # v3.0.0
|
||||
with:
|
||||
inputs: >-
|
||||
./dist/*.tar.gz
|
||||
./dist/*.whl
|
||||
|
||||
- name: Attach signed artifacts to GitHub Release
|
||||
if: env.skip_sign != 'true'
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ github.token }}
|
||||
# release.py already created the GitHub Release — just upload
|
||||
# the Sigstore signatures alongside the existing assets.
|
||||
run: >-
|
||||
gh release upload
|
||||
"$GITHUB_REF_NAME" dist/*.sigstore.json
|
||||
--repo "$GITHUB_REPOSITORY"
|
||||
--clobber
|
||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -1,3 +0,0 @@
|
||||
[submodule "tinker-atropos"]
|
||||
path = tinker-atropos
|
||||
url = https://github.com/nousresearch/tinker-atropos
|
||||
115
AGENTS.md
115
AGENTS.md
@@ -56,7 +56,6 @@ hermes-agent/
|
||||
├── tui_gateway/ # Python JSON-RPC backend for the TUI
|
||||
├── acp_adapter/ # ACP server (VS Code / Zed / JetBrains integration)
|
||||
├── cron/ # Scheduler — jobs.py, scheduler.py
|
||||
├── environments/ # RL training environments (Atropos)
|
||||
├── scripts/ # run_tests.sh, release.py, auxiliary scripts
|
||||
├── website/ # Docusaurus docs site
|
||||
└── tests/ # Pytest suite (~17k tests across ~900 files as of May 2026)
|
||||
@@ -309,6 +308,29 @@ The registry handles schema collection, dispatch, availability checking, and err
|
||||
|
||||
---
|
||||
|
||||
## Dependency Pinning Policy
|
||||
|
||||
All dependencies must have upper bounds to limit supply-chain attack surface.
|
||||
This policy was established after the litellm compromise (PR #2796, #2810) and
|
||||
reinforced after the Mini Shai-Hulud worm campaign (May 2026).
|
||||
|
||||
| Source type | Treatment | Example |
|
||||
|---|---|---|
|
||||
| PyPI package | `>=floor,<next_major` | `"httpx>=0.28.1,<1"` |
|
||||
| Git URL | Commit SHA | `git+https://...@<40-char-sha>` |
|
||||
| GitHub Actions | Commit SHA + comment | `uses: actions/checkout@<sha> # v4` |
|
||||
| CI-only pip | `==exact` | `pyyaml==6.0.2` |
|
||||
|
||||
**When adding a new dependency to `pyproject.toml`:**
|
||||
1. Pin to `>=current_version,<next_major` for post-1.0 (e.g. `>=1.5.0,<2`).
|
||||
2. For pre-1.0 packages, use `<0.(current_minor + 2)` (e.g. `>=0.29,<0.32`).
|
||||
3. Never commit a bare `>=X.Y.Z` without a ceiling — CI and reviewers will reject it.
|
||||
4. Run `uv lock` to regenerate `uv.lock` with hashes.
|
||||
|
||||
Reference: #2810 (bounds pass), #9801 (SHA pinning + audit CI).
|
||||
|
||||
---
|
||||
|
||||
## Adding Configuration
|
||||
|
||||
### config.yaml options:
|
||||
@@ -513,6 +535,17 @@ generic plugin surface (new hook, new ctx method) — never hardcode
|
||||
plugin-specific logic into core. PR #5295 removed 95 lines of hardcoded
|
||||
honcho argparse from `main.py` for exactly this reason.
|
||||
|
||||
**No new in-tree memory providers (policy, May 2026):** the set of
|
||||
built-in memory providers under `plugins/memory/` is closed. New memory
|
||||
backends must ship as **standalone plugin repos** that users install
|
||||
into `~/.hermes/plugins/` (or via pip entry points) — they implement
|
||||
the same `MemoryProvider` ABC, register through the same discovery
|
||||
path, and integrate via `hermes memory setup` / `post_setup()` without
|
||||
landing in this tree. PRs that add a new directory under
|
||||
`plugins/memory/` will be closed with a pointer to publish the
|
||||
provider as its own repo. Existing in-tree providers stay; bug fixes
|
||||
to them are welcome.
|
||||
|
||||
### Model-provider plugins (`plugins/model-providers/<name>/`)
|
||||
|
||||
Every inference backend (openrouter, anthropic, gmi, deepseek, nvidia, …)
|
||||
@@ -580,6 +613,86 @@ during setup, injected at load time).
|
||||
Top-level `tags:` and `category:` are also accepted and mirrored from
|
||||
`metadata.hermes.*` by the loader.
|
||||
|
||||
### Skill authoring standards (HARDLINE)
|
||||
|
||||
Every new or modernized skill — bundled, optional, or contributed —
|
||||
must meet these standards before merge. Reviewers reject PRs that
|
||||
violate them.
|
||||
|
||||
1. **`description` ≤ 60 characters, one sentence, ends with a period.**
|
||||
Long descriptions bloat skill listings and dilute the model's
|
||||
attention when many skills are loaded. State the capability, not
|
||||
the implementation. No marketing words ("powerful",
|
||||
"comprehensive", "seamless", "advanced"). Don't repeat the skill
|
||||
name. Verify with:
|
||||
```python
|
||||
import re, pathlib
|
||||
m = re.search(r'^description: (.*)$',
|
||||
pathlib.Path('skills/<cat>/<name>/SKILL.md').read_text(),
|
||||
re.MULTILINE)
|
||||
assert len(m.group(1)) <= 60, len(m.group(1))
|
||||
```
|
||||
|
||||
2. **Tools referenced in SKILL.md prose must be native Hermes tools or
|
||||
MCP servers the skill explicitly expects.** When the skill needs a
|
||||
capability, point at the proper tool by name in backticks
|
||||
(`` `terminal` ``, `` `web_extract` ``, `` `read_file` ``,
|
||||
`` `patch` ``, `` `search_files` ``, `` `vision_analyze` ``,
|
||||
`` `browser_navigate` ``, `` `delegate_task` ``, etc.). Do NOT
|
||||
name shell utilities the agent already has wrapped — `grep` →
|
||||
`search_files`, `cat`/`head`/`tail` → `read_file`, `sed`/`awk` →
|
||||
`patch`, `find`/`ls` → `search_files target='files'`. If the skill
|
||||
depends on an MCP server, name the MCP server and document the
|
||||
expected setup in `## Prerequisites`. Anything else (third-party
|
||||
CLIs, shell pipelines, etc.) is fair game inside script files but
|
||||
should not be the headline interaction surface in the prose.
|
||||
|
||||
3. **`platforms:` gating audited against actual script imports.**
|
||||
Skills that use POSIX-only primitives (`fcntl`, `termios`,
|
||||
`os.setsid`, `os.kill(pid, 0)` for liveness, `/proc`, `/tmp`
|
||||
hardcoded, `signal.SIGKILL`, bash heredocs, `osascript`, `apt`,
|
||||
`systemctl`) must declare their supported platforms. Default
|
||||
posture: try to fix it cross-platform first — `tempfile.gettempdir`,
|
||||
`pathlib.Path`, `psutil.pid_exists`, Python-level filtering instead
|
||||
of `grep`. Gate to a narrower set only when the dependency is
|
||||
genuinely platform-bound.
|
||||
|
||||
4. **`author` credits the human contributor first.** For external
|
||||
contributions, the contributor's real name + GitHub handle goes
|
||||
first; "Hermes Agent" is the secondary collaborator. If the
|
||||
contributor's commit shows "Hermes Agent" as author (because they
|
||||
used Hermes to draft the skill), replace it with their actual name
|
||||
— credit the human, not the tool.
|
||||
|
||||
5. **SKILL.md body uses the modern section order.** `# <Skill> Skill`
|
||||
title, 2-3 sentence intro stating what it does and doesn't do,
|
||||
`## When to Use`, `## Prerequisites`, `## How to Run`,
|
||||
`## Quick Reference`, `## Procedure`, `## Pitfalls`,
|
||||
`## Verification`. Target ~200 lines for a complex skill,
|
||||
~100 lines for a simple one. Cut redundant intro fluff, marketing
|
||||
prose, and re-explanations of env vars already in
|
||||
`## Prerequisites`.
|
||||
|
||||
6. **Scripts go in `scripts/`, references in `references/`,
|
||||
templates in `templates/`.** Don't expect the model to inline-write
|
||||
parsers, XML walkers, or non-trivial logic every call — ship a
|
||||
helper script. Reference it from SKILL.md by path relative to the
|
||||
skill directory.
|
||||
|
||||
7. **Tests live at `tests/skills/test_<skill>_skill.py`** and use only
|
||||
stdlib + pytest + `unittest.mock`. No live network calls. Run via
|
||||
`scripts/run_tests.sh tests/skills/test_<skill>_skill.py -q`.
|
||||
|
||||
8. **`.env.example` additions are isolated to a clearly delimited
|
||||
block.** Don't touch the surrounding file — contributor-supplied
|
||||
`.env.example` versions are usually stale and edits outside the
|
||||
skill's own block must be dropped during salvage.
|
||||
|
||||
The full salvage / modernization checklist for external skill PRs
|
||||
lives in the `hermes-agent-dev` skill at
|
||||
`references/new-skill-pr-salvage.md` — load it before polishing
|
||||
contributor skill PRs.
|
||||
|
||||
---
|
||||
|
||||
## Toolsets
|
||||
|
||||
115
CONTRIBUTING.md
115
CONTRIBUTING.md
@@ -49,6 +49,24 @@ If your skill is specialized, community-contributed, or niche, it's better suite
|
||||
|
||||
---
|
||||
|
||||
## Memory Providers: Ship as a Standalone Plugin
|
||||
|
||||
**We are no longer accepting new memory providers into this repo.** The set of built-in providers under `plugins/memory/` (honcho, mem0, supermemory, byterover, hindsight, holographic, openviking, retaindb) is closed. If you want to add a new memory backend, publish it as a **standalone plugin repo** that users install into `~/.hermes/plugins/` (or via a pip entry point).
|
||||
|
||||
Standalone memory plugins:
|
||||
|
||||
- Implement the same `MemoryProvider` ABC (`agent/memory_provider.py`) — `sync_turn`, `prefetch`, `shutdown`, and optionally `post_setup(hermes_home, config)` for setup-wizard integration
|
||||
- Use the same discovery system — `discover_memory_providers()` picks them up from user/project plugin directories and pip entry points
|
||||
- Integrate with `hermes memory setup` via `post_setup()` — no need to touch core code
|
||||
- Can register their own CLI subcommands via `register_cli(subparser)` in a `cli.py` file
|
||||
- Get all the same lifecycle hooks and config plumbing as in-tree providers
|
||||
|
||||
PRs that add a new directory under `plugins/memory/` will be closed with a pointer to publish the provider as its own repo. Existing in-tree providers stay; bug fixes to them are welcome.
|
||||
|
||||
This isn't a quality bar — it's a coupling-and-maintenance decision. Memory providers are the most common plugin type and they shouldn't all live in this tree.
|
||||
|
||||
---
|
||||
|
||||
## Development Setup
|
||||
|
||||
### Prerequisites
|
||||
@@ -73,9 +91,6 @@ export VIRTUAL_ENV="$(pwd)/venv"
|
||||
# Install with all extras (messaging, cron, CLI menus, dev tools)
|
||||
uv pip install -e ".[all,dev]"
|
||||
|
||||
# Optional: RL training submodule
|
||||
# git submodule update --init tinker-atropos && uv pip install -e "./tinker-atropos"
|
||||
|
||||
# Optional: browser tools
|
||||
npm install
|
||||
```
|
||||
@@ -178,7 +193,6 @@ hermes-agent/
|
||||
│
|
||||
├── skills/ # Bundled skills (copied to ~/.hermes/skills/ on install)
|
||||
├── optional-skills/ # Official optional skills (discoverable via hub, not activated by default)
|
||||
├── environments/ # RL training environments (Atropos integration)
|
||||
├── tests/ # Test suite
|
||||
├── website/ # Documentation site (hermes-agent.nousresearch.com)
|
||||
│
|
||||
@@ -461,6 +475,58 @@ Gateway and messaging sessions never collect secrets in-band; they instruct the
|
||||
|
||||
See `skills/gifs/gif-search/` and `skills/email/himalaya/` for examples.
|
||||
|
||||
### Skill authoring standards (HARDLINE)
|
||||
|
||||
Every new or modernized skill — bundled, optional, or contributed — must meet these standards before merge. Reviewers reject PRs that violate them.
|
||||
|
||||
1. **`description` ≤ 60 characters, one sentence, ends with a period.** Long descriptions bloat the skill listing UI and dilute the model's attention when many skills are loaded. State the capability, not the implementation. No marketing words ("powerful", "comprehensive", "seamless", "advanced"). Don't repeat the skill name. Verify with:
|
||||
```python
|
||||
import re, pathlib
|
||||
m = re.search(r'^description: (.*)$',
|
||||
pathlib.Path('skills/<cat>/<name>/SKILL.md').read_text(),
|
||||
re.MULTILINE)
|
||||
assert len(m.group(1)) <= 60, len(m.group(1))
|
||||
```
|
||||
|
||||
Good: `Search arXiv papers by keyword, author, category, or ID.`
|
||||
Bad: `A powerful and comprehensive skill that allows the agent to search arXiv for relevant academic papers using various criteria including keywords, authors, and categories.`
|
||||
|
||||
2. **Tools referenced in SKILL.md prose must be native Hermes tools or MCP servers the skill explicitly expects.** When the skill needs a capability, point at the proper tool by name in backticks: `` `terminal` ``, `` `web_extract` ``, `` `web_search` ``, `` `read_file` ``, `` `write_file` ``, `` `patch` ``, `` `search_files` ``, `` `vision_analyze` ``, `` `browser_navigate` ``, `` `delegate_task` ``, `` `image_generate` ``, `` `text_to_speech` ``, `` `cronjob` ``, `` `memory` ``, `` `skill_view` ``, `` `todo` ``, `` `execute_code` ``.
|
||||
|
||||
Do NOT name shell utilities the agent already has wrapped:
|
||||
|
||||
| Don't say | Say |
|
||||
|---|---|
|
||||
| `grep`, `rg` | `search_files` |
|
||||
| `cat`, `head`, `tail` | `read_file` |
|
||||
| `sed`, `awk` | `patch` |
|
||||
| `find`, `ls` | `search_files` (with `target='files'`) |
|
||||
| `curl` for content extraction | `web_extract` |
|
||||
| `echo > file`, `cat <<EOF` | `write_file` |
|
||||
|
||||
If the skill depends on an MCP server, name the MCP server and document its setup in `## Prerequisites`. Third-party CLIs (e.g. `ffmpeg`, `gh`, a specific SDK) are fine to invoke from inside script files, but the prose should frame the interaction as "invoke through the `terminal` tool", not as a manual shell session.
|
||||
|
||||
3. **`platforms:` gating audited against actual script imports.** Skills that use POSIX-only primitives (`fcntl`, `termios`, `os.setsid`, `os.kill(pid, 0)` for liveness, `/proc`, hardcoded `/tmp` paths, `signal.SIGKILL`, bash heredocs, `osascript`, `apt`, `systemctl`) must declare their supported platforms via the `platforms:` frontmatter. Default posture is to fix it cross-platform first — `tempfile.gettempdir()`, `pathlib.Path`, `psutil.pid_exists()`, Python-level filtering instead of `grep`. Gate to a narrower set only when the dependency is genuinely platform-bound (e.g. `osascript` is macOS-only, `/proc` is Linux-only).
|
||||
|
||||
4. **`author` credits the human contributor first.** For external contributions, the contributor's real name + GitHub handle goes first (`Jane Doe (jane-doe)`); "Hermes Agent" is the secondary collaborator. If the contributor's commit shows "Hermes Agent" as author because they used Hermes to draft the skill, replace it with their actual name — credit the human, not the tool.
|
||||
|
||||
5. **SKILL.md body uses the modern section order.** `# <Skill> Skill` title, 2-3 sentence intro stating what it does and what it doesn't do, then:
|
||||
- `## When to Use` — trigger conditions
|
||||
- `## Prerequisites` — env vars, install steps, MCP setup, API key sourcing
|
||||
- `## How to Run` — canonical invocation through the `terminal` tool
|
||||
- `## Quick Reference` — flat command/API reference
|
||||
- `## Procedure` — numbered steps with copy-paste commands
|
||||
- `## Pitfalls` — known limits, rate limits, things that look broken but aren't
|
||||
- `## Verification` — single command that proves the skill works
|
||||
|
||||
Target ~200 lines for a complex skill, ~100 lines for a simple one. Cut redundant intro fluff, marketing prose, and re-explanations of env vars already documented in `## Prerequisites`.
|
||||
|
||||
6. **Scripts go in `scripts/`, references in `references/`, templates in `templates/`.** Don't expect the model to inline-write parsers, XML walkers, or non-trivial logic every call — ship a helper script. Reference scripts from SKILL.md by path relative to the skill directory.
|
||||
|
||||
7. **Tests live at `tests/skills/test_<skill>_skill.py`** and use only stdlib + pytest + `unittest.mock`. No live network calls. Run via `scripts/run_tests.sh tests/skills/test_<skill>_skill.py -q`. Must pass under the hermetic CI env (no API keys leaking through). Use `monkeypatch` and `tmp_path` for any env-var or filesystem dependencies.
|
||||
|
||||
8. **`.env.example` additions are isolated to a clearly delimited block.** Don't touch the surrounding file — contributor-supplied `.env.example` versions are usually stale, and edits outside the skill's own block will be dropped during salvage. Comment all values with `#` (it's documentation, not live config).
|
||||
|
||||
### Skill guidelines
|
||||
|
||||
- **No external dependencies unless absolutely necessary.** Prefer stdlib Python, curl, and existing Hermes tools (`web_extract`, `terminal`, `read_file`).
|
||||
@@ -734,6 +800,47 @@ Hermes has terminal access. Security matters.
|
||||
|
||||
If your PR affects security, note it explicitly in the description.
|
||||
|
||||
### Dependency pinning policy (supply chain hardening)
|
||||
|
||||
After the [litellm supply chain compromise](https://github.com/BerriAI/litellm/issues/24512) in March 2026 and the [Mini Shai-Hulud worm campaign](https://socket.dev/blog/tanstack-npm-packages-compromised-mini-shai-hulud-supply-chain-attack) in May 2026, all dependencies must follow these rules:
|
||||
|
||||
| Source type | Required treatment | Rationale |
|
||||
|---|---|---|
|
||||
| **PyPI package** | `>=floor,<next_major` | PyPI versions are immutable once published, but new versions can be pushed into your range. A `<next_major` ceiling stops a 1.x install from upgrading to a malicious 2.0.0. |
|
||||
| **Git URL** (atroposlib, tinker, yc-bench, Baileys) | Full commit SHA | Branches and tags are mutable refs; SHA is content-addressed. |
|
||||
| **GitHub Actions** | Full commit SHA + version comment | Action tags are mutable refs (e.g. tj-actions/changed-files March 2025). Pin as `uses: owner/action@<sha> # vX.Y.Z` |
|
||||
| **CI-only pip installs** | `==exact` | Hermetic CI builds; churn is acceptable. |
|
||||
|
||||
**Every new PyPI dependency in a PR must have a `<next_major` upper bound.** PRs adding unbounded `>=X.Y.Z` specs will be rejected by reviewers. The `supply-chain-audit.yml` CI workflow also flags dependency manifest changes for manual review.
|
||||
|
||||
**How to determine the ceiling:**
|
||||
- If the package is at version `1.x.y`, use `<2`.
|
||||
- If the package is at version `0.x.y` (pre-1.0), use `<0.(current_minor + 2)` — e.g. if current is `0.29.x`, use `<0.32`. This gives ~2 minor versions of headroom while keeping the window small enough that a hostile takeover version is unlikely to land inside it.
|
||||
- Exception: packages with very stable APIs (e.g. `aiohttp-socks`) can use `<1` at reviewer discretion.
|
||||
|
||||
**Examples:**
|
||||
```toml
|
||||
# ✅ Correct — post-1.0
|
||||
"openai>=2.21.0,<3"
|
||||
"pydantic>=2.12.5,<3"
|
||||
|
||||
# ✅ Correct — pre-1.0 (tight minor window)
|
||||
"asyncpg>=0.29,<0.32"
|
||||
"aiosqlite>=0.20,<0.23"
|
||||
"hindsight-client>=0.4.22,<0.5"
|
||||
|
||||
# ❌ Rejected — no upper bound
|
||||
"some-package>=1.2.3"
|
||||
|
||||
# ❌ Rejected — too tight (blocks legitimate patches)
|
||||
"some-package==1.2.3"
|
||||
|
||||
# ❌ Rejected — too loose for pre-1.0 (allows 80 minor versions)
|
||||
"some-package>=0.20,<1"
|
||||
```
|
||||
|
||||
**Reference PRs:** #2796 (litellm removal), #2810 (upper bounds pass), #9801 (SHA pinning + supply-chain-audit CI).
|
||||
|
||||
---
|
||||
|
||||
## Pull Request Process
|
||||
|
||||
@@ -94,9 +94,13 @@ RUN cd web && npm run build && \
|
||||
# hermes_cli/main.py succeeds (see #18800). /opt/hermes/web is build-time
|
||||
# only (HERMES_WEB_DIST points at hermes_cli/web_dist) and is intentionally
|
||||
# not chowned here.
|
||||
# The .venv MUST be hermes-writable so lazy_deps.py can install platform
|
||||
# packages (discord.py, telegram, slack, etc.) at first gateway boot.
|
||||
# Without this, `uv pip install` fails with EACCES and all messaging
|
||||
# adapters silently fail to load. See tools/lazy_deps.py.
|
||||
USER root
|
||||
RUN chmod -R a+rX /opt/hermes && \
|
||||
chown -R hermes:hermes /opt/hermes/ui-tui /opt/hermes/node_modules
|
||||
chown -R hermes:hermes /opt/hermes/.venv /opt/hermes/ui-tui /opt/hermes/node_modules
|
||||
# Start as root so the entrypoint can usermod/groupmod + gosu.
|
||||
# If HERMES_UID is unset, the entrypoint drops to the default hermes user (10000).
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
**The self-improving AI agent built by [Nous Research](https://nousresearch.com).** It's the only agent with a built-in learning loop — it creates skills from experience, improves them during use, nudges itself to persist knowledge, searches its own past conversations, and builds a deepening model of who you are across sessions. Run it on a $5 VPS, a GPU cluster, or serverless infrastructure that costs nearly nothing when idle. It's not tied to your laptop — talk to it from Telegram while it works on a cloud VM.
|
||||
|
||||
Use any model you want — [Nous Portal](https://portal.nousresearch.com), [OpenRouter](https://openrouter.ai) (200+ models), [NVIDIA NIM](https://build.nvidia.com) (Nemotron), [Xiaomi MiMo](https://platform.xiaomimimo.com), [z.ai/GLM](https://z.ai), [Kimi/Moonshot](https://platform.moonshot.ai), [MiniMax](https://www.minimax.io), [Hugging Face](https://huggingface.co), OpenAI, or your own endpoint. Switch with `hermes model` — no code changes, no lock-in.
|
||||
Use any model you want — [Nous Portal](https://portal.nousresearch.com), [OpenRouter](https://openrouter.ai) (200+ models), [NovitaAI](https://novita.ai) (AI-native cloud for Model API, Agent Sandbox, and GPU Cloud), [NVIDIA NIM](https://build.nvidia.com) (Nemotron), [Xiaomi MiMo](https://platform.xiaomimimo.com), [z.ai/GLM](https://z.ai), [Kimi/Moonshot](https://platform.moonshot.ai), [MiniMax](https://www.minimax.io), [Hugging Face](https://huggingface.co), OpenAI, or your own endpoint. Switch with `hermes model` — no code changes, no lock-in.
|
||||
|
||||
<table>
|
||||
<tr><td><b>A real terminal interface</b></td><td>Full TUI with multiline editing, slash-command autocomplete, conversation history, interrupt-and-redirect, and streaming tool output.</td></tr>
|
||||
@@ -23,7 +23,7 @@ Use any model you want — [Nous Portal](https://portal.nousresearch.com), [Open
|
||||
<tr><td><b>Scheduled automations</b></td><td>Built-in cron scheduler with delivery to any platform. Daily reports, nightly backups, weekly audits — all in natural language, running unattended.</td></tr>
|
||||
<tr><td><b>Delegates and parallelizes</b></td><td>Spawn isolated subagents for parallel workstreams. Write Python scripts that call tools via RPC, collapsing multi-step pipelines into zero-context-cost turns.</td></tr>
|
||||
<tr><td><b>Runs anywhere, not just your laptop</b></td><td>Seven terminal backends — local, Docker, SSH, Singularity, Modal, Daytona, and Vercel Sandbox. Daytona and Modal offer serverless persistence — your agent's environment hibernates when idle and wakes on demand, costing nearly nothing between sessions. Run it on a $5 VPS or a GPU cluster.</td></tr>
|
||||
<tr><td><b>Research-ready</b></td><td>Batch trajectory generation, Atropos RL environments, trajectory compression for training the next generation of tool-calling models.</td></tr>
|
||||
<tr><td><b>Research-ready</b></td><td>Batch trajectory generation, trajectory compression for training the next generation of tool-calling models.</td></tr>
|
||||
</table>
|
||||
|
||||
---
|
||||
@@ -175,8 +175,6 @@ uv pip install -e ".[all,dev]"
|
||||
scripts/run_tests.sh
|
||||
```
|
||||
|
||||
> **RL Training (optional):** The RL/Atropos integration (`environments/`) — see [`CONTRIBUTING.md`](https://github.com/NousResearch/hermes-agent/blob/main/CONTRIBUTING.md#development-setup) for the full setup.
|
||||
|
||||
---
|
||||
|
||||
## Community
|
||||
|
||||
@@ -23,7 +23,7 @@
|
||||
<tr><td><b>定时自动化</b></td><td>内置 cron 调度器,支持向任何平台投递。日报、夜间备份、周审计——全部用自然语言描述,无人值守运行。</td></tr>
|
||||
<tr><td><b>委派与并行</b></td><td>生成隔离子代理处理并行工作流。编写 Python 脚本通过 RPC 调用工具,将多步管道压缩为零上下文开销的轮次。</td></tr>
|
||||
<tr><td><b>随处运行</b></td><td>六种终端后端——本地、Docker、SSH、Daytona、Singularity 和 Modal。Daytona 和 Modal 提供 Serverless 持久化——代理环境空闲时休眠、按需唤醒,空闲期间几乎零成本。$5 VPS 或 GPU 集群都能跑。</td></tr>
|
||||
<tr><td><b>研究就绪</b></td><td>批量轨迹生成、Atropos RL 环境、轨迹压缩——用于训练下一代工具调用模型。</td></tr>
|
||||
<tr><td><b>研究就绪</b></td><td>批量轨迹生成、轨迹压缩——用于训练下一代工具调用模型。</td></tr>
|
||||
</table>
|
||||
|
||||
---
|
||||
@@ -161,12 +161,6 @@ uv pip install -e ".[all,dev]"
|
||||
python -m pytest tests/ -q
|
||||
```
|
||||
|
||||
> **RL 训练(可选):** 如需参与 RL/Tinker-Atropos 集成开发:
|
||||
> ```bash
|
||||
> git submodule update --init tinker-atropos
|
||||
> uv pip install -e "./tinker-atropos"
|
||||
> ```
|
||||
|
||||
---
|
||||
|
||||
## 社区
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
"""ACP auth helpers — detect the currently configured Hermes provider."""
|
||||
"""ACP auth helpers — detect and advertise Hermes authentication methods."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
TERMINAL_SETUP_AUTH_METHOD_ID = "hermes-setup"
|
||||
|
||||
|
||||
def detect_provider() -> Optional[str]:
|
||||
@@ -22,3 +25,44 @@ def detect_provider() -> Optional[str]:
|
||||
def has_provider() -> bool:
|
||||
"""Return True if Hermes can resolve any runtime provider credentials."""
|
||||
return detect_provider() is not None
|
||||
|
||||
|
||||
def build_auth_methods() -> list[Any]:
|
||||
"""Return registry-compatible ACP auth methods for Hermes.
|
||||
|
||||
The official ACP registry validates that agents advertise at least one
|
||||
usable auth method during the initial handshake. A fresh Zed install may
|
||||
not have Hermes provider credentials configured yet, so Hermes always
|
||||
advertises a terminal setup method. When credentials are already present,
|
||||
it also advertises the resolved provider as the default agent-managed
|
||||
runtime credential method.
|
||||
"""
|
||||
from acp.schema import AuthMethodAgent, TerminalAuthMethod
|
||||
|
||||
methods: list[Any] = []
|
||||
provider = detect_provider()
|
||||
if provider:
|
||||
methods.append(
|
||||
AuthMethodAgent(
|
||||
id=provider,
|
||||
name=f"{provider} runtime credentials",
|
||||
description=(
|
||||
"Authenticate Hermes using the currently configured "
|
||||
f"{provider} runtime credentials."
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
methods.append(
|
||||
TerminalAuthMethod(
|
||||
id=TERMINAL_SETUP_AUTH_METHOD_ID,
|
||||
name="Configure Hermes provider",
|
||||
description=(
|
||||
"Open Hermes' interactive model/provider setup in a terminal. "
|
||||
"Use this when Hermes has not been configured on this machine yet."
|
||||
),
|
||||
type="terminal",
|
||||
args=["--setup"],
|
||||
)
|
||||
)
|
||||
return methods
|
||||
|
||||
288
acp_adapter/bootstrap/bootstrap_browser_tools.ps1
Normal file
288
acp_adapter/bootstrap/bootstrap_browser_tools.ps1
Normal file
@@ -0,0 +1,288 @@
|
||||
# bootstrap_browser_tools.ps1 — install agent-browser + Playwright Chromium
|
||||
# into ~/.hermes/node/ for use by Hermes Agent's browser tools on Windows.
|
||||
#
|
||||
# Targets the registry-install path: users who got Hermes via
|
||||
# `uvx --from 'hermes-agent[acp]==X' hermes-acp` don't have a repo clone,
|
||||
# so the install.ps1 `npm install`-in-repo flow doesn't apply. This script
|
||||
# is a self-contained, idempotent slice of install.ps1's browser block.
|
||||
#
|
||||
# Usage:
|
||||
# .\bootstrap_browser_tools.ps1 # use defaults
|
||||
# .\bootstrap_browser_tools.ps1 -Yes # accept Chromium download
|
||||
# .\bootstrap_browser_tools.ps1 -SkipChromium # Node + agent-browser only
|
||||
#
|
||||
# Idempotent: re-running this is safe and fast.
|
||||
|
||||
[CmdletBinding()]
|
||||
param(
|
||||
[switch]$Yes,
|
||||
[switch]$SkipChromium
|
||||
)
|
||||
|
||||
$ErrorActionPreference = "Stop"
|
||||
$NodeVersion = "22"
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# Logging
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
function Write-Info { param([string]$msg) Write-Host "[*] $msg" -ForegroundColor Cyan }
|
||||
function Write-Success { param([string]$msg) Write-Host "[+] $msg" -ForegroundColor Green }
|
||||
function Write-Warn { param([string]$msg) Write-Host "[!] $msg" -ForegroundColor Yellow }
|
||||
function Write-Err { param([string]$msg) Write-Host "[x] $msg" -ForegroundColor Red }
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# Paths
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
$HermesHome = $env:HERMES_HOME
|
||||
if (-not $HermesHome) {
|
||||
$HermesHome = Join-Path $env:USERPROFILE ".hermes"
|
||||
}
|
||||
$NodePrefix = Join-Path $HermesHome "node"
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# Step 1: Node.js
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
function Resolve-NpmExe {
|
||||
# Same gotcha as install.ps1: prefer npm.cmd over npm.ps1 so the
|
||||
# PowerShell execution policy doesn't block us.
|
||||
$cmd = Get-Command npm -ErrorAction SilentlyContinue
|
||||
if (-not $cmd) { return $null }
|
||||
$npmExe = $cmd.Source
|
||||
if ($npmExe -like "*.ps1") {
|
||||
$sibling = Join-Path (Split-Path $npmExe -Parent) "npm.cmd"
|
||||
if (Test-Path $sibling) { return $sibling }
|
||||
}
|
||||
return $npmExe
|
||||
}
|
||||
|
||||
function Resolve-NpxExe {
|
||||
$cmd = Get-Command npx -ErrorAction SilentlyContinue
|
||||
if (-not $cmd) { return $null }
|
||||
$npxExe = $cmd.Source
|
||||
if ($npxExe -like "*.ps1") {
|
||||
$sibling = Join-Path (Split-Path $npxExe -Parent) "npx.cmd"
|
||||
if (Test-Path $sibling) { return $sibling }
|
||||
}
|
||||
return $npxExe
|
||||
}
|
||||
|
||||
function Ensure-Node {
|
||||
# System Node on PATH?
|
||||
$sysNode = Get-Command node -ErrorAction SilentlyContinue
|
||||
if ($sysNode) {
|
||||
try {
|
||||
$v = & $sysNode.Source --version
|
||||
$major = [int]($v -replace '^v(\d+).*', '$1')
|
||||
if ($major -ge 20) {
|
||||
Write-Success "Node.js $v found on PATH"
|
||||
return
|
||||
}
|
||||
Write-Warn "Node.js $v is older than v20 — installing managed Node."
|
||||
} catch {
|
||||
Write-Warn "Failed to query Node version: $_"
|
||||
}
|
||||
}
|
||||
|
||||
# Hermes-managed Node?
|
||||
$managedNode = Join-Path $NodePrefix "node.exe"
|
||||
if (Test-Path $managedNode) {
|
||||
$v = & $managedNode --version
|
||||
Write-Success "Node.js $v found (Hermes-managed at $NodePrefix)"
|
||||
# Prepend to current-process PATH so subsequent npm/npx calls find it.
|
||||
$env:PATH = "$NodePrefix;$env:PATH"
|
||||
return
|
||||
}
|
||||
|
||||
Write-Info "Installing Node.js $NodeVersion LTS into $NodePrefix ..."
|
||||
|
||||
$arch = if ([Environment]::Is64BitOperatingSystem) { "x64" } else { "x86" }
|
||||
$indexUrl = "https://nodejs.org/dist/latest-v${NodeVersion}.x/"
|
||||
|
||||
try {
|
||||
$indexPage = Invoke-WebRequest -Uri $indexUrl -UseBasicParsing
|
||||
$matches = [regex]::Matches($indexPage.Content, "node-v${NodeVersion}\.\d+\.\d+-win-${arch}\.zip")
|
||||
if ($matches.Count -eq 0) {
|
||||
Write-Err "Could not locate Node.js $NodeVersion zip for win-$arch"
|
||||
throw "no tarball"
|
||||
}
|
||||
$zipName = $matches[0].Value
|
||||
$zipUrl = "$indexUrl$zipName"
|
||||
|
||||
$tmpDir = Join-Path $env:TEMP "hermes-node-$([guid]::NewGuid().ToString('N'))"
|
||||
New-Item -ItemType Directory -Force -Path $tmpDir | Out-Null
|
||||
$zipPath = Join-Path $tmpDir $zipName
|
||||
|
||||
Write-Info "Downloading $zipName ..."
|
||||
Invoke-WebRequest -Uri $zipUrl -OutFile $zipPath -UseBasicParsing
|
||||
|
||||
Expand-Archive -Path $zipPath -DestinationPath $tmpDir -Force
|
||||
$extracted = Get-ChildItem -Path $tmpDir -Directory | Where-Object { $_.Name -like "node-v*" } | Select-Object -First 1
|
||||
|
||||
if (-not $extracted) { Write-Err "Node.js extraction failed"; throw "extract" }
|
||||
|
||||
if (Test-Path $NodePrefix) { Remove-Item -Recurse -Force $NodePrefix }
|
||||
New-Item -ItemType Directory -Force -Path $HermesHome | Out-Null
|
||||
Move-Item -Path $extracted.FullName -Destination $NodePrefix
|
||||
|
||||
Remove-Item -Recurse -Force $tmpDir -ErrorAction SilentlyContinue
|
||||
|
||||
$env:PATH = "$NodePrefix;$env:PATH"
|
||||
$v = & "$NodePrefix\node.exe" --version
|
||||
Write-Success "Node.js $v installed to $NodePrefix"
|
||||
} catch {
|
||||
Write-Err "Node.js install failed: $_"
|
||||
Write-Info "Install Node 20+ manually from https://nodejs.org/en/download/ and re-run."
|
||||
throw
|
||||
}
|
||||
}
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# Step 2: agent-browser
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
function Ensure-AgentBrowser {
|
||||
$npmExe = Resolve-NpmExe
|
||||
if (-not $npmExe) {
|
||||
Write-Err "npm not on PATH after Node install — aborting"
|
||||
throw "npm missing"
|
||||
}
|
||||
|
||||
# Already installed?
|
||||
$existing = Get-Command agent-browser -ErrorAction SilentlyContinue
|
||||
if ($existing) {
|
||||
Write-Success "agent-browser already installed at $($existing.Source)"
|
||||
return
|
||||
}
|
||||
|
||||
# When the user has system Node (winget / installer-based), `npm install
|
||||
# -g` writes to a directory that may require admin rights. Force the
|
||||
# prefix to the user-writable Hermes-managed Node directory so we never
|
||||
# need elevation and the agent can always find the result. Mirrors the
|
||||
# bash bootstrap's `--prefix $NODE_PREFIX` strategy.
|
||||
New-Item -ItemType Directory -Force -Path $NodePrefix | Out-Null
|
||||
|
||||
Write-Info "Installing agent-browser (npm, prefix=$NodePrefix)..."
|
||||
& $npmExe install -g --prefix $NodePrefix --silent `
|
||||
"agent-browser@^0.26.0" "@askjo/camofox-browser@^1.5.2"
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
Write-Err "npm install -g agent-browser failed (exit $LASTEXITCODE)"
|
||||
throw "npm install"
|
||||
}
|
||||
|
||||
# Windows npm global installs drop shims at $NodePrefix\ root (not bin/).
|
||||
# Prepend to PATH so any subsequent npx call resolves them.
|
||||
$env:PATH = "$NodePrefix;$env:PATH"
|
||||
|
||||
Write-Success "agent-browser installed to $NodePrefix"
|
||||
}
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# Step 3: Playwright Chromium
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
function Find-SystemBrowser {
|
||||
$candidates = @(
|
||||
"C:\Program Files\Google\Chrome\Application\chrome.exe",
|
||||
"C:\Program Files (x86)\Google\Chrome\Application\chrome.exe",
|
||||
"C:\Program Files\Chromium\Application\chromium.exe",
|
||||
"${env:LOCALAPPDATA}\Google\Chrome\Application\chrome.exe",
|
||||
"${env:LOCALAPPDATA}\Chromium\Application\chromium.exe"
|
||||
)
|
||||
foreach ($p in $candidates) {
|
||||
if (Test-Path $p) { return $p }
|
||||
}
|
||||
# Edge — Chromium-based, agent-browser can use it
|
||||
foreach ($p in @(
|
||||
"C:\Program Files (x86)\Microsoft\Edge\Application\msedge.exe",
|
||||
"C:\Program Files\Microsoft\Edge\Application\msedge.exe"
|
||||
)) {
|
||||
if (Test-Path $p) { return $p }
|
||||
}
|
||||
return $null
|
||||
}
|
||||
|
||||
function Write-BrowserEnv {
|
||||
param([string]$BrowserPath)
|
||||
$envFile = Join-Path $HermesHome ".env"
|
||||
New-Item -ItemType Directory -Force -Path $HermesHome | Out-Null
|
||||
if (Test-Path $envFile) {
|
||||
$existing = Get-Content $envFile -Raw -ErrorAction SilentlyContinue
|
||||
if ($existing -and ($existing -match "(?m)^AGENT_BROWSER_EXECUTABLE_PATH=")) {
|
||||
return
|
||||
}
|
||||
}
|
||||
Add-Content -Path $envFile -Value ""
|
||||
Add-Content -Path $envFile -Value "# Hermes Agent browser tools — use the system Chrome/Chromium/Edge binary."
|
||||
Add-Content -Path $envFile -Value "AGENT_BROWSER_EXECUTABLE_PATH=$BrowserPath"
|
||||
Write-Success "Configured browser tools to use $BrowserPath"
|
||||
}
|
||||
|
||||
function Confirm-ChromiumDownload {
|
||||
if ($Yes) { return $true }
|
||||
if (-not [Environment]::UserInteractive) {
|
||||
Write-Warn "Non-interactive shell — skipping Chromium prompt."
|
||||
Write-Info "Re-run with -Yes to install Chromium (~400 MB download)."
|
||||
return $false
|
||||
}
|
||||
$reply = Read-Host "Install Playwright Chromium (~400 MB download)? [y/N]"
|
||||
return ($reply -match "^(y|yes)$")
|
||||
}
|
||||
|
||||
function Ensure-Chromium {
|
||||
if ($SkipChromium) {
|
||||
Write-Info "Skipping Chromium install (-SkipChromium)"
|
||||
return
|
||||
}
|
||||
|
||||
# agent-browser on Windows expects a Playwright-managed Chromium under
|
||||
# %LOCALAPPDATA%\ms-playwright. The system-browser shortcut from the
|
||||
# Linux/macOS path doesn't apply the same way on Windows — Playwright's
|
||||
# default launch path won't pick up a stock Chrome install without an
|
||||
# explicit AGENT_BROWSER_EXECUTABLE_PATH. We still offer it as a
|
||||
# fallback when the user doesn't want the download.
|
||||
|
||||
if (-not (Confirm-ChromiumDownload)) {
|
||||
$sys = Find-SystemBrowser
|
||||
if ($sys) {
|
||||
Write-Info "Using system browser at $sys (Chromium download skipped)."
|
||||
Write-BrowserEnv -BrowserPath $sys
|
||||
} else {
|
||||
Write-Info "Chromium install skipped. Browser tools won't launch until"
|
||||
Write-Info "Chromium is installed or AGENT_BROWSER_EXECUTABLE_PATH is set."
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
$npxExe = Resolve-NpxExe
|
||||
if (-not $npxExe) {
|
||||
Write-Err "npx not on PATH — cannot install Playwright Chromium"
|
||||
throw "npx missing"
|
||||
}
|
||||
|
||||
Write-Info "Installing Playwright Chromium (~400 MB) ..."
|
||||
& $npxExe --yes playwright install chromium
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
Write-Err "Playwright Chromium install failed (exit $LASTEXITCODE)"
|
||||
Write-Info "Try again later: npx --yes playwright install chromium"
|
||||
throw "playwright"
|
||||
}
|
||||
Write-Success "Playwright Chromium installed"
|
||||
}
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# Main
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
Write-Info "Hermes Agent: bootstrapping browser tools"
|
||||
Write-Info " HERMES_HOME = $HermesHome"
|
||||
Write-Info " OS = Windows"
|
||||
|
||||
Ensure-Node
|
||||
Ensure-AgentBrowser
|
||||
Ensure-Chromium
|
||||
|
||||
Write-Success "Browser tools setup complete."
|
||||
Write-Info "Hermes Agent will pick up agent-browser from $NodePrefix on next launch."
|
||||
399
acp_adapter/bootstrap/bootstrap_browser_tools.sh
Executable file
399
acp_adapter/bootstrap/bootstrap_browser_tools.sh
Executable file
@@ -0,0 +1,399 @@
|
||||
#!/usr/bin/env bash
|
||||
#
|
||||
# bootstrap_browser_tools.sh — install agent-browser + Playwright Chromium
|
||||
# into ~/.hermes/node/ for use by Hermes Agent's browser tools.
|
||||
#
|
||||
# Targets the registry-install path: users who got Hermes via
|
||||
# `uvx --from 'hermes-agent[acp]==X' hermes-acp` don't have a repo clone,
|
||||
# so the install.sh `npm install`-in-repo flow doesn't apply. This script
|
||||
# is a self-contained, idempotent slice of install.sh's browser block —
|
||||
# safe to run from `hermes-acp --setup-browser`, from a fresh terminal,
|
||||
# or from install.sh itself (it's a no-op when everything is already in place).
|
||||
#
|
||||
# Usage:
|
||||
# bootstrap_browser_tools.sh # use defaults
|
||||
# bootstrap_browser_tools.sh --yes # accept the ~400MB Chromium download
|
||||
# bootstrap_browser_tools.sh --skip-chromium # only install Node + agent-browser
|
||||
# HERMES_HOME=/custom/path bootstrap_browser_tools.sh
|
||||
#
|
||||
# Idempotent: re-running this is safe and fast. Each step checks whether
|
||||
# the work is already done.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# Config
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
NODE_VERSION="22"
|
||||
HERMES_HOME="${HERMES_HOME:-$HOME/.hermes}"
|
||||
NODE_PREFIX="$HERMES_HOME/node"
|
||||
|
||||
SKIP_CHROMIUM=false
|
||||
ASSUME_YES=false
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# Logging
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
if [ -t 1 ]; then
|
||||
C_GREEN='\033[0;32m'
|
||||
C_YELLOW='\033[0;33m'
|
||||
C_BLUE='\033[0;34m'
|
||||
C_RED='\033[0;31m'
|
||||
C_RESET='\033[0m'
|
||||
else
|
||||
C_GREEN='' ; C_YELLOW='' ; C_BLUE='' ; C_RED='' ; C_RESET=''
|
||||
fi
|
||||
|
||||
log_info() { printf "${C_BLUE}[*]${C_RESET} %s\n" "$*"; }
|
||||
log_success() { printf "${C_GREEN}[✓]${C_RESET} %s\n" "$*"; }
|
||||
log_warn() { printf "${C_YELLOW}[!]${C_RESET} %s\n" "$*" >&2; }
|
||||
log_error() { printf "${C_RED}[✗]${C_RESET} %s\n" "$*" >&2; }
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# Arg parsing
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
while [ $# -gt 0 ]; do
|
||||
case "$1" in
|
||||
--skip-chromium) SKIP_CHROMIUM=true ;;
|
||||
--yes|-y) ASSUME_YES=true ;;
|
||||
-h|--help)
|
||||
cat <<EOF
|
||||
Bootstrap Hermes Agent browser tools.
|
||||
|
||||
Installs Node.js (into ~/.hermes/node/), the agent-browser npm package,
|
||||
and the Playwright Chromium browser engine.
|
||||
|
||||
Options:
|
||||
--skip-chromium Install Node + agent-browser but skip Chromium download
|
||||
--yes, -y Accept the ~400 MB Chromium download without prompting
|
||||
-h, --help Show this help
|
||||
|
||||
Environment:
|
||||
HERMES_HOME Override Hermes data dir (default: \$HOME/.hermes)
|
||||
EOF
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
log_error "Unknown option: $1"
|
||||
exit 2
|
||||
;;
|
||||
esac
|
||||
shift
|
||||
done
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# OS / arch detection
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
OS="unknown"
|
||||
case "$(uname -s)" in
|
||||
Linux*) OS="linux" ;;
|
||||
Darwin*) OS="macos" ;;
|
||||
*)
|
||||
log_error "Unsupported OS: $(uname -s)"
|
||||
log_info "Windows users: run scripts/bootstrap_browser_tools.ps1 in PowerShell."
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
NODE_ARCH=""
|
||||
case "$(uname -m)" in
|
||||
x86_64) NODE_ARCH="x64" ;;
|
||||
aarch64|arm64) NODE_ARCH="arm64" ;;
|
||||
armv7l) NODE_ARCH="armv7l" ;;
|
||||
*)
|
||||
log_error "Unsupported architecture: $(uname -m)"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
NODE_OS=""
|
||||
case "$OS" in
|
||||
linux) NODE_OS="linux" ;;
|
||||
macos) NODE_OS="darwin" ;;
|
||||
esac
|
||||
|
||||
DISTRO=""
|
||||
if [ -f /etc/os-release ]; then
|
||||
# shellcheck disable=SC1091
|
||||
. /etc/os-release
|
||||
DISTRO="${ID:-}"
|
||||
fi
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# Step 1: Node.js
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
ensure_node() {
|
||||
# Already on PATH and recent enough?
|
||||
if command -v node >/dev/null 2>&1; then
|
||||
local found_ver major
|
||||
found_ver=$(node --version 2>/dev/null)
|
||||
major=$(echo "$found_ver" | sed -E 's/^v([0-9]+).*/\1/')
|
||||
if [ -n "$major" ] && [ "$major" -ge 20 ]; then
|
||||
log_success "Node.js $found_ver found on PATH"
|
||||
return 0
|
||||
fi
|
||||
log_warn "Node.js $found_ver is older than v20 — installing managed Node."
|
||||
fi
|
||||
|
||||
if [ -x "$NODE_PREFIX/bin/node" ]; then
|
||||
local found_ver
|
||||
found_ver=$("$NODE_PREFIX/bin/node" --version 2>/dev/null || echo "?")
|
||||
export PATH="$NODE_PREFIX/bin:$PATH"
|
||||
log_success "Node.js $found_ver found (Hermes-managed at $NODE_PREFIX)"
|
||||
return 0
|
||||
fi
|
||||
|
||||
log_info "Installing Node.js $NODE_VERSION LTS into $NODE_PREFIX ..."
|
||||
|
||||
local index_url="https://nodejs.org/dist/latest-v${NODE_VERSION}.x/"
|
||||
local tarball_name
|
||||
tarball_name=$(curl -fsSL "$index_url" \
|
||||
| grep -oE "node-v${NODE_VERSION}\.[0-9]+\.[0-9]+-${NODE_OS}-${NODE_ARCH}\.tar\.xz" \
|
||||
| head -1)
|
||||
|
||||
if [ -z "$tarball_name" ]; then
|
||||
tarball_name=$(curl -fsSL "$index_url" \
|
||||
| grep -oE "node-v${NODE_VERSION}\.[0-9]+\.[0-9]+-${NODE_OS}-${NODE_ARCH}\.tar\.gz" \
|
||||
| head -1)
|
||||
fi
|
||||
|
||||
if [ -z "$tarball_name" ]; then
|
||||
log_error "Could not locate Node.js $NODE_VERSION tarball for $NODE_OS-$NODE_ARCH"
|
||||
log_info "Install Node 20+ manually: https://nodejs.org/en/download/"
|
||||
return 1
|
||||
fi
|
||||
|
||||
local tmp_dir
|
||||
tmp_dir=$(mktemp -d)
|
||||
trap 'rm -rf "$tmp_dir"' RETURN
|
||||
|
||||
log_info "Downloading $tarball_name ..."
|
||||
if ! curl -fsSL "${index_url}${tarball_name}" -o "$tmp_dir/$tarball_name"; then
|
||||
log_error "Node.js download failed"
|
||||
return 1
|
||||
fi
|
||||
|
||||
if [[ "$tarball_name" == *.tar.xz ]]; then
|
||||
tar xf "$tmp_dir/$tarball_name" -C "$tmp_dir"
|
||||
else
|
||||
tar xzf "$tmp_dir/$tarball_name" -C "$tmp_dir"
|
||||
fi
|
||||
|
||||
local extracted_dir
|
||||
extracted_dir=$(ls -d "$tmp_dir"/node-v* 2>/dev/null | head -1)
|
||||
if [ ! -d "$extracted_dir" ]; then
|
||||
log_error "Node.js extraction failed"
|
||||
return 1
|
||||
fi
|
||||
|
||||
mkdir -p "$HERMES_HOME"
|
||||
rm -rf "$NODE_PREFIX"
|
||||
mv "$extracted_dir" "$NODE_PREFIX"
|
||||
|
||||
export PATH="$NODE_PREFIX/bin:$PATH"
|
||||
|
||||
local installed_ver
|
||||
installed_ver=$("$NODE_PREFIX/bin/node" --version 2>/dev/null || echo "?")
|
||||
log_success "Node.js $installed_ver installed to $NODE_PREFIX"
|
||||
}
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# Step 2: agent-browser + @askjo/camofox-browser via global npm install
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
ensure_agent_browser() {
|
||||
if ! command -v npm >/dev/null 2>&1; then
|
||||
log_error "npm not on PATH after Node install — aborting"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# _find_agent_browser() in tools/browser_tool.py walks ~/.hermes/node/bin
|
||||
# plus a few standard prefixes, so installing globally into the managed
|
||||
# Node prefix is enough — no PATH manipulation needed from the agent side.
|
||||
if [ -x "$NODE_PREFIX/bin/agent-browser" ] || command -v agent-browser >/dev/null 2>&1; then
|
||||
log_success "agent-browser already installed"
|
||||
return 0
|
||||
fi
|
||||
|
||||
# When the system's `npm` resolves to a root-owned prefix (e.g.
|
||||
# /usr/lib/node_modules), `npm install -g` fails with EACCES without
|
||||
# sudo. Force the prefix to the user-writable Hermes-managed Node
|
||||
# directory so we never need sudo and the agent can always find the
|
||||
# result. If we installed Node ourselves above, this is a no-op
|
||||
# (managed Node already uses $NODE_PREFIX). If the user has system
|
||||
# Node, we still drop agent-browser under $NODE_PREFIX/bin/ — which
|
||||
# is exactly where _browser_candidate_path_dirs() looks first.
|
||||
mkdir -p "$NODE_PREFIX"
|
||||
|
||||
log_info "Installing agent-browser (npm, prefix=$NODE_PREFIX)..."
|
||||
if ! npm install -g --prefix "$NODE_PREFIX" --silent \
|
||||
agent-browser@^0.26.0 \
|
||||
"@askjo/camofox-browser@^1.5.2"; then
|
||||
log_error "npm install -g agent-browser failed"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# macOS/Linux global installs place the shim into $NODE_PREFIX/bin/.
|
||||
# Add it to PATH for any subsequent steps (npx playwright).
|
||||
export PATH="$NODE_PREFIX/bin:$PATH"
|
||||
|
||||
log_success "agent-browser installed to $NODE_PREFIX/bin/"
|
||||
}
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# Step 3: Playwright Chromium
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
confirm_chromium_download() {
|
||||
if [ "$ASSUME_YES" = true ]; then return 0; fi
|
||||
if [ ! -t 0 ]; then
|
||||
log_warn "Non-interactive shell — skipping Chromium prompt."
|
||||
log_info "Re-run with --yes to install Chromium (~400 MB download)."
|
||||
return 1
|
||||
fi
|
||||
printf "Install Playwright Chromium (~400 MB download)? [y/N] "
|
||||
local reply=""
|
||||
read -r reply || reply=""
|
||||
case "$reply" in
|
||||
y|Y|yes|YES) return 0 ;;
|
||||
*) return 1 ;;
|
||||
esac
|
||||
}
|
||||
|
||||
# Detect a usable system Chrome/Chromium. agent-browser's Chrome engine can
|
||||
# use it instead of downloading Playwright's bundled Chromium, saving the
|
||||
# download cost. Returns the path or empty string.
|
||||
find_system_browser() {
|
||||
local candidate
|
||||
for candidate in google-chrome google-chrome-stable chromium chromium-browser chrome; do
|
||||
if command -v "$candidate" >/dev/null 2>&1; then
|
||||
command -v "$candidate"
|
||||
return 0
|
||||
fi
|
||||
done
|
||||
# macOS app-bundle locations
|
||||
if [ "$OS" = "macos" ]; then
|
||||
for candidate in \
|
||||
"/Applications/Google Chrome.app/Contents/MacOS/Google Chrome" \
|
||||
"/Applications/Chromium.app/Contents/MacOS/Chromium" ; do
|
||||
if [ -x "$candidate" ]; then
|
||||
echo "$candidate"
|
||||
return 0
|
||||
fi
|
||||
done
|
||||
fi
|
||||
return 1
|
||||
}
|
||||
|
||||
write_browser_env() {
|
||||
local browser_path="$1"
|
||||
local env_file="$HERMES_HOME/.env"
|
||||
mkdir -p "$HERMES_HOME"
|
||||
if [ -f "$env_file" ] && grep -q "^AGENT_BROWSER_EXECUTABLE_PATH=" "$env_file"; then
|
||||
return 0
|
||||
fi
|
||||
{
|
||||
echo ""
|
||||
echo "# Hermes Agent browser tools — use the system Chrome/Chromium binary."
|
||||
echo "AGENT_BROWSER_EXECUTABLE_PATH=$browser_path"
|
||||
} >> "$env_file"
|
||||
log_success "Configured browser tools to use $browser_path"
|
||||
}
|
||||
|
||||
ensure_chromium() {
|
||||
if [ "$SKIP_CHROMIUM" = true ]; then
|
||||
log_info "Skipping Chromium install (--skip-chromium)"
|
||||
return 0
|
||||
fi
|
||||
|
||||
local system_browser
|
||||
system_browser="$(find_system_browser 2>/dev/null || true)"
|
||||
if [ -n "$system_browser" ]; then
|
||||
log_success "Found system browser: $system_browser"
|
||||
log_info "Skipping Playwright Chromium download; agent-browser will use it."
|
||||
write_browser_env "$system_browser"
|
||||
return 0
|
||||
fi
|
||||
|
||||
if ! confirm_chromium_download; then
|
||||
log_info "Chromium install skipped. Browser tools will only work if you"
|
||||
log_info "set AGENT_BROWSER_EXECUTABLE_PATH or install Chromium later."
|
||||
return 0
|
||||
fi
|
||||
|
||||
if ! command -v npx >/dev/null 2>&1; then
|
||||
log_error "npx not on PATH — cannot install Playwright Chromium"
|
||||
return 1
|
||||
fi
|
||||
|
||||
log_info "Installing Playwright Chromium (~400 MB) ..."
|
||||
|
||||
# On apt-based distros, --with-deps requires sudo. Try non-interactively
|
||||
# only — never prompt — and fall back to the bare browser-only install.
|
||||
local installed=false
|
||||
if [ "$OS" = "linux" ]; then
|
||||
case "$DISTRO" in
|
||||
ubuntu|debian|raspbian|pop|linuxmint|elementary|zorin|kali|parrot)
|
||||
if [ "$(id -u)" -eq 0 ] || (command -v sudo >/dev/null 2>&1 && sudo -n true 2>/dev/null); then
|
||||
log_info "Installing system deps with --with-deps (sudo available)"
|
||||
if npx --yes playwright install --with-deps chromium; then
|
||||
installed=true
|
||||
fi
|
||||
else
|
||||
log_warn "sudo not available non-interactively — installing Chromium without system deps."
|
||||
log_info "If browser tools fail to launch, an administrator should run:"
|
||||
log_info " sudo npx playwright install-deps chromium"
|
||||
fi
|
||||
;;
|
||||
arch|manjaro|cachyos|endeavouros|garuda)
|
||||
log_info "Arch-family system dependencies are not auto-installed."
|
||||
log_info "If launch fails, run: sudo pacman -S nss atk at-spi2-core cups libdrm libxkbcommon mesa pango cairo alsa-lib"
|
||||
;;
|
||||
fedora|rhel|centos|rocky|alma)
|
||||
log_info "Fedora/RHEL system dependencies are not auto-installed."
|
||||
log_info "If launch fails, run: sudo dnf install nss atk at-spi2-core cups-libs libdrm libxkbcommon mesa-libgbm pango cairo alsa-lib"
|
||||
;;
|
||||
opensuse*|sles)
|
||||
log_info "openSUSE system dependencies are not auto-installed."
|
||||
;;
|
||||
esac
|
||||
fi
|
||||
|
||||
if [ "$installed" = false ]; then
|
||||
if npx --yes playwright install chromium; then
|
||||
installed=true
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ "$installed" = true ]; then
|
||||
log_success "Playwright Chromium installed"
|
||||
else
|
||||
log_error "Playwright Chromium install failed"
|
||||
log_info "Try again later: npx --yes playwright install chromium"
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# Main
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
main() {
|
||||
log_info "Hermes Agent: bootstrapping browser tools"
|
||||
log_info " HERMES_HOME = $HERMES_HOME"
|
||||
log_info " OS / arch = $NODE_OS-$NODE_ARCH ${DISTRO:+($DISTRO)}"
|
||||
|
||||
ensure_node
|
||||
ensure_agent_browser
|
||||
ensure_chromium
|
||||
|
||||
log_success "Browser tools setup complete."
|
||||
log_info "Hermes Agent will pick up agent-browser from $NODE_PREFIX/bin/ on next launch."
|
||||
}
|
||||
|
||||
main
|
||||
@@ -24,6 +24,7 @@ except ModuleNotFoundError:
|
||||
# means UTF-8 stdio setup is skipped on Windows; POSIX is unaffected.
|
||||
pass
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
@@ -107,8 +108,150 @@ def _load_env() -> None:
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="hermes-acp",
|
||||
description="Run Hermes Agent as an ACP stdio server.",
|
||||
)
|
||||
parser.add_argument("--version", action="store_true", help="Print Hermes version and exit")
|
||||
parser.add_argument(
|
||||
"--check",
|
||||
action="store_true",
|
||||
help="Verify ACP dependencies and adapter imports, then exit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--setup",
|
||||
action="store_true",
|
||||
help="Run interactive Hermes provider/model setup for ACP terminal auth",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--setup-browser",
|
||||
action="store_true",
|
||||
help="Install agent-browser + Playwright Chromium into ~/.hermes/node/ "
|
||||
"for browser tool support. Idempotent.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--yes",
|
||||
"-y",
|
||||
action="store_true",
|
||||
dest="assume_yes",
|
||||
help="Accept all prompts (currently used by --setup-browser to skip the "
|
||||
"~400 MB Chromium download confirmation).",
|
||||
)
|
||||
return parser.parse_args(argv)
|
||||
|
||||
|
||||
def _print_version() -> None:
|
||||
from hermes_cli import __version__ as hermes_version
|
||||
|
||||
print(hermes_version)
|
||||
|
||||
|
||||
def _run_check() -> None:
|
||||
import acp # noqa: F401
|
||||
from acp_adapter.server import HermesACPAgent # noqa: F401
|
||||
|
||||
print("Hermes ACP check OK")
|
||||
|
||||
|
||||
def _run_setup() -> None:
|
||||
from hermes_cli.main import main as hermes_main
|
||||
|
||||
old_argv = sys.argv[:]
|
||||
try:
|
||||
sys.argv = [old_argv[0] if old_argv else "hermes", "model"]
|
||||
hermes_main()
|
||||
finally:
|
||||
sys.argv = old_argv
|
||||
|
||||
# Offer browser-tools install as a follow-up. The terminal auth method
|
||||
# is the one supported first-run UX for registry installs, so this is
|
||||
# the natural moment to ask. Skip silently if stdin isn't a TTY (the
|
||||
# answer can't be collected anyway).
|
||||
if not sys.stdin.isatty():
|
||||
return
|
||||
try:
|
||||
reply = input(
|
||||
"\nInstall browser tools? Downloads agent-browser (npm) and "
|
||||
"optionally Playwright Chromium (~400 MB). [y/N] "
|
||||
).strip().lower()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
return
|
||||
if reply in {"y", "yes"}:
|
||||
_run_setup_browser(assume_yes=False)
|
||||
|
||||
|
||||
def _run_setup_browser(assume_yes: bool = False) -> int:
|
||||
"""Bootstrap agent-browser + Playwright Chromium for the registry-install path.
|
||||
|
||||
Shells out to the bundled platform-specific bootstrap script
|
||||
(acp_adapter/bootstrap/bootstrap_browser_tools.{sh,ps1}) so the install
|
||||
logic lives in one place — readable, debuggable, and shareable with
|
||||
install.sh / install.ps1 if we ever want to call it from there too.
|
||||
|
||||
Returns the script's exit code (0 on success).
|
||||
"""
|
||||
import platform
|
||||
import subprocess
|
||||
|
||||
bootstrap_dir = Path(__file__).resolve().parent / "bootstrap"
|
||||
|
||||
if platform.system() == "Windows":
|
||||
script = bootstrap_dir / "bootstrap_browser_tools.ps1"
|
||||
if not script.is_file():
|
||||
print(
|
||||
f"Bootstrap script not found at {script} — wheel may be incomplete.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
cmd = [
|
||||
"powershell.exe",
|
||||
"-NoProfile",
|
||||
"-ExecutionPolicy", "Bypass",
|
||||
"-File", str(script),
|
||||
]
|
||||
if assume_yes:
|
||||
cmd.append("-Yes")
|
||||
else:
|
||||
script = bootstrap_dir / "bootstrap_browser_tools.sh"
|
||||
if not script.is_file():
|
||||
print(
|
||||
f"Bootstrap script not found at {script} — wheel may be incomplete.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
cmd = ["bash", str(script)]
|
||||
if assume_yes:
|
||||
cmd.append("--yes")
|
||||
|
||||
# stdio is inherited so the user sees the bootstrap's progress live.
|
||||
try:
|
||||
result = subprocess.run(cmd, check=False)
|
||||
except FileNotFoundError as exc:
|
||||
# bash / powershell.exe not on PATH
|
||||
print(f"Could not launch browser bootstrap: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
return result.returncode
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> None:
|
||||
"""Entry point: load env, configure logging, run the ACP agent."""
|
||||
args = _parse_args(argv)
|
||||
if args.version:
|
||||
_print_version()
|
||||
return
|
||||
if args.check:
|
||||
_run_check()
|
||||
return
|
||||
if args.setup:
|
||||
_run_setup()
|
||||
return
|
||||
if args.setup_browser:
|
||||
rc = _run_setup_browser(assume_yes=args.assume_yes)
|
||||
if rc != 0:
|
||||
sys.exit(rc)
|
||||
return
|
||||
|
||||
_setup_logging()
|
||||
_load_env()
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
"""ACP permission bridging — maps ACP approval requests to hermes approval callbacks."""
|
||||
"""ACP permission bridging for Hermes dangerous-command approvals."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from concurrent.futures import TimeoutError as FutureTimeout
|
||||
from itertools import count
|
||||
from typing import Callable
|
||||
|
||||
from acp.schema import (
|
||||
@@ -14,24 +15,87 @@ from acp.schema import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maps ACP PermissionOptionKind -> hermes approval result strings
|
||||
_KIND_TO_HERMES = {
|
||||
# Maps ACP permission option ids to Hermes approval result strings.
|
||||
# Option ids are stable across both the ``allow_permanent=True`` and
|
||||
# ``allow_permanent=False`` paths even though the option list differs.
|
||||
_OPTION_ID_TO_HERMES = {
|
||||
"allow_once": "once",
|
||||
"allow_session": "session",
|
||||
"allow_always": "always",
|
||||
"reject_once": "deny",
|
||||
"reject_always": "deny",
|
||||
"deny": "deny",
|
||||
}
|
||||
|
||||
_PERMISSION_REQUEST_IDS = count(1)
|
||||
|
||||
|
||||
def _build_permission_options(*, allow_permanent: bool) -> list[PermissionOption]:
|
||||
"""Return ACP options that match Hermes approval semantics."""
|
||||
options = [
|
||||
PermissionOption(option_id="allow_once", kind="allow_once", name="Allow once"),
|
||||
PermissionOption(
|
||||
option_id="allow_session",
|
||||
# ACP has no session-scoped kind, so use the closest persistent
|
||||
# hint while keeping Hermes semantics in the option id.
|
||||
kind="allow_always",
|
||||
name="Allow for session",
|
||||
),
|
||||
]
|
||||
if allow_permanent:
|
||||
options.append(
|
||||
PermissionOption(
|
||||
option_id="allow_always",
|
||||
kind="allow_always",
|
||||
name="Allow always",
|
||||
),
|
||||
)
|
||||
options.append(PermissionOption(option_id="deny", kind="reject_once", name="Deny"))
|
||||
return options
|
||||
|
||||
|
||||
def _build_permission_tool_call(command: str, description: str):
|
||||
"""Return the ACP tool-call update attached to a permission request.
|
||||
|
||||
``request_permission`` expects a ``ToolCallUpdate`` payload — produced
|
||||
by ``_acp.update_tool_call`` — not a ``ToolCallStart``. Each request
|
||||
gets a unique ``perm-check-N`` id so concurrent requests don't collide.
|
||||
"""
|
||||
import acp as _acp
|
||||
|
||||
tool_call_id = f"perm-check-{next(_PERMISSION_REQUEST_IDS)}"
|
||||
return _acp.update_tool_call(
|
||||
tool_call_id,
|
||||
title=description,
|
||||
kind="execute",
|
||||
status="pending",
|
||||
content=[_acp.tool_content(_acp.text_block(f"$ {command}"))],
|
||||
raw_input={"command": command, "description": description},
|
||||
)
|
||||
|
||||
|
||||
def _map_outcome_to_hermes(outcome: object, *, allowed_option_ids: set[str]) -> str:
|
||||
"""Map an ACP permission outcome into Hermes approval strings."""
|
||||
if not isinstance(outcome, AllowedOutcome):
|
||||
return "deny"
|
||||
|
||||
option_id = outcome.option_id
|
||||
if option_id not in allowed_option_ids:
|
||||
logger.warning("Permission request returned unknown option_id: %s", option_id)
|
||||
return "deny"
|
||||
return _OPTION_ID_TO_HERMES.get(option_id, "deny")
|
||||
|
||||
|
||||
def make_approval_callback(
|
||||
request_permission_fn: Callable,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
session_id: str,
|
||||
timeout: float = 60.0,
|
||||
) -> Callable[[str, str], str]:
|
||||
) -> Callable[..., str]:
|
||||
"""
|
||||
Return a hermes-compatible ``approval_callback(command, description) -> str``
|
||||
that bridges to the ACP client's ``request_permission`` call.
|
||||
Return a Hermes-compatible approval callback that bridges to ACP.
|
||||
|
||||
The callback accepts ``command`` and ``description`` plus optional
|
||||
keyword arguments such as ``allow_permanent`` used by
|
||||
``tools.approval.prompt_dangerous_approval()``.
|
||||
|
||||
Args:
|
||||
request_permission_fn: The ACP connection's ``request_permission`` coroutine.
|
||||
@@ -40,41 +104,38 @@ def make_approval_callback(
|
||||
timeout: Seconds to wait for a response before auto-denying.
|
||||
"""
|
||||
|
||||
def _callback(command: str, description: str) -> str:
|
||||
options = [
|
||||
PermissionOption(option_id="allow_once", kind="allow_once", name="Allow once"),
|
||||
PermissionOption(option_id="allow_always", kind="allow_always", name="Allow always"),
|
||||
PermissionOption(option_id="deny", kind="reject_once", name="Deny"),
|
||||
]
|
||||
import acp as _acp
|
||||
|
||||
tool_call = _acp.start_tool_call("perm-check", command, kind="execute")
|
||||
|
||||
coro = request_permission_fn(
|
||||
session_id=session_id,
|
||||
tool_call=tool_call,
|
||||
options=options,
|
||||
)
|
||||
def _callback(
|
||||
command: str,
|
||||
description: str,
|
||||
*,
|
||||
allow_permanent: bool = True,
|
||||
**_: object,
|
||||
) -> str:
|
||||
options = _build_permission_options(allow_permanent=allow_permanent)
|
||||
|
||||
future = None
|
||||
try:
|
||||
tool_call = _build_permission_tool_call(command, description)
|
||||
coro = request_permission_fn(
|
||||
session_id=session_id,
|
||||
tool_call=tool_call,
|
||||
options=options,
|
||||
)
|
||||
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
||||
response = future.result(timeout=timeout)
|
||||
except (FutureTimeout, Exception) as exc:
|
||||
if future is not None:
|
||||
future.cancel()
|
||||
logger.warning("Permission request timed out or failed: %s", exc)
|
||||
return "deny"
|
||||
|
||||
if response is None:
|
||||
return "deny"
|
||||
|
||||
outcome = response.outcome
|
||||
if isinstance(outcome, AllowedOutcome):
|
||||
option_id = outcome.option_id
|
||||
# Look up the kind from our options list
|
||||
for opt in options:
|
||||
if opt.option_id == option_id:
|
||||
return _KIND_TO_HERMES.get(opt.kind, "deny")
|
||||
return "once" # fallback for unknown option_id
|
||||
else:
|
||||
return "deny"
|
||||
allowed_option_ids = {option.option_id for option in options}
|
||||
return _map_outcome_to_hermes(
|
||||
response.outcome,
|
||||
allowed_option_ids=allowed_option_ids,
|
||||
)
|
||||
|
||||
return _callback
|
||||
|
||||
@@ -57,13 +57,7 @@ from acp.schema import (
|
||||
UserMessageChunk,
|
||||
)
|
||||
|
||||
# AuthMethodAgent was renamed from AuthMethod in agent-client-protocol 0.9.0
|
||||
try:
|
||||
from acp.schema import AuthMethodAgent
|
||||
except ImportError:
|
||||
from acp.schema import AuthMethod as AuthMethodAgent # type: ignore[attr-defined]
|
||||
|
||||
from acp_adapter.auth import detect_provider
|
||||
from acp_adapter.auth import TERMINAL_SETUP_AUTH_METHOD_ID, build_auth_methods, detect_provider
|
||||
from acp_adapter.events import (
|
||||
make_message_cb,
|
||||
make_step_cb,
|
||||
@@ -744,16 +738,7 @@ class HermesACPAgent(acp.Agent):
|
||||
resolved_protocol_version = (
|
||||
protocol_version if isinstance(protocol_version, int) else acp.PROTOCOL_VERSION
|
||||
)
|
||||
provider = detect_provider()
|
||||
auth_methods = None
|
||||
if provider:
|
||||
auth_methods = [
|
||||
AuthMethodAgent(
|
||||
id=provider,
|
||||
name=f"{provider} runtime credentials",
|
||||
description=f"Authenticate Hermes using the currently configured {provider} runtime credentials.",
|
||||
)
|
||||
]
|
||||
auth_methods = build_auth_methods()
|
||||
|
||||
client_name = client_info.name if client_info else "unknown"
|
||||
logger.info(
|
||||
@@ -784,10 +769,18 @@ class HermesACPAgent(acp.Agent):
|
||||
# server has provider credentials configured — harmless under
|
||||
# Hermes' threat model (ACP is stdio-only, local-trust), but poor
|
||||
# API hygiene and confusing if ACP ever grows multi-method auth.
|
||||
provider = detect_provider()
|
||||
if not provider:
|
||||
if not isinstance(method_id, str):
|
||||
return None
|
||||
if not isinstance(method_id, str) or method_id.strip().lower() != provider:
|
||||
normalized_method = method_id.strip().lower()
|
||||
provider = detect_provider()
|
||||
|
||||
if normalized_method == TERMINAL_SETUP_AUTH_METHOD_ID:
|
||||
# Terminal auth launches Hermes setup/model selection out-of-band.
|
||||
# Only report success once that flow has produced usable runtime
|
||||
# credentials for the normal ACP session.
|
||||
return AuthenticateResponse() if provider else None
|
||||
|
||||
if not provider or normalized_method != provider:
|
||||
return None
|
||||
return AuthenticateResponse()
|
||||
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
{
|
||||
"schema_version": 1,
|
||||
"name": "hermes-agent",
|
||||
"display_name": "Hermes Agent",
|
||||
"description": "AI agent by Nous Research with 90+ tools, persistent memory, and multi-platform support",
|
||||
"icon": "icon.svg",
|
||||
"id": "hermes-agent",
|
||||
"name": "Hermes Agent",
|
||||
"version": "0.13.0",
|
||||
"description": "Self-improving open-source AI agent by Nous Research with ACP editor integration, persistent memory, skills, and rich tool support.",
|
||||
"repository": "https://github.com/NousResearch/hermes-agent",
|
||||
"website": "https://hermes-agent.nousresearch.com/docs/user-guide/features/acp",
|
||||
"authors": ["Nous Research"],
|
||||
"license": "MIT",
|
||||
"distribution": {
|
||||
"type": "command",
|
||||
"command": "hermes",
|
||||
"args": ["acp"]
|
||||
"uvx": {
|
||||
"package": "hermes-agent[acp]==0.13.0",
|
||||
"args": ["hermes-acp"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,25 +1,8 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 64 64" width="64" height="64">
|
||||
<defs>
|
||||
<linearGradient id="gold" x1="0%" y1="0%" x2="0%" y2="100%">
|
||||
<stop offset="0%" style="stop-color:#F5C542;stop-opacity:1" />
|
||||
<stop offset="100%" style="stop-color:#D4961C;stop-opacity:1" />
|
||||
</linearGradient>
|
||||
</defs>
|
||||
<!-- Staff -->
|
||||
<rect x="30" y="10" width="4" height="46" rx="2" fill="url(#gold)" />
|
||||
<!-- Wings (left) -->
|
||||
<path d="M30 18 C24 14, 14 14, 10 18 C14 16, 22 16, 28 20" fill="#F5C542" opacity="0.9" />
|
||||
<path d="M30 22 C26 19, 18 19, 14 22 C18 20, 24 20, 28 24" fill="#D4961C" opacity="0.8" />
|
||||
<!-- Wings (right) -->
|
||||
<path d="M34 18 C40 14, 50 14, 54 18 C50 16, 42 16, 36 20" fill="#F5C542" opacity="0.9" />
|
||||
<path d="M34 22 C38 19, 46 19, 50 22 C46 20, 40 20, 36 24" fill="#D4961C" opacity="0.8" />
|
||||
<!-- Left serpent -->
|
||||
<path d="M32 48 C22 44, 20 38, 26 34 C20 36, 18 42, 24 46 C18 40, 22 30, 30 28 C24 32, 22 38, 28 42"
|
||||
fill="none" stroke="#F5C542" stroke-width="2.5" stroke-linecap="round" />
|
||||
<!-- Right serpent -->
|
||||
<path d="M32 48 C42 44, 44 38, 38 34 C44 36, 46 42, 40 46 C46 40, 42 30, 34 28 C40 32, 42 38, 36 42"
|
||||
fill="none" stroke="#D4961C" stroke-width="2.5" stroke-linecap="round" />
|
||||
<!-- Orb at top -->
|
||||
<circle cx="32" cy="10" r="4" fill="#F5C542" />
|
||||
<circle cx="32" cy="10" r="2" fill="#FFF8E1" opacity="0.7" />
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16" width="16" height="16" fill="none">
|
||||
<path d="M8 1.5v13" stroke="currentColor" stroke-width="1.5" stroke-linecap="round"/>
|
||||
<path d="M8 3.25c-2.35-1.4-4.7-.95-6.25.35 1.85-.2 3.8.2 5.55 1.55" stroke="currentColor" stroke-width="1.1" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M8 3.25c2.35-1.4 4.7-.95 6.25.35-1.85-.2-3.8.2-5.55 1.55" stroke="currentColor" stroke-width="1.1" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M8 13.25c-2.3-1-3.05-2.65-1.35-4.15-2 .8-2.35 2.95-.35 4" stroke="currentColor" stroke-width="1.1" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M8 13.25c2.3-1 3.05-2.65 1.35-4.15 2 .8 2.35 2.95.35 4" stroke="currentColor" stroke-width="1.1" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<circle cx="8" cy="1.8" r="1.1" fill="currentColor"/>
|
||||
</svg>
|
||||
|
||||
|
Before Width: | Height: | Size: 1.4 KiB After Width: | Height: | Size: 882 B |
@@ -1305,9 +1305,8 @@ def convert_tools_to_anthropic(tools: List[Dict]) -> List[Dict]:
|
||||
),
|
||||
}
|
||||
# Forward cache_control marker when present on the OpenAI-format
|
||||
# tool dict (set by ``mark_tools_for_long_lived_cache``). Anthropic's
|
||||
# tools array supports cache_control on the last tool to cache the
|
||||
# entire schema cross-session.
|
||||
# tool dict. Anthropic's tools array supports cache_control on the
|
||||
# last tool to cache the entire schema cross-session.
|
||||
cache_control = t.get("cache_control")
|
||||
if isinstance(cache_control, dict):
|
||||
anthropic_tool["cache_control"] = dict(cache_control)
|
||||
|
||||
@@ -382,7 +382,28 @@ _AI_GATEWAY_HEADERS = {
|
||||
# Nous Portal extra_body for product attribution.
|
||||
# Callers should pass this as extra_body in chat.completions.create()
|
||||
# when the auxiliary client is backed by Nous Portal.
|
||||
NOUS_EXTRA_BODY = {"tags": ["product=hermes-agent", "client=aux"]}
|
||||
#
|
||||
# The tags are computed from agent.portal_tags so the client= marker stays
|
||||
# in lockstep with hermes_cli.__version__ across every Portal call site
|
||||
# (main loop, aux, compression, web_extract). Do not inline a literal here;
|
||||
# see agent/portal_tags.py for the rationale.
|
||||
from agent.portal_tags import nous_portal_tags as _nous_portal_tags
|
||||
|
||||
|
||||
def _nous_extra_body() -> dict:
|
||||
"""Return a fresh Nous Portal ``extra_body`` dict.
|
||||
|
||||
Computed at call time so a hot-reloaded ``hermes_cli.__version__`` is
|
||||
reflected without restarting long-running processes.
|
||||
"""
|
||||
return {"tags": _nous_portal_tags()}
|
||||
|
||||
|
||||
# Backwards-compatible module attribute. Some callers (tests, third-party
|
||||
# plugins) read ``NOUS_EXTRA_BODY`` directly; keep it as a snapshot of the
|
||||
# current tags. Callers that need the freshest value should call
|
||||
# ``_nous_extra_body()`` or import ``nous_portal_tags`` directly.
|
||||
NOUS_EXTRA_BODY = _nous_extra_body()
|
||||
|
||||
# Set at resolve time — True if the auxiliary client points to Nous Portal
|
||||
auxiliary_is_nous: bool = False
|
||||
@@ -1386,6 +1407,7 @@ def _try_openrouter(explicit_api_key: str = None) -> Tuple[Optional[OpenAI], Opt
|
||||
if pool_present:
|
||||
or_key = explicit_api_key or _pool_runtime_api_key(entry)
|
||||
if not or_key:
|
||||
_mark_provider_unhealthy("openrouter", ttl=60)
|
||||
return None, None
|
||||
base_url = _pool_runtime_base_url(entry, OPENROUTER_BASE_URL) or OPENROUTER_BASE_URL
|
||||
logger.debug("Auxiliary client: OpenRouter via pool")
|
||||
@@ -1394,6 +1416,7 @@ def _try_openrouter(explicit_api_key: str = None) -> Tuple[Optional[OpenAI], Opt
|
||||
|
||||
or_key = explicit_api_key or os.getenv("OPENROUTER_API_KEY")
|
||||
if not or_key:
|
||||
_mark_provider_unhealthy("openrouter", ttl=60)
|
||||
return None, None
|
||||
logger.debug("Auxiliary client: OpenRouter")
|
||||
return OpenAI(api_key=or_key, base_url=OPENROUTER_BASE_URL,
|
||||
@@ -1425,6 +1448,7 @@ def _try_nous(vision: bool = False) -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
"Auxiliary: skipping Nous Portal (rate-limited, resets in %.0fs)",
|
||||
_remaining,
|
||||
)
|
||||
_mark_provider_unhealthy("nous", ttl=_remaining)
|
||||
return None, None
|
||||
except Exception:
|
||||
pass
|
||||
@@ -1432,7 +1456,21 @@ def _try_nous(vision: bool = False) -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
nous = _read_nous_auth()
|
||||
runtime = _resolve_nous_runtime_api(force_refresh=False)
|
||||
if runtime is None and not nous:
|
||||
logger.warning(
|
||||
"Auxiliary Nous client unavailable: no Nous authentication found "
|
||||
"(run: hermes auth)."
|
||||
)
|
||||
_mark_provider_unhealthy("nous", ttl=60)
|
||||
return None, None
|
||||
if runtime is None and nous:
|
||||
# Runtime credential mint failed but stored Nous auth is still present.
|
||||
# Falls back to the raw stored token below; surface a debug line so
|
||||
# operators investigating expired/invalid sessions have a breadcrumb,
|
||||
# without blocking the fallback path the rest of this function relies on.
|
||||
logger.debug(
|
||||
"Auxiliary Nous: runtime credential mint failed; falling back to "
|
||||
"stored auth.json token."
|
||||
)
|
||||
global auxiliary_is_nous
|
||||
auxiliary_is_nous = True
|
||||
logger.debug("Auxiliary client: Nous Portal")
|
||||
@@ -3437,7 +3475,7 @@ def get_auxiliary_extra_body() -> dict:
|
||||
Includes Nous Portal product tags when the auxiliary client is backed
|
||||
by Nous Portal. Returns empty dict otherwise.
|
||||
"""
|
||||
return dict(NOUS_EXTRA_BODY) if auxiliary_is_nous else {}
|
||||
return _nous_extra_body() if auxiliary_is_nous else {}
|
||||
|
||||
|
||||
def auxiliary_max_tokens_param(value: int) -> dict:
|
||||
@@ -3828,7 +3866,7 @@ def _resolve_task_provider_model(
|
||||
# (e.g. OPENROUTER_API_KEY) instead of locking into "custom".
|
||||
return cfg_provider, resolved_model, cfg_base_url, None, resolved_api_mode
|
||||
if cfg_provider and cfg_provider != "auto":
|
||||
return cfg_provider, resolved_model, None, None, resolved_api_mode
|
||||
return cfg_provider, resolved_model, cfg_base_url, cfg_api_key, resolved_api_mode
|
||||
|
||||
return "auto", resolved_model, None, None, resolved_api_mode
|
||||
|
||||
@@ -4026,7 +4064,7 @@ def _build_call_kwargs(
|
||||
# Provider-specific extra_body
|
||||
merged_extra = dict(extra_body or {})
|
||||
if provider == "nous" or auxiliary_is_nous:
|
||||
merged_extra.setdefault("tags", []).extend(NOUS_EXTRA_BODY["tags"])
|
||||
merged_extra.setdefault("tags", []).extend(_nous_portal_tags())
|
||||
if merged_extra:
|
||||
kwargs["extra_body"] = merged_extra
|
||||
|
||||
@@ -4411,7 +4449,7 @@ def extract_content_or_reasoning(response) -> str:
|
||||
1. ``message.content`` — strip inline think/reasoning blocks, check for
|
||||
remaining non-whitespace text.
|
||||
2. ``message.reasoning`` / ``message.reasoning_content`` — direct
|
||||
structured reasoning fields (DeepSeek, Moonshot, Novita, etc.).
|
||||
structured reasoning fields (DeepSeek, Moonshot, NovitaAI, etc.).
|
||||
3. ``message.reasoning_details`` — OpenRouter unified array format.
|
||||
|
||||
Returns the best available text, or ``""`` if nothing found.
|
||||
|
||||
@@ -1185,6 +1185,26 @@ The user has requested that this compaction PRIORITISE preserving all informatio
|
||||
idx += 1
|
||||
return idx
|
||||
|
||||
def _protect_head_size(self, messages: List[Dict[str, Any]]) -> int:
|
||||
"""Total count of head messages to protect.
|
||||
|
||||
``protect_first_n`` is defined as *additional* messages protected
|
||||
beyond the system prompt. The system prompt (if present at index 0)
|
||||
is always implicitly protected — it's load-bearing context that
|
||||
must never be summarised away. This keeps semantics stable across
|
||||
call paths where the system prompt may or may not be included in
|
||||
the ``messages`` list (e.g. the gateway ``/compress`` handler
|
||||
strips it before calling compress()).
|
||||
|
||||
Examples:
|
||||
protect_first_n=0 → system prompt only (or nothing if no system msg)
|
||||
protect_first_n=3 → system + first 3 non-system messages
|
||||
"""
|
||||
head = 0
|
||||
if messages and messages[0].get("role") == "system":
|
||||
head = 1
|
||||
return head + self.protect_first_n
|
||||
|
||||
def _align_boundary_backward(self, messages: List[Dict[str, Any]], idx: int) -> int:
|
||||
"""Pull a compress-end boundary backward to avoid splitting a
|
||||
tool_call / result group.
|
||||
@@ -1343,7 +1363,7 @@ The user has requested that this compaction PRIORITISE preserving all informatio
|
||||
skip the LLM call when the transcript is still entirely inside
|
||||
the protected head/tail.
|
||||
"""
|
||||
compress_start = self._align_boundary_forward(messages, self.protect_first_n)
|
||||
compress_start = self._align_boundary_forward(messages, self._protect_head_size(messages))
|
||||
compress_end = self._find_tail_cut_by_tokens(messages, compress_start)
|
||||
return compress_start < compress_end
|
||||
|
||||
@@ -1379,7 +1399,7 @@ The user has requested that this compaction PRIORITISE preserving all informatio
|
||||
self._last_aux_model_failure_model = None
|
||||
n_messages = len(messages)
|
||||
# Only need head + 3 tail messages minimum (token budget decides the real tail size)
|
||||
_min_for_compress = self.protect_first_n + 3 + 1
|
||||
_min_for_compress = self._protect_head_size(messages) + 3 + 1
|
||||
if n_messages <= _min_for_compress:
|
||||
if not self.quiet_mode:
|
||||
logger.warning(
|
||||
@@ -1399,7 +1419,7 @@ The user has requested that this compaction PRIORITISE preserving all informatio
|
||||
logger.info("Pre-compression: pruned %d old tool result(s)", pruned_count)
|
||||
|
||||
# Phase 2: Determine boundaries
|
||||
compress_start = self.protect_first_n
|
||||
compress_start = self._protect_head_size(messages)
|
||||
compress_start = self._align_boundary_forward(messages, compress_start)
|
||||
|
||||
# Use token-budget tail protection instead of fixed message count
|
||||
@@ -1409,15 +1429,23 @@ The user has requested that this compaction PRIORITISE preserving all informatio
|
||||
return messages
|
||||
|
||||
turns_to_summarize = messages[compress_start:compress_end]
|
||||
# A persisted handoff summary can sit in the protected head after a
|
||||
# resume (commonly immediately after the system prompt). Search from
|
||||
# the first non-system message through the compression window so we can
|
||||
# rehydrate iterative-summary state without serializing that handoff as
|
||||
# a new turn. Protected messages after the handoff remain live context,
|
||||
# so only summarize messages that are both after the handoff and inside
|
||||
# the current compression window.
|
||||
summary_search_start = 1 if messages and messages[0].get("role") == "system" else 0
|
||||
summary_idx, summary_body = self._find_latest_context_summary(
|
||||
messages,
|
||||
compress_start,
|
||||
summary_search_start,
|
||||
compress_end,
|
||||
)
|
||||
if summary_idx is not None:
|
||||
if summary_body and not self._previous_summary:
|
||||
self._previous_summary = summary_body
|
||||
turns_to_summarize = messages[summary_idx + 1:compress_end]
|
||||
turns_to_summarize = messages[max(compress_start, summary_idx + 1):compress_end]
|
||||
|
||||
if not self.quiet_mode:
|
||||
logger.info(
|
||||
|
||||
@@ -55,6 +55,11 @@ class ContextEngine(ABC):
|
||||
# These control the preflight compression check. Subclasses may
|
||||
# override via __init__ or property; defaults are sensible for most
|
||||
# engines.
|
||||
#
|
||||
# protect_first_n semantics (since PR #13754): count of non-system head
|
||||
# messages always preserved verbatim, IN ADDITION to the system prompt
|
||||
# which is always implicitly protected. Default 3 keeps the
|
||||
# historical "system + first 3 non-system messages" head shape.
|
||||
|
||||
threshold_percent: float = 0.75
|
||||
protect_first_n: int = 3
|
||||
|
||||
@@ -14,6 +14,7 @@ from difflib import unified_diff
|
||||
from pathlib import Path
|
||||
|
||||
from utils import safe_json_loads
|
||||
from agent.tool_result_classification import file_mutation_result_landed
|
||||
|
||||
# ANSI escape codes for coloring tool failure indicators
|
||||
_RED = "\033[31m"
|
||||
@@ -239,21 +240,6 @@ def build_tool_preview(tool_name: str, args: dict, max_len: int | None = None) -
|
||||
msg = msg[:17] + "..."
|
||||
return f"to {target}: \"{msg}\""
|
||||
|
||||
if tool_name.startswith("rl_"):
|
||||
rl_previews = {
|
||||
"rl_list_environments": "listing envs",
|
||||
"rl_select_environment": args.get("name", ""),
|
||||
"rl_get_current_config": "reading config",
|
||||
"rl_edit_config": f"{args.get('field', '')}={args.get('value', '')}",
|
||||
"rl_start_training": "starting",
|
||||
"rl_check_status": args.get("run_id", "")[:16],
|
||||
"rl_stop_training": f"stopping {args.get('run_id', '')[:16]}",
|
||||
"rl_get_results": args.get("run_id", "")[:16],
|
||||
"rl_list_runs": "listing runs",
|
||||
"rl_test_inference": f"{args.get('num_steps', 3)} steps",
|
||||
}
|
||||
return rl_previews.get(tool_name)
|
||||
|
||||
key = primary_args.get(tool_name)
|
||||
if not key:
|
||||
for fallback_key in ("query", "text", "command", "path", "name", "prompt", "code", "goal"):
|
||||
@@ -810,6 +796,8 @@ def _detect_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str]
|
||||
"""
|
||||
if result is None:
|
||||
return False, ""
|
||||
if file_mutation_result_landed(tool_name, result):
|
||||
return False, ""
|
||||
|
||||
if tool_name == "terminal":
|
||||
data = safe_json_loads(result)
|
||||
@@ -978,15 +966,6 @@ def get_cute_tool_message(
|
||||
if action == "list":
|
||||
return _wrap(f"┊ ⏰ cron listing {dur}")
|
||||
return _wrap(f"┊ ⏰ cron {action} {args.get('job_id', '')} {dur}")
|
||||
if tool_name.startswith("rl_"):
|
||||
rl = {
|
||||
"rl_list_environments": "list envs", "rl_select_environment": f"select {args.get('name', '')}",
|
||||
"rl_get_current_config": "get config", "rl_edit_config": f"set {args.get('field', '?')}",
|
||||
"rl_start_training": "start training", "rl_check_status": f"status {args.get('run_id', '?')[:12]}",
|
||||
"rl_stop_training": f"stop {args.get('run_id', '?')[:12]}", "rl_get_results": f"results {args.get('run_id', '?')[:12]}",
|
||||
"rl_list_runs": "list runs", "rl_test_inference": "test inference",
|
||||
}
|
||||
return _wrap(f"┊ 🧪 rl {rl.get(tool_name, tool_name.replace('rl_', ''))} {dur}")
|
||||
if tool_name == "execute_code":
|
||||
code = args.get("code", "")
|
||||
first_line = code.strip().split("\n")[0] if code.strip() else ""
|
||||
|
||||
@@ -450,7 +450,13 @@ def _make_stream_chunk(
|
||||
finish_reason: Optional[str] = None,
|
||||
reasoning: str = "",
|
||||
) -> _GeminiStreamChunk:
|
||||
delta_kwargs: Dict[str, Any] = {"role": "assistant"}
|
||||
delta_kwargs: Dict[str, Any] = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": None,
|
||||
"reasoning": None,
|
||||
"reasoning_content": None,
|
||||
}
|
||||
if content:
|
||||
delta_kwargs["content"] = content
|
||||
if tool_call_delta is not None:
|
||||
|
||||
@@ -77,6 +77,17 @@ def get_active_provider() -> Optional[ImageGenProvider]:
|
||||
|
||||
Reads ``image_gen.provider`` from config.yaml; falls back per the
|
||||
module docstring.
|
||||
|
||||
**Availability semantics** (mirrors :mod:`agent.web_search_registry`):
|
||||
|
||||
- When ``image_gen.provider`` is explicitly set, the configured
|
||||
provider is returned even if :meth:`ImageGenProvider.is_available`
|
||||
reports False — the dispatcher surfaces a precise "X_API_KEY is not
|
||||
set" error rather than silently switching backends.
|
||||
- When ``image_gen.provider`` is unset, the fallback path (single-
|
||||
provider shortcut and the FAL legacy preference) is filtered by
|
||||
``is_available()`` so we don't pick a provider the user has no
|
||||
credentials for.
|
||||
"""
|
||||
configured: Optional[str] = None
|
||||
try:
|
||||
@@ -94,6 +105,17 @@ def get_active_provider() -> Optional[ImageGenProvider]:
|
||||
with _lock:
|
||||
snapshot = dict(_providers)
|
||||
|
||||
def _is_available_safe(p: ImageGenProvider) -> bool:
|
||||
"""Wrap ``is_available()`` so a buggy provider doesn't kill resolution."""
|
||||
try:
|
||||
return bool(p.is_available())
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug("image_gen provider %s.is_available() raised %s", p.name, exc)
|
||||
return False
|
||||
|
||||
# 1. Explicit config wins — return regardless of is_available() so the
|
||||
# user gets a precise downstream error message rather than a silent
|
||||
# backend switch.
|
||||
if configured:
|
||||
provider = snapshot.get(configured)
|
||||
if provider is not None:
|
||||
@@ -103,13 +125,16 @@ def get_active_provider() -> Optional[ImageGenProvider]:
|
||||
configured,
|
||||
)
|
||||
|
||||
# Fallback: single-provider case
|
||||
if len(snapshot) == 1:
|
||||
return next(iter(snapshot.values()))
|
||||
# 2. Fallback: single registered provider — but only if it's actually
|
||||
# available (no credentials = don't surface it as "active").
|
||||
available = [p for p in snapshot.values() if _is_available_safe(p)]
|
||||
if len(available) == 1:
|
||||
return available[0]
|
||||
|
||||
# Fallback: prefer legacy FAL for backward compat
|
||||
if "fal" in snapshot:
|
||||
return snapshot["fal"]
|
||||
# 3. Fallback: prefer legacy FAL for backward compat, when available.
|
||||
fal = snapshot.get("fal")
|
||||
if fal is not None and _is_available_safe(fal):
|
||||
return fal
|
||||
|
||||
return None
|
||||
|
||||
|
||||
106
agent/lsp/__init__.py
Normal file
106
agent/lsp/__init__.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""Language Server Protocol (LSP) integration for Hermes Agent.
|
||||
|
||||
Hermes runs full language servers (pyright, gopls, rust-analyzer,
|
||||
typescript-language-server, etc.) as subprocesses and pipes their
|
||||
``textDocument/publishDiagnostics`` output into the post-write lint
|
||||
delta filter used by ``write_file`` and ``patch``.
|
||||
|
||||
LSP is **gated on git workspace detection** — if the agent's cwd is
|
||||
inside a git repository, LSP runs against that workspace; otherwise the
|
||||
file_operations layer falls back to its existing in-process syntax
|
||||
checks. This keeps users on user-home cwd's (e.g. Telegram gateway
|
||||
chats) from spawning daemons they don't need.
|
||||
|
||||
Public API:
|
||||
|
||||
from agent.lsp import get_service
|
||||
|
||||
svc = get_service()
|
||||
if svc and svc.enabled_for(path):
|
||||
await svc.touch_file(path)
|
||||
diags = svc.diagnostics_for(path)
|
||||
|
||||
The bulk of the wiring is internal — most callers only need the layer
|
||||
in :func:`tools.file_operations.FileOperations._check_lint_delta`,
|
||||
which is already wired (see that module).
|
||||
|
||||
Architecture is documented in ``website/docs/user-guide/features/lsp.md``.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import logging
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from agent.lsp.manager import LSPService
|
||||
|
||||
logger = logging.getLogger("agent.lsp")
|
||||
|
||||
_service: Optional[LSPService] = None
|
||||
_atexit_registered = False
|
||||
_service_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_service() -> Optional[LSPService]:
|
||||
"""Return the process-wide LSP service singleton, or None when disabled.
|
||||
|
||||
The service is created lazily on first call. ``None`` is returned
|
||||
when LSP is disabled in config, when no workspace can be detected,
|
||||
or when the platform doesn't support subprocess-based LSP servers.
|
||||
|
||||
On first creation, registers an :mod:`atexit` handler that tears
|
||||
down spawned language servers on Python exit so a long-running
|
||||
CLI or gateway session doesn't leak pyright/gopls/etc. processes
|
||||
when it terminates.
|
||||
"""
|
||||
global _service, _atexit_registered
|
||||
if _service is not None:
|
||||
return _service if _service.is_active() else None
|
||||
with _service_lock:
|
||||
if _service is not None:
|
||||
return _service if _service.is_active() else None
|
||||
_service = LSPService.create_from_config()
|
||||
if not _atexit_registered:
|
||||
# ``atexit`` handlers run in LIFO order on normal Python
|
||||
# exit and on SystemExit, but NOT on os._exit() or
|
||||
# uncaught signals. Language servers are stateless
|
||||
# subprocesses — losing them on SIGKILL is fine; they'll
|
||||
# be reaped by the kernel along with their parent. We
|
||||
# care about clean exits where Python flushes stdio
|
||||
# before terminating; without this hook every
|
||||
# ``hermes chat`` exit would leak pyright processes that
|
||||
# outlive the parent for a few seconds while their
|
||||
# stdout buffers drain.
|
||||
atexit.register(_atexit_shutdown)
|
||||
_atexit_registered = True
|
||||
return _service if (_service is not None and _service.is_active()) else None
|
||||
|
||||
|
||||
def shutdown_service() -> None:
|
||||
"""Tear down the LSP service if one was started.
|
||||
|
||||
Safe to call multiple times; safe to call when no service was created.
|
||||
"""
|
||||
global _service
|
||||
with _service_lock:
|
||||
svc = _service
|
||||
_service = None
|
||||
if svc is not None:
|
||||
try:
|
||||
svc.shutdown()
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.debug("LSP shutdown error: %s", e)
|
||||
|
||||
|
||||
def _atexit_shutdown() -> None:
|
||||
"""atexit-registered wrapper. Logs at debug because by the time
|
||||
atexit fires the user has already seen the agent's final output —
|
||||
a noisy shutdown line on top of that is just clutter."""
|
||||
try:
|
||||
shutdown_service()
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.debug("atexit LSP shutdown failed: %s", e)
|
||||
|
||||
|
||||
__all__ = ["get_service", "shutdown_service", "LSPService"]
|
||||
308
agent/lsp/cli.py
Normal file
308
agent/lsp/cli.py
Normal file
@@ -0,0 +1,308 @@
|
||||
"""``hermes lsp`` CLI subcommand.
|
||||
|
||||
Subcommands:
|
||||
|
||||
- ``status`` — show service state, configured servers, install status.
|
||||
- ``install <server_id>`` — eagerly install one server's binary.
|
||||
- ``install-all`` — try to install every server with a known recipe.
|
||||
- ``restart`` — tear down running clients so the next edit re-spawns.
|
||||
- ``which <server_id>`` — print the resolved binary path for one server.
|
||||
- ``list`` — print the registry of supported servers.
|
||||
|
||||
The handlers are kept here (rather than in
|
||||
``hermes_cli/main.py``) so the LSP module ships self-contained.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def register_subparser(subparsers: argparse._SubParsersAction) -> None:
|
||||
"""Wire the ``hermes lsp`` subcommand tree into the main argparse."""
|
||||
parser = subparsers.add_parser(
|
||||
"lsp",
|
||||
help="Language Server Protocol management",
|
||||
description=(
|
||||
"Manage the LSP layer that powers post-write semantic "
|
||||
"diagnostics in write_file/patch."
|
||||
),
|
||||
)
|
||||
sub = parser.add_subparsers(dest="lsp_command")
|
||||
|
||||
sub_status = sub.add_parser("status", help="Show LSP service status")
|
||||
sub_status.add_argument(
|
||||
"--json", action="store_true", help="Emit machine-readable JSON"
|
||||
)
|
||||
|
||||
sub_list = sub.add_parser("list", help="List supported language servers")
|
||||
sub_list.add_argument(
|
||||
"--installed-only",
|
||||
action="store_true",
|
||||
help="Only show servers whose binary is currently available",
|
||||
)
|
||||
|
||||
sub_install = sub.add_parser("install", help="Install a server binary")
|
||||
sub_install.add_argument("server", help="Server id (e.g. pyright, gopls)")
|
||||
|
||||
sub_install_all = sub.add_parser(
|
||||
"install-all",
|
||||
help="Install every server with a known auto-install recipe",
|
||||
)
|
||||
sub_install_all.add_argument(
|
||||
"--include-manual",
|
||||
action="store_true",
|
||||
help="Even attempt servers marked manual-install (best effort)",
|
||||
)
|
||||
|
||||
sub_restart = sub.add_parser(
|
||||
"restart",
|
||||
help="Tear down running LSP clients (next edit re-spawns)",
|
||||
)
|
||||
|
||||
sub_which = sub.add_parser("which", help="Print binary path for a server")
|
||||
sub_which.add_argument("server", help="Server id")
|
||||
|
||||
parser.set_defaults(func=run_lsp_command)
|
||||
|
||||
|
||||
def run_lsp_command(args: argparse.Namespace) -> int:
|
||||
"""Top-level dispatcher for ``hermes lsp <subcommand>``."""
|
||||
sub = getattr(args, "lsp_command", None) or "status"
|
||||
try:
|
||||
if sub == "status":
|
||||
return _cmd_status(getattr(args, "json", False))
|
||||
if sub == "list":
|
||||
return _cmd_list(getattr(args, "installed_only", False))
|
||||
if sub == "install":
|
||||
return _cmd_install(args.server)
|
||||
if sub == "install-all":
|
||||
return _cmd_install_all(getattr(args, "include_manual", False))
|
||||
if sub == "restart":
|
||||
return _cmd_restart()
|
||||
if sub == "which":
|
||||
return _cmd_which(args.server)
|
||||
sys.stderr.write(f"unknown lsp subcommand: {sub}\n")
|
||||
return 2
|
||||
except KeyboardInterrupt:
|
||||
return 130
|
||||
|
||||
|
||||
def _cmd_status(emit_json: bool) -> int:
|
||||
from agent.lsp import get_service
|
||||
from agent.lsp.servers import SERVERS
|
||||
from agent.lsp.install import detect_status
|
||||
|
||||
svc = get_service()
|
||||
service_active = svc is not None
|
||||
info = svc.get_status() if svc is not None else {"enabled": False}
|
||||
|
||||
if emit_json:
|
||||
import json
|
||||
payload = {
|
||||
"service": info,
|
||||
"registry": [
|
||||
{
|
||||
"server_id": s.server_id,
|
||||
"extensions": list(s.extensions),
|
||||
"description": s.description,
|
||||
"binary_status": detect_status(_recipe_pkg_for(s.server_id)),
|
||||
}
|
||||
for s in SERVERS
|
||||
],
|
||||
}
|
||||
sys.stdout.write(json.dumps(payload, indent=2) + "\n")
|
||||
return 0
|
||||
|
||||
out = []
|
||||
out.append("LSP Service")
|
||||
out.append("===========")
|
||||
out.append(f" enabled: {info.get('enabled', False)}")
|
||||
if service_active:
|
||||
out.append(f" wait_mode: {info.get('wait_mode')}")
|
||||
out.append(f" wait_timeout: {info.get('wait_timeout')}s")
|
||||
out.append(f" install_strategy:{info.get('install_strategy')}")
|
||||
clients = info.get("clients") or []
|
||||
if clients:
|
||||
out.append(f" active clients: {len(clients)}")
|
||||
for c in clients:
|
||||
out.append(
|
||||
f" - {c['server_id']:20s} state={c['state']:10s} root={c['workspace_root']}"
|
||||
)
|
||||
else:
|
||||
out.append(" active clients: none")
|
||||
broken = info.get("broken") or []
|
||||
if broken:
|
||||
out.append(f" broken pairs: {len(broken)}")
|
||||
for b in broken:
|
||||
out.append(f" - {b}")
|
||||
disabled = info.get("disabled_servers") or []
|
||||
if disabled:
|
||||
out.append(f" disabled in cfg: {', '.join(disabled)}")
|
||||
|
||||
# Surface backend-tool gaps that aren't visible in the registry table:
|
||||
# some servers spawn fine but emit no diagnostics without a sidecar
|
||||
# binary (bash-language-server -> shellcheck).
|
||||
backend_warnings = _backend_warnings()
|
||||
if backend_warnings:
|
||||
out.append("")
|
||||
out.append("Backend warnings")
|
||||
out.append("================")
|
||||
for line in backend_warnings:
|
||||
out.append(f" ! {line}")
|
||||
out.append("")
|
||||
out.append("Registered Servers")
|
||||
out.append("==================")
|
||||
for s in SERVERS:
|
||||
pkg = _recipe_pkg_for(s.server_id)
|
||||
status = detect_status(pkg)
|
||||
marker = {
|
||||
"installed": "✓",
|
||||
"missing": "·",
|
||||
"manual-only": "?",
|
||||
}.get(status, " ")
|
||||
ext_summary = ", ".join(list(s.extensions)[:5])
|
||||
if len(s.extensions) > 5:
|
||||
ext_summary += f", … (+{len(s.extensions) - 5})"
|
||||
out.append(
|
||||
f" {marker} {s.server_id:24s} [{status:11s}] {ext_summary}"
|
||||
)
|
||||
if s.description:
|
||||
out.append(f" {s.description}")
|
||||
sys.stdout.write("\n".join(out) + "\n")
|
||||
return 0
|
||||
|
||||
|
||||
def _cmd_list(installed_only: bool) -> int:
|
||||
from agent.lsp.servers import SERVERS
|
||||
from agent.lsp.install import detect_status
|
||||
|
||||
for s in SERVERS:
|
||||
pkg = _recipe_pkg_for(s.server_id)
|
||||
status = detect_status(pkg)
|
||||
if installed_only and status != "installed":
|
||||
continue
|
||||
sys.stdout.write(
|
||||
f"{s.server_id:24s} [{status:11s}] {','.join(s.extensions)}\n"
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
def _cmd_install(server_id: str) -> int:
|
||||
from agent.lsp.install import try_install, INSTALL_RECIPES, detect_status
|
||||
pkg = _recipe_pkg_for(server_id)
|
||||
pre_status = detect_status(pkg)
|
||||
if pre_status == "installed":
|
||||
sys.stdout.write(f"{server_id} already installed\n")
|
||||
return 0
|
||||
sys.stdout.write(f"installing {server_id} (pkg={pkg}) ...\n")
|
||||
sys.stdout.flush()
|
||||
bin_path = try_install(pkg, "auto")
|
||||
if bin_path is None:
|
||||
recipe = INSTALL_RECIPES.get(pkg)
|
||||
if recipe and recipe.get("strategy") == "manual":
|
||||
sys.stderr.write(
|
||||
f"{server_id}: this server requires a manual install. "
|
||||
f"See documentation.\n"
|
||||
)
|
||||
else:
|
||||
sys.stderr.write(f"{server_id}: install failed (see logs).\n")
|
||||
return 1
|
||||
sys.stdout.write(f"installed: {bin_path}\n")
|
||||
return 0
|
||||
|
||||
|
||||
def _cmd_install_all(include_manual: bool) -> int:
|
||||
from agent.lsp.servers import SERVERS
|
||||
from agent.lsp.install import try_install, INSTALL_RECIPES, detect_status
|
||||
|
||||
rc = 0
|
||||
for s in SERVERS:
|
||||
pkg = _recipe_pkg_for(s.server_id)
|
||||
recipe = INSTALL_RECIPES.get(pkg)
|
||||
if recipe is None:
|
||||
continue
|
||||
if recipe.get("strategy") == "manual" and not include_manual:
|
||||
continue
|
||||
if detect_status(pkg) == "installed":
|
||||
sys.stdout.write(f" {s.server_id:24s} already installed\n")
|
||||
continue
|
||||
sys.stdout.write(f" installing {s.server_id} (pkg={pkg}) ... ")
|
||||
sys.stdout.flush()
|
||||
path = try_install(pkg, "auto")
|
||||
if path:
|
||||
sys.stdout.write(f"ok ({path})\n")
|
||||
else:
|
||||
sys.stdout.write("FAILED\n")
|
||||
rc = 1
|
||||
return rc
|
||||
|
||||
|
||||
def _cmd_restart() -> int:
|
||||
from agent.lsp import shutdown_service
|
||||
|
||||
shutdown_service()
|
||||
sys.stdout.write("LSP service shut down. Next edit will respawn clients.\n")
|
||||
return 0
|
||||
|
||||
|
||||
def _cmd_which(server_id: str) -> int:
|
||||
from agent.lsp.install import INSTALL_RECIPES, hermes_lsp_bin_dir
|
||||
import os
|
||||
import shutil as _shutil
|
||||
|
||||
recipe = INSTALL_RECIPES.get(server_id)
|
||||
bin_name = (recipe or {}).get("bin", server_id)
|
||||
staged = hermes_lsp_bin_dir() / bin_name
|
||||
if staged.exists():
|
||||
sys.stdout.write(str(staged) + "\n")
|
||||
return 0
|
||||
on_path = _shutil.which(bin_name)
|
||||
if on_path:
|
||||
sys.stdout.write(on_path + "\n")
|
||||
return 0
|
||||
sys.stderr.write(f"{server_id}: not installed\n")
|
||||
return 1
|
||||
|
||||
|
||||
def _recipe_pkg_for(server_id: str) -> str:
|
||||
"""Map a registry ``server_id`` to its install-recipe package key."""
|
||||
# The mapping lives here (not in install.py) because it's a CLI
|
||||
# convenience layer. Most server_ids are also their own recipe
|
||||
# key, but a few differ (e.g. ``vue-language-server`` →
|
||||
# ``@vue/language-server``).
|
||||
aliases = {
|
||||
"vue-language-server": "@vue/language-server",
|
||||
"astro-language-server": "@astrojs/language-server",
|
||||
"dockerfile-ls": "dockerfile-language-server-nodejs",
|
||||
"typescript": "typescript-language-server",
|
||||
}
|
||||
return aliases.get(server_id, server_id)
|
||||
|
||||
|
||||
def _backend_warnings() -> list:
|
||||
"""Return human-readable notes about LSP backend tools that are missing
|
||||
in a way that won't surface elsewhere.
|
||||
|
||||
Some language servers ship as thin wrappers around an external CLI for
|
||||
actual diagnostics — they spawn cleanly but never emit any errors when
|
||||
the sidecar binary isn't on PATH. bash-language-server / shellcheck
|
||||
is the load-bearing example.
|
||||
|
||||
Returned strings are short, actionable, and include the install
|
||||
suggestion across common platforms.
|
||||
"""
|
||||
import shutil as _shutil
|
||||
from agent.lsp.install import hermes_lsp_bin_dir
|
||||
notes: list = []
|
||||
bash_installed = _shutil.which("bash-language-server") is not None or (
|
||||
(hermes_lsp_bin_dir() / "bash-language-server").exists()
|
||||
)
|
||||
if bash_installed and _shutil.which("shellcheck") is None:
|
||||
notes.append(
|
||||
"bash-language-server is installed but shellcheck is missing — "
|
||||
"diagnostics will be empty (apt: shellcheck, brew: shellcheck, "
|
||||
"scoop: shellcheck)."
|
||||
)
|
||||
return notes
|
||||
930
agent/lsp/client.py
Normal file
930
agent/lsp/client.py
Normal file
@@ -0,0 +1,930 @@
|
||||
"""Async LSP client over stdin/stdout.
|
||||
|
||||
One :class:`LSPClient` corresponds to one ``(language_server, workspace_root)``
|
||||
pair — exactly what OpenCode keys clients on, and the same shape Claude
|
||||
Code uses. The client owns a child process, drives the JSON-RPC
|
||||
exchange, and exposes:
|
||||
|
||||
- :meth:`open_file` / :meth:`change_file` — text document sync
|
||||
- :meth:`wait_for_diagnostics` — block until the server emits fresh
|
||||
diagnostics for a specific file (or a timeout fires)
|
||||
- :meth:`diagnostics_for` — read the current per-file diagnostic store
|
||||
- :meth:`shutdown` — graceful close + SIGTERM/SIGKILL fallback
|
||||
|
||||
The class is designed for async use from a single asyncio event loop.
|
||||
The :class:`agent.lsp.manager.LSPService` runs an event loop in a
|
||||
background thread so the synchronous file_operations layer can call
|
||||
into it via :func:`agent.lsp.manager.LSPService.touch_file`.
|
||||
|
||||
Implementation notes:
|
||||
|
||||
- Push diagnostics are stored per-URI in :attr:`_push_diagnostics` from
|
||||
``textDocument/publishDiagnostics`` notifications. Pull diagnostics
|
||||
go in :attr:`_pull_diagnostics`. The merged view dedupes by content.
|
||||
|
||||
- Whole-document sync. Even when the server advertises incremental
|
||||
sync, we send a single ``contentChanges`` entry replacing the
|
||||
entire document. Pretending to be incremental while sending a
|
||||
full replacement is well-tolerated by every major server and saves
|
||||
range bookkeeping. See OpenCode's ``client.ts:584-659`` for the
|
||||
same trick.
|
||||
|
||||
- The "touch-file dance": every ``open_file`` call also fires a
|
||||
``workspace/didChangeWatchedFiles`` notification (CREATED on the
|
||||
first open, CHANGED thereafter). Some servers (clangd, eslint)
|
||||
only re-scan when this notification fires, even though the LSP spec
|
||||
doesn't strictly require it.
|
||||
|
||||
- ``ContentModified`` (-32801) errors get retried with exponential
|
||||
backoff up to 3 times. This matches Claude Code's
|
||||
``LSPServerInstance.sendRequest``.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set
|
||||
from urllib.parse import quote, unquote
|
||||
|
||||
from agent.lsp.protocol import (
|
||||
ERROR_CONTENT_MODIFIED,
|
||||
ERROR_METHOD_NOT_FOUND,
|
||||
LSPProtocolError,
|
||||
LSPRequestError,
|
||||
classify_message,
|
||||
encode_message,
|
||||
make_error_response,
|
||||
make_notification,
|
||||
make_request,
|
||||
make_response,
|
||||
read_message,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("agent.lsp.client")
|
||||
|
||||
# Timeouts (seconds) — mirror OpenCode's constants, scaled to seconds.
|
||||
INITIALIZE_TIMEOUT = 45.0
|
||||
DIAGNOSTICS_DOCUMENT_WAIT = 5.0
|
||||
DIAGNOSTICS_FULL_WAIT = 10.0
|
||||
DIAGNOSTICS_REQUEST_TIMEOUT = 3.0
|
||||
PUSH_DEBOUNCE = 0.15
|
||||
SHUTDOWN_GRACE = 1.0 # seconds between SIGTERM and SIGKILL
|
||||
|
||||
# Retry policy for transient ContentModified errors.
|
||||
MAX_CONTENT_MODIFIED_RETRIES = 3
|
||||
RETRY_BASE_DELAY = 0.5 # 0.5, 1.0, 2.0 — exponential
|
||||
|
||||
|
||||
def file_uri(path: str) -> str:
|
||||
"""Return ``file://`` URI for an absolute filesystem path.
|
||||
|
||||
Mirrors Node's ``pathToFileURL`` — handles spaces, unicode, and
|
||||
Windows drive letters (``C:\\foo`` → ``file:///C:/foo``).
|
||||
"""
|
||||
abs_path = os.path.abspath(path)
|
||||
if os.name == "nt":
|
||||
# Windows: backslash → forward slash, prepend extra slash so
|
||||
# the drive letter shows up as part of the path component.
|
||||
abs_path = abs_path.replace("\\", "/")
|
||||
if not abs_path.startswith("/"):
|
||||
abs_path = "/" + abs_path
|
||||
return "file://" + quote(abs_path, safe="/:")
|
||||
|
||||
|
||||
def uri_to_path(uri: str) -> str:
|
||||
"""Inverse of :func:`file_uri`."""
|
||||
if not uri.startswith("file://"):
|
||||
return uri
|
||||
raw = uri[len("file://"):]
|
||||
if os.name == "nt" and raw.startswith("/") and len(raw) > 2 and raw[2] == ":":
|
||||
raw = raw[1:] # strip leading slash before drive letter
|
||||
return os.path.normpath(unquote(raw))
|
||||
|
||||
|
||||
def _end_position(text: str) -> Dict[str, int]:
|
||||
"""Return the LSP Position at the end of ``text``.
|
||||
|
||||
Used to construct a single-range "replace whole document" change
|
||||
for ``textDocument/didChange`` regardless of the server's declared
|
||||
sync mode.
|
||||
"""
|
||||
if not text:
|
||||
return {"line": 0, "character": 0}
|
||||
lines = text.splitlines(keepends=False)
|
||||
last_line = len(lines) - 1
|
||||
last_col = len(lines[-1]) if lines else 0
|
||||
# If the text ends with a trailing newline, ``splitlines`` won't
|
||||
# represent it. The end position is then the start of the next
|
||||
# (empty) line — line index is len(lines), column 0.
|
||||
if text.endswith(("\n", "\r")):
|
||||
return {"line": last_line + 1, "character": 0}
|
||||
return {"line": last_line, "character": last_col}
|
||||
|
||||
|
||||
class LSPClient:
|
||||
"""Async LSP client tied to one server process and one workspace root.
|
||||
|
||||
Lifecycle:
|
||||
|
||||
c = LSPClient(server_id, workspace_root, command, args, init_options)
|
||||
await c.start() # spawn + initialize
|
||||
ver = await c.open_file("/path/to/foo.py")
|
||||
await c.wait_for_diagnostics("/path/to/foo.py", ver)
|
||||
diags = c.diagnostics_for("/path/to/foo.py")
|
||||
await c.shutdown()
|
||||
"""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# construction + lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
server_id: str,
|
||||
workspace_root: str,
|
||||
command: List[str],
|
||||
env: Optional[Dict[str, str]] = None,
|
||||
cwd: Optional[str] = None,
|
||||
initialization_options: Optional[Dict[str, Any]] = None,
|
||||
seed_diagnostics_on_first_push: bool = False,
|
||||
) -> None:
|
||||
self.server_id = server_id
|
||||
self.workspace_root = workspace_root
|
||||
self._command = list(command)
|
||||
self._env = env
|
||||
self._cwd = cwd or workspace_root
|
||||
self._init_options = initialization_options or {}
|
||||
self._seed_first_push = seed_diagnostics_on_first_push
|
||||
|
||||
# Process + streams
|
||||
self._proc: Optional[asyncio.subprocess.Process] = None
|
||||
self._stderr_task: Optional[asyncio.Task] = None
|
||||
self._reader_task: Optional[asyncio.Task] = None
|
||||
|
||||
# Request/response correlation
|
||||
self._next_id: int = 0
|
||||
self._pending: Dict[int, asyncio.Future] = {}
|
||||
|
||||
# Server-side request handlers (server → client requests).
|
||||
# Kept small and explicit; everything else returns method-not-found.
|
||||
self._request_handlers: Dict[str, Callable[[Any], Awaitable[Any]]] = {
|
||||
"window/workDoneProgress/create": self._handle_work_done_create,
|
||||
"workspace/configuration": self._handle_workspace_configuration,
|
||||
"client/registerCapability": self._handle_register_capability,
|
||||
"client/unregisterCapability": self._handle_unregister_capability,
|
||||
"workspace/workspaceFolders": self._handle_workspace_folders,
|
||||
"workspace/diagnostic/refresh": self._handle_diagnostic_refresh,
|
||||
}
|
||||
# Notifications (server → client) we care about.
|
||||
self._notification_handlers: Dict[str, Callable[[Any], None]] = {
|
||||
"textDocument/publishDiagnostics": self._handle_publish_diagnostics,
|
||||
# Everything else (window/showMessage, $/progress, etc.)
|
||||
# is silently dropped by default.
|
||||
}
|
||||
|
||||
# Tracked file state — required for didChange version bumps.
|
||||
self._files: Dict[str, Dict[str, Any]] = {}
|
||||
# Diagnostic stores, keyed by file path (NOT URI).
|
||||
self._push_diagnostics: Dict[str, List[Dict[str, Any]]] = {}
|
||||
self._pull_diagnostics: Dict[str, List[Dict[str, Any]]] = {}
|
||||
# Per-path "last published" time so wait-for-fresh logic works.
|
||||
self._published: Dict[str, float] = {}
|
||||
# Per-path version of the latest push (matches our didChange
|
||||
# version when the server respects it).
|
||||
self._published_version: Dict[str, int] = {}
|
||||
# First-push seen flag, for typescript-style seed-on-first-push.
|
||||
self._first_push_seen: Set[str] = set()
|
||||
# Capability registrations — only diagnostic ones are tracked.
|
||||
self._diagnostic_registrations: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# State machine
|
||||
self._state: str = "stopped"
|
||||
self._initialize_result: Optional[Dict[str, Any]] = None
|
||||
self._sync_kind: int = 1 # 1=Full, 2=Incremental
|
||||
self._stopping: bool = False
|
||||
|
||||
# Push event for waiters.
|
||||
self._push_event = asyncio.Event()
|
||||
# Monotonic counter incremented on every publishDiagnostics push.
|
||||
# Waiters snapshot it on entry and treat any increase as
|
||||
# "something happened, recheck the predicate". Avoids the
|
||||
# asyncio.Event sticky-state trap.
|
||||
self._push_counter = 0
|
||||
# Registration change event so wait_for_diagnostics can re-loop
|
||||
# when the server announces a new dynamic provider.
|
||||
self._registration_event = asyncio.Event()
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return self._state == "running" and self._proc is not None and self._proc.returncode is None
|
||||
|
||||
@property
|
||||
def state(self) -> str:
|
||||
return self._state
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Spawn the server and complete the initialize handshake.
|
||||
|
||||
Raises any exception encountered during spawn/init. On failure
|
||||
the process is killed and the client is left in state
|
||||
``"error"`` — re-call ``start()`` to retry.
|
||||
"""
|
||||
if self._state in ("running", "starting"):
|
||||
return
|
||||
self._state = "starting"
|
||||
try:
|
||||
await self._spawn()
|
||||
await self._initialize()
|
||||
self._state = "running"
|
||||
except Exception:
|
||||
self._state = "error"
|
||||
await self._cleanup_process()
|
||||
raise
|
||||
|
||||
async def _spawn(self) -> None:
|
||||
env = dict(os.environ)
|
||||
if self._env:
|
||||
env.update(self._env)
|
||||
|
||||
try:
|
||||
self._proc = await asyncio.create_subprocess_exec(
|
||||
self._command[0],
|
||||
*self._command[1:],
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
env=env,
|
||||
cwd=self._cwd,
|
||||
)
|
||||
except FileNotFoundError as e:
|
||||
raise LSPProtocolError(
|
||||
f"LSP server binary not found: {self._command[0]} ({e})"
|
||||
) from e
|
||||
|
||||
# Drain stderr at debug level — if we don't, the pipe buffer
|
||||
# fills and the server hangs.
|
||||
self._stderr_task = asyncio.create_task(self._drain_stderr())
|
||||
# Start the reader loop.
|
||||
self._reader_task = asyncio.create_task(self._reader_loop())
|
||||
|
||||
async def _drain_stderr(self) -> None:
|
||||
if self._proc is None or self._proc.stderr is None:
|
||||
return
|
||||
try:
|
||||
while True:
|
||||
line = await self._proc.stderr.readline()
|
||||
if not line:
|
||||
break
|
||||
text = line.decode("utf-8", errors="replace").rstrip()
|
||||
if text:
|
||||
logger.debug("[%s] stderr: %s", self.server_id, text[:1000])
|
||||
except (asyncio.CancelledError, OSError):
|
||||
pass
|
||||
|
||||
async def _reader_loop(self) -> None:
|
||||
if self._proc is None or self._proc.stdout is None:
|
||||
return
|
||||
try:
|
||||
while True:
|
||||
msg = await read_message(self._proc.stdout)
|
||||
if msg is None:
|
||||
logger.debug("[%s] server closed stdout cleanly", self.server_id)
|
||||
break
|
||||
kind, key = classify_message(msg)
|
||||
if kind == "response":
|
||||
self._dispatch_response(key, msg)
|
||||
elif kind == "request":
|
||||
asyncio.create_task(self._dispatch_request(key, msg))
|
||||
elif kind == "notification":
|
||||
self._dispatch_notification(key, msg)
|
||||
else:
|
||||
logger.warning("[%s] dropping invalid message: %r", self.server_id, msg)
|
||||
except LSPProtocolError as e:
|
||||
logger.warning("[%s] protocol error in reader loop: %s", self.server_id, e)
|
||||
except (asyncio.CancelledError, OSError):
|
||||
pass
|
||||
finally:
|
||||
# Wake up any pending requests so they can fail fast.
|
||||
for fut in list(self._pending.values()):
|
||||
if not fut.done():
|
||||
fut.set_exception(LSPProtocolError("server connection closed"))
|
||||
self._pending.clear()
|
||||
|
||||
async def _initialize(self) -> None:
|
||||
params = {
|
||||
"rootUri": file_uri(self.workspace_root),
|
||||
"rootPath": self.workspace_root,
|
||||
"processId": os.getpid(),
|
||||
"workspaceFolders": [
|
||||
{"name": "workspace", "uri": file_uri(self.workspace_root)}
|
||||
],
|
||||
"initializationOptions": self._init_options,
|
||||
"capabilities": {
|
||||
"window": {"workDoneProgress": True},
|
||||
"workspace": {
|
||||
"configuration": True,
|
||||
"workspaceFolders": True,
|
||||
"didChangeWatchedFiles": {"dynamicRegistration": True},
|
||||
"diagnostics": {"refreshSupport": False},
|
||||
},
|
||||
"textDocument": {
|
||||
"synchronization": {
|
||||
"dynamicRegistration": False,
|
||||
"didOpen": True,
|
||||
"didChange": True,
|
||||
"didSave": True,
|
||||
"willSave": False,
|
||||
"willSaveWaitUntil": False,
|
||||
},
|
||||
"diagnostic": {
|
||||
"dynamicRegistration": True,
|
||||
"relatedDocumentSupport": True,
|
||||
},
|
||||
"publishDiagnostics": {
|
||||
"relatedInformation": True,
|
||||
"tagSupport": {"valueSet": [1, 2]},
|
||||
"versionSupport": True,
|
||||
"codeDescriptionSupport": True,
|
||||
"dataSupport": False,
|
||||
},
|
||||
"hover": {"contentFormat": ["markdown", "plaintext"]},
|
||||
"definition": {"linkSupport": True},
|
||||
"references": {},
|
||||
"documentSymbol": {"hierarchicalDocumentSymbolSupport": True},
|
||||
},
|
||||
"general": {"positionEncodings": ["utf-16"]},
|
||||
},
|
||||
}
|
||||
|
||||
result = await asyncio.wait_for(
|
||||
self._send_request("initialize", params),
|
||||
timeout=INITIALIZE_TIMEOUT,
|
||||
)
|
||||
self._initialize_result = result
|
||||
self._sync_kind = self._extract_sync_kind(result.get("capabilities") or {})
|
||||
|
||||
await self._send_notification("initialized", {})
|
||||
if self._init_options:
|
||||
# Some servers (vtsls, eslint) want config pushed via
|
||||
# didChangeConfiguration even if it was sent in
|
||||
# initializationOptions.
|
||||
await self._send_notification(
|
||||
"workspace/didChangeConfiguration",
|
||||
{"settings": self._init_options},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_sync_kind(capabilities: dict) -> int:
|
||||
sync = capabilities.get("textDocumentSync")
|
||||
if isinstance(sync, int):
|
||||
return sync
|
||||
if isinstance(sync, dict):
|
||||
change = sync.get("change")
|
||||
if isinstance(change, int):
|
||||
return change
|
||||
return 1 # default to Full
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Best-effort graceful shutdown.
|
||||
|
||||
Sends ``shutdown`` + ``exit``, then SIGTERMs/SIGKILLs the
|
||||
process if it doesn't exit cleanly. Idempotent.
|
||||
"""
|
||||
if self._stopping:
|
||||
return
|
||||
self._stopping = True
|
||||
try:
|
||||
if self.is_running:
|
||||
try:
|
||||
await asyncio.wait_for(self._send_request("shutdown", None), timeout=2.0)
|
||||
except (asyncio.TimeoutError, LSPRequestError, LSPProtocolError):
|
||||
pass
|
||||
try:
|
||||
await self._send_notification("exit", None)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
self._state = "stopped"
|
||||
await self._cleanup_process()
|
||||
|
||||
async def _cleanup_process(self) -> None:
|
||||
if self._reader_task is not None and not self._reader_task.done():
|
||||
self._reader_task.cancel()
|
||||
try:
|
||||
await self._reader_task
|
||||
except (asyncio.CancelledError, Exception): # noqa: BLE001
|
||||
pass
|
||||
if self._stderr_task is not None and not self._stderr_task.done():
|
||||
self._stderr_task.cancel()
|
||||
try:
|
||||
await self._stderr_task
|
||||
except (asyncio.CancelledError, Exception): # noqa: BLE001
|
||||
pass
|
||||
proc = self._proc
|
||||
self._proc = None
|
||||
if proc is None:
|
||||
return
|
||||
if proc.returncode is None:
|
||||
try:
|
||||
proc.terminate()
|
||||
try:
|
||||
await asyncio.wait_for(proc.wait(), timeout=SHUTDOWN_GRACE)
|
||||
except asyncio.TimeoutError:
|
||||
try:
|
||||
proc.kill()
|
||||
await proc.wait()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# request / notification plumbing
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _send_request(self, method: str, params: Any) -> Any:
|
||||
if self._proc is None or self._proc.stdin is None or self._proc.stdin.is_closing():
|
||||
raise LSPProtocolError(f"cannot send {method!r}: stdin closed")
|
||||
loop = asyncio.get_running_loop()
|
||||
req_id = self._next_id
|
||||
self._next_id += 1
|
||||
fut: asyncio.Future = loop.create_future()
|
||||
self._pending[req_id] = fut
|
||||
try:
|
||||
self._proc.stdin.write(encode_message(make_request(req_id, method, params)))
|
||||
await self._proc.stdin.drain()
|
||||
except (BrokenPipeError, ConnectionResetError, OSError) as e:
|
||||
self._pending.pop(req_id, None)
|
||||
raise LSPProtocolError(f"send failed for {method!r}: {e}") from e
|
||||
try:
|
||||
return await fut
|
||||
finally:
|
||||
self._pending.pop(req_id, None)
|
||||
|
||||
async def _send_request_with_retry(self, method: str, params: Any, *, timeout: float) -> Any:
|
||||
"""Send a request, retrying on ``ContentModified`` (-32801).
|
||||
|
||||
Other errors propagate. The retry policy matches Claude Code's
|
||||
``LSPServerInstance.sendRequest`` — 3 attempts with delays
|
||||
0.5s, 1.0s, 2.0s.
|
||||
"""
|
||||
for attempt in range(MAX_CONTENT_MODIFIED_RETRIES + 1):
|
||||
try:
|
||||
return await asyncio.wait_for(self._send_request(method, params), timeout=timeout)
|
||||
except LSPRequestError as e:
|
||||
if e.code == ERROR_CONTENT_MODIFIED and attempt < MAX_CONTENT_MODIFIED_RETRIES:
|
||||
await asyncio.sleep(RETRY_BASE_DELAY * (2 ** attempt))
|
||||
continue
|
||||
raise
|
||||
|
||||
async def _send_notification(self, method: str, params: Any) -> None:
|
||||
if self._proc is None or self._proc.stdin is None or self._proc.stdin.is_closing():
|
||||
return
|
||||
try:
|
||||
self._proc.stdin.write(encode_message(make_notification(method, params)))
|
||||
await self._proc.stdin.drain()
|
||||
except (BrokenPipeError, ConnectionResetError, OSError) as e:
|
||||
logger.debug("[%s] notify %s failed: %s", self.server_id, method, e)
|
||||
|
||||
async def _send_response(self, req_id: Any, result: Any) -> None:
|
||||
if self._proc is None or self._proc.stdin is None or self._proc.stdin.is_closing():
|
||||
return
|
||||
try:
|
||||
self._proc.stdin.write(encode_message(make_response(req_id, result)))
|
||||
await self._proc.stdin.drain()
|
||||
except (BrokenPipeError, ConnectionResetError, OSError):
|
||||
pass
|
||||
|
||||
async def _send_error_response(self, req_id: Any, code: int, message: str) -> None:
|
||||
if self._proc is None or self._proc.stdin is None or self._proc.stdin.is_closing():
|
||||
return
|
||||
try:
|
||||
self._proc.stdin.write(encode_message(make_error_response(req_id, code, message)))
|
||||
await self._proc.stdin.drain()
|
||||
except (BrokenPipeError, ConnectionResetError, OSError):
|
||||
pass
|
||||
|
||||
def _dispatch_response(self, req_id: int, msg: dict) -> None:
|
||||
fut = self._pending.get(req_id)
|
||||
if fut is None or fut.done():
|
||||
return
|
||||
if "error" in msg:
|
||||
err = msg["error"] or {}
|
||||
fut.set_exception(
|
||||
LSPRequestError(
|
||||
code=int(err.get("code", -32000)),
|
||||
message=str(err.get("message", "unknown")),
|
||||
data=err.get("data"),
|
||||
)
|
||||
)
|
||||
else:
|
||||
fut.set_result(msg.get("result"))
|
||||
|
||||
async def _dispatch_request(self, req_id: Any, msg: dict) -> None:
|
||||
method = msg.get("method", "")
|
||||
params = msg.get("params")
|
||||
handler = self._request_handlers.get(method)
|
||||
if handler is None:
|
||||
await self._send_error_response(req_id, ERROR_METHOD_NOT_FOUND, f"method not found: {method}")
|
||||
return
|
||||
try:
|
||||
result = await handler(params)
|
||||
except Exception as e: # noqa: BLE001 — protocol must not blow up
|
||||
logger.warning("[%s] request handler %s failed: %s", self.server_id, method, e)
|
||||
await self._send_error_response(req_id, -32000, f"handler failed: {e}")
|
||||
return
|
||||
await self._send_response(req_id, result)
|
||||
|
||||
def _dispatch_notification(self, method: str, msg: dict) -> None:
|
||||
handler = self._notification_handlers.get(method)
|
||||
if handler is None:
|
||||
return
|
||||
try:
|
||||
handler(msg.get("params"))
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.debug("[%s] notification handler %s failed: %s", self.server_id, method, e)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# built-in server-→-client request handlers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _handle_work_done_create(self, params: Any) -> Any:
|
||||
# Acknowledge progress tokens — required by some servers.
|
||||
return None
|
||||
|
||||
async def _handle_workspace_configuration(self, params: Any) -> Any:
|
||||
# Walk dotted sections through initializationOptions. Mirrors
|
||||
# OpenCode's `client.ts:198-220` — return null when missing.
|
||||
if not isinstance(params, dict):
|
||||
return [None]
|
||||
items = params.get("items") or []
|
||||
out: List[Any] = []
|
||||
for item in items:
|
||||
if not isinstance(item, dict):
|
||||
out.append(None)
|
||||
continue
|
||||
section = item.get("section")
|
||||
if not section or not self._init_options:
|
||||
out.append(self._init_options or None)
|
||||
continue
|
||||
cur: Any = self._init_options
|
||||
for part in str(section).split("."):
|
||||
if isinstance(cur, dict) and part in cur:
|
||||
cur = cur[part]
|
||||
else:
|
||||
cur = None
|
||||
break
|
||||
out.append(cur)
|
||||
return out
|
||||
|
||||
async def _handle_register_capability(self, params: Any) -> Any:
|
||||
if not isinstance(params, dict):
|
||||
return None
|
||||
for reg in params.get("registrations") or []:
|
||||
if not isinstance(reg, dict):
|
||||
continue
|
||||
method = reg.get("method")
|
||||
reg_id = reg.get("id")
|
||||
if method == "textDocument/diagnostic" and reg_id:
|
||||
self._diagnostic_registrations[str(reg_id)] = reg
|
||||
self._registration_event.set()
|
||||
return None
|
||||
|
||||
async def _handle_unregister_capability(self, params: Any) -> Any:
|
||||
if not isinstance(params, dict):
|
||||
return None
|
||||
for unreg in params.get("unregisterations") or []:
|
||||
if not isinstance(unreg, dict):
|
||||
continue
|
||||
reg_id = unreg.get("id")
|
||||
if reg_id:
|
||||
self._diagnostic_registrations.pop(str(reg_id), None)
|
||||
return None
|
||||
|
||||
async def _handle_workspace_folders(self, params: Any) -> Any:
|
||||
return [{"name": "workspace", "uri": file_uri(self.workspace_root)}]
|
||||
|
||||
async def _handle_diagnostic_refresh(self, params: Any) -> Any:
|
||||
# We don't honour refresh — we re-pull on every touchFile.
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# publishDiagnostics handler
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _handle_publish_diagnostics(self, params: Any) -> None:
|
||||
if not isinstance(params, dict):
|
||||
return
|
||||
uri = params.get("uri")
|
||||
if not isinstance(uri, str):
|
||||
return
|
||||
path = uri_to_path(uri)
|
||||
diagnostics = params.get("diagnostics") or []
|
||||
if not isinstance(diagnostics, list):
|
||||
diagnostics = []
|
||||
version = params.get("version")
|
||||
loop_time = asyncio.get_event_loop().time()
|
||||
|
||||
if self._seed_first_push and path not in self._first_push_seen:
|
||||
# First push: seed without firing the event so a waiter
|
||||
# doesn't resolve on the very first push (which arrives
|
||||
# before the user-triggered didChange could've produced
|
||||
# fresh diagnostics).
|
||||
self._first_push_seen.add(path)
|
||||
self._push_diagnostics[path] = diagnostics
|
||||
self._published[path] = loop_time
|
||||
if isinstance(version, int):
|
||||
self._published_version[path] = version
|
||||
return
|
||||
|
||||
self._push_diagnostics[path] = diagnostics
|
||||
self._published[path] = loop_time
|
||||
if isinstance(version, int):
|
||||
self._published_version[path] = version
|
||||
self._first_push_seen.add(path)
|
||||
# Bump the monotonic push counter and wake every waiter. We
|
||||
# keep the Event sticky-set so any wait already in progress
|
||||
# resolves; waiters re-check their predicate after waking and
|
||||
# decide whether to keep waiting. ``_push_counter`` is what
|
||||
# they actually compare against to detect a fresh event.
|
||||
self._push_counter += 1
|
||||
self._push_event.set()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# public file-sync API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def open_file(self, path: str, *, language_id: str = "plaintext") -> int:
|
||||
"""Send didOpen (first time) or didChange (subsequent) for ``path``.
|
||||
|
||||
Returns the new document version number that the agent's
|
||||
``wait_for_diagnostics`` should match against.
|
||||
"""
|
||||
if not self.is_running:
|
||||
raise LSPProtocolError("client not running")
|
||||
|
||||
abs_path = os.path.abspath(path)
|
||||
try:
|
||||
text = Path(abs_path).read_text(encoding="utf-8", errors="replace")
|
||||
except OSError as e:
|
||||
raise LSPProtocolError(f"cannot read {abs_path}: {e}") from e
|
||||
|
||||
uri = file_uri(abs_path)
|
||||
existing = self._files.get(abs_path)
|
||||
|
||||
if existing is not None:
|
||||
# Re-open: bump version, fire didChangeWatchedFiles + didChange.
|
||||
await self._send_notification(
|
||||
"workspace/didChangeWatchedFiles",
|
||||
{"changes": [{"uri": uri, "type": 2}]}, # 2 = CHANGED
|
||||
)
|
||||
new_version = existing["version"] + 1
|
||||
old_text = existing["text"]
|
||||
content_changes: List[Dict[str, Any]]
|
||||
if self._sync_kind == 2:
|
||||
content_changes = [
|
||||
{
|
||||
"range": {
|
||||
"start": {"line": 0, "character": 0},
|
||||
"end": _end_position(old_text),
|
||||
},
|
||||
"text": text,
|
||||
}
|
||||
]
|
||||
else:
|
||||
content_changes = [{"text": text}]
|
||||
await self._send_notification(
|
||||
"textDocument/didChange",
|
||||
{
|
||||
"textDocument": {"uri": uri, "version": new_version},
|
||||
"contentChanges": content_changes,
|
||||
},
|
||||
)
|
||||
self._files[abs_path] = {"version": new_version, "text": text}
|
||||
return new_version
|
||||
|
||||
# First open: didChangeWatchedFiles CREATED + didOpen.
|
||||
await self._send_notification(
|
||||
"workspace/didChangeWatchedFiles",
|
||||
{"changes": [{"uri": uri, "type": 1}]}, # 1 = CREATED
|
||||
)
|
||||
# Clear any stale push/pull entries — fresh open should start
|
||||
# from scratch.
|
||||
self._push_diagnostics.pop(abs_path, None)
|
||||
self._pull_diagnostics.pop(abs_path, None)
|
||||
self._published.pop(abs_path, None)
|
||||
self._published_version.pop(abs_path, None)
|
||||
await self._send_notification(
|
||||
"textDocument/didOpen",
|
||||
{
|
||||
"textDocument": {
|
||||
"uri": uri,
|
||||
"languageId": language_id,
|
||||
"version": 0,
|
||||
"text": text,
|
||||
}
|
||||
},
|
||||
)
|
||||
self._files[abs_path] = {"version": 0, "text": text}
|
||||
return 0
|
||||
|
||||
async def save_file(self, path: str) -> None:
|
||||
"""Send didSave for ``path``. Some linters re-scan only on save."""
|
||||
if not self.is_running:
|
||||
return
|
||||
abs_path = os.path.abspath(path)
|
||||
await self._send_notification(
|
||||
"textDocument/didSave",
|
||||
{"textDocument": {"uri": file_uri(abs_path)}},
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# diagnostics: pull + wait
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _pull_document_diagnostics(self, path: str) -> None:
|
||||
"""Send ``textDocument/diagnostic`` for one file.
|
||||
|
||||
Stores results into :attr:`_pull_diagnostics`. Silently
|
||||
no-ops on errors (server may not support the pull endpoint).
|
||||
"""
|
||||
try:
|
||||
params: Dict[str, Any] = {
|
||||
"textDocument": {"uri": file_uri(os.path.abspath(path))}
|
||||
}
|
||||
result = await self._send_request_with_retry(
|
||||
"textDocument/diagnostic",
|
||||
params,
|
||||
timeout=DIAGNOSTICS_REQUEST_TIMEOUT,
|
||||
)
|
||||
except (LSPRequestError, LSPProtocolError, asyncio.TimeoutError) as e:
|
||||
logger.debug("[%s] document diagnostic pull failed: %s", self.server_id, e)
|
||||
return
|
||||
if not isinstance(result, dict):
|
||||
return
|
||||
items = result.get("items")
|
||||
if isinstance(items, list):
|
||||
self._pull_diagnostics[os.path.abspath(path)] = items
|
||||
related = result.get("relatedDocuments")
|
||||
if isinstance(related, dict):
|
||||
for uri, sub in related.items():
|
||||
if not isinstance(sub, dict):
|
||||
continue
|
||||
sub_items = sub.get("items")
|
||||
if isinstance(sub_items, list):
|
||||
self._pull_diagnostics[uri_to_path(uri)] = sub_items
|
||||
|
||||
async def wait_for_diagnostics(
|
||||
self,
|
||||
path: str,
|
||||
version: int,
|
||||
*,
|
||||
mode: str = "document",
|
||||
) -> None:
|
||||
"""Wait for the server to publish diagnostics for ``path`` at ``version``.
|
||||
|
||||
``mode`` is ``"document"`` (5s budget, document pulls) or
|
||||
``"full"`` (10s budget, also workspace pulls). Best-effort —
|
||||
returns silently on timeout. Does NOT throw if the server
|
||||
doesn't support pull diagnostics; we still get the push side.
|
||||
"""
|
||||
budget = DIAGNOSTICS_FULL_WAIT if mode == "full" else DIAGNOSTICS_DOCUMENT_WAIT
|
||||
deadline = asyncio.get_event_loop().time() + budget
|
||||
abs_path = os.path.abspath(path)
|
||||
|
||||
while True:
|
||||
remaining = deadline - asyncio.get_event_loop().time()
|
||||
if remaining <= 0:
|
||||
return
|
||||
|
||||
# Concurrent: document pull + push wait.
|
||||
pull_task = asyncio.create_task(self._pull_document_diagnostics(abs_path))
|
||||
push_task = asyncio.create_task(self._wait_for_fresh_push(abs_path, version, remaining))
|
||||
done, pending = await asyncio.wait(
|
||||
{pull_task, push_task},
|
||||
timeout=remaining,
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
for t in pending:
|
||||
t.cancel()
|
||||
for t in pending:
|
||||
try:
|
||||
await t
|
||||
except (asyncio.CancelledError, Exception): # noqa: BLE001
|
||||
pass
|
||||
|
||||
# If we got a fresh push for our version, we're done.
|
||||
current_v = self._published_version.get(abs_path)
|
||||
if abs_path in self._published and (
|
||||
current_v is None or current_v >= version
|
||||
):
|
||||
return
|
||||
|
||||
# Pull may have populated _pull_diagnostics — that's also
|
||||
# success.
|
||||
if abs_path in self._pull_diagnostics:
|
||||
return
|
||||
|
||||
# Loop until budget runs out.
|
||||
|
||||
async def _wait_for_fresh_push(self, path: str, version: int, timeout: float) -> None:
|
||||
"""Wait until a publishDiagnostics arrives for ``path`` at ``version``+."""
|
||||
deadline = asyncio.get_event_loop().time() + timeout
|
||||
baseline = self._push_counter
|
||||
while True:
|
||||
current_v = self._published_version.get(path)
|
||||
if path in self._published and (current_v is None or current_v >= version):
|
||||
# Debounce — wait a tick in case more diagnostics arrive
|
||||
# immediately after. TS often emits in pairs. We
|
||||
# snapshot the counter so we wake on a *new* push, not
|
||||
# on the one that satisfied us a moment ago.
|
||||
debounce_baseline = self._push_counter
|
||||
debounce_deadline = asyncio.get_event_loop().time() + PUSH_DEBOUNCE
|
||||
while self._push_counter == debounce_baseline:
|
||||
remaining = debounce_deadline - asyncio.get_event_loop().time()
|
||||
if remaining <= 0:
|
||||
break
|
||||
self._push_event.clear()
|
||||
try:
|
||||
await asyncio.wait_for(self._push_event.wait(), timeout=remaining)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
return
|
||||
remaining = deadline - asyncio.get_event_loop().time()
|
||||
if remaining <= 0:
|
||||
return
|
||||
if self._push_counter > baseline:
|
||||
# New event arrived but predicate still false — re-check
|
||||
# immediately without waiting again.
|
||||
baseline = self._push_counter
|
||||
continue
|
||||
self._push_event.clear()
|
||||
try:
|
||||
await asyncio.wait_for(self._push_event.wait(), timeout=min(remaining, 0.5))
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
def diagnostics_for(self, path: str) -> List[Dict[str, Any]]:
|
||||
"""Return current merged + deduped diagnostics for one file.
|
||||
|
||||
Diagnostics from push and pull stores are concatenated and
|
||||
deduplicated by ``(severity, code, message, range)`` content
|
||||
key. Empty list if the server hasn't published anything.
|
||||
"""
|
||||
abs_path = os.path.abspath(path)
|
||||
push = self._push_diagnostics.get(abs_path) or []
|
||||
pull = self._pull_diagnostics.get(abs_path) or []
|
||||
return _dedupe(push, pull)
|
||||
|
||||
|
||||
def _dedupe(*lists: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
seen: Set[str] = set()
|
||||
out: List[Dict[str, Any]] = []
|
||||
for lst in lists:
|
||||
for d in lst:
|
||||
if not isinstance(d, dict):
|
||||
continue
|
||||
key = _diagnostic_key(d)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
out.append(d)
|
||||
return out
|
||||
|
||||
|
||||
def _diagnostic_key(d: Dict[str, Any]) -> str:
|
||||
"""Content-equality key for a diagnostic.
|
||||
|
||||
Matches the structural-equality used in claude-code's
|
||||
``areDiagnosticsEqual`` — message + severity + source + code +
|
||||
range coords. The range is reduced to a tuple to keep the key
|
||||
stable across dict orderings.
|
||||
"""
|
||||
rng = d.get("range") or {}
|
||||
start = rng.get("start") or {}
|
||||
end = rng.get("end") or {}
|
||||
code = d.get("code")
|
||||
if code is not None and not isinstance(code, str):
|
||||
code = str(code)
|
||||
return "\x00".join(
|
||||
[
|
||||
str(d.get("severity") or 1),
|
||||
str(code or ""),
|
||||
str(d.get("source") or ""),
|
||||
str(d.get("message") or "").strip(),
|
||||
f"{start.get('line', 0)}:{start.get('character', 0)}-{end.get('line', 0)}:{end.get('character', 0)}",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LSPClient",
|
||||
"file_uri",
|
||||
"uri_to_path",
|
||||
"INITIALIZE_TIMEOUT",
|
||||
"DIAGNOSTICS_DOCUMENT_WAIT",
|
||||
"DIAGNOSTICS_FULL_WAIT",
|
||||
]
|
||||
213
agent/lsp/eventlog.py
Normal file
213
agent/lsp/eventlog.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""Structured logging with steady-state silence for the LSP layer.
|
||||
|
||||
The LSP layer fires on every write_file/patch. In a busy session
|
||||
that's hundreds of events. We want users to be able to ``rg`` the
|
||||
log for "did LSP fire on that edit?" without drowning in noise.
|
||||
|
||||
The level model:
|
||||
|
||||
- ``DEBUG`` for steady-state events that have no novel signal:
|
||||
``clean``, ``feature off``, ``extension not mapped``, ``no project
|
||||
root for already-announced file``, ``server unavailable for
|
||||
already-announced binary``. These never reach ``agent.log`` at the
|
||||
default INFO threshold.
|
||||
|
||||
- ``INFO`` for state transitions worth surfacing exactly once per
|
||||
session: ``active for <root>`` the first time a (server_id,
|
||||
workspace_root) client starts, ``no project root for <path>``
|
||||
the first time we see that file. Plus every diagnostic event
|
||||
(those are inherently rare and per-edit, exactly what users grep
|
||||
for).
|
||||
|
||||
- ``WARNING`` for action-required failures: ``server unavailable``
|
||||
(binary not on PATH) the first time per (server_id, binary),
|
||||
``no server configured`` once per language. Per-call WARNING for
|
||||
timeouts and unexpected bridge exceptions.
|
||||
|
||||
The dedup is in-process module-level sets. Each set grows at most by
|
||||
the number of distinct (server_id, root) and (server_id, binary)
|
||||
pairs touched in one Python process — bytes of memory in even an
|
||||
aggressive monorepo session. Bounded LRU was rejected: evicting an
|
||||
entry would risk re-firing the WARNING/INFO line we explicitly want
|
||||
to suppress.
|
||||
|
||||
Grep recipe::
|
||||
|
||||
tail -f ~/.hermes/logs/agent.log | rg 'lsp\\['
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from typing import Tuple
|
||||
|
||||
# Dedicated logger name so the documented grep recipe survives a
|
||||
# ``logging.getLogger(__name__)`` rename of any internal module.
|
||||
event_log = logging.getLogger("hermes.lint.lsp")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Once-per-X dedup sets
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_announce_lock = threading.Lock()
|
||||
_announced_active: set = set() # keys: (server_id, workspace_root)
|
||||
_announced_unavailable: set = set() # keys: (server_id, binary_path_or_name)
|
||||
_announced_no_root: set = set() # keys: (server_id, file_path)
|
||||
_announced_no_server: set = set() # keys: (server_id,)
|
||||
|
||||
|
||||
def _short_path(file_path: str) -> str:
|
||||
"""Render *file_path* relative to the cwd when sensible, else absolute.
|
||||
|
||||
Keeps log lines readable for the common case (the user is inside
|
||||
the project they're editing) without emitting brittle ``../../..``
|
||||
chains for the cross-tree case.
|
||||
"""
|
||||
if not file_path:
|
||||
return file_path
|
||||
try:
|
||||
rel = os.path.relpath(file_path)
|
||||
except ValueError:
|
||||
return file_path
|
||||
if rel.startswith(".." + os.sep) or rel == "..":
|
||||
return file_path
|
||||
return rel
|
||||
|
||||
|
||||
def _emit(server_id: str, level: int, message: str) -> None:
|
||||
event_log.log(level, "lsp[%s] %s", server_id, message)
|
||||
|
||||
|
||||
def _announce_once(bucket: set, key: Tuple) -> bool:
|
||||
"""Return True if *key* has not been announced for *bucket* yet.
|
||||
|
||||
Atomically marks the key as announced so concurrent callers
|
||||
cannot both win the race and double-log.
|
||||
"""
|
||||
with _announce_lock:
|
||||
if key in bucket:
|
||||
return False
|
||||
bucket.add(key)
|
||||
return True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public event helpers — call these from the LSP layer.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def log_clean(server_id: str, file_path: str) -> None:
|
||||
"""No diagnostics emitted for *file_path*. DEBUG (silent at default)."""
|
||||
_emit(server_id, logging.DEBUG, f"clean ({_short_path(file_path)})")
|
||||
|
||||
|
||||
def log_disabled(server_id: str, file_path: str, reason: str) -> None:
|
||||
"""LSP intentionally skipped for this file (feature off, ext unmapped,
|
||||
backend not local, etc.). DEBUG."""
|
||||
_emit(server_id, logging.DEBUG, f"skipped: {reason} ({_short_path(file_path)})")
|
||||
|
||||
|
||||
def log_active(server_id: str, workspace_root: str) -> None:
|
||||
"""A new LSP client started for (server_id, workspace_root).
|
||||
|
||||
INFO once per (server_id, workspace_root); DEBUG thereafter.
|
||||
Lets users verify "is LSP actually running?" with a single grep.
|
||||
"""
|
||||
key = (server_id, workspace_root)
|
||||
if _announce_once(_announced_active, key):
|
||||
_emit(server_id, logging.INFO, f"active for {workspace_root}")
|
||||
else:
|
||||
_emit(server_id, logging.DEBUG, f"reused client for {workspace_root}")
|
||||
|
||||
|
||||
def log_diagnostics(server_id: str, file_path: str, count: int) -> None:
|
||||
"""Diagnostics arrived for a file. INFO every time — these are the
|
||||
failure signals users actually want to grep for, and they are
|
||||
inherently rare per edit."""
|
||||
_emit(server_id, logging.INFO, f"{count} diags ({_short_path(file_path)})")
|
||||
|
||||
|
||||
def log_no_project_root(server_id: str, file_path: str) -> None:
|
||||
"""File had no recognised project marker. INFO once per file,
|
||||
DEBUG thereafter."""
|
||||
key = (server_id, file_path)
|
||||
if _announce_once(_announced_no_root, key):
|
||||
_emit(server_id, logging.INFO, f"no project root for {_short_path(file_path)}")
|
||||
else:
|
||||
_emit(server_id, logging.DEBUG, f"no project root for {_short_path(file_path)}")
|
||||
|
||||
|
||||
def log_server_unavailable(server_id: str, binary_or_pkg: str) -> None:
|
||||
"""The server binary couldn't be resolved. WARNING once per
|
||||
(server_id, binary), DEBUG thereafter so a hundred subsequent
|
||||
.py edits don't spam the log."""
|
||||
key = (server_id, binary_or_pkg)
|
||||
if _announce_once(_announced_unavailable, key):
|
||||
_emit(
|
||||
server_id,
|
||||
logging.WARNING,
|
||||
f"server unavailable: {binary_or_pkg} not found "
|
||||
"(install via `hermes lsp install <id>` or set lsp.servers.<id>.command)",
|
||||
)
|
||||
else:
|
||||
_emit(server_id, logging.DEBUG, f"server still unavailable: {binary_or_pkg}")
|
||||
|
||||
|
||||
def log_no_server_configured(server_id: str) -> None:
|
||||
"""No spawn recipe for this language. WARNING once."""
|
||||
if _announce_once(_announced_no_server, (server_id,)):
|
||||
_emit(server_id, logging.WARNING, "no server configured")
|
||||
|
||||
|
||||
def log_timeout(server_id: str, file_path: str, kind: str = "diagnostics") -> None:
|
||||
"""A request to the server timed out. WARNING every time — these are
|
||||
inherently novel events worth surfacing on each occurrence."""
|
||||
_emit(
|
||||
server_id,
|
||||
logging.WARNING,
|
||||
f"{kind} timed out for {_short_path(file_path)}",
|
||||
)
|
||||
|
||||
|
||||
def log_server_error(server_id: str, file_path: str, exc: BaseException) -> None:
|
||||
"""An unexpected exception bubbled out of the LSP layer. WARNING."""
|
||||
_emit(
|
||||
server_id,
|
||||
logging.WARNING,
|
||||
f"unexpected error for {_short_path(file_path)}: {type(exc).__name__}: {exc}",
|
||||
)
|
||||
|
||||
|
||||
def log_spawn_failed(server_id: str, workspace_root: str, exc: BaseException) -> None:
|
||||
"""The LSP server failed to spawn or initialize. WARNING."""
|
||||
_emit(
|
||||
server_id,
|
||||
logging.WARNING,
|
||||
f"spawn/initialize failed for {workspace_root}: {type(exc).__name__}: {exc}",
|
||||
)
|
||||
|
||||
|
||||
def reset_announce_caches() -> None:
|
||||
"""Test-only: clear the dedup caches. Production code never calls this."""
|
||||
with _announce_lock:
|
||||
_announced_active.clear()
|
||||
_announced_unavailable.clear()
|
||||
_announced_no_root.clear()
|
||||
_announced_no_server.clear()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"event_log",
|
||||
"log_clean",
|
||||
"log_disabled",
|
||||
"log_active",
|
||||
"log_diagnostics",
|
||||
"log_no_project_root",
|
||||
"log_server_unavailable",
|
||||
"log_no_server_configured",
|
||||
"log_timeout",
|
||||
"log_server_error",
|
||||
"log_spawn_failed",
|
||||
"reset_announce_caches",
|
||||
]
|
||||
376
agent/lsp/install.py
Normal file
376
agent/lsp/install.py
Normal file
@@ -0,0 +1,376 @@
|
||||
"""Auto-installation of LSP server binaries.
|
||||
|
||||
Tries to install missing servers using whatever package manager is
|
||||
appropriate. All installs go to a Hermes-owned bin staging dir,
|
||||
``<HERMES_HOME>/lsp/bin/``, so we don't pollute the user's global
|
||||
toolchain.
|
||||
|
||||
Strategies:
|
||||
|
||||
- ``auto`` — attempt to install with the best available package
|
||||
manager. This is the default.
|
||||
- ``manual`` — never install; if a binary is missing, the server is
|
||||
silently skipped and the user is told about it via ``hermes lsp
|
||||
status``.
|
||||
- ``off`` — same as ``manual`` for now (kept distinct so we can
|
||||
evolve behavior later, e.g. logging differently).
|
||||
|
||||
The actual installs happen synchronously the first time a server is
|
||||
needed and concurrent calls to :func:`try_install` for the same
|
||||
package are deduplicated via a per-package lock.
|
||||
|
||||
Failure modes are non-fatal: every install path is wrapped in
|
||||
try/except and returns ``None`` on failure. The tool layer then
|
||||
falls back to its in-process syntax checker, exactly as if the user
|
||||
hadn't enabled LSP at all.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
logger = logging.getLogger("agent.lsp.install")
|
||||
|
||||
# Package-name → install-strategy hint registry. Each entry is a
|
||||
# tuple of strategy name + package name + executable name. When the
|
||||
# install completes, we look for the executable in
|
||||
# ``<HERMES_HOME>/lsp/bin/`` first, then on PATH.
|
||||
#
|
||||
# Optional fields:
|
||||
# - ``extra_pkgs``: list of sibling packages to install alongside
|
||||
# ``pkg`` in the same node_modules tree. Used when an LSP server
|
||||
# has a runtime peer dependency that npm doesn't auto-pull (e.g.
|
||||
# typescript-language-server needs ``typescript``).
|
||||
INSTALL_RECIPES: Dict[str, Dict[str, Any]] = {
|
||||
# Python
|
||||
"pyright": {"strategy": "npm", "pkg": "pyright", "bin": "pyright-langserver"},
|
||||
# JS/TS family
|
||||
"typescript-language-server": {
|
||||
"strategy": "npm",
|
||||
"pkg": "typescript-language-server",
|
||||
"bin": "typescript-language-server",
|
||||
# typescript-language-server requires the `typescript` SDK
|
||||
# (tsserver) to be importable from the same node_modules tree;
|
||||
# otherwise initialize() fails with "Could not find a valid
|
||||
# TypeScript installation". Install them together.
|
||||
"extra_pkgs": ["typescript"],
|
||||
},
|
||||
"@vue/language-server": {
|
||||
"strategy": "npm",
|
||||
"pkg": "@vue/language-server",
|
||||
"bin": "vue-language-server",
|
||||
},
|
||||
"svelte-language-server": {
|
||||
"strategy": "npm",
|
||||
"pkg": "svelte-language-server",
|
||||
"bin": "svelteserver",
|
||||
},
|
||||
"@astrojs/language-server": {
|
||||
"strategy": "npm",
|
||||
"pkg": "@astrojs/language-server",
|
||||
"bin": "astro-ls",
|
||||
},
|
||||
"yaml-language-server": {
|
||||
"strategy": "npm",
|
||||
"pkg": "yaml-language-server",
|
||||
"bin": "yaml-language-server",
|
||||
},
|
||||
"bash-language-server": {
|
||||
"strategy": "npm",
|
||||
"pkg": "bash-language-server",
|
||||
"bin": "bash-language-server",
|
||||
},
|
||||
"intelephense": {"strategy": "npm", "pkg": "intelephense", "bin": "intelephense"},
|
||||
"dockerfile-language-server-nodejs": {
|
||||
"strategy": "npm",
|
||||
"pkg": "dockerfile-language-server-nodejs",
|
||||
"bin": "docker-langserver",
|
||||
},
|
||||
# Go
|
||||
"gopls": {"strategy": "go", "pkg": "golang.org/x/tools/gopls@latest", "bin": "gopls"},
|
||||
# Rust — too heavy (hundreds of MB to bootstrap). We do NOT
|
||||
# auto-install rust-analyzer; users install via rustup.
|
||||
"rust-analyzer": {"strategy": "manual", "pkg": "", "bin": "rust-analyzer"},
|
||||
# C/C++ — manual (clangd ships with LLVM, very heavy)
|
||||
"clangd": {"strategy": "manual", "pkg": "", "bin": "clangd"},
|
||||
# Lua — manual (LuaLS is platform-specific binaries from GitHub
|
||||
# releases; complex enough that we punt to the user)
|
||||
"lua-language-server": {"strategy": "manual", "pkg": "", "bin": "lua-language-server"},
|
||||
}
|
||||
|
||||
|
||||
_install_locks: Dict[str, threading.Lock] = {}
|
||||
_install_results: Dict[str, Optional[str]] = {}
|
||||
_install_lock_meta = threading.Lock()
|
||||
|
||||
|
||||
def hermes_lsp_bin_dir() -> Path:
|
||||
"""Return the Hermes-owned bin staging dir for LSP servers."""
|
||||
home = os.environ.get("HERMES_HOME")
|
||||
if home is None:
|
||||
home = os.path.join(os.path.expanduser("~"), ".hermes")
|
||||
p = Path(home) / "lsp" / "bin"
|
||||
p.mkdir(parents=True, exist_ok=True)
|
||||
return p
|
||||
|
||||
|
||||
def _existing_binary(name: str) -> Optional[str]:
|
||||
"""Probe the staging dir + PATH for a binary named ``name``."""
|
||||
staged = hermes_lsp_bin_dir() / name
|
||||
if staged.exists() and os.access(staged, os.X_OK):
|
||||
return str(staged)
|
||||
on_path = shutil.which(name)
|
||||
if on_path:
|
||||
return on_path
|
||||
return None
|
||||
|
||||
|
||||
def _get_lock(pkg: str) -> threading.Lock:
|
||||
with _install_lock_meta:
|
||||
lock = _install_locks.get(pkg)
|
||||
if lock is None:
|
||||
lock = threading.Lock()
|
||||
_install_locks[pkg] = lock
|
||||
return lock
|
||||
|
||||
|
||||
def try_install(pkg: str, strategy: str = "auto") -> Optional[str]:
|
||||
"""Try to install ``pkg`` and return the binary path if successful.
|
||||
|
||||
``strategy`` is ``"auto"``, ``"manual"``, or ``"off"``. In
|
||||
``manual``/``off`` mode, this function only probes for an
|
||||
existing binary and returns ``None`` if not found.
|
||||
|
||||
The install is cached per-package — a second call returns the
|
||||
same path (or ``None``) without reinstalling. Concurrent calls
|
||||
are serialized.
|
||||
"""
|
||||
if strategy not in ("auto",):
|
||||
# Only ``auto`` triggers an actual install. In manual/off,
|
||||
# we still check whether the binary already exists.
|
||||
recipe = INSTALL_RECIPES.get(pkg, {})
|
||||
bin_name = recipe.get("bin", pkg)
|
||||
return _existing_binary(bin_name)
|
||||
|
||||
if pkg in _install_results:
|
||||
return _install_results[pkg]
|
||||
|
||||
lock = _get_lock(pkg)
|
||||
with lock:
|
||||
# Double-check after acquiring lock.
|
||||
if pkg in _install_results:
|
||||
return _install_results[pkg]
|
||||
result = _do_install(pkg)
|
||||
_install_results[pkg] = result
|
||||
return result
|
||||
|
||||
|
||||
def _do_install(pkg: str) -> Optional[str]:
|
||||
recipe = INSTALL_RECIPES.get(pkg)
|
||||
if recipe is None:
|
||||
# Not in our registry — best-effort: just probe PATH.
|
||||
return shutil.which(pkg)
|
||||
|
||||
strategy = recipe.get("strategy", "manual")
|
||||
bin_name = recipe.get("bin", pkg)
|
||||
|
||||
# Check if already present (shutil.which or staging dir)
|
||||
existing = _existing_binary(bin_name)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
if strategy == "manual":
|
||||
logger.debug("[install] %s requires manual install (recipe=%s)", pkg, recipe)
|
||||
return None
|
||||
|
||||
if strategy == "npm":
|
||||
return _install_npm(
|
||||
recipe.get("pkg", pkg),
|
||||
bin_name,
|
||||
extra_pkgs=recipe.get("extra_pkgs") or [],
|
||||
)
|
||||
if strategy == "go":
|
||||
return _install_go(recipe.get("pkg", pkg), bin_name)
|
||||
if strategy == "pip":
|
||||
return _install_pip(recipe.get("pkg", pkg), bin_name)
|
||||
|
||||
logger.warning("[install] unknown strategy %r for %s", strategy, pkg)
|
||||
return None
|
||||
|
||||
|
||||
def _install_npm(
|
||||
pkg: str,
|
||||
bin_name: str,
|
||||
extra_pkgs: Optional[list] = None,
|
||||
) -> Optional[str]:
|
||||
"""Install an npm package into our staging dir.
|
||||
|
||||
Uses ``npm install --prefix`` so the binaries land in
|
||||
``<staging>/node_modules/.bin/<bin_name>`` and we symlink them up
|
||||
one level for direct PATH-style access.
|
||||
|
||||
``extra_pkgs`` is a list of sibling packages to install in the
|
||||
same ``node_modules`` tree. Used for LSP servers with runtime
|
||||
peer deps that npm doesn't auto-pull (typescript-language-server
|
||||
needs ``typescript`` next to it; intelephense ships standalone).
|
||||
"""
|
||||
npm = shutil.which("npm")
|
||||
if npm is None:
|
||||
logger.info("[install] cannot install %s: npm not on PATH", pkg)
|
||||
return None
|
||||
staging = hermes_lsp_bin_dir().parent # <HERMES_HOME>/lsp/
|
||||
install_targets = [pkg] + list(extra_pkgs or [])
|
||||
try:
|
||||
logger.info(
|
||||
"[install] npm install --prefix %s %s",
|
||||
staging,
|
||||
" ".join(install_targets),
|
||||
)
|
||||
proc = subprocess.run(
|
||||
[npm, "install", "--prefix", str(staging), "--silent", "--no-fund", "--no-audit", *install_targets],
|
||||
check=False,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300,
|
||||
)
|
||||
if proc.returncode != 0:
|
||||
logger.warning(
|
||||
"[install] npm install failed for %s: %s", pkg, proc.stderr.strip()[:500]
|
||||
)
|
||||
return None
|
||||
except (subprocess.TimeoutExpired, OSError) as e:
|
||||
logger.warning("[install] npm install errored for %s: %s", pkg, e)
|
||||
return None
|
||||
|
||||
# Find the bin
|
||||
nm_bin = staging / "node_modules" / ".bin" / bin_name
|
||||
if os.name == "nt":
|
||||
# On Windows npm sometimes drops `.cmd` shims
|
||||
candidates = [nm_bin, nm_bin.with_suffix(".cmd")]
|
||||
else:
|
||||
candidates = [nm_bin]
|
||||
for c in candidates:
|
||||
if c.exists():
|
||||
# Symlink into our `lsp/bin/` for stable PATH access.
|
||||
link = hermes_lsp_bin_dir() / c.name
|
||||
if not link.exists():
|
||||
try:
|
||||
link.symlink_to(c)
|
||||
except (OSError, NotImplementedError):
|
||||
# Symlinks fail on some Windows setups — copy instead.
|
||||
try:
|
||||
shutil.copy2(c, link)
|
||||
except OSError:
|
||||
return str(c)
|
||||
return str(link if link.exists() else c)
|
||||
logger.warning("[install] npm install for %s succeeded but bin %s not found", pkg, bin_name)
|
||||
return None
|
||||
|
||||
|
||||
def _install_go(pkg: str, bin_name: str) -> Optional[str]:
|
||||
"""Install a Go module to GOBIN=<staging>."""
|
||||
go = shutil.which("go")
|
||||
if go is None:
|
||||
logger.info("[install] cannot install %s: go not on PATH", pkg)
|
||||
return None
|
||||
staging = hermes_lsp_bin_dir()
|
||||
env = dict(os.environ)
|
||||
env["GOBIN"] = str(staging)
|
||||
try:
|
||||
logger.info("[install] go install %s (GOBIN=%s)", pkg, staging)
|
||||
proc = subprocess.run(
|
||||
[go, "install", pkg],
|
||||
check=False,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=600,
|
||||
env=env,
|
||||
)
|
||||
if proc.returncode != 0:
|
||||
logger.warning(
|
||||
"[install] go install failed for %s: %s", pkg, proc.stderr.strip()[:500]
|
||||
)
|
||||
return None
|
||||
except (subprocess.TimeoutExpired, OSError) as e:
|
||||
logger.warning("[install] go install errored for %s: %s", pkg, e)
|
||||
return None
|
||||
bin_path = staging / bin_name
|
||||
if os.name == "nt":
|
||||
bin_path = bin_path.with_suffix(".exe")
|
||||
if bin_path.exists():
|
||||
return str(bin_path)
|
||||
logger.warning("[install] go install for %s succeeded but bin %s not found", pkg, bin_name)
|
||||
return None
|
||||
|
||||
|
||||
def _install_pip(pkg: str, bin_name: str) -> Optional[str]:
|
||||
"""Install a Python package into a hermes-owned target dir.
|
||||
|
||||
We avoid polluting the user's site-packages by using
|
||||
``pip install --target``. Bins go into
|
||||
``<staging>/python-packages/bin/`` which we symlink into
|
||||
``<staging>/bin``. Note: this only works for packages that ship a
|
||||
console script.
|
||||
"""
|
||||
pip_target = hermes_lsp_bin_dir().parent / "python-packages"
|
||||
pip_target.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
logger.info("[install] pip install --target %s %s", pip_target, pkg)
|
||||
proc = subprocess.run(
|
||||
[sys.executable, "-m", "pip", "install", "--target", str(pip_target), "--quiet", pkg],
|
||||
check=False,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300,
|
||||
)
|
||||
if proc.returncode != 0:
|
||||
logger.warning(
|
||||
"[install] pip install failed for %s: %s", pkg, proc.stderr.strip()[:500]
|
||||
)
|
||||
return None
|
||||
except (subprocess.TimeoutExpired, OSError) as e:
|
||||
logger.warning("[install] pip install errored for %s: %s", pkg, e)
|
||||
return None
|
||||
# Look for the script
|
||||
bin_path = pip_target / "bin" / bin_name
|
||||
if bin_path.exists():
|
||||
link = hermes_lsp_bin_dir() / bin_name
|
||||
if not link.exists():
|
||||
try:
|
||||
link.symlink_to(bin_path)
|
||||
except (OSError, NotImplementedError):
|
||||
try:
|
||||
shutil.copy2(bin_path, link)
|
||||
except OSError:
|
||||
return str(bin_path)
|
||||
return str(link if link.exists() else bin_path)
|
||||
return None
|
||||
|
||||
|
||||
def detect_status(pkg: str) -> str:
|
||||
"""Return ``installed``, ``missing``, or ``manual-only`` for a package.
|
||||
|
||||
Used by the ``hermes lsp status`` CLI to give users a quick
|
||||
overview of what's available without spawning anything.
|
||||
"""
|
||||
recipe = INSTALL_RECIPES.get(pkg)
|
||||
bin_name = recipe.get("bin", pkg) if recipe else pkg
|
||||
if _existing_binary(bin_name):
|
||||
return "installed"
|
||||
if recipe and recipe.get("strategy") == "manual":
|
||||
return "manual-only"
|
||||
return "missing"
|
||||
|
||||
|
||||
__all__ = [
|
||||
"INSTALL_RECIPES",
|
||||
"try_install",
|
||||
"detect_status",
|
||||
"hermes_lsp_bin_dir",
|
||||
]
|
||||
639
agent/lsp/manager.py
Normal file
639
agent/lsp/manager.py
Normal file
@@ -0,0 +1,639 @@
|
||||
"""Service-level orchestration for LSP clients.
|
||||
|
||||
The :class:`LSPService` is the bridge between the synchronous
|
||||
file_operations layer and the async :class:`agent.lsp.client.LSPClient`.
|
||||
|
||||
Design choices:
|
||||
|
||||
- A **single asyncio event loop** runs in a background thread. All
|
||||
client work happens on that loop. Synchronous callers from
|
||||
``tools/file_operations.py`` use :meth:`get_diagnostics_sync` to
|
||||
open + wait + drain in one blocking call.
|
||||
|
||||
- One client per ``(server_id, workspace_root)`` key. Lazy spawn:
|
||||
the first request for a key spawns the client; subsequent requests
|
||||
re-use it.
|
||||
|
||||
- A **broken-set** records ``(server_id, workspace_root)`` pairs that
|
||||
failed to spawn or initialize. These are never retried for the
|
||||
life of the service. Mirrors OpenCode's design.
|
||||
|
||||
- A **delta baseline** map keeps "diagnostics-as-of-the-last-snapshot"
|
||||
per file. ``snapshot_baseline()`` is called BEFORE a write; the
|
||||
next ``get_diagnostics_sync()`` returns only diagnostics that
|
||||
weren't in the baseline. This is the lift from Claude Code's
|
||||
``beforeFileEdited`` / ``getNewDiagnostics`` pattern, except wired
|
||||
to the local LSP layer instead of MCP IDE RPC.
|
||||
|
||||
The service is **off by default** — call :meth:`is_active` to check
|
||||
whether it's actually doing anything. When LSP is disabled in
|
||||
config, when no git workspace can be detected, when all configured
|
||||
servers are missing binaries and auto-install is off, ``is_active``
|
||||
returns False and the file_operations layer falls through to the
|
||||
in-process syntax check.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import Future as ConcurrentFuture
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from agent.lsp import eventlog
|
||||
from agent.lsp.client import (
|
||||
DIAGNOSTICS_DOCUMENT_WAIT,
|
||||
LSPClient,
|
||||
file_uri,
|
||||
)
|
||||
from agent.lsp.servers import (
|
||||
ServerContext,
|
||||
ServerDef,
|
||||
SpawnSpec,
|
||||
find_server_for_file,
|
||||
language_id_for,
|
||||
)
|
||||
from agent.lsp.workspace import (
|
||||
clear_cache,
|
||||
is_inside_workspace,
|
||||
resolve_workspace_for_file,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("agent.lsp.manager")
|
||||
|
||||
DEFAULT_IDLE_TIMEOUT = 600 # seconds; servers idle for >10min get reaped
|
||||
|
||||
|
||||
class _BackgroundLoop:
|
||||
"""A daemon thread that owns one asyncio event loop.
|
||||
|
||||
Provides :meth:`run` for synchronous callers — submits a coroutine
|
||||
to the loop and blocks until it finishes (or a timeout fires).
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._ready = threading.Event()
|
||||
|
||||
def start(self) -> None:
|
||||
if self._thread is not None:
|
||||
return
|
||||
self._thread = threading.Thread(
|
||||
target=self._run_forever,
|
||||
name="hermes-lsp-loop",
|
||||
daemon=True,
|
||||
)
|
||||
self._thread.start()
|
||||
self._ready.wait(timeout=5.0)
|
||||
|
||||
def _run_forever(self) -> None:
|
||||
loop = asyncio.new_event_loop()
|
||||
self._loop = loop
|
||||
asyncio.set_event_loop(loop)
|
||||
self._ready.set()
|
||||
try:
|
||||
loop.run_forever()
|
||||
finally:
|
||||
try:
|
||||
loop.close()
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
|
||||
def run(self, coro, *, timeout: Optional[float] = None) -> Any:
|
||||
"""Submit a coroutine to the loop and block until done.
|
||||
|
||||
Returns the coroutine's result, or raises its exception.
|
||||
"""
|
||||
if self._loop is None:
|
||||
raise RuntimeError("background loop not started")
|
||||
fut: ConcurrentFuture = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
||||
try:
|
||||
return fut.result(timeout=timeout)
|
||||
except Exception:
|
||||
fut.cancel()
|
||||
raise
|
||||
|
||||
def stop(self) -> None:
|
||||
loop = self._loop
|
||||
if loop is None:
|
||||
return
|
||||
try:
|
||||
loop.call_soon_threadsafe(loop.stop)
|
||||
except RuntimeError:
|
||||
pass
|
||||
if self._thread is not None:
|
||||
self._thread.join(timeout=2.0)
|
||||
self._loop = None
|
||||
self._thread = None
|
||||
|
||||
|
||||
class LSPService:
|
||||
"""The process-wide LSP service.
|
||||
|
||||
Created once via :meth:`create_from_config`; the
|
||||
:func:`agent.lsp.get_service` accessor manages the singleton.
|
||||
Most callers should use that accessor rather than constructing
|
||||
:class:`LSPService` directly.
|
||||
"""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# construction + factory
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
enabled: bool,
|
||||
wait_mode: str,
|
||||
wait_timeout: float,
|
||||
install_strategy: str,
|
||||
binary_overrides: Optional[Dict[str, List[str]]] = None,
|
||||
env_overrides: Optional[Dict[str, Dict[str, str]]] = None,
|
||||
init_overrides: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
disabled_servers: Optional[List[str]] = None,
|
||||
idle_timeout: float = DEFAULT_IDLE_TIMEOUT,
|
||||
) -> None:
|
||||
self._enabled = enabled
|
||||
self._wait_mode = wait_mode if wait_mode in ("document", "full") else "document"
|
||||
self._wait_timeout = wait_timeout
|
||||
self._install_strategy = install_strategy
|
||||
self._binary_overrides = binary_overrides or {}
|
||||
self._env_overrides = env_overrides or {}
|
||||
self._init_overrides = init_overrides or {}
|
||||
self._disabled_servers = set(disabled_servers or [])
|
||||
self._idle_timeout = idle_timeout
|
||||
|
||||
self._loop = _BackgroundLoop()
|
||||
if self._enabled:
|
||||
self._loop.start()
|
||||
|
||||
# Per-(server_id, workspace_root) state
|
||||
self._clients: Dict[Tuple[str, str], LSPClient] = {}
|
||||
self._broken: set = set()
|
||||
self._spawning: Dict[Tuple[str, str], asyncio.Future] = {}
|
||||
self._last_used: Dict[Tuple[str, str], float] = {}
|
||||
self._state_lock = threading.Lock()
|
||||
|
||||
# Delta baseline: file path → snapshot of diagnostics taken
|
||||
# immediately before a write. ``get_diagnostics_sync`` filters
|
||||
# out anything in the baseline so the agent only sees errors
|
||||
# introduced by the current edit.
|
||||
self._delta_baseline: Dict[str, List[Dict[str, Any]]] = {}
|
||||
|
||||
@classmethod
|
||||
def create_from_config(cls) -> Optional["LSPService"]:
|
||||
"""Build a service from ``hermes_cli.config`` settings.
|
||||
|
||||
Returns ``None`` if the config can't be loaded. The service
|
||||
itself returns ``is_active()`` False when LSP is disabled.
|
||||
"""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
cfg = load_config()
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.debug("LSP config load failed: %s", e)
|
||||
return None
|
||||
|
||||
lsp_cfg = (cfg.get("lsp") or {}) if isinstance(cfg, dict) else {}
|
||||
if not isinstance(lsp_cfg, dict):
|
||||
lsp_cfg = {}
|
||||
|
||||
enabled = bool(lsp_cfg.get("enabled", True))
|
||||
wait_mode = lsp_cfg.get("wait_mode", "document")
|
||||
wait_timeout = float(lsp_cfg.get("wait_timeout", DIAGNOSTICS_DOCUMENT_WAIT))
|
||||
install_strategy = lsp_cfg.get("install_strategy", "auto")
|
||||
servers_cfg = lsp_cfg.get("servers") or {}
|
||||
disabled = []
|
||||
binary_overrides: Dict[str, List[str]] = {}
|
||||
env_overrides: Dict[str, Dict[str, str]] = {}
|
||||
init_overrides: Dict[str, Dict[str, Any]] = {}
|
||||
if isinstance(servers_cfg, dict):
|
||||
for name, sub in servers_cfg.items():
|
||||
if not isinstance(sub, dict):
|
||||
continue
|
||||
if sub.get("disabled"):
|
||||
disabled.append(name)
|
||||
cmd = sub.get("command")
|
||||
if isinstance(cmd, list) and cmd:
|
||||
binary_overrides[name] = cmd
|
||||
env = sub.get("env")
|
||||
if isinstance(env, dict):
|
||||
env_overrides[name] = {k: str(v) for k, v in env.items()}
|
||||
init = sub.get("initialization_options")
|
||||
if isinstance(init, dict):
|
||||
init_overrides[name] = init
|
||||
|
||||
return cls(
|
||||
enabled=enabled,
|
||||
wait_mode=wait_mode,
|
||||
wait_timeout=wait_timeout,
|
||||
install_strategy=install_strategy,
|
||||
binary_overrides=binary_overrides,
|
||||
env_overrides=env_overrides,
|
||||
init_overrides=init_overrides,
|
||||
disabled_servers=disabled,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def is_active(self) -> bool:
|
||||
"""Return True iff this service should be consulted at all."""
|
||||
return self._enabled
|
||||
|
||||
def enabled_for(self, file_path: str) -> bool:
|
||||
"""Return True iff LSP should run for this specific file.
|
||||
|
||||
Gates on workspace detection (file or cwd inside a git worktree),
|
||||
on whether any registered server matches the extension, and
|
||||
on whether the (server_id, workspace_root) pair is in the
|
||||
broken-set from a previous spawn failure.
|
||||
|
||||
Files in already-broken pairs return False so the file_operations
|
||||
layer skips the LSP path entirely — no spawn attempts, no
|
||||
timeout cost — until the service is restarted (``hermes lsp
|
||||
restart``) or the process exits.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return False
|
||||
srv = find_server_for_file(file_path)
|
||||
if srv is None or srv.server_id in self._disabled_servers:
|
||||
return False
|
||||
ws_root, gated_in = resolve_workspace_for_file(file_path)
|
||||
if not (ws_root and gated_in):
|
||||
return False
|
||||
# Broken-set short-circuit. Use the per-server root if we can
|
||||
# compute one cheaply; otherwise fall back to the workspace
|
||||
# root as the broken key (which is what _get_or_spawn would
|
||||
# have used anyway when it failed).
|
||||
try:
|
||||
per_server_root = srv.resolve_root(file_path, ws_root) or ws_root
|
||||
except Exception: # noqa: BLE001
|
||||
per_server_root = ws_root
|
||||
if (srv.server_id, per_server_root) in self._broken:
|
||||
return False
|
||||
return True
|
||||
|
||||
def snapshot_baseline(self, file_path: str) -> None:
|
||||
"""Snapshot current diagnostics for ``file_path`` as the delta baseline.
|
||||
|
||||
Called BEFORE a write so the next ``get_diagnostics_sync()``
|
||||
can filter out pre-existing errors. Best-effort — failures
|
||||
are silently swallowed so a flaky server can't break a write.
|
||||
|
||||
Outer timeouts (e.g. server hangs during initialize) mark the
|
||||
(server_id, workspace_root) pair as broken so subsequent edits
|
||||
skip it instantly instead of re-paying the timeout cost.
|
||||
"""
|
||||
if not self.enabled_for(file_path):
|
||||
return
|
||||
try:
|
||||
diags = self._loop.run(self._snapshot_async(file_path), timeout=8.0)
|
||||
self._delta_baseline[os.path.abspath(file_path)] = diags or []
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.debug("baseline snapshot failed for %s: %s", file_path, e)
|
||||
self._mark_broken_for_file(file_path, e)
|
||||
self._delta_baseline[os.path.abspath(file_path)] = []
|
||||
|
||||
def get_diagnostics_sync(
|
||||
self,
|
||||
file_path: str,
|
||||
*,
|
||||
delta: bool = True,
|
||||
timeout: Optional[float] = None,
|
||||
line_shift: Optional[Callable[[int], Optional[int]]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Synchronously open ``file_path`` in the right server, wait for
|
||||
diagnostics, return them.
|
||||
|
||||
If ``delta`` is True (default), the result is filtered against
|
||||
any baseline previously captured via :meth:`snapshot_baseline`.
|
||||
Diagnostics present in the baseline are removed so the caller
|
||||
only sees errors introduced by the current edit.
|
||||
|
||||
When ``line_shift`` is provided, baseline diagnostics are
|
||||
remapped through it before the set-difference. This handles
|
||||
the case where the edit deleted or inserted lines, causing
|
||||
pre-existing diagnostics below the edit point to surface at
|
||||
different line numbers in the post-edit snapshot — without
|
||||
the shift, they'd all look "introduced by this edit". Pass
|
||||
a callable built by
|
||||
:func:`agent.lsp.range_shift.build_line_shift` (pre_text,
|
||||
post_text). Omit when pre/post content isn't available;
|
||||
the unshifted comparison still catches diagnostics that
|
||||
didn't move.
|
||||
|
||||
Returns an empty list when LSP is disabled, when no workspace
|
||||
can be detected, when no server matches, or when the server
|
||||
can't be spawned. Never raises.
|
||||
"""
|
||||
if not self.enabled_for(file_path):
|
||||
return []
|
||||
|
||||
# Resolve server_id eagerly so we can emit structured logs even
|
||||
# when the request errors out below.
|
||||
srv = find_server_for_file(file_path)
|
||||
server_id = srv.server_id if srv else "?"
|
||||
|
||||
try:
|
||||
t = timeout if timeout is not None else self._wait_timeout + 2.0
|
||||
diags = self._loop.run(self._open_and_wait_async(file_path), timeout=t) or []
|
||||
except asyncio.TimeoutError as e:
|
||||
eventlog.log_timeout(server_id, file_path)
|
||||
logger.debug("LSP diagnostics timeout for %s: %s", file_path, e)
|
||||
self._mark_broken_for_file(file_path, e)
|
||||
return []
|
||||
except Exception as e: # noqa: BLE001
|
||||
eventlog.log_server_error(server_id, file_path, e)
|
||||
logger.debug("LSP diagnostics fetch failed for %s: %s", file_path, e)
|
||||
self._mark_broken_for_file(file_path, e)
|
||||
return []
|
||||
|
||||
abs_path = os.path.abspath(file_path)
|
||||
if delta:
|
||||
baseline = self._delta_baseline.get(abs_path) or []
|
||||
if baseline:
|
||||
if line_shift is not None:
|
||||
# Remap baseline diagnostics into post-edit
|
||||
# coordinates so shifted-but-otherwise-identical
|
||||
# entries hash equal under _diag_key. Entries
|
||||
# that mapped into a deleted region drop out
|
||||
# silently — they no longer apply.
|
||||
from agent.lsp.range_shift import shift_baseline
|
||||
baseline = shift_baseline(baseline, line_shift)
|
||||
seen = {_diag_key(d) for d in baseline}
|
||||
diags = [d for d in diags if _diag_key(d) not in seen]
|
||||
# Roll baseline forward — next call returns deltas relative
|
||||
# to the just-emitted state, mirroring claude-code's
|
||||
# diagnosticTracking.
|
||||
try:
|
||||
fresh = self._loop.run(self._current_diags_async(file_path), timeout=2.0) or []
|
||||
except Exception: # noqa: BLE001
|
||||
fresh = []
|
||||
if fresh:
|
||||
self._delta_baseline[abs_path] = fresh
|
||||
|
||||
if diags:
|
||||
eventlog.log_diagnostics(server_id, file_path, len(diags))
|
||||
else:
|
||||
eventlog.log_clean(server_id, file_path)
|
||||
return diags
|
||||
|
||||
def _mark_broken_for_file(self, file_path: str, exc: BaseException) -> None:
|
||||
"""Mark the (server_id, workspace_root) pair as broken so subsequent
|
||||
edits skip it instantly instead of re-paying timeout cost.
|
||||
|
||||
Called when the outer ``_loop.run`` timeout cancels an in-flight
|
||||
spawn/initialize that the inner ``_get_or_spawn`` task was still
|
||||
holding open. Without this, every subsequent write would re-enter
|
||||
the spawn path and re-pay the full ``snapshot_baseline``
|
||||
timeout (8s) until the binary is fixed.
|
||||
|
||||
Also kills any orphan client process that survived the cancelled
|
||||
future, and emits a single eventlog WARNING so the user knows
|
||||
which server gave up.
|
||||
|
||||
``exc`` is whatever exception the outer wrapper caught — used
|
||||
only for logging, never re-raised.
|
||||
"""
|
||||
srv = find_server_for_file(file_path)
|
||||
if srv is None:
|
||||
return
|
||||
ws_root, gated = resolve_workspace_for_file(file_path)
|
||||
if not (ws_root and gated):
|
||||
return
|
||||
try:
|
||||
per_server_root = srv.resolve_root(file_path, ws_root) or ws_root
|
||||
except Exception: # noqa: BLE001
|
||||
per_server_root = ws_root
|
||||
key = (srv.server_id, per_server_root)
|
||||
already_broken = key in self._broken
|
||||
self._broken.add(key)
|
||||
|
||||
# Kill any client we managed to spawn before the timeout. The
|
||||
# cancelled future never reached the broken-set add inside
|
||||
# ``_get_or_spawn`` so the client may still be hanging in
|
||||
# ``_clients`` with a half-initialized state.
|
||||
with self._state_lock:
|
||||
client = self._clients.pop(key, None)
|
||||
if client is not None:
|
||||
try:
|
||||
# Fire-and-forget shutdown — give it a second to cleanup,
|
||||
# but don't block. We're already on a slow path.
|
||||
self._loop.run(client.shutdown(), timeout=1.0)
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
|
||||
if not already_broken:
|
||||
eventlog.log_spawn_failed(srv.server_id, per_server_root, exc)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Tear down all clients and stop the background loop."""
|
||||
if not self._enabled:
|
||||
return
|
||||
try:
|
||||
self._loop.run(self._shutdown_async(), timeout=10.0)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.debug("LSP shutdown error: %s", e)
|
||||
self._loop.stop()
|
||||
clear_cache()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# async internals
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _snapshot_async(self, file_path: str) -> List[Dict[str, Any]]:
|
||||
client = await self._get_or_spawn(file_path)
|
||||
if client is None:
|
||||
return []
|
||||
try:
|
||||
version = await client.open_file(file_path, language_id=language_id_for(file_path))
|
||||
await client.wait_for_diagnostics(file_path, version, mode=self._wait_mode)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.debug("snapshot open/wait failed: %s", e)
|
||||
return []
|
||||
self._last_used[(client.server_id, client.workspace_root)] = time.time()
|
||||
return list(client.diagnostics_for(file_path))
|
||||
|
||||
async def _open_and_wait_async(self, file_path: str) -> List[Dict[str, Any]]:
|
||||
client = await self._get_or_spawn(file_path)
|
||||
if client is None:
|
||||
return []
|
||||
try:
|
||||
version = await client.open_file(file_path, language_id=language_id_for(file_path))
|
||||
await client.save_file(file_path)
|
||||
await client.wait_for_diagnostics(file_path, version, mode=self._wait_mode)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.debug("open/wait failed for %s: %s", file_path, e)
|
||||
return []
|
||||
self._last_used[(client.server_id, client.workspace_root)] = time.time()
|
||||
return list(client.diagnostics_for(file_path))
|
||||
|
||||
async def _current_diags_async(self, file_path: str) -> List[Dict[str, Any]]:
|
||||
ws, gated = resolve_workspace_for_file(file_path)
|
||||
srv = find_server_for_file(file_path)
|
||||
if not (ws and gated and srv):
|
||||
return []
|
||||
with self._state_lock:
|
||||
client = self._clients.get((srv.server_id, ws))
|
||||
if client is None:
|
||||
return []
|
||||
return list(client.diagnostics_for(file_path))
|
||||
|
||||
async def _get_or_spawn(self, file_path: str) -> Optional[LSPClient]:
|
||||
srv = find_server_for_file(file_path)
|
||||
if srv is None:
|
||||
return None
|
||||
if srv.server_id in self._disabled_servers:
|
||||
eventlog.log_disabled(srv.server_id, file_path, "disabled in config")
|
||||
return None
|
||||
ws_root, gated = resolve_workspace_for_file(file_path)
|
||||
if not (ws_root and gated):
|
||||
eventlog.log_no_project_root(srv.server_id, file_path)
|
||||
return None
|
||||
per_server_root = srv.resolve_root(file_path, ws_root)
|
||||
if per_server_root is None:
|
||||
eventlog.log_disabled(
|
||||
srv.server_id, file_path, "exclude marker hit (server gated off)"
|
||||
)
|
||||
return None # exclude marker hit, server gated off
|
||||
|
||||
key = (srv.server_id, per_server_root)
|
||||
if key in self._broken:
|
||||
return None
|
||||
with self._state_lock:
|
||||
client = self._clients.get(key)
|
||||
if client is not None and client.is_running:
|
||||
eventlog.log_active(srv.server_id, per_server_root)
|
||||
return client
|
||||
spawning = self._spawning.get(key)
|
||||
if spawning is not None:
|
||||
try:
|
||||
return await spawning
|
||||
except Exception: # noqa: BLE001
|
||||
return None
|
||||
|
||||
# Begin spawn
|
||||
loop = asyncio.get_running_loop()
|
||||
spawn_future: asyncio.Future = loop.create_future()
|
||||
with self._state_lock:
|
||||
self._spawning[key] = spawn_future
|
||||
try:
|
||||
ctx = ServerContext(
|
||||
workspace_root=per_server_root,
|
||||
install_strategy=self._install_strategy,
|
||||
binary_overrides=self._binary_overrides,
|
||||
env_overrides=self._env_overrides,
|
||||
init_overrides=self._init_overrides,
|
||||
)
|
||||
spec = srv.build_spawn(per_server_root, ctx)
|
||||
if spec is None:
|
||||
# ``build_spawn`` returns None when the binary can't be
|
||||
# located (auto-install disabled, manual-only server,
|
||||
# or install attempt failed). Surface this once via
|
||||
# the structured logger so the user can act on it.
|
||||
eventlog.log_server_unavailable(srv.server_id, srv.server_id)
|
||||
self._broken.add(key)
|
||||
spawn_future.set_result(None)
|
||||
return None
|
||||
client = LSPClient(
|
||||
server_id=srv.server_id,
|
||||
workspace_root=spec.workspace_root,
|
||||
command=spec.command,
|
||||
env=spec.env,
|
||||
cwd=spec.cwd,
|
||||
initialization_options=spec.initialization_options,
|
||||
seed_diagnostics_on_first_push=spec.seed_diagnostics_on_first_push or srv.seed_first_push,
|
||||
)
|
||||
try:
|
||||
await client.start()
|
||||
except Exception as e: # noqa: BLE001
|
||||
eventlog.log_spawn_failed(srv.server_id, per_server_root, e)
|
||||
self._broken.add(key)
|
||||
spawn_future.set_result(None)
|
||||
return None
|
||||
with self._state_lock:
|
||||
self._clients[key] = client
|
||||
self._last_used[key] = time.time()
|
||||
eventlog.log_active(srv.server_id, per_server_root)
|
||||
spawn_future.set_result(client)
|
||||
return client
|
||||
finally:
|
||||
with self._state_lock:
|
||||
self._spawning.pop(key, None)
|
||||
|
||||
async def _shutdown_async(self) -> None:
|
||||
with self._state_lock:
|
||||
clients = list(self._clients.values())
|
||||
self._clients.clear()
|
||||
self._broken.clear()
|
||||
self._last_used.clear()
|
||||
await asyncio.gather(
|
||||
*(c.shutdown() for c in clients),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# status / introspection (used by ``hermes lsp status``)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""Return a snapshot of the service for the CLI status command."""
|
||||
with self._state_lock:
|
||||
clients = [
|
||||
{
|
||||
"server_id": k[0],
|
||||
"workspace_root": k[1],
|
||||
"state": c.state,
|
||||
"running": c.is_running,
|
||||
}
|
||||
for k, c in self._clients.items()
|
||||
]
|
||||
broken = list(self._broken)
|
||||
return {
|
||||
"enabled": self._enabled,
|
||||
"wait_mode": self._wait_mode,
|
||||
"wait_timeout": self._wait_timeout,
|
||||
"install_strategy": self._install_strategy,
|
||||
"clients": clients,
|
||||
"broken": broken,
|
||||
"disabled_servers": sorted(self._disabled_servers),
|
||||
}
|
||||
|
||||
|
||||
def _diag_key(d: Dict[str, Any]) -> str:
|
||||
"""Content equality key used for cross-edit delta filtering.
|
||||
|
||||
Includes the diagnostic's position range — when used together
|
||||
with :func:`agent.lsp.range_shift.shift_baseline`, the baseline
|
||||
is line-shifted into post-edit coordinates BEFORE this key is
|
||||
computed, so identical-but-shifted diagnostics hash equal. Two
|
||||
genuinely distinct diagnostics at different lines (e.g. the same
|
||||
error class introduced at a second site) hash differently and
|
||||
are surfaced as new.
|
||||
|
||||
Mirrors :func:`agent.lsp.client._diagnostic_key`; intentionally
|
||||
identical so the two layers agree on diagnostic identity.
|
||||
"""
|
||||
rng = d.get("range") or {}
|
||||
start = rng.get("start") or {}
|
||||
end = rng.get("end") or {}
|
||||
code = d.get("code")
|
||||
if code is not None and not isinstance(code, str):
|
||||
code = str(code)
|
||||
return "\x00".join(
|
||||
[
|
||||
str(d.get("severity") or 1),
|
||||
str(code or ""),
|
||||
str(d.get("source") or ""),
|
||||
str(d.get("message") or "").strip(),
|
||||
f"{start.get('line', 0)}:{start.get('character', 0)}-{end.get('line', 0)}:{end.get('character', 0)}",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["LSPService"]
|
||||
196
agent/lsp/protocol.py
Normal file
196
agent/lsp/protocol.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""Minimal LSP JSON-RPC 2.0 framer over async streams.
|
||||
|
||||
LSP wire format:
|
||||
|
||||
Content-Length: <bytes>\\r\\n
|
||||
\\r\\n
|
||||
<utf-8 JSON body>
|
||||
|
||||
The body is a JSON-RPC 2.0 envelope: request, response, or notification.
|
||||
|
||||
This module replaces what ``vscode-jsonrpc/node`` would do in a
|
||||
TypeScript implementation. We keep it deliberately small — just the
|
||||
framer + envelope helpers — so :class:`agent.lsp.client.LSPClient` can
|
||||
focus on protocol semantics.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger("agent.lsp.protocol")
|
||||
|
||||
# LSP error codes we care about. Full list in
|
||||
# https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#errorCodes
|
||||
ERROR_CONTENT_MODIFIED = -32801
|
||||
ERROR_REQUEST_CANCELLED = -32800
|
||||
ERROR_METHOD_NOT_FOUND = -32601
|
||||
|
||||
|
||||
class LSPProtocolError(Exception):
|
||||
"""Raised when the wire protocol is violated.
|
||||
|
||||
Distinct from :class:`LSPRequestError` which represents a server
|
||||
returning a JSON-RPC error response — that's protocol-conformant.
|
||||
This exception means the framing or envelope itself is broken.
|
||||
"""
|
||||
|
||||
|
||||
class LSPRequestError(Exception):
|
||||
"""Raised when an LSP request returns an error response.
|
||||
|
||||
Carries the JSON-RPC ``code``, ``message``, and optional ``data``.
|
||||
"""
|
||||
|
||||
def __init__(self, code: int, message: str, data: Any = None) -> None:
|
||||
super().__init__(f"LSP error {code}: {message}")
|
||||
self.code = code
|
||||
self.message = message
|
||||
self.data = data
|
||||
|
||||
|
||||
def encode_message(obj: dict) -> bytes:
|
||||
"""Encode a JSON-RPC envelope as a Content-Length framed byte string.
|
||||
|
||||
The body is encoded as compact UTF-8 JSON (no spaces between
|
||||
separators) — matches what ``vscode-jsonrpc`` emits and keeps the
|
||||
Content-Length count exact.
|
||||
"""
|
||||
body = json.dumps(obj, separators=(",", ":"), ensure_ascii=False).encode("utf-8")
|
||||
header = f"Content-Length: {len(body)}\r\n\r\n".encode("ascii")
|
||||
return header + body
|
||||
|
||||
|
||||
async def read_message(reader: asyncio.StreamReader) -> Optional[dict]:
|
||||
"""Read one Content-Length framed JSON-RPC message from the stream.
|
||||
|
||||
Returns ``None`` on clean EOF (server closed stdout cleanly between
|
||||
messages — typical shutdown). Raises :class:`LSPProtocolError` on
|
||||
malformed framing.
|
||||
|
||||
The reader is advanced to just past the JSON body on success.
|
||||
"""
|
||||
headers: dict = {}
|
||||
header_bytes = 0
|
||||
while True:
|
||||
try:
|
||||
line = await reader.readuntil(b"\r\n")
|
||||
except asyncio.IncompleteReadError as e:
|
||||
# EOF while reading headers. If we hadn't started a header
|
||||
# block, treat as clean EOF; otherwise the framing is bad.
|
||||
if not e.partial and not headers:
|
||||
return None
|
||||
raise LSPProtocolError(
|
||||
f"unexpected EOF while reading LSP headers (partial={e.partial!r})"
|
||||
) from e
|
||||
# Defensive cap against a server streaming headers without ever
|
||||
# emitting CRLF-CRLF. Caps total header bytes at 8 KiB — a
|
||||
# well-behaved server fits in well under 200 bytes.
|
||||
header_bytes += len(line)
|
||||
if header_bytes > 8192:
|
||||
raise LSPProtocolError(
|
||||
f"LSP header block exceeded 8 KiB without terminator"
|
||||
)
|
||||
line = line[:-2] # strip CRLF
|
||||
if not line:
|
||||
break # blank line ends header block
|
||||
try:
|
||||
key, _, value = line.decode("ascii").partition(":")
|
||||
except UnicodeDecodeError as e:
|
||||
raise LSPProtocolError(f"non-ASCII LSP header: {line!r}") from e
|
||||
if not key:
|
||||
raise LSPProtocolError(f"malformed LSP header line: {line!r}")
|
||||
headers[key.strip().lower()] = value.strip()
|
||||
|
||||
cl = headers.get("content-length")
|
||||
if cl is None:
|
||||
raise LSPProtocolError(f"LSP message missing Content-Length: {headers!r}")
|
||||
try:
|
||||
n = int(cl)
|
||||
except ValueError as e:
|
||||
raise LSPProtocolError(f"non-integer Content-Length: {cl!r}") from e
|
||||
if n < 0 or n > 64 * 1024 * 1024: # 64 MiB sanity cap
|
||||
raise LSPProtocolError(f"unreasonable Content-Length: {n}")
|
||||
|
||||
try:
|
||||
body = await reader.readexactly(n)
|
||||
except asyncio.IncompleteReadError as e:
|
||||
raise LSPProtocolError(
|
||||
f"truncated LSP body: expected {n} bytes, got {len(e.partial)}"
|
||||
) from e
|
||||
|
||||
try:
|
||||
return json.loads(body.decode("utf-8"))
|
||||
except json.JSONDecodeError as e:
|
||||
raise LSPProtocolError(f"invalid JSON in LSP body: {e}") from e
|
||||
except UnicodeDecodeError as e:
|
||||
raise LSPProtocolError(f"non-UTF-8 LSP body: {e}") from e
|
||||
|
||||
|
||||
def make_request(req_id: int, method: str, params: Any) -> dict:
|
||||
"""Build a JSON-RPC 2.0 request envelope."""
|
||||
msg: dict = {"jsonrpc": "2.0", "id": req_id, "method": method}
|
||||
if params is not None:
|
||||
msg["params"] = params
|
||||
return msg
|
||||
|
||||
|
||||
def make_notification(method: str, params: Any) -> dict:
|
||||
"""Build a JSON-RPC 2.0 notification envelope (no ``id``)."""
|
||||
msg: dict = {"jsonrpc": "2.0", "method": method}
|
||||
if params is not None:
|
||||
msg["params"] = params
|
||||
return msg
|
||||
|
||||
|
||||
def make_response(req_id: Any, result: Any) -> dict:
|
||||
"""Build a JSON-RPC 2.0 success response envelope."""
|
||||
return {"jsonrpc": "2.0", "id": req_id, "result": result}
|
||||
|
||||
|
||||
def make_error_response(req_id: Any, code: int, message: str, data: Any = None) -> dict:
|
||||
"""Build a JSON-RPC 2.0 error response envelope."""
|
||||
err: dict = {"code": code, "message": message}
|
||||
if data is not None:
|
||||
err["data"] = data
|
||||
return {"jsonrpc": "2.0", "id": req_id, "error": err}
|
||||
|
||||
|
||||
def classify_message(msg: dict) -> Tuple[str, Any]:
|
||||
"""Return ``(kind, key)`` where kind is one of ``request``,
|
||||
``response``, ``notification``, ``invalid``.
|
||||
|
||||
The key is the request id for request/response, the method name
|
||||
for notifications, and ``None`` for invalid messages.
|
||||
"""
|
||||
if not isinstance(msg, dict):
|
||||
return "invalid", None
|
||||
if msg.get("jsonrpc") != "2.0":
|
||||
return "invalid", None
|
||||
has_id = "id" in msg
|
||||
has_method = "method" in msg
|
||||
if has_id and has_method:
|
||||
return "request", msg["id"]
|
||||
if has_id and ("result" in msg or "error" in msg):
|
||||
return "response", msg["id"]
|
||||
if has_method and not has_id:
|
||||
return "notification", msg["method"]
|
||||
return "invalid", None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ERROR_CONTENT_MODIFIED",
|
||||
"ERROR_REQUEST_CANCELLED",
|
||||
"ERROR_METHOD_NOT_FOUND",
|
||||
"LSPProtocolError",
|
||||
"LSPRequestError",
|
||||
"encode_message",
|
||||
"read_message",
|
||||
"make_request",
|
||||
"make_notification",
|
||||
"make_response",
|
||||
"make_error_response",
|
||||
"classify_message",
|
||||
]
|
||||
149
agent/lsp/range_shift.py
Normal file
149
agent/lsp/range_shift.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""Diff-aware line-shift map for cross-edit LSP delta filtering.
|
||||
|
||||
When an edit deletes or inserts lines in the middle of a file, every
|
||||
diagnostic below the edit point shifts to a new line number. The
|
||||
LSPService delta filter subtracts the pre-edit baseline from the
|
||||
post-edit diagnostics keyed on ``(severity, code, source, message,
|
||||
range)`` — without an adjustment, the shifted-but-otherwise-identical
|
||||
diagnostics look brand-new and the agent gets flooded with noise.
|
||||
|
||||
The fix used here is the same trick git's blame and unified diff use:
|
||||
build a piecewise-linear map from pre-edit line numbers to post-edit
|
||||
line numbers, then apply that map to baseline diagnostics before the
|
||||
set-difference. Diagnostics whose pre-edit line is in a region the
|
||||
edit deleted return ``None`` and are dropped from the baseline (they
|
||||
genuinely no longer apply).
|
||||
|
||||
Trade-off vs. dropping range from the key entirely (the previous
|
||||
fix): preserves the "new instance of an identical error at a
|
||||
different line" signal — if the model introduces a second instance
|
||||
of the same error class at a different location, that one will be
|
||||
surfaced as new instead of swallowed by content-only dedup.
|
||||
|
||||
The map is derived from ``difflib.SequenceMatcher.get_opcodes()`` and
|
||||
exposed as a single callable so callers don't have to reason about
|
||||
diff regions.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import difflib
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
|
||||
def build_line_shift(pre_text: str, post_text: str) -> Callable[[int], Optional[int]]:
|
||||
"""Build a function mapping pre-edit line numbers to post-edit line numbers.
|
||||
|
||||
Lines are 0-indexed to match the LSP wire format
|
||||
(``range.start.line`` is 0-indexed).
|
||||
|
||||
The returned callable takes a pre-edit 0-indexed line number and
|
||||
returns the corresponding post-edit 0-indexed line number, or
|
||||
``None`` if that line was deleted by the edit (no post-edit
|
||||
counterpart exists).
|
||||
|
||||
Cost: one ``SequenceMatcher.get_opcodes()`` call up front; the
|
||||
returned closure is O(log n) per call (binary search over opcode
|
||||
regions). Cheap enough to call once per write/patch and apply to
|
||||
every baseline diagnostic.
|
||||
"""
|
||||
pre_lines = pre_text.splitlines() if pre_text else []
|
||||
post_lines = post_text.splitlines() if post_text else []
|
||||
|
||||
# Trivial case: identical content or no content — identity map.
|
||||
if pre_lines == post_lines:
|
||||
return lambda line: line
|
||||
|
||||
# SequenceMatcher.get_opcodes() returns a list of
|
||||
# (tag, i1, i2, j1, j2) where tag is 'equal', 'replace', 'delete',
|
||||
# or 'insert'. i1:i2 is the range in pre, j1:j2 is the range in
|
||||
# post. We build a list of (i1, i2, j1, j2, tag) tuples and
|
||||
# binary-search by i for each lookup.
|
||||
sm = difflib.SequenceMatcher(a=pre_lines, b=post_lines, autojunk=False)
|
||||
opcodes = sm.get_opcodes()
|
||||
|
||||
def shift(line: int) -> Optional[int]:
|
||||
# Find the opcode region whose i1 <= line < i2.
|
||||
# Linear scan is fine — typical opcode count is small (single
|
||||
# digits for a typical patch-tool edit).
|
||||
for tag, i1, i2, j1, j2 in opcodes:
|
||||
if i1 <= line < i2:
|
||||
if tag == "equal":
|
||||
# Pre-line N → post-line (N - i1 + j1).
|
||||
return line - i1 + j1
|
||||
if tag == "delete":
|
||||
# Pre-line is in a deleted region — no post counterpart.
|
||||
return None
|
||||
if tag == "replace":
|
||||
# Replace == delete + insert; the pre-line has no
|
||||
# post counterpart in any meaningful sense. Drop.
|
||||
return None
|
||||
# 'insert' has i1 == i2 so line < i2 can't be hit.
|
||||
if line < i1:
|
||||
# Past the relevant region — handled in earlier iteration.
|
||||
break
|
||||
# Past the last opcode region (line >= len(pre_lines)).
|
||||
# Anchor at end of post.
|
||||
return max(0, len(post_lines) - 1) if post_lines else None
|
||||
|
||||
return shift
|
||||
|
||||
|
||||
def shift_diagnostic_range(diag: Dict[str, Any],
|
||||
shift: Callable[[int], Optional[int]]) -> Optional[Dict[str, Any]]:
|
||||
"""Return a copy of ``diag`` with its line range remapped through ``shift``.
|
||||
|
||||
Returns ``None`` if the diagnostic's start line maps to ``None``
|
||||
(the line was deleted by the edit) — caller drops it from the
|
||||
baseline since the diagnostic no longer applies.
|
||||
|
||||
Both ``start.line`` and ``end.line`` are remapped independently;
|
||||
when only the end maps to ``None`` (rare, multi-line diagnostic
|
||||
straddling the edit boundary) we collapse to a single-line range
|
||||
at the shifted start to keep the diagnostic in the baseline.
|
||||
|
||||
The original ``diag`` is not mutated.
|
||||
"""
|
||||
rng = diag.get("range") or {}
|
||||
start = rng.get("start") or {}
|
||||
end = rng.get("end") or {}
|
||||
|
||||
pre_start_line = int(start.get("line", 0))
|
||||
pre_end_line = int(end.get("line", pre_start_line))
|
||||
|
||||
new_start_line = shift(pre_start_line)
|
||||
if new_start_line is None:
|
||||
return None
|
||||
|
||||
new_end_line = shift(pre_end_line)
|
||||
if new_end_line is None:
|
||||
# Diagnostic straddled the deletion — collapse to start.
|
||||
new_end_line = new_start_line
|
||||
|
||||
shifted = dict(diag)
|
||||
shifted["range"] = {
|
||||
"start": {
|
||||
"line": new_start_line,
|
||||
"character": int(start.get("character", 0)),
|
||||
},
|
||||
"end": {
|
||||
"line": new_end_line,
|
||||
"character": int(end.get("character", 0)),
|
||||
},
|
||||
}
|
||||
return shifted
|
||||
|
||||
|
||||
def shift_baseline(baseline: List[Dict[str, Any]],
|
||||
shift: Callable[[int], Optional[int]]) -> List[Dict[str, Any]]:
|
||||
"""Apply ``shift`` to every diagnostic in ``baseline``, dropping deleted entries."""
|
||||
out: List[Dict[str, Any]] = []
|
||||
for d in baseline:
|
||||
if not isinstance(d, dict):
|
||||
continue
|
||||
shifted = shift_diagnostic_range(d, shift)
|
||||
if shifted is not None:
|
||||
out.append(shifted)
|
||||
return out
|
||||
|
||||
|
||||
__all__ = ["build_line_shift", "shift_diagnostic_range", "shift_baseline"]
|
||||
78
agent/lsp/reporter.py
Normal file
78
agent/lsp/reporter.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Format LSP diagnostics for inclusion in tool output.
|
||||
|
||||
The model sees a compact, severity-filtered, line-bounded summary of
|
||||
diagnostics introduced by the latest edit. Format matches what
|
||||
OpenCode's ``lsp/diagnostic.ts`` and Claude Code's
|
||||
``formatDiagnosticsSummary`` produce — ``<diagnostics>`` blocks with
|
||||
1-indexed line/column, capped at ``MAX_PER_FILE`` errors.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
# Severity-1 only by default — warnings/info/hints would flood the
|
||||
# agent. Lift this in config under ``lsp.severities`` if needed.
|
||||
SEVERITY_NAMES = {1: "ERROR", 2: "WARN", 3: "INFO", 4: "HINT"}
|
||||
DEFAULT_SEVERITIES = frozenset({1}) # ERROR only
|
||||
|
||||
MAX_PER_FILE = 20
|
||||
MAX_TOTAL_CHARS = 4000
|
||||
|
||||
|
||||
def format_diagnostic(d: Dict[str, Any]) -> str:
|
||||
"""One-line representation of a single diagnostic."""
|
||||
sev = SEVERITY_NAMES.get(d.get("severity") or 1, "ERROR")
|
||||
rng = d.get("range") or {}
|
||||
start = rng.get("start") or {}
|
||||
line = int(start.get("line", 0)) + 1
|
||||
col = int(start.get("character", 0)) + 1
|
||||
msg = str(d.get("message") or "").rstrip()
|
||||
code = d.get("code")
|
||||
code_part = f" [{code}]" if code not in (None, "") else ""
|
||||
source = d.get("source")
|
||||
source_part = f" ({source})" if source else ""
|
||||
return f"{sev} [{line}:{col}] {msg}{code_part}{source_part}"
|
||||
|
||||
|
||||
def report_for_file(
|
||||
file_path: str,
|
||||
diagnostics: List[Dict[str, Any]],
|
||||
*,
|
||||
severities: frozenset = DEFAULT_SEVERITIES,
|
||||
max_per_file: int = MAX_PER_FILE,
|
||||
) -> str:
|
||||
"""Build a ``<diagnostics file=...>`` block for one file.
|
||||
|
||||
Returns an empty string when no diagnostics pass the severity
|
||||
filter, so callers can do ``if block:`` to skip empty cases.
|
||||
"""
|
||||
if not diagnostics:
|
||||
return ""
|
||||
filtered = [d for d in diagnostics if (d.get("severity") or 1) in severities]
|
||||
if not filtered:
|
||||
return ""
|
||||
limited = filtered[:max_per_file]
|
||||
extra = len(filtered) - len(limited)
|
||||
lines = [format_diagnostic(d) for d in limited]
|
||||
body = "\n".join(lines)
|
||||
if extra > 0:
|
||||
body += f"\n... and {extra} more"
|
||||
return f"<diagnostics file=\"{file_path}\">\n{body}\n</diagnostics>"
|
||||
|
||||
|
||||
def truncate(s: str, *, limit: int = MAX_TOTAL_CHARS) -> str:
|
||||
"""Hard-cap a formatted summary string."""
|
||||
if len(s) <= limit:
|
||||
return s
|
||||
marker = "\n…[truncated]"
|
||||
return s[: limit - len(marker)] + marker
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SEVERITY_NAMES",
|
||||
"DEFAULT_SEVERITIES",
|
||||
"MAX_PER_FILE",
|
||||
"format_diagnostic",
|
||||
"report_for_file",
|
||||
"truncate",
|
||||
]
|
||||
1040
agent/lsp/servers.py
Normal file
1040
agent/lsp/servers.py
Normal file
File diff suppressed because it is too large
Load Diff
223
agent/lsp/workspace.py
Normal file
223
agent/lsp/workspace.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""Workspace and project-root resolution for LSP.
|
||||
|
||||
Two concerns live here:
|
||||
|
||||
1. **Workspace gate** — the upper-level "is this directory a project?"
|
||||
check. Hermes only runs LSP when the cwd (or the file being edited)
|
||||
sits inside a git worktree. Files outside any git root never
|
||||
trigger LSP, even if a server is configured. This keeps Telegram
|
||||
gateway users on user-home cwd's from spawning daemons.
|
||||
|
||||
2. **NearestRoot** — the per-server project-root walk. Each language
|
||||
server cares about a different marker (``pyproject.toml`` for
|
||||
Python, ``Cargo.toml`` for Rust, ``go.mod`` for Go, etc.) and
|
||||
wants the directory containing that marker. ``nearest_root()``
|
||||
walks up from a starting path looking for any of a list of marker
|
||||
files, optionally bailing if an exclude marker shows up first.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger("agent.lsp.workspace")
|
||||
|
||||
# Cache: cwd → (worktree_root, is_git) so repeated calls don't re-stat.
|
||||
# Cleared on shutdown. Keyed by absolute resolved path so symlink
|
||||
# folds collapse to one entry.
|
||||
_workspace_cache: dict = {}
|
||||
|
||||
|
||||
def normalize_path(path: str) -> str:
|
||||
"""Normalize a path for use as a stable map key.
|
||||
|
||||
Resolves ``~``, makes absolute, and collapses ``.``/``..``. We do
|
||||
NOT resolve symlinks here — symlink stability matters for some
|
||||
LSP servers (rust-analyzer cares about Cargo workspace identity)
|
||||
and we want the canonical path the user typed when possible.
|
||||
"""
|
||||
return os.path.abspath(os.path.expanduser(path))
|
||||
|
||||
|
||||
def find_git_worktree(start: str) -> Optional[str]:
|
||||
"""Walk up from ``start`` looking for a ``.git`` entry (file or dir).
|
||||
|
||||
Returns the directory containing ``.git``, or ``None`` if no git
|
||||
root is found before hitting the filesystem root.
|
||||
|
||||
A ``.git`` *file* (not directory) means we're inside a git
|
||||
worktree set up via ``git worktree add`` — both forms count.
|
||||
"""
|
||||
try:
|
||||
start_path = Path(normalize_path(start))
|
||||
if start_path.is_file():
|
||||
start_path = start_path.parent
|
||||
except (OSError, RuntimeError, ValueError):
|
||||
# Pathological input (loop in symlinks, encoding error, etc.) —
|
||||
# bail out rather than crash the lint hook.
|
||||
return None
|
||||
|
||||
# Cache check
|
||||
cached = _workspace_cache.get(str(start_path))
|
||||
if cached is not None:
|
||||
root, _is_git = cached
|
||||
return root
|
||||
|
||||
cur = start_path
|
||||
# Defensive cap: the deepest reasonable monorepo is well under 64
|
||||
# levels. Caps the walk so a pathological cwd or a symlink cycle
|
||||
# we somehow traverse can't keep us looping.
|
||||
for _ in range(64):
|
||||
git_marker = cur / ".git"
|
||||
try:
|
||||
if git_marker.exists():
|
||||
resolved = str(cur)
|
||||
_workspace_cache[str(start_path)] = (resolved, True)
|
||||
return resolved
|
||||
except OSError:
|
||||
# Permission error on a parent dir — bail out cleanly.
|
||||
break
|
||||
parent = cur.parent
|
||||
if parent == cur:
|
||||
break
|
||||
cur = parent
|
||||
|
||||
_workspace_cache[str(start_path)] = (None, False)
|
||||
return None
|
||||
|
||||
|
||||
def is_inside_workspace(path: str, workspace_root: str) -> bool:
|
||||
"""Return True iff ``path`` is inside (or equal to) ``workspace_root``.
|
||||
|
||||
Uses absolute paths but does not resolve symlinks — a file accessed
|
||||
via a symlink that points outside the workspace still counts as
|
||||
outside. This is the conservative interpretation; matches LSP
|
||||
behaviour where servers reject didOpen for unrelated files.
|
||||
"""
|
||||
p = normalize_path(path)
|
||||
root = normalize_path(workspace_root)
|
||||
if p == root:
|
||||
return True
|
||||
# Use os.path.commonpath to handle case-insensitive filesystems
|
||||
# correctly on macOS/Windows.
|
||||
try:
|
||||
common = os.path.commonpath([p, root])
|
||||
except ValueError:
|
||||
# Different drives on Windows.
|
||||
return False
|
||||
return common == root
|
||||
|
||||
|
||||
def nearest_root(
|
||||
start: str,
|
||||
markers: Iterable[str],
|
||||
*,
|
||||
excludes: Optional[Iterable[str]] = None,
|
||||
ceiling: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""Walk up from ``start`` looking for any of the given marker files.
|
||||
|
||||
Returns the **directory containing** the first matched marker, or
|
||||
``None`` if no marker is found before hitting ``ceiling`` (or the
|
||||
filesystem root if no ceiling).
|
||||
|
||||
If ``excludes`` is provided and an exclude marker matches *first*
|
||||
in the upward walk, returns ``None`` — the server is gated off
|
||||
for that file. Mirrors OpenCode's NearestRoot exclude semantics
|
||||
(e.g. typescript skips deno projects when ``deno.json`` is found
|
||||
before ``package.json``).
|
||||
"""
|
||||
start_path = Path(normalize_path(start))
|
||||
try:
|
||||
if start_path.is_file():
|
||||
start_path = start_path.parent
|
||||
except (OSError, RuntimeError, ValueError):
|
||||
return None
|
||||
ceiling_path = Path(normalize_path(ceiling)) if ceiling else None
|
||||
|
||||
markers_list = list(markers)
|
||||
excludes_list = list(excludes) if excludes else []
|
||||
|
||||
cur = start_path
|
||||
# Defensive cap matching ``find_git_worktree``. Bounded walk
|
||||
# protects against pathological inputs even though the
|
||||
# parent-equality stop normally terminates within ~10 steps.
|
||||
for _ in range(64):
|
||||
# Check excludes first — if an exclude is found at this level,
|
||||
# the server is gated off for this file.
|
||||
for exc in excludes_list:
|
||||
try:
|
||||
if (cur / exc).exists():
|
||||
return None
|
||||
except OSError:
|
||||
continue
|
||||
# Then check markers.
|
||||
for marker in markers_list:
|
||||
try:
|
||||
if (cur / marker).exists():
|
||||
return str(cur)
|
||||
except OSError:
|
||||
continue
|
||||
# Stop conditions.
|
||||
if ceiling_path is not None and cur == ceiling_path:
|
||||
return None
|
||||
parent = cur.parent
|
||||
if parent == cur:
|
||||
return None
|
||||
cur = parent
|
||||
return None
|
||||
|
||||
|
||||
def resolve_workspace_for_file(
|
||||
file_path: str,
|
||||
*,
|
||||
cwd: Optional[str] = None,
|
||||
) -> Tuple[Optional[str], bool]:
|
||||
"""Resolve the workspace root for a file.
|
||||
|
||||
Returns ``(workspace_root, gated_in)`` where ``gated_in`` is True
|
||||
iff LSP should run for this file at all. Currently the gate is
|
||||
"file is inside a git worktree found by walking up from cwd OR
|
||||
from the file itself".
|
||||
|
||||
The cwd path takes precedence — if the agent was launched in a
|
||||
git project, that worktree is the workspace, and any edit inside
|
||||
it (regardless of where the file lives) is in-scope. If the cwd
|
||||
isn't in a git worktree, we try the file's own location as a
|
||||
fallback.
|
||||
|
||||
Returns ``(None, False)`` when neither path is in a git worktree.
|
||||
"""
|
||||
cwd = cwd or os.getcwd()
|
||||
cwd_root = find_git_worktree(cwd)
|
||||
if cwd_root is not None:
|
||||
if is_inside_workspace(file_path, cwd_root):
|
||||
return cwd_root, True
|
||||
# File is outside the cwd's worktree — try the file's own
|
||||
# location as a secondary anchor. Useful for monorepos where
|
||||
# the user opens an unrelated checkout.
|
||||
file_root = find_git_worktree(file_path)
|
||||
if file_root is not None:
|
||||
return file_root, True
|
||||
return None, False
|
||||
|
||||
|
||||
def clear_cache() -> None:
|
||||
"""Clear the workspace-resolution cache.
|
||||
|
||||
Called on service shutdown so a subsequent re-init doesn't pick
|
||||
up stale results from a previous session.
|
||||
"""
|
||||
_workspace_cache.clear()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"find_git_worktree",
|
||||
"is_inside_workspace",
|
||||
"nearest_root",
|
||||
"normalize_path",
|
||||
"resolve_workspace_for_file",
|
||||
"clear_cache",
|
||||
]
|
||||
@@ -10,7 +10,7 @@ import os
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
@@ -47,7 +47,7 @@ def _resolve_requests_verify() -> bool | str:
|
||||
_PROVIDER_PREFIXES: frozenset[str] = frozenset({
|
||||
"openrouter", "nous", "openai-codex", "copilot", "copilot-acp",
|
||||
"gemini", "ollama-cloud", "zai", "kimi-coding", "kimi-coding-cn", "stepfun", "minimax", "minimax-oauth", "minimax-cn", "anthropic", "deepseek",
|
||||
"opencode-zen", "opencode-go", "ai-gateway", "kilocode", "alibaba",
|
||||
"opencode-zen", "opencode-go", "ai-gateway", "kilocode", "alibaba", "novita",
|
||||
"qwen-oauth",
|
||||
"xiaomi",
|
||||
"arcee",
|
||||
@@ -66,7 +66,7 @@ _PROVIDER_PREFIXES: frozenset[str] = frozenset({
|
||||
"gmi-cloud", "gmicloud",
|
||||
"xai", "x-ai", "x.ai", "grok",
|
||||
"nvidia", "nim", "nvidia-nim", "nemotron",
|
||||
"qwen-portal",
|
||||
"qwen-portal", "novita-ai", "novitaai",
|
||||
})
|
||||
|
||||
|
||||
@@ -104,6 +104,8 @@ def _strip_provider_prefix(model: str) -> str:
|
||||
|
||||
_model_metadata_cache: Dict[str, Dict[str, Any]] = {}
|
||||
_model_metadata_cache_time: float = 0
|
||||
_novita_metadata_cache: Dict[str, Dict[str, Any]] = {}
|
||||
_novita_metadata_cache_time: float = 0
|
||||
_MODEL_CACHE_TTL = 3600
|
||||
_endpoint_model_metadata_cache: Dict[str, Dict[str, Dict[str, Any]]] = {}
|
||||
_endpoint_model_metadata_cache_time: Dict[str, float] = {}
|
||||
@@ -285,6 +287,7 @@ def grok_supports_reasoning_effort(model: str) -> bool:
|
||||
_CONTEXT_LENGTH_KEYS = (
|
||||
"context_length",
|
||||
"context_window",
|
||||
"context_size",
|
||||
"max_context_length",
|
||||
"max_position_embeddings",
|
||||
"max_model_len",
|
||||
@@ -361,6 +364,7 @@ _URL_TO_PROVIDER: Dict[str, str] = {
|
||||
"api.xiaomimimo.com": "xiaomi",
|
||||
"xiaomimimo.com": "xiaomi",
|
||||
"api.gmi-serving.com": "gmi",
|
||||
"api.novita.ai": "novita",
|
||||
"tokenhub.tencentmaas.com": "tencent-tokenhub",
|
||||
"ollama.com": "ollama-cloud",
|
||||
}
|
||||
@@ -557,6 +561,16 @@ def _extract_max_completion_tokens(payload: Dict[str, Any]) -> Optional[int]:
|
||||
|
||||
|
||||
def _extract_pricing(payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
novita_input = payload.get("input_token_price_per_m")
|
||||
novita_output = payload.get("output_token_price_per_m")
|
||||
if novita_input is not None or novita_output is not None:
|
||||
pricing: Dict[str, Any] = {}
|
||||
if novita_input is not None:
|
||||
pricing["prompt"] = str(float(novita_input) / 10_000 / 1_000_000)
|
||||
if novita_output is not None:
|
||||
pricing["completion"] = str(float(novita_output) / 10_000 / 1_000_000)
|
||||
return pricing
|
||||
|
||||
alias_map = {
|
||||
"prompt": ("prompt", "input", "input_cost_per_token", "prompt_token_cost"),
|
||||
"completion": ("completion", "output", "output_cost_per_token", "completion_token_cost"),
|
||||
@@ -1330,21 +1344,40 @@ def _resolve_codex_oauth_context_length(
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_nous_context_length(model: str) -> Optional[int]:
|
||||
"""Resolve Nous Portal model context length via OpenRouter metadata.
|
||||
def _resolve_nous_context_length(
|
||||
model: str,
|
||||
base_url: str = "",
|
||||
api_key: str = "",
|
||||
) -> Tuple[Optional[int], str]:
|
||||
"""Resolve Nous Portal model context length.
|
||||
|
||||
Nous model IDs are bare (e.g. 'claude-opus-4-6') while OpenRouter uses
|
||||
prefixed IDs (e.g. 'anthropic/claude-opus-4.6'). Try suffix matching
|
||||
with version normalization (dot↔dash).
|
||||
Tries the live Nous inference endpoint first (authoritative), then falls
|
||||
back to OpenRouter metadata with suffix/version matching.
|
||||
|
||||
Nous model IDs are bare after prefix-stripping (e.g. 'qwen3.6-plus',
|
||||
'claude-opus-4-6') while OpenRouter uses prefixed IDs (e.g.
|
||||
'qwen/qwen3.6-plus', 'anthropic/claude-opus-4.6'). Version
|
||||
normalization (dot↔dash) is applied to handle name drifts.
|
||||
|
||||
Returns ``(context_length, source)`` where ``source`` is one of:
|
||||
- ``"portal"`` — live /v1/models response (authoritative)
|
||||
- ``"openrouter"`` — OpenRouter cache fallback (non-authoritative;
|
||||
callers must NOT persist this to the on-disk cache or a single
|
||||
portal blip will freeze the wrong value in forever)
|
||||
- ``""`` — could not resolve
|
||||
"""
|
||||
metadata = fetch_model_metadata() # OpenRouter cache
|
||||
# Portal first — the Nous /models endpoint is authoritative for what our
|
||||
# infrastructure enforces and may differ from OR (e.g. OR reports 1M for
|
||||
# qwen3.6-plus; the portal correctly says 262144). Fall back to the OR
|
||||
# catalog only if the portal doesn't list the model.
|
||||
if base_url:
|
||||
portal_ctx = _resolve_endpoint_context_length(model, base_url, api_key=api_key)
|
||||
if portal_ctx is not None:
|
||||
return portal_ctx, "portal"
|
||||
|
||||
metadata = fetch_model_metadata()
|
||||
|
||||
def _safe_ctx(or_id: str, entry: dict) -> Optional[int]:
|
||||
"""Return context length, but reject stale 32k values for Kimi models.
|
||||
|
||||
Apply the same guard used for the generic OpenRouter path (step 6 in
|
||||
resolve_context_length) so the Nous portal path does not short-circuit it.
|
||||
"""
|
||||
ctx = entry.get("context_length")
|
||||
if ctx is None:
|
||||
return None
|
||||
@@ -1357,19 +1390,20 @@ def _resolve_nous_context_length(model: str) -> Optional[int]:
|
||||
return None
|
||||
return ctx
|
||||
|
||||
# Exact match first
|
||||
if model in metadata:
|
||||
return _safe_ctx(model, metadata[model])
|
||||
ctx = _safe_ctx(model, metadata[model])
|
||||
if ctx is not None:
|
||||
return ctx, "openrouter"
|
||||
|
||||
normalized = _normalize_model_version(model).lower()
|
||||
|
||||
for or_id, entry in metadata.items():
|
||||
bare = or_id.split("/", 1)[1] if "/" in or_id else or_id
|
||||
if bare.lower() == model.lower() or _normalize_model_version(bare).lower() == normalized:
|
||||
return _safe_ctx(or_id, entry)
|
||||
ctx = _safe_ctx(or_id, entry)
|
||||
if ctx is not None:
|
||||
return ctx, "openrouter"
|
||||
|
||||
# Partial prefix match for cases like gemini-3-flash → gemini-3-flash-preview
|
||||
# Require match to be at a word boundary (followed by -, :, or end of string)
|
||||
model_lower = model.lower()
|
||||
for or_id, entry in metadata.items():
|
||||
bare = or_id.split("/", 1)[1] if "/" in or_id else or_id
|
||||
@@ -1377,9 +1411,11 @@ def _resolve_nous_context_length(model: str) -> Optional[int]:
|
||||
if candidate.startswith(query) and (
|
||||
len(candidate) == len(query) or candidate[len(query)] in "-:."
|
||||
):
|
||||
return _safe_ctx(or_id, entry)
|
||||
ctx = _safe_ctx(or_id, entry)
|
||||
if ctx is not None:
|
||||
return ctx, "openrouter"
|
||||
|
||||
return None
|
||||
return None, ""
|
||||
|
||||
|
||||
def get_model_context_length(
|
||||
@@ -1394,14 +1430,18 @@ def get_model_context_length(
|
||||
|
||||
Resolution order:
|
||||
0. Explicit config override (model.context_length or custom_providers per-model)
|
||||
1. Persistent cache (previously discovered via probing)
|
||||
1. Persistent cache (previously discovered via probing). Nous URLs
|
||||
bypass the cache here so step 5b can always reconcile against
|
||||
the authoritative portal /v1/models response.
|
||||
1b. AWS Bedrock static table (must precede custom-endpoint probe)
|
||||
2. Active endpoint metadata (/models for explicit custom endpoints)
|
||||
3. Local server query (for local endpoints)
|
||||
4. Anthropic /v1/models API (API-key users only, not OAuth)
|
||||
5. Provider-aware lookups (before generic OpenRouter cache):
|
||||
a. Copilot live /models API
|
||||
b. Nous suffix-match via OpenRouter cache
|
||||
b. Nous: live /v1/models probe first (authoritative), then OR
|
||||
cache fallback with suffix/version normalisation. Only
|
||||
portal-derived values are persisted to disk.
|
||||
c. Codex OAuth /models probe
|
||||
d. GMI /models endpoint
|
||||
e. Ollama native /api/show probe (any base_url, provider-agnostic)
|
||||
@@ -1464,6 +1504,20 @@ def get_model_context_length(
|
||||
model, base_url, f"{cached:,}",
|
||||
)
|
||||
_invalidate_cached_context_length(model, base_url)
|
||||
# Nous Portal: the portal /v1/models endpoint is authoritative.
|
||||
# Bypass the persistent cache so step 5b can always reconcile
|
||||
# against it — this corrects pre-fix entries seeded from the
|
||||
# OR catalog (the same OR underreport class that the Kimi/Qwen
|
||||
# DEFAULT_CONTEXT_LENGTHS overrides exist to mitigate) without
|
||||
# touching the on-disk file when the portal is unreachable.
|
||||
# The in-memory 300s endpoint metadata cache makes the per-call
|
||||
# cost amortise to ~0 within a process.
|
||||
elif _infer_provider_from_url(base_url) == "nous":
|
||||
logger.debug(
|
||||
"Bypassing persistent cache for %s@%s (Nous portal authoritative)",
|
||||
model, base_url,
|
||||
)
|
||||
# Fall through; step 5b reconciles and overwrites if portal responds.
|
||||
else:
|
||||
return cached
|
||||
|
||||
@@ -1487,6 +1541,13 @@ def get_model_context_length(
|
||||
except ImportError:
|
||||
pass # boto3 not installed — fall through to generic resolution
|
||||
|
||||
if provider == "novita" or (base_url and base_url_host_matches(base_url, "api.novita.ai")):
|
||||
ctx = _resolve_endpoint_context_length(model, base_url or "https://api.novita.ai/openai/v1", api_key=api_key)
|
||||
if ctx is not None:
|
||||
if base_url:
|
||||
save_context_length(model, base_url, ctx)
|
||||
return ctx
|
||||
|
||||
# 2. Active endpoint metadata for truly custom/unknown endpoints.
|
||||
# Known providers (Copilot, OpenAI, Anthropic, etc.) skip this — their
|
||||
# /models endpoint may report a provider-imposed limit (e.g. Copilot
|
||||
@@ -1555,8 +1616,18 @@ def get_model_context_length(
|
||||
pass # Fall through to models.dev
|
||||
|
||||
if effective_provider == "nous":
|
||||
ctx = _resolve_nous_context_length(model)
|
||||
ctx, source = _resolve_nous_context_length(
|
||||
model, base_url=base_url or "", api_key=api_key or ""
|
||||
)
|
||||
if ctx:
|
||||
# Persist ONLY portal-derived values. Caching an OR-fallback
|
||||
# value here would freeze in a wrong number on the first portal
|
||||
# blip / auth glitch and step-1 would short-circuit it forever.
|
||||
# OR's catalog is community-maintained and is precisely why the
|
||||
# Kimi/Qwen DEFAULT_CONTEXT_LENGTHS overrides exist — we don't
|
||||
# want it leaking into the persistent cache for Nous URLs.
|
||||
if base_url and source == "portal":
|
||||
save_context_length(model, base_url, ctx)
|
||||
return ctx
|
||||
if effective_provider == "openai-codex":
|
||||
# Codex OAuth enforces lower context limits than the direct OpenAI
|
||||
|
||||
@@ -141,6 +141,7 @@ class ProviderInfo:
|
||||
# Hermes provider names → models.dev provider IDs
|
||||
PROVIDER_TO_MODELS_DEV: Dict[str, str] = {
|
||||
"openrouter": "openrouter",
|
||||
"novita": "novita-ai",
|
||||
"anthropic": "anthropic",
|
||||
"openai": "openai",
|
||||
"openai-codex": "openai",
|
||||
|
||||
64
agent/portal_tags.py
Normal file
64
agent/portal_tags.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Centralized Nous Portal request tags.
|
||||
|
||||
Every Hermes request that hits the Nous Portal — main agent loop, auxiliary
|
||||
client (compression / titles / vision / web_extract / session_search / etc.),
|
||||
and any future code path — must carry the same product-attribution tags so
|
||||
Nous can attribute usage to Hermes Agent and bucket it by client release.
|
||||
|
||||
Tag shape (sent in OpenAI-compatible ``extra_body['tags']``):
|
||||
|
||||
[
|
||||
"product=hermes-agent",
|
||||
"client=hermes-client-v<__version__>",
|
||||
]
|
||||
|
||||
The version is sourced live from ``hermes_cli.__version__`` so it auto-aligns
|
||||
to whatever release is installed; the release script
|
||||
(``scripts/release.py``) regex-bumps that single string, and every Portal
|
||||
request picks up the new tag on the next process start.
|
||||
|
||||
Why one helper instead of inlining the literal at each site:
|
||||
* Four call sites (main loop profile, aux client, run_agent compression
|
||||
fallback, web_tools fallback) used to drift apart — see PR #24194 which
|
||||
only got the aux site, leaving the main loop sending a different tag set.
|
||||
* Tests should assert the same tag list everywhere; centralizing makes that
|
||||
assertion a one-liner against this module.
|
||||
|
||||
Do NOT pre-compute these as module-level constants in the consumers. The
|
||||
version can change at runtime (editable installs, hot-reload tooling), and
|
||||
``hermes_cli.__version__`` is the canonical source of truth.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
|
||||
def _hermes_version() -> str:
|
||||
"""Return the current Hermes release version, e.g. ``"0.13.0"``.
|
||||
|
||||
Falls back to ``"unknown"`` if ``hermes_cli`` cannot be imported (should
|
||||
never happen in a real install — guarded for defensive testing).
|
||||
"""
|
||||
try:
|
||||
from hermes_cli import __version__
|
||||
return __version__
|
||||
except Exception:
|
||||
return "unknown"
|
||||
|
||||
|
||||
def hermes_client_tag() -> str:
|
||||
"""Return the ``client=...`` tag for Nous Portal requests.
|
||||
|
||||
Format: ``client=hermes-client-v<MAJOR>.<MINOR>.<PATCH>``.
|
||||
"""
|
||||
return f"client=hermes-client-v{_hermes_version()}"
|
||||
|
||||
|
||||
def nous_portal_tags() -> List[str]:
|
||||
"""Return the canonical list of Nous Portal product tags.
|
||||
|
||||
Always returns a fresh list so callers can mutate it freely
|
||||
(e.g. ``merged_extra.setdefault("tags", []).extend(nous_portal_tags())``).
|
||||
"""
|
||||
return ["product=hermes-agent", hermes_client_tag()]
|
||||
@@ -268,7 +268,7 @@ TOOL_USE_ENFORCEMENT_GUIDANCE = (
|
||||
|
||||
# Model name substrings that trigger tool-use enforcement guidance.
|
||||
# Add new patterns here when a model family needs explicit steering.
|
||||
TOOL_USE_ENFORCEMENT_MODELS = ("gpt", "codex", "gemini", "gemma", "grok")
|
||||
TOOL_USE_ENFORCEMENT_MODELS = ("gpt", "codex", "gemini", "gemma", "grok", "glm")
|
||||
|
||||
# OpenAI GPT/Codex-specific execution guidance. Addresses known failure modes
|
||||
# where GPT models abandon work on partial results, skip prerequisite lookups,
|
||||
|
||||
@@ -1,25 +1,15 @@
|
||||
"""Anthropic prompt caching strategies.
|
||||
"""Anthropic prompt caching strategy.
|
||||
|
||||
Two layouts:
|
||||
|
||||
* ``system_and_3`` (default, used everywhere except the long-lived path):
|
||||
4 cache_control breakpoints — system prompt + last 3 non-system messages.
|
||||
All at the same TTL (5m or 1h). Reduces input token costs by ~75% on
|
||||
multi-turn conversations within a single session.
|
||||
|
||||
* ``prefix_and_2`` (Claude on Anthropic / OpenRouter / Nous Portal):
|
||||
4 breakpoints split across two TTL tiers — tools[-1] (1h) +
|
||||
stable system prefix (1h) + last 2 non-system messages (5m). The
|
||||
long-lived prefix is byte-stable across sessions for a given user
|
||||
config, so every fresh session reads the cached system+tools instead
|
||||
of re-paying for them. Within-session rolling window shrinks from 3
|
||||
messages to 2 to free the breakpoint budget.
|
||||
Single layout: ``system_and_3``. 4 cache_control breakpoints — system
|
||||
prompt + last 3 non-system messages, all at the same TTL (5m or 1h).
|
||||
Reduces input token costs by ~75% on multi-turn conversations within a
|
||||
single session.
|
||||
|
||||
Pure functions -- no class state, no AIAgent dependency.
|
||||
"""
|
||||
|
||||
import copy
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
def _apply_cache_marker(msg: dict, cache_marker: dict, native_anthropic: bool = False) -> None:
|
||||
@@ -87,115 +77,3 @@ def apply_anthropic_cache_control(
|
||||
_apply_cache_marker(messages[idx], marker, native_anthropic=native_anthropic)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def _mark_system_stable_block(
|
||||
messages: List[Dict[str, Any]],
|
||||
long_lived_marker: Dict[str, str],
|
||||
) -> bool:
|
||||
"""Mark the *first* content block of the system message with the 1h marker.
|
||||
|
||||
The system message is expected to have been split into multiple content
|
||||
blocks beforehand by the caller — block[0] is the cross-session-stable
|
||||
prefix, subsequent blocks carry context files + volatile suffix.
|
||||
Falls back to marking the whole system message as a single block when
|
||||
the message hasn't been split (preserves correctness on the fallback path).
|
||||
|
||||
Returns True when a marker was placed.
|
||||
"""
|
||||
if not messages or messages[0].get("role") != "system":
|
||||
return False
|
||||
|
||||
sys_msg = messages[0]
|
||||
content = sys_msg.get("content")
|
||||
|
||||
# Already a list of blocks → mark the first block.
|
||||
if isinstance(content, list) and content:
|
||||
first = content[0]
|
||||
if isinstance(first, dict):
|
||||
first["cache_control"] = long_lived_marker
|
||||
return True
|
||||
return False
|
||||
|
||||
# String content (no split) → cannot place a stable-prefix breakpoint
|
||||
# without changing the byte content. Caller is responsible for
|
||||
# splitting; if they didn't, fall through to envelope marker so we still
|
||||
# cache *something* for this turn.
|
||||
if isinstance(content, str) and content:
|
||||
sys_msg["content"] = [
|
||||
{"type": "text", "text": content, "cache_control": long_lived_marker}
|
||||
]
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def apply_anthropic_cache_control_long_lived(
|
||||
api_messages: List[Dict[str, Any]],
|
||||
long_lived_ttl: str = "1h",
|
||||
rolling_ttl: str = "5m",
|
||||
native_anthropic: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Apply prefix_and_2 caching: long-lived stable prefix + rolling window.
|
||||
|
||||
Layout (4 breakpoints total):
|
||||
* Stable system prefix (block[0]) → ``long_lived_ttl`` TTL
|
||||
* Last 2 non-system messages → ``rolling_ttl`` TTL each
|
||||
|
||||
NOTE: this function does NOT mark the tools array. Tools cache_control
|
||||
is attached separately (see ``mark_tools_for_long_lived_cache``) because
|
||||
tools live outside the messages list in the API payload.
|
||||
|
||||
The caller MUST have split the system message into ordered content
|
||||
blocks where block[0] is the cross-session-stable portion. If the system
|
||||
message is still a single string, it is wrapped into a single block and
|
||||
marked — this is correct, just less effective (the volatile suffix is
|
||||
not isolated, so the prefix invalidates per-session).
|
||||
|
||||
Returns:
|
||||
Deep copy of messages with cache_control breakpoints injected.
|
||||
"""
|
||||
messages = copy.deepcopy(api_messages)
|
||||
if not messages:
|
||||
return messages
|
||||
|
||||
long_marker = _build_marker(long_lived_ttl)
|
||||
rolling_marker = _build_marker(rolling_ttl)
|
||||
|
||||
placed_prefix = _mark_system_stable_block(messages, long_marker)
|
||||
|
||||
# Reserve 1 breakpoint for the system prefix (when placed); spend the
|
||||
# remaining 3 on the rolling tail. Anthropic max is 4 total —
|
||||
# tools[-1] (when marked) consumes the 4th, so we cap rolling at 2 here.
|
||||
rolling_budget = 2 if placed_prefix else 3
|
||||
non_sys = [i for i in range(len(messages)) if messages[i].get("role") != "system"]
|
||||
for idx in non_sys[-rolling_budget:]:
|
||||
_apply_cache_marker(messages[idx], rolling_marker, native_anthropic=native_anthropic)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def mark_tools_for_long_lived_cache(
|
||||
tools: Optional[List[Dict[str, Any]]],
|
||||
long_lived_ttl: str = "1h",
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
"""Attach cache_control to the last tool in the OpenAI-format tools list.
|
||||
|
||||
Anthropic prefix-cache order is ``tools → system → messages``. Marking
|
||||
the last tool dict caches the entire tools array (Anthropic's docs:
|
||||
"the marker is placed on the last block you want included in the cached
|
||||
prefix"). Marker is preserved across the OpenAI-wire boundary on
|
||||
OpenRouter and Nous Portal (which proxies to OpenRouter); on native
|
||||
Anthropic the marker is forwarded by ``convert_tools_to_anthropic``.
|
||||
|
||||
Returns a deep copy of the tools list with the marker attached, or the
|
||||
input unchanged when tools is empty/None. Pure function — does not
|
||||
mutate the input.
|
||||
"""
|
||||
if not tools:
|
||||
return tools
|
||||
out = copy.deepcopy(tools)
|
||||
last = out[-1]
|
||||
if isinstance(last, dict):
|
||||
last["cache_control"] = _build_marker(long_lived_ttl)
|
||||
return out
|
||||
|
||||
@@ -14,6 +14,7 @@ from dataclasses import dataclass, field
|
||||
from typing import Any, Mapping
|
||||
|
||||
from utils import safe_json_loads
|
||||
from agent.tool_result_classification import file_mutation_result_landed
|
||||
|
||||
|
||||
IDEMPOTENT_TOOL_NAMES = frozenset(
|
||||
@@ -196,6 +197,8 @@ def classify_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str
|
||||
"""
|
||||
if result is None:
|
||||
return False, ""
|
||||
if file_mutation_result_landed(tool_name, result):
|
||||
return False, ""
|
||||
|
||||
if tool_name == "terminal":
|
||||
data = safe_json_loads(result)
|
||||
|
||||
26
agent/tool_result_classification.py
Normal file
26
agent/tool_result_classification.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Shared helpers for classifying tool result payloads."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
|
||||
FILE_MUTATING_TOOL_NAMES = frozenset({"write_file", "patch"})
|
||||
|
||||
|
||||
def file_mutation_result_landed(tool_name: str, result: Any) -> bool:
|
||||
"""Return True when a file mutation result proves the write landed."""
|
||||
if tool_name not in FILE_MUTATING_TOOL_NAMES or not isinstance(result, str):
|
||||
return False
|
||||
try:
|
||||
data = json.loads(result.strip())
|
||||
except Exception:
|
||||
return False
|
||||
if not isinstance(data, dict) or data.get("error"):
|
||||
return False
|
||||
if tool_name == "write_file":
|
||||
return "bytes_written" in data
|
||||
if tool_name == "patch":
|
||||
return data.get("success") is True
|
||||
return False
|
||||
368
agent/transports/codex_app_server.py
Normal file
368
agent/transports/codex_app_server.py
Normal file
@@ -0,0 +1,368 @@
|
||||
"""Codex app-server JSON-RPC client.
|
||||
|
||||
Speaks the protocol documented in codex-rs/app-server/README.md (codex 0.125+).
|
||||
Transport is newline-delimited JSON-RPC 2.0 over stdio: spawn `codex app-server`,
|
||||
do an `initialize` handshake, then drive `thread/start` + `turn/start` and
|
||||
consume streaming `item/*` notifications until `turn/completed`.
|
||||
|
||||
This module is the wire-level speaker only. Higher-level concerns (event
|
||||
projection into Hermes' display, approval bridging, transcript projection into
|
||||
AIAgent.messages, plugin migration) live in sibling modules.
|
||||
|
||||
Status: optional opt-in runtime gated behind `model.openai_runtime ==
|
||||
"codex_app_server"`. Hermes' default tool dispatch is unchanged when this
|
||||
runtime is not selected.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import queue
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
# Default minimum codex version we test against. The PR sets this from the
|
||||
# `codex --version` parsed at install time; bumping is a one-line change here.
|
||||
MIN_CODEX_VERSION = (0, 125, 0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodexAppServerError(RuntimeError):
|
||||
"""Raised on JSON-RPC errors from the app-server."""
|
||||
|
||||
code: int
|
||||
message: str
|
||||
data: Optional[Any] = None
|
||||
|
||||
def __str__(self) -> str: # pragma: no cover - trivial
|
||||
return f"codex app-server error {self.code}: {self.message}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Pending:
|
||||
queue: queue.Queue
|
||||
method: str
|
||||
sent_at: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
class CodexAppServerClient:
|
||||
"""Minimal JSON-RPC 2.0 client for `codex app-server` over stdio.
|
||||
|
||||
Threading model:
|
||||
- Spawning thread (caller) drives request/response pairs synchronously.
|
||||
- One reader thread parses stdout, dispatches replies to the right
|
||||
pending future, and routes notifications + server-initiated requests
|
||||
to bounded queues that the caller drains on their own cadence.
|
||||
- One reader thread captures stderr for diagnostics; codex emits
|
||||
tracing logs there at RUST_LOG-controlled levels.
|
||||
|
||||
Intentionally NOT async. AIAgent.run_conversation() is synchronous and
|
||||
runs on the main thread; layering asyncio just to drive a stdio child
|
||||
creates surprising interrupt semantics. We use blocking queues with
|
||||
timeouts and rely on `turn/interrupt` for cancellation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
codex_bin: str = "codex",
|
||||
codex_home: Optional[str] = None,
|
||||
extra_args: Optional[list[str]] = None,
|
||||
env: Optional[dict[str, str]] = None,
|
||||
) -> None:
|
||||
self._codex_bin = codex_bin
|
||||
cmd = [codex_bin, "app-server"] + list(extra_args or [])
|
||||
spawn_env = os.environ.copy()
|
||||
if env:
|
||||
spawn_env.update(env)
|
||||
if codex_home:
|
||||
spawn_env["CODEX_HOME"] = codex_home
|
||||
# Codex emits tracing to stderr; default WARN keeps it quiet for users.
|
||||
spawn_env.setdefault("RUST_LOG", "warn")
|
||||
|
||||
self._proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
bufsize=0,
|
||||
env=spawn_env,
|
||||
)
|
||||
self._next_id = 1
|
||||
self._pending: dict[int, _Pending] = {}
|
||||
self._pending_lock = threading.Lock()
|
||||
self._notifications: queue.Queue = queue.Queue()
|
||||
self._server_requests: queue.Queue = queue.Queue()
|
||||
self._stderr_lines: list[str] = []
|
||||
self._stderr_lock = threading.Lock()
|
||||
self._closed = False
|
||||
self._initialized = False
|
||||
|
||||
self._reader = threading.Thread(target=self._read_stdout, daemon=True)
|
||||
self._reader.start()
|
||||
self._stderr_reader = threading.Thread(target=self._read_stderr, daemon=True)
|
||||
self._stderr_reader.start()
|
||||
|
||||
# ---------- lifecycle ----------
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
client_name: str = "hermes",
|
||||
client_title: str = "Hermes Agent",
|
||||
client_version: str = "0.1",
|
||||
capabilities: Optional[dict] = None,
|
||||
timeout: float = 10.0,
|
||||
) -> dict:
|
||||
"""Send `initialize` + `initialized` handshake. Returns the server's
|
||||
InitializeResponse (userAgent, codexHome, platformFamily, platformOs)."""
|
||||
if self._initialized:
|
||||
raise RuntimeError("already initialized")
|
||||
params = {
|
||||
"clientInfo": {
|
||||
"name": client_name,
|
||||
"title": client_title,
|
||||
"version": client_version,
|
||||
},
|
||||
"capabilities": capabilities or {},
|
||||
}
|
||||
result = self.request("initialize", params, timeout=timeout)
|
||||
self.notify("initialized")
|
||||
self._initialized = True
|
||||
return result
|
||||
|
||||
def close(self, timeout: float = 3.0) -> None:
|
||||
"""Close stdin and wait for the subprocess to exit, escalating to kill."""
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
try:
|
||||
if self._proc.stdin and not self._proc.stdin.closed:
|
||||
self._proc.stdin.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self._proc.terminate()
|
||||
self._proc.wait(timeout=timeout)
|
||||
except subprocess.TimeoutExpired:
|
||||
try:
|
||||
self._proc.kill()
|
||||
self._proc.wait(timeout=1.0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def __enter__(self) -> "CodexAppServerClient":
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc: Any) -> None:
|
||||
self.close()
|
||||
|
||||
# ---------- send/receive ----------
|
||||
|
||||
def request(
|
||||
self,
|
||||
method: str,
|
||||
params: Optional[dict] = None,
|
||||
timeout: float = 30.0,
|
||||
) -> dict:
|
||||
"""Send a JSON-RPC request and block on the response. Returns `result`,
|
||||
raises CodexAppServerError on `error`."""
|
||||
rid = self._take_id()
|
||||
q: queue.Queue = queue.Queue(maxsize=1)
|
||||
with self._pending_lock:
|
||||
self._pending[rid] = _Pending(queue=q, method=method)
|
||||
self._send({"id": rid, "method": method, "params": params or {}})
|
||||
try:
|
||||
msg = q.get(timeout=timeout)
|
||||
except queue.Empty:
|
||||
with self._pending_lock:
|
||||
self._pending.pop(rid, None)
|
||||
raise TimeoutError(
|
||||
f"codex app-server method {method!r} timed out after {timeout}s"
|
||||
)
|
||||
if "error" in msg:
|
||||
err = msg["error"]
|
||||
raise CodexAppServerError(
|
||||
code=err.get("code", -1),
|
||||
message=err.get("message", ""),
|
||||
data=err.get("data"),
|
||||
)
|
||||
return msg.get("result", {})
|
||||
|
||||
def notify(self, method: str, params: Optional[dict] = None) -> None:
|
||||
"""Send a JSON-RPC notification (no id, no response expected)."""
|
||||
self._send({"method": method, "params": params or {}})
|
||||
|
||||
def respond(self, request_id: Any, result: dict) -> None:
|
||||
"""Reply to a server-initiated request (e.g. approval prompts)."""
|
||||
self._send({"id": request_id, "result": result})
|
||||
|
||||
def respond_error(
|
||||
self, request_id: Any, code: int, message: str, data: Optional[Any] = None
|
||||
) -> None:
|
||||
"""Reply to a server-initiated request with an error."""
|
||||
err: dict[str, Any] = {"code": code, "message": message}
|
||||
if data is not None:
|
||||
err["data"] = data
|
||||
self._send({"id": request_id, "error": err})
|
||||
|
||||
def take_notification(self, timeout: float = 0.0) -> Optional[dict]:
|
||||
"""Pop the next streaming notification, or return None on timeout.
|
||||
|
||||
timeout=0.0 means non-blocking. Use small positive timeouts inside the
|
||||
AIAgent turn loop to interleave reads with interrupt checks."""
|
||||
try:
|
||||
if timeout <= 0:
|
||||
return self._notifications.get_nowait()
|
||||
return self._notifications.get(timeout=timeout)
|
||||
except queue.Empty:
|
||||
return None
|
||||
|
||||
def take_server_request(self, timeout: float = 0.0) -> Optional[dict]:
|
||||
"""Pop the next server-initiated request (e.g. exec/applyPatch approval)."""
|
||||
try:
|
||||
if timeout <= 0:
|
||||
return self._server_requests.get_nowait()
|
||||
return self._server_requests.get(timeout=timeout)
|
||||
except queue.Empty:
|
||||
return None
|
||||
|
||||
# ---------- diagnostics ----------
|
||||
|
||||
def stderr_tail(self, n: int = 20) -> list[str]:
|
||||
"""Return last n lines of codex's stderr (for error reports)."""
|
||||
with self._stderr_lock:
|
||||
return list(self._stderr_lines[-n:])
|
||||
|
||||
def is_alive(self) -> bool:
|
||||
return self._proc.poll() is None
|
||||
|
||||
# ---------- internals ----------
|
||||
|
||||
def _take_id(self) -> int:
|
||||
# JSON-RPC ids only need to be unique per-connection. A simple
|
||||
# monotonically increasing int is the common choice and matches what
|
||||
# codex's own clients use.
|
||||
rid = self._next_id
|
||||
self._next_id += 1
|
||||
return rid
|
||||
|
||||
def _send(self, obj: dict) -> None:
|
||||
if self._closed:
|
||||
raise RuntimeError("codex app-server client is closed")
|
||||
if self._proc.stdin is None:
|
||||
raise RuntimeError("codex app-server stdin not available")
|
||||
try:
|
||||
self._proc.stdin.write((json.dumps(obj) + "\n").encode("utf-8"))
|
||||
self._proc.stdin.flush()
|
||||
except (BrokenPipeError, ValueError) as exc:
|
||||
raise RuntimeError(
|
||||
f"codex app-server stdin closed unexpectedly: {exc}"
|
||||
) from exc
|
||||
|
||||
def _read_stdout(self) -> None:
|
||||
if self._proc.stdout is None:
|
||||
return
|
||||
try:
|
||||
for line in iter(self._proc.stdout.readline, b""):
|
||||
if not line:
|
||||
break
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
msg = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
# Non-JSON output is unexpected on stdout; tracing belongs
|
||||
# on stderr. Surface it via stderr buffer for diagnostics.
|
||||
with self._stderr_lock:
|
||||
self._stderr_lines.append(
|
||||
f"<non-json on stdout> {line[:200]!r}"
|
||||
)
|
||||
continue
|
||||
self._dispatch(msg)
|
||||
except Exception as exc:
|
||||
with self._stderr_lock:
|
||||
self._stderr_lines.append(f"<stdout reader error> {exc}")
|
||||
|
||||
def _dispatch(self, msg: dict) -> None:
|
||||
# Reply (has id + result/error, no method)
|
||||
if "id" in msg and ("result" in msg or "error" in msg):
|
||||
with self._pending_lock:
|
||||
pending = self._pending.pop(msg["id"], None)
|
||||
if pending is not None:
|
||||
try:
|
||||
pending.queue.put_nowait(msg)
|
||||
except queue.Full: # pragma: no cover - defensive
|
||||
pass
|
||||
return
|
||||
# Server-initiated request (has id + method)
|
||||
if "id" in msg and "method" in msg:
|
||||
self._server_requests.put(msg)
|
||||
return
|
||||
# Notification (no id)
|
||||
if "method" in msg:
|
||||
self._notifications.put(msg)
|
||||
|
||||
def _read_stderr(self) -> None:
|
||||
if self._proc.stderr is None:
|
||||
return
|
||||
try:
|
||||
for line in iter(self._proc.stderr.readline, b""):
|
||||
if not line:
|
||||
break
|
||||
with self._stderr_lock:
|
||||
self._stderr_lines.append(
|
||||
line.decode("utf-8", "replace").rstrip()
|
||||
)
|
||||
# Bound memory: keep last 500 lines.
|
||||
if len(self._stderr_lines) > 500:
|
||||
self._stderr_lines = self._stderr_lines[-500:]
|
||||
except Exception: # pragma: no cover
|
||||
pass
|
||||
|
||||
|
||||
def parse_codex_version(output: str) -> Optional[tuple[int, int, int]]:
|
||||
"""Parse `codex --version` output. Returns (major, minor, patch) or None."""
|
||||
# Output format: "codex-cli 0.130.0" possibly followed by metadata.
|
||||
import re
|
||||
|
||||
match = re.search(r"(\d+)\.(\d+)\.(\d+)", output or "")
|
||||
if not match:
|
||||
return None
|
||||
return (int(match.group(1)), int(match.group(2)), int(match.group(3)))
|
||||
|
||||
|
||||
def check_codex_binary(
|
||||
codex_bin: str = "codex", min_version: tuple[int, int, int] = MIN_CODEX_VERSION
|
||||
) -> tuple[bool, str]:
|
||||
"""Verify codex CLI is installed and meets minimum version.
|
||||
|
||||
Returns (ok, message). Used by setup wizard and runtime startup."""
|
||||
try:
|
||||
proc = subprocess.run(
|
||||
[codex_bin, "--version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
return False, (
|
||||
f"codex CLI not found at {codex_bin!r}. Install with: "
|
||||
f"npm i -g @openai/codex"
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
return False, "codex --version timed out"
|
||||
if proc.returncode != 0:
|
||||
return False, f"codex --version exited {proc.returncode}: {proc.stderr.strip()}"
|
||||
version = parse_codex_version(proc.stdout)
|
||||
if version is None:
|
||||
return False, f"could not parse codex version from: {proc.stdout!r}"
|
||||
if version < min_version:
|
||||
return False, (
|
||||
f"codex {'.'.join(map(str, version))} is older than required "
|
||||
f"{'.'.join(map(str, min_version))}. Run: npm i -g @openai/codex"
|
||||
)
|
||||
return True, ".".join(map(str, version))
|
||||
810
agent/transports/codex_app_server_session.py
Normal file
810
agent/transports/codex_app_server_session.py
Normal file
@@ -0,0 +1,810 @@
|
||||
"""Session adapter for codex app-server runtime.
|
||||
|
||||
Owns one Codex thread per Hermes session. Drives `turn/start`, consumes
|
||||
streaming notifications via CodexEventProjector, handles server-initiated
|
||||
approval requests (apply_patch, exec command), translates cancellation,
|
||||
and returns a clean turn result that AIAgent.run_conversation() can splice
|
||||
into its `messages` list.
|
||||
|
||||
Lifecycle:
|
||||
session = CodexAppServerSession(cwd="/home/x/proj")
|
||||
session.ensure_started() # spawns + handshake + thread/start
|
||||
result = session.run_turn(user_input="hello") # blocks until turn/completed
|
||||
# result.final_text → assistant text returned to caller
|
||||
# result.projected_messages → list of {role, content, ...} for messages list
|
||||
# result.tool_iterations → how many tool-shaped items completed (skill nudge counter)
|
||||
# result.interrupted → True if Ctrl+C / interrupt_requested fired mid-turn
|
||||
session.close() # tears down subprocess
|
||||
|
||||
Threading model: the adapter is single-threaded from the caller's perspective.
|
||||
The underlying CodexAppServerClient owns its own reader threads but exposes
|
||||
blocking-with-timeout queues that this adapter polls in a loop, so the run_turn
|
||||
call is synchronous and behaves like AIAgent's existing chat_completions loop.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from agent.redact import redact_sensitive_text
|
||||
from agent.transports.codex_app_server import (
|
||||
CodexAppServerClient,
|
||||
CodexAppServerError,
|
||||
)
|
||||
from agent.transports.codex_event_projector import CodexEventProjector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# How many tailing stderr lines from the codex subprocess to attach to a
|
||||
# user-facing error when we don't have a more specific classification (OAuth,
|
||||
# wedge watchdog, etc.). Small enough to keep error messages legible, large
|
||||
# enough to surface a config/provider/auth diagnostic.
|
||||
_STDERR_TAIL_LINES = 12
|
||||
|
||||
|
||||
# Permission profile mapping mirrors the docstring in PR proposal:
|
||||
# Hermes' tools.terminal.security_mode → Codex's permissions profile id.
|
||||
# Defaults if config is missing → workspace-write (matches Codex's own default).
|
||||
_HERMES_TO_CODEX_PERMISSION_PROFILE = {
|
||||
"auto": "workspace-write",
|
||||
"approval-required": "read-only-with-approval",
|
||||
"unrestricted": "full-access",
|
||||
# Backstop alias used by some skills/tests.
|
||||
"yolo": "full-access",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TurnResult:
|
||||
"""Result of one user→assistant→tool turn through the codex app-server."""
|
||||
|
||||
final_text: str = ""
|
||||
projected_messages: list[dict] = field(default_factory=list)
|
||||
tool_iterations: int = 0
|
||||
interrupted: bool = False
|
||||
error: Optional[str] = None # Set if turn ended in a non-recoverable error
|
||||
turn_id: Optional[str] = None
|
||||
thread_id: Optional[str] = None
|
||||
# Hint to the caller that the underlying codex subprocess is likely
|
||||
# wedged (turn-level timeout fired, post-tool watchdog tripped, or
|
||||
# token-refresh failure killed the child). The caller should retire
|
||||
# the session so the next turn respawns codex from scratch instead
|
||||
# of riding a CPU-spinning or auth-broken process. Mirrors openclaw
|
||||
# beta.8's "retire timed-out app-server clients" fix.
|
||||
should_retire: bool = False
|
||||
|
||||
|
||||
# Markers we accept as terminal even when codex never emits turn/completed.
|
||||
# Some codex versions stream `<turn_aborted>` as raw text in agentMessage
|
||||
# items when an interrupt or upstream error tears the turn down before the
|
||||
# normal completion path fires. Mirrors openclaw beta.8 fix.
|
||||
_TURN_ABORTED_MARKERS = ("<turn_aborted>", "<turn_aborted/>")
|
||||
|
||||
|
||||
# Substrings in codex stderr / JSON-RPC error messages that signal the
|
||||
# subprocess died because its OAuth credentials are no longer valid.
|
||||
# Kept conservative: we only redirect users to `codex login` when we're
|
||||
# reasonably sure that's the actual failure, otherwise we surface the
|
||||
# original error verbatim. Mirrors openclaw beta.8's auth-refresh
|
||||
# classification.
|
||||
_OAUTH_REFRESH_FAILURE_HINTS = (
|
||||
"invalid_grant",
|
||||
"invalid grant",
|
||||
"refresh token",
|
||||
"refresh_token",
|
||||
"token refresh",
|
||||
"token_refresh",
|
||||
"token has expired",
|
||||
"expired_token",
|
||||
"expired token",
|
||||
"not authenticated",
|
||||
"unauthenticated",
|
||||
"unauthorized",
|
||||
"401 unauthorized",
|
||||
"re-authenticate",
|
||||
"reauthenticate",
|
||||
"please log in",
|
||||
"please login",
|
||||
"auth profile",
|
||||
"no auth profile",
|
||||
"oauth",
|
||||
)
|
||||
|
||||
|
||||
def _classify_oauth_failure(*parts: str) -> Optional[str]:
|
||||
"""Return a user-friendly re-auth hint if any of the provided strings
|
||||
look like a codex OAuth/token-refresh failure; otherwise None.
|
||||
|
||||
Used for both `turn/start` JSON-RPC errors and post-mortem stderr
|
||||
inspection when the subprocess exits unexpectedly. Conservative on
|
||||
purpose — we only redirect users to `codex login` when the signal
|
||||
is strong, so unrelated runtime failures still surface verbatim.
|
||||
"""
|
||||
haystack = " ".join(p for p in parts if p).lower()
|
||||
if not haystack:
|
||||
return None
|
||||
for needle in _OAUTH_REFRESH_FAILURE_HINTS:
|
||||
if needle in haystack:
|
||||
return (
|
||||
"Codex authentication failed — your ChatGPT/Codex login "
|
||||
"looks expired or invalid. Run `codex login` to refresh, "
|
||||
"then retry. (Fall back to default runtime with "
|
||||
"`/codex-runtime auto` if the issue persists.)"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ServerRequestRouting:
|
||||
"""Default policies for codex-side approval requests when no interactive
|
||||
callback is wired in. These are only used by tests + cron / non-interactive
|
||||
contexts; the live CLI path passes an approval_callback that defers to
|
||||
tools.approval.prompt_dangerous_approval()."""
|
||||
|
||||
auto_approve_exec: bool = False
|
||||
auto_approve_apply_patch: bool = False
|
||||
|
||||
|
||||
class CodexAppServerSession:
|
||||
"""One Codex thread per Hermes session, lifetime owned by AIAgent.
|
||||
|
||||
Not thread-safe — one caller drives it at a time, matching how AIAgent's
|
||||
run_conversation() loop is structured today. The codex client itself can
|
||||
handle interleaved reads/writes via its own threads, but the adapter's
|
||||
state (projector, thread_id, turn counter) is owned by the caller thread.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
cwd: Optional[str] = None,
|
||||
codex_bin: str = "codex",
|
||||
codex_home: Optional[str] = None,
|
||||
permission_profile: Optional[str] = None,
|
||||
approval_callback: Optional[Callable[..., str]] = None,
|
||||
on_event: Optional[Callable[[dict], None]] = None,
|
||||
request_routing: Optional[_ServerRequestRouting] = None,
|
||||
client_factory: Optional[Callable[..., CodexAppServerClient]] = None,
|
||||
) -> None:
|
||||
self._cwd = cwd or os.getcwd()
|
||||
self._codex_bin = codex_bin
|
||||
self._codex_home = codex_home
|
||||
self._permission_profile = (
|
||||
permission_profile or _HERMES_TO_CODEX_PERMISSION_PROFILE.get(
|
||||
os.environ.get("HERMES_TERMINAL_SECURITY_MODE", "auto"),
|
||||
"workspace-write",
|
||||
)
|
||||
)
|
||||
self._approval_callback = approval_callback
|
||||
self._on_event = on_event # Display hook (kawaii spinner ticks etc.)
|
||||
self._routing = request_routing or _ServerRequestRouting()
|
||||
self._client_factory = client_factory or CodexAppServerClient
|
||||
|
||||
self._client: Optional[CodexAppServerClient] = None
|
||||
self._thread_id: Optional[str] = None
|
||||
self._interrupt_event = threading.Event()
|
||||
# Pending file-change items, keyed by item id. Populated on
|
||||
# item/started for fileChange items; consumed by the approval
|
||||
# bridge when codex sends item/fileChange/requestApproval. The
|
||||
# approval params don't carry the changeset, so we cache here
|
||||
# to surface a real summary in the approval prompt (quirk #4).
|
||||
self._pending_file_changes: dict[str, str] = {}
|
||||
self._closed = False
|
||||
|
||||
# ---------- lifecycle ----------
|
||||
|
||||
def ensure_started(self) -> str:
|
||||
"""Spawn the subprocess, do the initialize handshake, and start a
|
||||
thread. Returns the codex thread id. Idempotent — repeated calls
|
||||
return the same thread id."""
|
||||
if self._thread_id is not None:
|
||||
return self._thread_id
|
||||
if self._client is None:
|
||||
self._client = self._client_factory(
|
||||
codex_bin=self._codex_bin, codex_home=self._codex_home
|
||||
)
|
||||
self._client.initialize(
|
||||
client_name="hermes",
|
||||
client_title="Hermes Agent",
|
||||
client_version=_get_hermes_version(),
|
||||
)
|
||||
# Permission selection is intentionally NOT sent on thread/start.
|
||||
# Two reasons (live-tested against codex 0.130.0):
|
||||
# 1. `thread/start.permissions` is gated behind the experimentalApi
|
||||
# capability on this codex version — we'd have to opt in during
|
||||
# initialize and accept the unstable surface.
|
||||
# 2. Even with experimentalApi declared and the correct shape
|
||||
# (`{"type": "profile", "id": "..."}`, not `{"profileId": ...}`),
|
||||
# codex requires a matching `[permissions]` table in
|
||||
# ~/.codex/config.toml or it fails the request with
|
||||
# 'default_permissions requires a [permissions] table'.
|
||||
# Letting codex pick its default (`:read-only` unless the user has
|
||||
# configured otherwise in their codex config.toml) is the standard
|
||||
# codex CLI workflow and avoids fighting codex's own validation.
|
||||
# Users who want a write-capable profile configure it in their
|
||||
# ~/.codex/config.toml the same way they would for any codex usage.
|
||||
params: dict[str, Any] = {"cwd": self._cwd}
|
||||
result = self._client.request("thread/start", params, timeout=15)
|
||||
# Cross-fill thread.id/sessionId — different codex versions have
|
||||
# serialized this under either key. Mirrors openclaw beta.8's
|
||||
# tolerance fix so future codex drops/renames don't KeyError us
|
||||
# at handshake time.
|
||||
thread_obj = result.get("thread") or {}
|
||||
thread_id = (
|
||||
thread_obj.get("id")
|
||||
or thread_obj.get("sessionId")
|
||||
or result.get("sessionId")
|
||||
or result.get("threadId")
|
||||
)
|
||||
if not thread_id:
|
||||
raise CodexAppServerError(
|
||||
code=-32603,
|
||||
message=(
|
||||
"codex thread/start returned no thread id "
|
||||
f"(payload keys: {sorted(result.keys())})"
|
||||
),
|
||||
)
|
||||
self._thread_id = thread_id
|
||||
logger.info(
|
||||
"codex app-server thread started: id=%s profile=%s cwd=%s",
|
||||
self._thread_id[:8],
|
||||
self._permission_profile,
|
||||
self._cwd,
|
||||
)
|
||||
return self._thread_id
|
||||
|
||||
def close(self) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
if self._client is not None:
|
||||
try:
|
||||
self._client.close()
|
||||
except Exception: # pragma: no cover - best-effort cleanup
|
||||
pass
|
||||
self._client = None
|
||||
self._thread_id = None
|
||||
|
||||
def __enter__(self) -> "CodexAppServerSession":
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc: Any) -> None:
|
||||
self.close()
|
||||
|
||||
# ---------- interrupt ----------
|
||||
|
||||
def request_interrupt(self) -> None:
|
||||
"""Idempotent: signal the active turn loop to issue turn/interrupt
|
||||
and unwind. Called by AIAgent's _interrupt_requested path."""
|
||||
self._interrupt_event.set()
|
||||
|
||||
# ---------- diagnostics ----------
|
||||
|
||||
def _format_error_with_stderr(
|
||||
self,
|
||||
prefix: str,
|
||||
exc: Any = "",
|
||||
*,
|
||||
tail_lines: int = _STDERR_TAIL_LINES,
|
||||
) -> str:
|
||||
"""Build a user-facing error string for codex failures.
|
||||
|
||||
Appends the last few lines of codex's stderr buffer when available,
|
||||
passed through agent.redact with force=True so secrets in provider
|
||||
error responses (auth headers, query-string tokens, sk-* keys) never
|
||||
leak into chat output or trajectories. The codex CLI's own error
|
||||
text ('Internal error', 'turn/start failed: ...') is otherwise
|
||||
opaque and forces users to re-run with verbose flags to diagnose
|
||||
config / provider / auth-bridge problems.
|
||||
|
||||
Use this for the generic / catch-all branches. Specific
|
||||
classifications (OAuth via _classify_oauth_failure, post-tool wedge
|
||||
watchdog) already produce a clean hint and should be used instead.
|
||||
"""
|
||||
exc_str = str(exc) if exc != "" and exc is not None else ""
|
||||
base = f"{prefix}: {exc_str}" if exc_str else prefix
|
||||
if self._client is None:
|
||||
return base
|
||||
try:
|
||||
tail = self._client.stderr_tail(tail_lines)
|
||||
except Exception: # pragma: no cover - diagnostic best-effort
|
||||
return base
|
||||
if not tail:
|
||||
return base
|
||||
joined = "\n".join(line.rstrip() for line in tail if line)
|
||||
if not joined.strip():
|
||||
return base
|
||||
redacted = redact_sensitive_text(joined, force=True)
|
||||
return f"{base}\ncodex stderr (last {len(tail)} lines):\n{redacted}"
|
||||
|
||||
# ---------- per-turn ----------
|
||||
|
||||
def run_turn(
|
||||
self,
|
||||
user_input: str,
|
||||
*,
|
||||
turn_timeout: float = 600.0,
|
||||
notification_poll_timeout: float = 0.25,
|
||||
post_tool_quiet_timeout: float = 90.0,
|
||||
) -> TurnResult:
|
||||
"""Send a user message and block until turn/completed, while
|
||||
forwarding server-initiated approval requests and projecting items
|
||||
into Hermes' messages shape.
|
||||
|
||||
post_tool_quiet_timeout: if codex emits a tool completion and then
|
||||
goes quiet for this many seconds without emitting another item or
|
||||
`turn/completed`, fast-fail and mark the session for retirement.
|
||||
Mirrors openclaw beta.8's post-tool completion watchdog (#81697)
|
||||
so a wedged codex doesn't burn the full turn deadline.
|
||||
"""
|
||||
# Pre-create the result so startup failures (codex subprocess can't
|
||||
# spawn, initialize handshake rejects, thread/start blows up) surface
|
||||
# the same way per-turn failures do — with a TurnResult.error string
|
||||
# the caller can render — instead of bubbling raw codex exceptions
|
||||
# up to AIAgent.run_conversation.
|
||||
result = TurnResult()
|
||||
try:
|
||||
self.ensure_started()
|
||||
except (CodexAppServerError, TimeoutError) as exc:
|
||||
result.error = self._format_error_with_stderr(
|
||||
"codex app-server startup failed", exc
|
||||
)
|
||||
# Subprocess almost certainly unhealthy — retire so the next
|
||||
# turn re-spawns cleanly.
|
||||
result.should_retire = True
|
||||
return result
|
||||
assert self._client is not None and self._thread_id is not None
|
||||
result.thread_id = self._thread_id
|
||||
|
||||
self._interrupt_event.clear()
|
||||
projector = CodexEventProjector()
|
||||
|
||||
# Send turn/start with the user input. Text-only for now (codex
|
||||
# supports rich content but Hermes' text path is the common case).
|
||||
try:
|
||||
ts = self._client.request(
|
||||
"turn/start",
|
||||
{
|
||||
"threadId": self._thread_id,
|
||||
"input": [{"type": "text", "text": user_input}],
|
||||
},
|
||||
timeout=10,
|
||||
)
|
||||
except CodexAppServerError as exc:
|
||||
# Classify auth/refresh failures so the user gets a clear
|
||||
# `codex login` pointer instead of a raw RPC error string.
|
||||
stderr_blob = "\n".join(self._client.stderr_tail(40))
|
||||
hint = _classify_oauth_failure(exc.message, stderr_blob)
|
||||
if hint is not None:
|
||||
result.error = hint
|
||||
# Subprocess is fine on a JSON-RPC level here, but the
|
||||
# token store is broken — retire so the next turn does a
|
||||
# clean handshake (and the user has a chance to re-auth
|
||||
# via `codex login` between turns).
|
||||
result.should_retire = True
|
||||
else:
|
||||
result.error = self._format_error_with_stderr(
|
||||
"turn/start failed", exc
|
||||
)
|
||||
return result
|
||||
except TimeoutError as exc:
|
||||
# turn/start hanging is a strong signal the subprocess is wedged.
|
||||
stderr_blob = "\n".join(self._client.stderr_tail(40))
|
||||
hint = _classify_oauth_failure(stderr_blob)
|
||||
result.error = hint or self._format_error_with_stderr(
|
||||
"turn/start timed out", exc
|
||||
)
|
||||
result.should_retire = True
|
||||
return result
|
||||
|
||||
result.turn_id = (ts.get("turn") or {}).get("id")
|
||||
deadline = time.time() + turn_timeout
|
||||
turn_complete = False
|
||||
# Post-tool watchdog state. last_tool_completion_at is set whenever
|
||||
# a tool-shaped item completes; if no further notification arrives
|
||||
# within post_tool_quiet_timeout and the turn hasn't completed, we
|
||||
# fast-fail and retire the session.
|
||||
last_tool_completion_at: Optional[float] = None
|
||||
|
||||
while time.time() < deadline and not turn_complete:
|
||||
if self._interrupt_event.is_set():
|
||||
self._issue_interrupt(result.turn_id)
|
||||
result.interrupted = True
|
||||
break
|
||||
|
||||
# Detect a dead subprocess between iterations. If codex exited
|
||||
# (e.g. crashed, segfaulted, or its auth refresh thread killed
|
||||
# the process), we won't get any more notifications — bail out
|
||||
# rather than waiting for the full turn deadline.
|
||||
if not self._client.is_alive():
|
||||
stderr_blob = "\n".join(self._client.stderr_tail(60))
|
||||
hint = _classify_oauth_failure(stderr_blob)
|
||||
if hint is not None:
|
||||
result.error = hint
|
||||
else:
|
||||
result.error = self._format_error_with_stderr(
|
||||
"codex app-server subprocess exited unexpectedly",
|
||||
tail_lines=20,
|
||||
)
|
||||
result.should_retire = True
|
||||
break
|
||||
|
||||
# Post-tool watchdog: if a tool completion was the most recent
|
||||
# signal and codex has been silent past the quiet timeout, give
|
||||
# up on this turn instead of waiting for the outer deadline.
|
||||
if (
|
||||
last_tool_completion_at is not None
|
||||
and (time.time() - last_tool_completion_at)
|
||||
> post_tool_quiet_timeout
|
||||
):
|
||||
self._issue_interrupt(result.turn_id)
|
||||
result.interrupted = True
|
||||
result.error = (
|
||||
f"codex went silent for "
|
||||
f"{post_tool_quiet_timeout:.0f}s after a tool result; "
|
||||
f"retiring app-server session."
|
||||
)
|
||||
result.should_retire = True
|
||||
break
|
||||
|
||||
# Drain any server-initiated requests (approvals) before
|
||||
# reading notifications, so the codex side isn't blocked.
|
||||
sreq = self._client.take_server_request(timeout=0)
|
||||
if sreq is not None:
|
||||
# Drain any pending notifications first so per-turn state
|
||||
# (e.g. _pending_file_changes for fileChange approvals) is
|
||||
# up to date when we make the approval decision. Bounded
|
||||
# to avoid starving the server-request response.
|
||||
for _ in range(8):
|
||||
pending = self._client.take_notification(timeout=0)
|
||||
if pending is None:
|
||||
break
|
||||
self._track_pending_file_change(pending)
|
||||
proj = projector.project(pending)
|
||||
if proj.messages:
|
||||
result.projected_messages.extend(proj.messages)
|
||||
if proj.is_tool_iteration:
|
||||
result.tool_iterations += 1
|
||||
last_tool_completion_at = time.time()
|
||||
if proj.final_text is not None:
|
||||
result.final_text = proj.final_text
|
||||
if _has_turn_aborted_marker(proj.final_text):
|
||||
turn_complete = True
|
||||
result.interrupted = True
|
||||
result.error = (
|
||||
result.error
|
||||
or "codex reported turn_aborted"
|
||||
)
|
||||
self._handle_server_request(sreq)
|
||||
# Activity counts as live signal — reset the post-tool
|
||||
# quiet timer so an approval round-trip doesn't trip it.
|
||||
last_tool_completion_at = None
|
||||
continue
|
||||
|
||||
note = self._client.take_notification(
|
||||
timeout=notification_poll_timeout
|
||||
)
|
||||
if note is None:
|
||||
continue
|
||||
|
||||
method = note.get("method", "")
|
||||
if self._on_event is not None:
|
||||
try:
|
||||
self._on_event(note)
|
||||
except Exception: # pragma: no cover - display callback
|
||||
logger.debug("on_event callback raised", exc_info=True)
|
||||
|
||||
# Track in-progress fileChange items so the approval bridge
|
||||
# can surface a real change summary when codex requests
|
||||
# approval (the approval params themselves don't carry the
|
||||
# changeset). Quirk #4 fix.
|
||||
self._track_pending_file_change(note)
|
||||
|
||||
# Project into messages
|
||||
projection = projector.project(note)
|
||||
if projection.messages:
|
||||
result.projected_messages.extend(projection.messages)
|
||||
if projection.is_tool_iteration:
|
||||
result.tool_iterations += 1
|
||||
# Arm/refresh the post-tool quiet watchdog whenever a
|
||||
# tool-shaped item completes.
|
||||
last_tool_completion_at = time.time()
|
||||
else:
|
||||
# Any non-tool projected activity (assistant message,
|
||||
# status update, etc.) means codex is still producing
|
||||
# output — clear the quiet timer so we don't fast-fail.
|
||||
if projection.messages or projection.final_text is not None:
|
||||
last_tool_completion_at = None
|
||||
if projection.final_text is not None:
|
||||
# Codex can emit multiple agentMessage items in one turn
|
||||
# (e.g. partial then final). Take the last one as canonical.
|
||||
result.final_text = projection.final_text
|
||||
# Some codex builds tear a turn down by emitting a
|
||||
# `<turn_aborted>` marker in the agent message text and
|
||||
# never sending turn/completed. Treat the marker itself
|
||||
# as terminal so we don't burn the full deadline.
|
||||
if _has_turn_aborted_marker(projection.final_text):
|
||||
turn_complete = True
|
||||
result.interrupted = True
|
||||
result.error = (
|
||||
result.error or "codex reported turn_aborted"
|
||||
)
|
||||
|
||||
if method == "turn/completed":
|
||||
turn_complete = True
|
||||
turn_status = (
|
||||
(note.get("params") or {}).get("turn") or {}
|
||||
).get("status")
|
||||
if turn_status and turn_status not in ("completed", "interrupted"):
|
||||
err_obj = (
|
||||
(note.get("params") or {}).get("turn") or {}
|
||||
).get("error")
|
||||
if err_obj:
|
||||
err_msg = err_obj.get("message") or str(err_obj)
|
||||
# If the turn failed for an auth/refresh reason,
|
||||
# rewrite the error into a re-auth hint AND mark
|
||||
# the session for retirement.
|
||||
stderr_blob = "\n".join(
|
||||
self._client.stderr_tail(40)
|
||||
)
|
||||
hint = _classify_oauth_failure(err_msg, stderr_blob)
|
||||
if hint is not None:
|
||||
result.error = hint
|
||||
result.should_retire = True
|
||||
else:
|
||||
result.error = self._format_error_with_stderr(
|
||||
f"turn ended status={turn_status}", err_msg
|
||||
)
|
||||
|
||||
if not turn_complete and not result.interrupted:
|
||||
# Hit the deadline. Issue interrupt to stop wasted compute, and
|
||||
# tell the caller to retire the session — a turn that never
|
||||
# finished is a strong sign codex is wedged in a way the next
|
||||
# turn shouldn't inherit.
|
||||
self._issue_interrupt(result.turn_id)
|
||||
result.interrupted = True
|
||||
if not result.error:
|
||||
result.error = self._format_error_with_stderr(
|
||||
f"turn timed out after {turn_timeout}s"
|
||||
)
|
||||
result.should_retire = True
|
||||
|
||||
return result
|
||||
|
||||
# ---------- internals ----------
|
||||
|
||||
def _issue_interrupt(self, turn_id: Optional[str]) -> None:
|
||||
if self._client is None or self._thread_id is None or turn_id is None:
|
||||
return
|
||||
try:
|
||||
self._client.request(
|
||||
"turn/interrupt",
|
||||
{"threadId": self._thread_id, "turnId": turn_id},
|
||||
timeout=5,
|
||||
)
|
||||
except CodexAppServerError as exc:
|
||||
# "no active turn to interrupt" is fine — already done.
|
||||
logger.debug("turn/interrupt non-fatal: %s", exc)
|
||||
except TimeoutError:
|
||||
logger.warning("turn/interrupt timed out")
|
||||
|
||||
def _handle_server_request(self, req: dict) -> None:
|
||||
"""Translate a codex server request (approval) into Hermes' approval
|
||||
flow, then send the response.
|
||||
|
||||
Method names verified live against codex 0.130.0 (Apr 2026):
|
||||
item/commandExecution/requestApproval — exec approvals
|
||||
item/fileChange/requestApproval — apply_patch approvals
|
||||
item/permissions/requestApproval — permissions changes
|
||||
(we decline; user controls
|
||||
permission profile in
|
||||
~/.codex/config.toml).
|
||||
"""
|
||||
if self._client is None:
|
||||
return
|
||||
method = req.get("method", "")
|
||||
rid = req.get("id")
|
||||
params = req.get("params") or {}
|
||||
|
||||
if method == "item/commandExecution/requestApproval":
|
||||
decision = self._decide_exec_approval(params)
|
||||
self._client.respond(rid, {"decision": decision})
|
||||
elif method == "item/fileChange/requestApproval":
|
||||
decision = self._decide_apply_patch_approval(params)
|
||||
self._client.respond(rid, {"decision": decision})
|
||||
elif method == "item/permissions/requestApproval":
|
||||
# Codex sometimes asks to escalate permissions mid-turn. We
|
||||
# always decline — the user already chose their permission
|
||||
# profile in ~/.codex/config.toml and surprise escalations
|
||||
# shouldn't be silently accepted.
|
||||
self._client.respond(rid, {"decision": "decline"})
|
||||
elif method == "mcpServer/elicitation/request":
|
||||
# Codex's MCP layer asks the user for structured input on
|
||||
# behalf of an MCP server (e.g. tool-call confirmation,
|
||||
# OAuth, form data). For our own hermes-tools callback we
|
||||
# auto-accept — the user already approved Hermes' tools
|
||||
# by enabling the runtime, and we never expose anything
|
||||
# codex's built-in shell can't already do. For other MCP
|
||||
# servers we decline so the user explicitly opts in via
|
||||
# codex's own auth flow.
|
||||
server_name = params.get("serverName") or ""
|
||||
if server_name == "hermes-tools":
|
||||
self._client.respond(
|
||||
rid,
|
||||
{"action": "accept", "content": None, "_meta": None},
|
||||
)
|
||||
else:
|
||||
self._client.respond(
|
||||
rid,
|
||||
{"action": "decline", "content": None, "_meta": None},
|
||||
)
|
||||
else:
|
||||
# Unknown server request — codex can extend this surface. Reject
|
||||
# cleanly so codex doesn't hang waiting for us.
|
||||
logger.warning("Unknown codex server request: %s", method)
|
||||
self._client.respond_error(
|
||||
rid, code=-32601, message=f"Unsupported method: {method}"
|
||||
)
|
||||
|
||||
def _decide_exec_approval(self, params: dict) -> str:
|
||||
if self._routing.auto_approve_exec:
|
||||
return "accept"
|
||||
command = params.get("command") or ""
|
||||
# Codex's CommandExecutionRequestApprovalParams has cwd as Optional —
|
||||
# fall back to the session's cwd when codex doesn't include it so the
|
||||
# approval prompt is never empty (quirk #10 fix).
|
||||
cwd = params.get("cwd") or self._cwd or "<unknown>"
|
||||
reason = params.get("reason")
|
||||
description = f"Codex requests exec in {cwd}"
|
||||
if reason:
|
||||
description += f" — {reason}"
|
||||
if self._approval_callback is not None:
|
||||
try:
|
||||
choice = self._approval_callback(
|
||||
command, description, allow_permanent=False
|
||||
)
|
||||
return _approval_choice_to_codex_decision(choice)
|
||||
except Exception:
|
||||
logger.exception("approval_callback raised on exec request")
|
||||
return "decline"
|
||||
return "decline" # fail-closed when no callback wired
|
||||
|
||||
def _decide_apply_patch_approval(self, params: dict) -> str:
|
||||
if self._routing.auto_approve_apply_patch:
|
||||
return "accept"
|
||||
if self._approval_callback is not None:
|
||||
# FileChangeRequestApprovalParams gives us reason + grantRoot.
|
||||
# The actual changeset lives on the corresponding fileChange
|
||||
# item which the projector has already cached for us — look it
|
||||
# up by item_id so the user sees what's actually changing.
|
||||
reason = params.get("reason")
|
||||
grant_root = params.get("grantRoot")
|
||||
item_id = params.get("itemId") or ""
|
||||
change_summary = self._lookup_pending_file_change(item_id)
|
||||
description_parts = []
|
||||
if reason:
|
||||
description_parts.append(reason)
|
||||
if change_summary:
|
||||
description_parts.append(change_summary)
|
||||
if grant_root:
|
||||
description_parts.append(f"grants write to {grant_root}")
|
||||
description = (
|
||||
"; ".join(description_parts)
|
||||
if description_parts
|
||||
else "Codex requests to apply a patch"
|
||||
)
|
||||
command_label = (
|
||||
f"apply_patch: {change_summary}" if change_summary
|
||||
else f"apply_patch: {reason}" if reason
|
||||
else "apply_patch"
|
||||
)
|
||||
try:
|
||||
choice = self._approval_callback(
|
||||
command_label,
|
||||
description,
|
||||
allow_permanent=False,
|
||||
)
|
||||
return _approval_choice_to_codex_decision(choice)
|
||||
except Exception:
|
||||
logger.exception("approval_callback raised on apply_patch")
|
||||
return "decline"
|
||||
return "decline"
|
||||
|
||||
def _track_pending_file_change(self, note: dict) -> None:
|
||||
"""Maintain self._pending_file_changes from item/started + item/completed
|
||||
notifications. Lets the apply_patch approval prompt show what's
|
||||
actually changing — codex's approval params don't carry the data."""
|
||||
method = note.get("method", "")
|
||||
params = note.get("params") or {}
|
||||
item = params.get("item") or {}
|
||||
if item.get("type") != "fileChange":
|
||||
return
|
||||
item_id = item.get("id") or ""
|
||||
if not item_id:
|
||||
return
|
||||
if method == "item/started":
|
||||
changes = item.get("changes") or []
|
||||
if not changes:
|
||||
self._pending_file_changes[item_id] = "1 change pending"
|
||||
return
|
||||
kinds: dict[str, int] = {}
|
||||
paths: list[str] = []
|
||||
for ch in changes:
|
||||
if not isinstance(ch, dict):
|
||||
continue
|
||||
kind = (ch.get("kind") or {}).get("type") or "update"
|
||||
kinds[kind] = kinds.get(kind, 0) + 1
|
||||
p = ch.get("path") or ""
|
||||
if p:
|
||||
paths.append(p)
|
||||
counts = ", ".join(f"{n} {k}" for k, n in sorted(kinds.items()))
|
||||
preview = ", ".join(paths[:3])
|
||||
if len(paths) > 3:
|
||||
preview += f", +{len(paths) - 3} more"
|
||||
self._pending_file_changes[item_id] = (
|
||||
f"{counts}: {preview}" if preview else counts
|
||||
)
|
||||
elif method == "item/completed":
|
||||
self._pending_file_changes.pop(item_id, None)
|
||||
|
||||
def _lookup_pending_file_change(self, item_id: str) -> Optional[str]:
|
||||
"""Look up an in-progress fileChange item by id and summarize its
|
||||
changes for the approval prompt. Returns None when we don't have
|
||||
the item cached (e.g. approval arrived before item/started, or
|
||||
fileChange item content not tracked yet)."""
|
||||
if not item_id:
|
||||
return None
|
||||
cached = self._pending_file_changes.get(item_id)
|
||||
if not cached:
|
||||
return None
|
||||
return cached
|
||||
|
||||
|
||||
def _approval_choice_to_codex_decision(choice: str) -> str:
|
||||
"""Map Hermes approval choices onto codex's CommandExecutionApprovalDecision
|
||||
/ FileChangeApprovalDecision wire values.
|
||||
|
||||
Hermes returns 'once', 'session', 'always', or 'deny'.
|
||||
Codex expects 'accept', 'acceptForSession', 'decline', or 'cancel'
|
||||
(verified against codex-rs/app-server-protocol/src/protocol/v2/item.rs
|
||||
on codex 0.130.0).
|
||||
"""
|
||||
if choice in ("once",):
|
||||
return "accept"
|
||||
if choice in ("session", "always"):
|
||||
return "acceptForSession"
|
||||
return "decline"
|
||||
|
||||
|
||||
def _has_turn_aborted_marker(text: str) -> bool:
|
||||
"""Return True if `text` contains any of the raw markers codex uses
|
||||
to signal a turn was aborted without emitting `turn/completed`.
|
||||
|
||||
Codex emits `<turn_aborted>` (and sometimes `<turn_aborted/>`) as raw
|
||||
text inside agentMessage items when an interrupt or upstream error
|
||||
tears the turn down before the normal completion path fires. Mirrors
|
||||
openclaw beta.8's terminal-marker fix so we don't burn the full turn
|
||||
deadline waiting for a turn/completed that never comes.
|
||||
"""
|
||||
if not text:
|
||||
return False
|
||||
for marker in _TURN_ABORTED_MARKERS:
|
||||
if marker in text:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _get_hermes_version() -> str:
|
||||
"""Best-effort Hermes version string for codex's userAgent line."""
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
|
||||
return version("hermes-agent")
|
||||
except Exception: # pragma: no cover
|
||||
return "0.0.0"
|
||||
312
agent/transports/codex_event_projector.py
Normal file
312
agent/transports/codex_event_projector.py
Normal file
@@ -0,0 +1,312 @@
|
||||
"""Projects codex app-server events into Hermes' messages list.
|
||||
|
||||
The translator that lets Hermes' memory/skill review keep working under the
|
||||
Codex runtime: it converts Codex `item/*` notifications into the standard
|
||||
OpenAI-shaped `{role, content, tool_calls, tool_call_id}` entries that
|
||||
`agent/curator.py` already knows how to read.
|
||||
|
||||
Codex emits items with a discriminator field `type`:
|
||||
- userMessage → {role: "user", content}
|
||||
- agentMessage → {role: "assistant", content}
|
||||
- reasoning → stashed in the assistant's "reasoning" field
|
||||
- commandExecution → assistant tool_call(name="exec") + tool result
|
||||
- fileChange → assistant tool_call(name="apply_patch") + tool result
|
||||
- mcpToolCall → assistant tool_call(name=f"mcp.{server}.{tool}") + tool result
|
||||
- dynamicToolCall → assistant tool_call(name=tool) + tool result
|
||||
- plan/hookPrompt/collabAgentToolCall → recorded as opaque assistant notes
|
||||
|
||||
Each item maps to AT MOST one assistant entry + one tool entry, preserving
|
||||
Hermes' message-alternation invariants (system → user → assistant → user/tool
|
||||
→ assistant → ...). Multiple Codex tool calls within one Codex turn produce
|
||||
multiple consecutive (assistant, tool) pairs, which is the same shape Hermes
|
||||
already produces for parallel tool calls.
|
||||
|
||||
Counters tracked alongside projection:
|
||||
- tool_iterations: ticks once per completed tool-shaped item. Used by
|
||||
AIAgent._iters_since_skill (skill nudge gate, default threshold 10).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
def _deterministic_call_id(item_type: str, item_id: str) -> str:
|
||||
"""Stable id for tool_call message correlation.
|
||||
|
||||
Uses the codex item id directly when present (already a uuid); falls back
|
||||
to a content hash so replay produces the same id across sessions and
|
||||
prefix caches stay valid. See AGENTS.md Pitfall #16 (deterministic IDs in
|
||||
tool call history)."""
|
||||
if item_id:
|
||||
return f"codex_{item_type}_{item_id}"
|
||||
digest = hashlib.sha256(f"{item_type}".encode()).hexdigest()[:16]
|
||||
return f"codex_{item_type}_{digest}"
|
||||
|
||||
|
||||
def _format_tool_args(d: dict) -> str:
|
||||
"""Format a dict as JSON the way Hermes' existing tool_calls path does."""
|
||||
return json.dumps(d, ensure_ascii=False, sort_keys=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProjectionResult:
|
||||
"""Output of projecting one Codex item.
|
||||
|
||||
`messages` is a list because some Codex items produce two messages
|
||||
(assistant tool_call + tool result). Empty list = item ignored (e.g. a
|
||||
streaming `outputDelta` that doesn't materialize into messages until the
|
||||
`item/completed` event)."""
|
||||
|
||||
messages: list[dict] = field(default_factory=list)
|
||||
is_tool_iteration: bool = False
|
||||
final_text: Optional[str] = None # Set when an agentMessage completes
|
||||
|
||||
|
||||
class CodexEventProjector:
|
||||
"""Stateful projector consuming Codex notifications in arrival order.
|
||||
|
||||
Owns the in-progress reasoning content (codex emits reasoning as separate
|
||||
items but Hermes stashes it on the next assistant message)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._pending_reasoning: list[str] = []
|
||||
|
||||
def project(self, notification: dict) -> ProjectionResult:
|
||||
"""Project a single notification. Idempotent for non-completion events;
|
||||
only `item/completed` and `turn/completed` materialize messages."""
|
||||
method = notification.get("method", "")
|
||||
params = notification.get("params", {}) or {}
|
||||
|
||||
# We only materialize messages on `item/completed`. Streaming deltas
|
||||
# (`item/<type>/outputDelta`, `item/<type>/delta`) are display-only and
|
||||
# don't enter the messages list — same way Hermes already only writes
|
||||
# the assistant message after the streaming completion event.
|
||||
if method != "item/completed":
|
||||
return ProjectionResult()
|
||||
|
||||
item = params.get("item") or {}
|
||||
item_type = item.get("type") or ""
|
||||
item_id = item.get("id") or ""
|
||||
|
||||
if item_type == "agentMessage":
|
||||
return self._project_agent_message(item)
|
||||
if item_type == "reasoning":
|
||||
self._pending_reasoning.extend(item.get("summary") or [])
|
||||
self._pending_reasoning.extend(item.get("content") or [])
|
||||
return ProjectionResult()
|
||||
if item_type == "commandExecution":
|
||||
return self._project_command(item, item_id)
|
||||
if item_type == "fileChange":
|
||||
return self._project_file_change(item, item_id)
|
||||
if item_type == "mcpToolCall":
|
||||
return self._project_mcp_tool_call(item, item_id)
|
||||
if item_type == "dynamicToolCall":
|
||||
return self._project_dynamic_tool_call(item, item_id)
|
||||
if item_type == "userMessage":
|
||||
return self._project_user_message(item)
|
||||
|
||||
# Unknown / rare items (plan, hookPrompt, collabAgentToolCall, etc.)
|
||||
# — record as opaque assistant note so memory review can still see
|
||||
# *something* happened, but don't fabricate tool_call structure.
|
||||
return self._project_opaque(item, item_type)
|
||||
|
||||
# ---------- per-type projections ----------
|
||||
|
||||
def _project_agent_message(self, item: dict) -> ProjectionResult:
|
||||
text = item.get("text") or ""
|
||||
msg: dict[str, Any] = {"role": "assistant", "content": text}
|
||||
if self._pending_reasoning:
|
||||
msg["reasoning"] = "\n".join(self._pending_reasoning)
|
||||
self._pending_reasoning = []
|
||||
return ProjectionResult(messages=[msg], final_text=text)
|
||||
|
||||
def _project_user_message(self, item: dict) -> ProjectionResult:
|
||||
# codex's userMessage content is a list of UserInput variants. For
|
||||
# projection purposes we flatten any text fragments and ignore
|
||||
# non-text parts (images, etc.) — Hermes' messages store text only.
|
||||
text_parts: list[str] = []
|
||||
for fragment in item.get("content") or []:
|
||||
if isinstance(fragment, dict):
|
||||
if fragment.get("type") == "text":
|
||||
text_parts.append(fragment.get("text") or "")
|
||||
elif "text" in fragment:
|
||||
text_parts.append(str(fragment["text"]))
|
||||
return ProjectionResult(
|
||||
messages=[{"role": "user", "content": "\n".join(text_parts)}]
|
||||
)
|
||||
|
||||
def _project_command(self, item: dict, item_id: str) -> ProjectionResult:
|
||||
call_id = _deterministic_call_id("exec", item_id)
|
||||
args = {
|
||||
"command": item.get("command") or "",
|
||||
"cwd": item.get("cwd") or "",
|
||||
}
|
||||
assistant_msg = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "exec_command",
|
||||
"arguments": _format_tool_args(args),
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
if self._pending_reasoning:
|
||||
assistant_msg["reasoning"] = "\n".join(self._pending_reasoning)
|
||||
self._pending_reasoning = []
|
||||
output = item.get("aggregatedOutput") or ""
|
||||
exit_code = item.get("exitCode")
|
||||
if exit_code is not None and exit_code != 0:
|
||||
output = f"[exit {exit_code}]\n{output}"
|
||||
tool_msg = {
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": output,
|
||||
}
|
||||
return ProjectionResult(
|
||||
messages=[assistant_msg, tool_msg], is_tool_iteration=True
|
||||
)
|
||||
|
||||
def _project_file_change(self, item: dict, item_id: str) -> ProjectionResult:
|
||||
call_id = _deterministic_call_id("apply_patch", item_id)
|
||||
# Reduce the codex changes array to a digest the agent loop will
|
||||
# find readable. We record per-file change kinds (Add/Update/Delete)
|
||||
# without inlining full file contents — those can be huge.
|
||||
changes_summary = []
|
||||
for change in item.get("changes") or []:
|
||||
kind = (change.get("kind") or {}).get("type") or "update"
|
||||
path = change.get("path") or ""
|
||||
changes_summary.append({"kind": kind, "path": path})
|
||||
args = {"changes": changes_summary}
|
||||
assistant_msg = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "apply_patch",
|
||||
"arguments": _format_tool_args(args),
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
if self._pending_reasoning:
|
||||
assistant_msg["reasoning"] = "\n".join(self._pending_reasoning)
|
||||
self._pending_reasoning = []
|
||||
status = item.get("status") or "unknown"
|
||||
n = len(changes_summary)
|
||||
tool_msg = {
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": f"apply_patch status={status}, {n} change(s)",
|
||||
}
|
||||
return ProjectionResult(
|
||||
messages=[assistant_msg, tool_msg], is_tool_iteration=True
|
||||
)
|
||||
|
||||
def _project_mcp_tool_call(self, item: dict, item_id: str) -> ProjectionResult:
|
||||
server = item.get("server") or "mcp"
|
||||
tool = item.get("tool") or "unknown"
|
||||
call_id = _deterministic_call_id(f"mcp_{server}_{tool}", item_id)
|
||||
args = item.get("arguments") or {}
|
||||
if not isinstance(args, dict):
|
||||
args = {"arguments": args}
|
||||
assistant_msg = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": f"mcp.{server}.{tool}",
|
||||
"arguments": _format_tool_args(args),
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
if self._pending_reasoning:
|
||||
assistant_msg["reasoning"] = "\n".join(self._pending_reasoning)
|
||||
self._pending_reasoning = []
|
||||
result = item.get("result")
|
||||
error = item.get("error")
|
||||
if error:
|
||||
content = f"[error] {json.dumps(error, ensure_ascii=False)[:1000]}"
|
||||
elif result is not None:
|
||||
content = json.dumps(result, ensure_ascii=False)[:4000]
|
||||
else:
|
||||
content = ""
|
||||
tool_msg = {
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": content,
|
||||
}
|
||||
return ProjectionResult(
|
||||
messages=[assistant_msg, tool_msg], is_tool_iteration=True
|
||||
)
|
||||
|
||||
def _project_dynamic_tool_call(
|
||||
self, item: dict, item_id: str
|
||||
) -> ProjectionResult:
|
||||
tool = item.get("tool") or "unknown"
|
||||
call_id = _deterministic_call_id(f"dyn_{tool}", item_id)
|
||||
args = item.get("arguments") or {}
|
||||
if not isinstance(args, dict):
|
||||
args = {"arguments": args}
|
||||
assistant_msg = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool,
|
||||
"arguments": _format_tool_args(args),
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
if self._pending_reasoning:
|
||||
assistant_msg["reasoning"] = "\n".join(self._pending_reasoning)
|
||||
self._pending_reasoning = []
|
||||
content_items = item.get("contentItems") or []
|
||||
if isinstance(content_items, list) and content_items:
|
||||
content = json.dumps(content_items, ensure_ascii=False)[:4000]
|
||||
else:
|
||||
success = item.get("success")
|
||||
content = f"success={success}"
|
||||
tool_msg = {
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": content,
|
||||
}
|
||||
return ProjectionResult(
|
||||
messages=[assistant_msg, tool_msg], is_tool_iteration=True
|
||||
)
|
||||
|
||||
def _project_opaque(self, item: dict, item_type: str) -> ProjectionResult:
|
||||
# Record the existence of the item without inventing tool_calls.
|
||||
# Memory review will see this and may or may not save anything.
|
||||
try:
|
||||
payload = json.dumps(item, ensure_ascii=False)[:1500]
|
||||
except (TypeError, ValueError):
|
||||
payload = repr(item)[:1500]
|
||||
return ProjectionResult(
|
||||
messages=[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"[codex {item_type}] {payload}",
|
||||
}
|
||||
]
|
||||
)
|
||||
225
agent/transports/hermes_tools_mcp_server.py
Normal file
225
agent/transports/hermes_tools_mcp_server.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""Hermes-tools-as-MCP server for the codex_app_server runtime.
|
||||
|
||||
When the user runs `openai/*` turns through the codex app-server, codex
|
||||
owns the loop and builds its own tool list. By default, that means
|
||||
Hermes' richer tool surface — web search, browser automation,
|
||||
delegate_task subagents, vision analysis, persistent memory, skills,
|
||||
cross-session search, image generation, TTS — is unreachable.
|
||||
|
||||
This module exposes a curated subset of those Hermes tools to the
|
||||
spawned codex subprocess via stdio MCP. Codex registers it as a normal
|
||||
MCP server (per `~/.codex/config.toml [mcp_servers.hermes-tools]`) and
|
||||
the user gets full Hermes capability inside a Codex turn.
|
||||
|
||||
Scope (what we expose):
|
||||
- web_search, web_extract — Firecrawl, no codex equivalent
|
||||
- browser_navigate / _click / _type / — Camofox/Browserbase automation
|
||||
_snapshot / _screenshot / _scroll / _back / _press / _vision
|
||||
- delegate_task — Hermes subagents
|
||||
- vision_analyze — image inspection by vision model
|
||||
- image_generate — image generation
|
||||
- memory — Hermes' persistent memory store
|
||||
- skill_view, skills_list — Hermes' skill library
|
||||
- session_search — cross-session search
|
||||
- text_to_speech — TTS
|
||||
|
||||
What we DO NOT expose (codex has equivalents):
|
||||
- terminal / shell — codex's own shell tool
|
||||
- read_file / write_file / patch — codex's apply_patch + shell
|
||||
- search_files / process — codex's shell
|
||||
- clarify, todo — codex's own UX
|
||||
|
||||
Run with: python -m agent.transports.hermes_tools_mcp_server
|
||||
Spawned by: CodexAppServerSession.ensure_started() when the runtime is
|
||||
active and config opts in.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Tools we expose. Each name MUST match a registered Hermes tool that
|
||||
# `model_tools.handle_function_call()` can dispatch.
|
||||
#
|
||||
# What we deliberately DO NOT expose:
|
||||
# - terminal / shell / read_file / write_file / patch / search_files /
|
||||
# process — codex's built-ins cover these and approval routes through
|
||||
# codex's own UI.
|
||||
# - delegate_task / memory / session_search / todo — these are
|
||||
# `_AGENT_LOOP_TOOLS` in Hermes (model_tools.py:493). They require
|
||||
# the running AIAgent context to dispatch (mid-loop state), so a
|
||||
# stateless MCP callback can't drive them. Hermes' default runtime
|
||||
# keeps these working; the codex_app_server runtime cannot.
|
||||
EXPOSED_TOOLS: tuple[str, ...] = (
|
||||
"web_search",
|
||||
"web_extract",
|
||||
"browser_navigate",
|
||||
"browser_click",
|
||||
"browser_type",
|
||||
"browser_press",
|
||||
"browser_snapshot",
|
||||
"browser_scroll",
|
||||
"browser_back",
|
||||
"browser_get_images",
|
||||
"browser_console",
|
||||
"browser_vision",
|
||||
"vision_analyze",
|
||||
"image_generate",
|
||||
"skill_view",
|
||||
"skills_list",
|
||||
"text_to_speech",
|
||||
# Kanban worker handoff tools — gated on HERMES_KANBAN_TASK env var
|
||||
# (set by the kanban dispatcher when spawning a worker). Without these
|
||||
# in the callback, a worker spawned with openai_runtime=codex_app_server
|
||||
# could do the work but couldn't report completion back to the kernel,
|
||||
# making it hang until timeout. Stateless dispatch — they just read
|
||||
# the env var and write to ~/.hermes/kanban.db.
|
||||
"kanban_complete",
|
||||
"kanban_block",
|
||||
"kanban_comment",
|
||||
"kanban_heartbeat",
|
||||
"kanban_show",
|
||||
"kanban_list",
|
||||
# NOTE: kanban_create / kanban_unblock / kanban_link are orchestrator-
|
||||
# only — the kanban tool gates them on HERMES_KANBAN_TASK being unset.
|
||||
# They're exposed here for orchestrator agents running on the codex
|
||||
# runtime that need to dispatch new tasks.
|
||||
"kanban_create",
|
||||
"kanban_unblock",
|
||||
"kanban_link",
|
||||
)
|
||||
|
||||
|
||||
def _build_server() -> Any:
|
||||
"""Create the FastMCP server with Hermes tools attached. Lazy imports
|
||||
so the module can be imported without the mcp package installed
|
||||
(we degrade to a clear error only when actually run)."""
|
||||
try:
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
except ImportError as exc: # pragma: no cover - install hint
|
||||
raise ImportError(
|
||||
f"hermes-tools MCP server requires the 'mcp' package: {exc}"
|
||||
) from exc
|
||||
|
||||
# Discover Hermes tools so dispatch works.
|
||||
from model_tools import (
|
||||
get_tool_definitions,
|
||||
handle_function_call,
|
||||
)
|
||||
|
||||
mcp = FastMCP(
|
||||
"hermes-tools",
|
||||
instructions=(
|
||||
"Hermes Agent's tool surface, exposed for use inside a Codex "
|
||||
"session. Use these for capabilities Codex's built-in toolset "
|
||||
"doesn't cover: web search/extract, browser automation, "
|
||||
"subagent delegation, vision, image generation, persistent "
|
||||
"memory, skills, and cross-session search."
|
||||
),
|
||||
)
|
||||
|
||||
# Pull authoritative Hermes tool schemas for the ones we expose, so
|
||||
# MCP clients see the same parameter docs Hermes gives the model.
|
||||
all_defs = {
|
||||
td["function"]["name"]: td["function"]
|
||||
for td in (get_tool_definitions(quiet_mode=True) or [])
|
||||
if isinstance(td, dict) and td.get("type") == "function"
|
||||
}
|
||||
|
||||
exposed_count = 0
|
||||
|
||||
for name in EXPOSED_TOOLS:
|
||||
spec = all_defs.get(name)
|
||||
if spec is None:
|
||||
logger.debug(
|
||||
"skipping %s — not registered in this Hermes process", name
|
||||
)
|
||||
continue
|
||||
|
||||
description = spec.get("description") or f"Hermes {name} tool"
|
||||
params_schema = spec.get("parameters") or {"type": "object", "properties": {}}
|
||||
|
||||
# FastMCP wants a Python callable. Build a closure that takes the
|
||||
# arguments dict, dispatches via handle_function_call, and returns
|
||||
# the result string. We use add_tool() for full control over the
|
||||
# input schema (FastMCP's @tool() decorator inspects type hints,
|
||||
# which we can't get from a JSON schema at runtime).
|
||||
def _make_handler(tool_name: str):
|
||||
def _dispatch(**kwargs: Any) -> str:
|
||||
try:
|
||||
return handle_function_call(tool_name, kwargs or {})
|
||||
except Exception as exc:
|
||||
logger.exception("tool %s raised", tool_name)
|
||||
return json.dumps({"error": str(exc), "tool": tool_name})
|
||||
_dispatch.__name__ = tool_name
|
||||
_dispatch.__doc__ = description
|
||||
return _dispatch
|
||||
|
||||
try:
|
||||
mcp.add_tool(
|
||||
_make_handler(name),
|
||||
name=name,
|
||||
description=description,
|
||||
# FastMCP accepts JSON schema directly via the
|
||||
# input_schema parameter on newer versions; older
|
||||
# versions use parameters_schema. Try both for compat.
|
||||
)
|
||||
except TypeError:
|
||||
# Older mcp SDK signature — fall back to decorator-style.
|
||||
handler = _make_handler(name)
|
||||
handler = mcp.tool(name=name, description=description)(handler)
|
||||
|
||||
exposed_count += 1
|
||||
|
||||
logger.info(
|
||||
"hermes-tools MCP server registered %d/%d tools",
|
||||
exposed_count,
|
||||
len(EXPOSED_TOOLS),
|
||||
)
|
||||
return mcp
|
||||
|
||||
|
||||
def main(argv: Optional[list[str]] = None) -> int:
|
||||
"""Entry point for `python -m agent.transports.hermes_tools_mcp_server`."""
|
||||
argv = argv or sys.argv[1:]
|
||||
verbose = "--verbose" in argv or "-v" in argv
|
||||
|
||||
log_level = logging.INFO if verbose else logging.WARNING
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
stream=sys.stderr, # MCP uses stdio for protocol — logs MUST go to stderr
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
)
|
||||
|
||||
# Quiet mode: keep Hermes' own banners off stdout (which is the MCP wire).
|
||||
os.environ.setdefault("HERMES_QUIET", "1")
|
||||
os.environ.setdefault("HERMES_REDACT_SECRETS", "true")
|
||||
|
||||
try:
|
||||
server = _build_server()
|
||||
except ImportError as exc:
|
||||
sys.stderr.write(f"hermes-tools MCP server cannot start: {exc}\n")
|
||||
return 2
|
||||
|
||||
# FastMCP runs with stdio transport by default when launched as a
|
||||
# subprocess.
|
||||
try:
|
||||
server.run()
|
||||
except KeyboardInterrupt:
|
||||
return 0
|
||||
except Exception as exc:
|
||||
logger.exception("hermes-tools MCP server crashed")
|
||||
sys.stderr.write(f"hermes-tools MCP server error: {exc}\n")
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -370,6 +370,17 @@ _OFFICIAL_DOCS_PRICING: Dict[tuple[str, str], PricingEntry] = {
|
||||
source_url="https://api-docs.deepseek.com/quick_start/pricing",
|
||||
pricing_version="deepseek-pricing-2026-03-16",
|
||||
),
|
||||
(
|
||||
"deepseek",
|
||||
"deepseek-v4-pro",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("1.74"),
|
||||
output_cost_per_million=Decimal("3.48"),
|
||||
cache_read_cost_per_million=Decimal("0.0145"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://api-docs.deepseek.com/quick_start/pricing",
|
||||
pricing_version="deepseek-pricing-2026-05-12",
|
||||
),
|
||||
# Google Gemini
|
||||
(
|
||||
"google",
|
||||
|
||||
299
agent/video_gen_provider.py
Normal file
299
agent/video_gen_provider.py
Normal file
@@ -0,0 +1,299 @@
|
||||
"""
|
||||
Video Generation Provider ABC
|
||||
=============================
|
||||
|
||||
Defines the pluggable-backend interface for video generation. Providers register
|
||||
instances via ``PluginContext.register_video_gen_provider()``; the active one
|
||||
(selected via ``video_gen.provider`` in ``config.yaml``) services every
|
||||
``video_generate`` tool call.
|
||||
|
||||
Providers live in ``<repo>/plugins/video_gen/<name>/`` (built-in, auto-loaded
|
||||
as ``kind: backend``) or ``~/.hermes/plugins/video_gen/<name>/`` (user, opt-in
|
||||
via ``plugins.enabled``).
|
||||
|
||||
Mirrors the ``image_gen`` provider design (``agent/image_gen_provider.py``) so
|
||||
the two surfaces stay learnable together.
|
||||
|
||||
Unified surface
|
||||
---------------
|
||||
One tool — ``video_generate`` — covers **text-to-video** and **image-to-video**.
|
||||
The router is the presence of ``image_url``: if it's set, the provider routes
|
||||
to its image-to-video endpoint; if it's omitted, the provider routes to
|
||||
text-to-video. Users pick one **model family** (e.g. Pixverse v6, Veo 3.1,
|
||||
Kling O3 Standard); the provider handles which underlying FAL/xAI endpoint
|
||||
to hit.
|
||||
|
||||
Video edit and video extend are intentionally NOT exposed in this surface —
|
||||
the inconsistency across backends is too large for one unified tool. If
|
||||
those use cases warrant attention later they can ship as separate tools.
|
||||
|
||||
Response shape
|
||||
--------------
|
||||
All providers return a dict built by :func:`success_response` /
|
||||
:func:`error_response`. Keys:
|
||||
|
||||
success bool
|
||||
video str | None URL or absolute file path
|
||||
model str provider-specific model identifier
|
||||
prompt str echoed prompt
|
||||
modality str "text" | "image" (which mode was used)
|
||||
aspect_ratio str provider-native (e.g. "16:9") or ""
|
||||
duration int seconds (0 if not applicable)
|
||||
provider str provider name (for diagnostics)
|
||||
error str only when success=False
|
||||
error_type str only when success=False
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import base64
|
||||
import datetime
|
||||
import logging
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Common aspect ratios across providers (Veo / Kling / xAI / Pixverse). The
|
||||
# tool schema advertises this set as an enum hint, but providers may accept
|
||||
# a narrower or wider set — they are responsible for clamping.
|
||||
COMMON_ASPECT_RATIOS: Tuple[str, ...] = ("16:9", "9:16", "1:1", "4:3", "3:4", "3:2", "2:3")
|
||||
DEFAULT_ASPECT_RATIO = "16:9"
|
||||
|
||||
COMMON_RESOLUTIONS: Tuple[str, ...] = ("480p", "540p", "720p", "1080p")
|
||||
DEFAULT_RESOLUTION = "720p"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ABC
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class VideoGenProvider(abc.ABC):
|
||||
"""Abstract base class for a video generation backend.
|
||||
|
||||
Subclasses must implement :meth:`generate`. Everything else has sane
|
||||
defaults — override only what your provider needs.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Stable short identifier used in ``video_gen.provider`` config.
|
||||
|
||||
Lowercase, no spaces. Examples: ``xai``, ``fal``, ``google``.
|
||||
"""
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
"""Human-readable label shown in ``hermes tools``. Defaults to ``name.title()``."""
|
||||
return self.name.title()
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Return True when this provider can service calls.
|
||||
|
||||
Typically checks for a required API key and optional-dependency
|
||||
import. Default: True.
|
||||
"""
|
||||
return True
|
||||
|
||||
def list_models(self) -> List[Dict[str, Any]]:
|
||||
"""Return catalog entries for ``hermes tools`` model picker.
|
||||
|
||||
Each entry represents a **model family** that supports text-to-video
|
||||
and/or image-to-video routing internally::
|
||||
|
||||
{
|
||||
"id": "veo-3.1", # required
|
||||
"display": "Veo 3.1", # optional; defaults to id
|
||||
"speed": "~60s", # optional
|
||||
"strengths": "...", # optional
|
||||
"price": "$0.20/s", # optional
|
||||
"modalities": ["text", "image"], # optional, advisory
|
||||
}
|
||||
|
||||
Default: empty list (provider has no user-selectable models).
|
||||
"""
|
||||
return []
|
||||
|
||||
def get_setup_schema(self) -> Dict[str, Any]:
|
||||
"""Return provider metadata for the ``hermes tools`` picker."""
|
||||
return {
|
||||
"name": self.display_name,
|
||||
"badge": "",
|
||||
"tag": "",
|
||||
"env_vars": [],
|
||||
}
|
||||
|
||||
def default_model(self) -> Optional[str]:
|
||||
"""Return the default model id, or None if not applicable."""
|
||||
models = self.list_models()
|
||||
if models:
|
||||
return models[0].get("id")
|
||||
return None
|
||||
|
||||
def capabilities(self) -> Dict[str, Any]:
|
||||
"""Return what this provider supports.
|
||||
|
||||
Returned dict (all keys optional)::
|
||||
|
||||
{
|
||||
"modalities": ["text", "image"], # which inputs the backend accepts
|
||||
"aspect_ratios": ["16:9", "9:16", ...],
|
||||
"resolutions": ["720p", "1080p"],
|
||||
"max_duration": 15, # seconds
|
||||
"min_duration": 1,
|
||||
"supports_audio": True,
|
||||
"supports_negative_prompt": True,
|
||||
"max_reference_images": 7,
|
||||
}
|
||||
|
||||
Used by the tool layer for soft validation and by ``hermes tools``
|
||||
for the picker. Default: text-only.
|
||||
"""
|
||||
return {
|
||||
"modalities": ["text"],
|
||||
"aspect_ratios": list(COMMON_ASPECT_RATIOS),
|
||||
"resolutions": list(COMMON_RESOLUTIONS),
|
||||
"max_duration": 10,
|
||||
"min_duration": 1,
|
||||
"supports_audio": False,
|
||||
"supports_negative_prompt": False,
|
||||
"max_reference_images": 0,
|
||||
}
|
||||
|
||||
@abc.abstractmethod
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
model: Optional[str] = None,
|
||||
image_url: Optional[str] = None,
|
||||
reference_image_urls: Optional[List[str]] = None,
|
||||
duration: Optional[int] = None,
|
||||
aspect_ratio: str = DEFAULT_ASPECT_RATIO,
|
||||
resolution: str = DEFAULT_RESOLUTION,
|
||||
negative_prompt: Optional[str] = None,
|
||||
audio: Optional[bool] = None,
|
||||
seed: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate a video from a prompt (text-to-video) or animate an image
|
||||
(image-to-video).
|
||||
|
||||
Routing: if ``image_url`` is provided, the provider should route to
|
||||
its image-to-video endpoint; otherwise text-to-video. The plugin
|
||||
is responsible for picking the right underlying endpoint within
|
||||
the user's chosen model family.
|
||||
|
||||
Implementations should return the dict from :func:`success_response`
|
||||
or :func:`error_response`. ``kwargs`` may contain forward-compat
|
||||
parameters future versions of the schema will expose —
|
||||
implementations MUST ignore unknown keys (no TypeError).
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _videos_cache_dir() -> Path:
|
||||
"""Return ``$HERMES_HOME/cache/videos/``, creating parents as needed."""
|
||||
from hermes_constants import get_hermes_home
|
||||
|
||||
path = get_hermes_home() / "cache" / "videos"
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def save_b64_video(
|
||||
b64_data: str,
|
||||
*,
|
||||
prefix: str = "video",
|
||||
extension: str = "mp4",
|
||||
) -> Path:
|
||||
"""Decode base64 video data and write under ``$HERMES_HOME/cache/videos/``.
|
||||
|
||||
Returns the absolute :class:`Path` to the saved file.
|
||||
|
||||
Filename format: ``<prefix>_<YYYYMMDD_HHMMSS>_<short-uuid>.<ext>``.
|
||||
"""
|
||||
raw = base64.b64decode(b64_data)
|
||||
ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
short = uuid.uuid4().hex[:8]
|
||||
path = _videos_cache_dir() / f"{prefix}_{ts}_{short}.{extension}"
|
||||
path.write_bytes(raw)
|
||||
return path
|
||||
|
||||
|
||||
def save_bytes_video(
|
||||
raw: bytes,
|
||||
*,
|
||||
prefix: str = "video",
|
||||
extension: str = "mp4",
|
||||
) -> Path:
|
||||
"""Write raw video bytes (e.g. an HTTP download body) to the cache."""
|
||||
ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
short = uuid.uuid4().hex[:8]
|
||||
path = _videos_cache_dir() / f"{prefix}_{ts}_{short}.{extension}"
|
||||
path.write_bytes(raw)
|
||||
return path
|
||||
|
||||
|
||||
def success_response(
|
||||
*,
|
||||
video: str,
|
||||
model: str,
|
||||
prompt: str,
|
||||
modality: str = "text",
|
||||
aspect_ratio: str = "",
|
||||
duration: int = 0,
|
||||
provider: str,
|
||||
extra: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build a uniform success response dict.
|
||||
|
||||
``video`` may be an HTTP URL or an absolute filesystem path.
|
||||
``modality`` is ``"text"`` (text-to-video) or ``"image"`` (image-to-video) —
|
||||
indicates which endpoint was actually hit, useful for diagnostics.
|
||||
"""
|
||||
payload: Dict[str, Any] = {
|
||||
"success": True,
|
||||
"video": video,
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"modality": modality,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
"duration": int(duration) if duration else 0,
|
||||
"provider": provider,
|
||||
}
|
||||
if extra:
|
||||
for k, v in extra.items():
|
||||
payload.setdefault(k, v)
|
||||
return payload
|
||||
|
||||
|
||||
def error_response(
|
||||
*,
|
||||
error: str,
|
||||
error_type: str = "provider_error",
|
||||
provider: str = "",
|
||||
model: str = "",
|
||||
prompt: str = "",
|
||||
aspect_ratio: str = "",
|
||||
) -> Dict[str, Any]:
|
||||
"""Build a uniform error response dict."""
|
||||
return {
|
||||
"success": False,
|
||||
"video": None,
|
||||
"error": error,
|
||||
"error_type": error_type,
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
"provider": provider,
|
||||
}
|
||||
117
agent/video_gen_registry.py
Normal file
117
agent/video_gen_registry.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""
|
||||
Video Generation Provider Registry
|
||||
==================================
|
||||
|
||||
Central map of registered providers. Populated by plugins at import-time via
|
||||
``PluginContext.register_video_gen_provider()``; consumed by the
|
||||
``video_generate`` tool to dispatch each call to the active backend.
|
||||
|
||||
Active selection
|
||||
----------------
|
||||
The active provider is chosen by ``video_gen.provider`` in ``config.yaml``.
|
||||
If unset, :func:`get_active_provider` applies fallback logic:
|
||||
|
||||
1. If exactly one provider is registered, use it.
|
||||
2. Otherwise return ``None`` (the tool surfaces a helpful error pointing
|
||||
the user at ``hermes tools``).
|
||||
|
||||
Mirrors ``agent/image_gen_registry.py`` so the two surfaces behave the
|
||||
same.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from agent.video_gen_provider import VideoGenProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_providers: Dict[str, VideoGenProvider] = {}
|
||||
_lock = threading.Lock()
|
||||
|
||||
|
||||
def register_provider(provider: VideoGenProvider) -> None:
|
||||
"""Register a video generation provider.
|
||||
|
||||
Re-registration (same ``name``) overwrites the previous entry and logs
|
||||
a debug message — this makes hot-reload scenarios (tests, dev loops)
|
||||
behave predictably.
|
||||
"""
|
||||
if not isinstance(provider, VideoGenProvider):
|
||||
raise TypeError(
|
||||
f"register_provider() expects a VideoGenProvider instance, "
|
||||
f"got {type(provider).__name__}"
|
||||
)
|
||||
name = provider.name
|
||||
if not isinstance(name, str) or not name.strip():
|
||||
raise ValueError("Video gen provider .name must be a non-empty string")
|
||||
with _lock:
|
||||
existing = _providers.get(name)
|
||||
_providers[name] = provider
|
||||
if existing is not None:
|
||||
logger.debug("Video gen provider '%s' re-registered (was %r)", name, type(existing).__name__)
|
||||
else:
|
||||
logger.debug("Registered video gen provider '%s' (%s)", name, type(provider).__name__)
|
||||
|
||||
|
||||
def list_providers() -> List[VideoGenProvider]:
|
||||
"""Return all registered providers, sorted by name."""
|
||||
with _lock:
|
||||
items = list(_providers.values())
|
||||
return sorted(items, key=lambda p: p.name)
|
||||
|
||||
|
||||
def get_provider(name: str) -> Optional[VideoGenProvider]:
|
||||
"""Return the provider registered under *name*, or None."""
|
||||
if not isinstance(name, str):
|
||||
return None
|
||||
with _lock:
|
||||
return _providers.get(name.strip())
|
||||
|
||||
|
||||
def get_active_provider() -> Optional[VideoGenProvider]:
|
||||
"""Resolve the currently-active provider.
|
||||
|
||||
Reads ``video_gen.provider`` from config.yaml; falls back per the
|
||||
module docstring.
|
||||
"""
|
||||
configured: Optional[str] = None
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
|
||||
cfg = load_config()
|
||||
section = cfg.get("video_gen") if isinstance(cfg, dict) else None
|
||||
if isinstance(section, dict):
|
||||
raw = section.get("provider")
|
||||
if isinstance(raw, str) and raw.strip():
|
||||
configured = raw.strip()
|
||||
except Exception as exc:
|
||||
logger.debug("Could not read video_gen.provider from config: %s", exc)
|
||||
|
||||
with _lock:
|
||||
snapshot = dict(_providers)
|
||||
|
||||
if configured:
|
||||
provider = snapshot.get(configured)
|
||||
if provider is not None:
|
||||
return provider
|
||||
logger.debug(
|
||||
"video_gen.provider='%s' configured but not registered; falling back",
|
||||
configured,
|
||||
)
|
||||
|
||||
# Fallback: single-provider case
|
||||
if len(snapshot) == 1:
|
||||
return next(iter(snapshot.values()))
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _reset_for_tests() -> None:
|
||||
"""Clear the registry. **Test-only.**"""
|
||||
with _lock:
|
||||
_providers.clear()
|
||||
221
agent/web_search_provider.py
Normal file
221
agent/web_search_provider.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""
|
||||
Web Search Provider ABC
|
||||
=======================
|
||||
|
||||
Defines the pluggable-backend interface for web search and content extraction.
|
||||
Providers register instances via ``PluginContext.register_web_search_provider()``;
|
||||
the active one (selected via ``web.search_backend`` / ``web.extract_backend`` /
|
||||
``web.backend`` in ``config.yaml``) services every ``web_search`` /
|
||||
``web_extract`` tool call.
|
||||
|
||||
Providers live in ``<repo>/plugins/web/<name>/`` (built-in, auto-loaded as
|
||||
``kind: backend``) or ``~/.hermes/plugins/web/<name>/`` (user, opt-in via
|
||||
``plugins.enabled``).
|
||||
|
||||
This ABC is the SINGLE plugin-facing surface for web providers — every
|
||||
provider in the tree (brave-free, ddgs, searxng, exa, parallel, tavily,
|
||||
firecrawl) implements it. The legacy in-tree ``tools.web_providers.base``
|
||||
ABCs were deleted in PR #25182 along with the per-vendor inline helpers
|
||||
in ``tools/web_tools.py``; the response-shape contract documented below
|
||||
is preserved bit-for-bit so the tool wrapper does not have to translate.
|
||||
|
||||
Response shape (preserved from the legacy contract):
|
||||
|
||||
Search results::
|
||||
|
||||
{
|
||||
"success": True,
|
||||
"data": {
|
||||
"web": [
|
||||
{"title": str, "url": str, "description": str, "position": int},
|
||||
...
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
Extract results::
|
||||
|
||||
{
|
||||
"success": True,
|
||||
"data": [
|
||||
{"url": str, "title": str, "content": str,
|
||||
"raw_content": str, "metadata": dict},
|
||||
...
|
||||
]
|
||||
}
|
||||
|
||||
On failure (either capability)::
|
||||
|
||||
{"success": False, "error": str}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ABC
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class WebSearchProvider(abc.ABC):
|
||||
"""Abstract base class for a web search/extract/crawl backend.
|
||||
|
||||
Subclasses must implement :meth:`is_available` and at least one of
|
||||
:meth:`search` / :meth:`extract` / :meth:`crawl`. The
|
||||
:meth:`supports_search` / :meth:`supports_extract` / :meth:`supports_crawl`
|
||||
capability flags let the registry route each tool call to the right
|
||||
provider, and let multi-capability providers (Firecrawl, Tavily, Exa,
|
||||
…) advertise multiple capabilities from a single class.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Stable short identifier used in ``web.search_backend`` /
|
||||
``web.extract_backend`` / ``web.backend`` config keys.
|
||||
|
||||
Lowercase, no spaces; hyphens permitted to preserve existing
|
||||
user-visible names. Examples: ``brave-free``, ``ddgs``,
|
||||
``searxng``, ``firecrawl``.
|
||||
"""
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
"""Human-readable label shown in ``hermes tools``. Defaults to ``name``."""
|
||||
return self.name
|
||||
|
||||
@abc.abstractmethod
|
||||
def is_available(self) -> bool:
|
||||
"""Return True when this provider can service calls.
|
||||
|
||||
Typically a cheap check (env var present, optional Python dep
|
||||
importable, instance URL set). Must NOT make network calls — this
|
||||
runs at tool-registration time and on every ``hermes tools`` paint.
|
||||
"""
|
||||
|
||||
def supports_search(self) -> bool:
|
||||
"""Return True if this provider implements :meth:`search`."""
|
||||
return True
|
||||
|
||||
def supports_extract(self) -> bool:
|
||||
"""Return True if this provider implements :meth:`extract`.
|
||||
|
||||
Both sync and async :meth:`extract` implementations are valid — the
|
||||
dispatcher detects coroutine functions via
|
||||
:func:`inspect.iscoroutinefunction` and awaits as needed. Sync
|
||||
implementations that perform blocking I/O (HTTP, SDK calls) should
|
||||
ideally wrap in :func:`asyncio.to_thread` at the call site; small
|
||||
providers can keep their sync shape and let the dispatcher handle
|
||||
threading.
|
||||
"""
|
||||
return False
|
||||
|
||||
def supports_crawl(self) -> bool:
|
||||
"""Return True if this provider implements :meth:`crawl`.
|
||||
|
||||
Crawl differs from extract in that the agent provides a *seed URL*
|
||||
and the provider walks linked pages on its own — useful for
|
||||
documentation sites where the agent doesn't know all relevant
|
||||
URLs upfront. Tavily is the only built-in backend that natively
|
||||
crawls today; Firecrawl provides a similar capability that we
|
||||
don't currently surface as a tool.
|
||||
|
||||
Providers that don't crawl should leave this as False; the
|
||||
dispatcher in :func:`tools.web_tools.web_crawl_tool` will fall
|
||||
back to its auxiliary-model summarization path.
|
||||
"""
|
||||
return False
|
||||
|
||||
def search(self, query: str, limit: int = 5) -> Dict[str, Any]:
|
||||
"""Execute a web search.
|
||||
|
||||
Override when :meth:`supports_search` returns True. The default
|
||||
raises NotImplementedError; callers should gate on
|
||||
:meth:`supports_search` before calling.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{self.name} does not support search (override supports_search)"
|
||||
)
|
||||
|
||||
def extract(self, urls: List[str], **kwargs: Any) -> Any:
|
||||
"""Extract content from one or more URLs.
|
||||
|
||||
Override when :meth:`supports_extract` returns True. The default
|
||||
raises NotImplementedError; callers should gate on
|
||||
:meth:`supports_extract` before calling.
|
||||
|
||||
Return shape: a list of result dicts matching what the legacy
|
||||
:func:`tools.web_tools.web_extract_tool` post-processing pipeline
|
||||
expects::
|
||||
|
||||
[
|
||||
{
|
||||
"url": str,
|
||||
"title": str,
|
||||
"content": str,
|
||||
"raw_content": str,
|
||||
"metadata": dict, # optional
|
||||
"error": str, # optional, only on per-URL failure
|
||||
},
|
||||
...
|
||||
]
|
||||
|
||||
Implementations MAY be ``async def`` — the dispatcher detects
|
||||
coroutines via :func:`inspect.iscoroutinefunction` and awaits.
|
||||
|
||||
``kwargs`` may carry forward-compat fields (``format``, ``include_raw``,
|
||||
``max_chars``) — implementations should ignore unknown keys.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{self.name} does not support extract (override supports_extract)"
|
||||
)
|
||||
|
||||
def crawl(self, url: str, **kwargs: Any) -> Any:
|
||||
"""Crawl a seed URL and return results.
|
||||
|
||||
Override when :meth:`supports_crawl` returns True. The default
|
||||
raises NotImplementedError; callers should gate on
|
||||
:meth:`supports_crawl` before calling.
|
||||
|
||||
Return shape: ``{"results": [{"url": str, "title": str,
|
||||
"content": str, ...}, ...]}`` matching what
|
||||
:func:`tools.web_tools.web_crawl_tool` post-processing expects.
|
||||
|
||||
Implementations MAY be ``async def``.
|
||||
|
||||
``kwargs`` may carry forward-compat fields (e.g. ``max_depth``,
|
||||
``include_domains``) — implementations should ignore unknown keys.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{self.name} does not support crawl (override supports_crawl)"
|
||||
)
|
||||
|
||||
def get_setup_schema(self) -> Dict[str, Any]:
|
||||
"""Return provider metadata for the ``hermes tools`` picker.
|
||||
|
||||
Used by ``hermes_cli/tools_config.py`` to inject this provider as a
|
||||
row in the Web Search / Web Extract picker. Shape::
|
||||
|
||||
{
|
||||
"name": "Brave Search (Free)",
|
||||
"badge": "free",
|
||||
"tag": "No paid tier needed — uses Brave's free API.",
|
||||
"env_vars": [
|
||||
{"key": "BRAVE_SEARCH_API_KEY",
|
||||
"prompt": "Brave Search API key",
|
||||
"url": "https://brave.com/search/api/"},
|
||||
],
|
||||
}
|
||||
|
||||
Default: minimal entry derived from ``display_name``. Override to
|
||||
expose API key prompts, badges, and instance URL fields.
|
||||
"""
|
||||
return {
|
||||
"name": self.display_name,
|
||||
"badge": "",
|
||||
"tag": "",
|
||||
"env_vars": [],
|
||||
}
|
||||
262
agent/web_search_registry.py
Normal file
262
agent/web_search_registry.py
Normal file
@@ -0,0 +1,262 @@
|
||||
"""
|
||||
Web Search Provider Registry
|
||||
============================
|
||||
|
||||
Central map of registered web providers. Populated by plugins at import-time
|
||||
via :meth:`PluginContext.register_web_search_provider`; consumed by the
|
||||
``web_search`` and ``web_extract`` tool wrappers in :mod:`tools.web_tools` to
|
||||
dispatch each call to the active backend.
|
||||
|
||||
Active selection
|
||||
----------------
|
||||
The active provider is chosen by configuration with this precedence:
|
||||
|
||||
1. ``web.search_backend`` / ``web.extract_backend`` / ``web.crawl_backend``
|
||||
(per-capability override).
|
||||
2. ``web.backend`` (shared fallback).
|
||||
3. If exactly one capability-eligible provider is registered AND available,
|
||||
use it.
|
||||
4. Legacy preference order — ``firecrawl`` → ``parallel`` → ``tavily`` →
|
||||
``exa`` → ``searxng`` → ``brave-free`` → ``ddgs`` — filtered by
|
||||
availability. Matches the historic ``tools.web_tools._get_backend()``
|
||||
candidate order so installs that never set a config key keep landing
|
||||
on the same provider they did before the plugin migration.
|
||||
5. Otherwise ``None`` — the tool surfaces a helpful error pointing at
|
||||
``hermes tools``.
|
||||
|
||||
The capability filter (``supports_search`` / ``supports_extract`` /
|
||||
``supports_crawl``) is applied at every step so a search-only provider
|
||||
(``brave-free``) configured as ``web.extract_backend`` correctly falls
|
||||
through to an extract-capable backend.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from agent.web_search_provider import WebSearchProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_providers: Dict[str, WebSearchProvider] = {}
|
||||
_lock = threading.Lock()
|
||||
|
||||
|
||||
def register_provider(provider: WebSearchProvider) -> None:
|
||||
"""Register a web search/extract provider.
|
||||
|
||||
Re-registration (same ``name``) overwrites the previous entry and logs
|
||||
a debug message — makes hot-reload scenarios (tests, dev loops) behave
|
||||
predictably.
|
||||
"""
|
||||
if not isinstance(provider, WebSearchProvider):
|
||||
raise TypeError(
|
||||
f"register_provider() expects a WebSearchProvider instance, "
|
||||
f"got {type(provider).__name__}"
|
||||
)
|
||||
name = provider.name
|
||||
if not isinstance(name, str) or not name.strip():
|
||||
raise ValueError("Web provider .name must be a non-empty string")
|
||||
with _lock:
|
||||
existing = _providers.get(name)
|
||||
_providers[name] = provider
|
||||
if existing is not None:
|
||||
logger.debug(
|
||||
"Web provider '%s' re-registered (was %r)",
|
||||
name, type(existing).__name__,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Registered web provider '%s' (%s)",
|
||||
name, type(provider).__name__,
|
||||
)
|
||||
|
||||
|
||||
def list_providers() -> List[WebSearchProvider]:
|
||||
"""Return all registered providers, sorted by name."""
|
||||
with _lock:
|
||||
items = list(_providers.values())
|
||||
return sorted(items, key=lambda p: p.name)
|
||||
|
||||
|
||||
def get_provider(name: str) -> Optional[WebSearchProvider]:
|
||||
"""Return the provider registered under *name*, or None."""
|
||||
if not isinstance(name, str):
|
||||
return None
|
||||
with _lock:
|
||||
return _providers.get(name.strip())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Active-provider resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _read_config_key(*path: str) -> Optional[str]:
|
||||
"""Resolve a dotted config key from ``config.yaml``. Returns None on miss."""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
|
||||
cfg = load_config()
|
||||
cur = cfg
|
||||
for segment in path:
|
||||
if not isinstance(cur, dict):
|
||||
return None
|
||||
cur = cur.get(segment)
|
||||
if isinstance(cur, str) and cur.strip():
|
||||
return cur.strip()
|
||||
except Exception as exc:
|
||||
logger.debug("Could not read config %s: %s", ".".join(path), exc)
|
||||
return None
|
||||
|
||||
|
||||
# Legacy preference order — preserves behaviour for users who set no
|
||||
# ``web.backend`` / ``web.<capability>_backend`` config key at all. Matches
|
||||
# the historic candidate order in :func:`tools.web_tools._get_backend`
|
||||
# (paid providers first so existing paid setups don't get downgraded to
|
||||
# a free tier on upgrade). Filtered by ``is_available()`` at walk time so
|
||||
# we don't surface a provider the user has no credentials for.
|
||||
_LEGACY_PREFERENCE = (
|
||||
"firecrawl",
|
||||
"parallel",
|
||||
"tavily",
|
||||
"exa",
|
||||
"searxng",
|
||||
"brave-free",
|
||||
"ddgs",
|
||||
)
|
||||
|
||||
|
||||
def _resolve(configured: Optional[str], *, capability: str) -> Optional[WebSearchProvider]:
|
||||
"""Resolve the active provider for a capability ("search" | "extract" | "crawl").
|
||||
|
||||
Resolution rules (in order):
|
||||
|
||||
1. **Explicit config wins, ignoring availability.** If
|
||||
``web.{capability}_backend`` or ``web.backend`` names a registered
|
||||
provider that supports *capability*, return it even if its
|
||||
:meth:`is_available` returns False — the dispatcher will surface a
|
||||
precise "X_API_KEY is not set" error to the user instead of silently
|
||||
routing somewhere else. Matches legacy
|
||||
:func:`tools.web_tools._get_backend` behavior for configured names.
|
||||
|
||||
2. **Single-provider shortcut.** When only one registered provider
|
||||
supports *capability* AND ``is_available()`` reports True, return it.
|
||||
|
||||
3. **Legacy preference walk, filtered by availability.** Walk the
|
||||
:data:`_LEGACY_PREFERENCE` order (firecrawl → parallel → tavily →
|
||||
exa → searxng → brave-free → ddgs) looking for a provider whose
|
||||
``supports_<capability>()`` is True AND whose ``is_available()`` is
|
||||
True. Matches the historic ``tools.web_tools._get_backend()``
|
||||
candidate order so users with credentials but no explicit config
|
||||
key keep landing on the same provider as pre-migration. This is
|
||||
the path that fires when no config key is set — pick the
|
||||
highest-priority backend the user actually has credentials for.
|
||||
|
||||
Returns None when no provider is configured AND no available provider
|
||||
matches the legacy preference; the dispatcher then returns a "set up a
|
||||
provider" error to the user.
|
||||
"""
|
||||
with _lock:
|
||||
snapshot = dict(_providers)
|
||||
|
||||
def _capable(p: WebSearchProvider) -> bool:
|
||||
if capability == "search":
|
||||
return bool(p.supports_search())
|
||||
if capability == "extract":
|
||||
return bool(p.supports_extract())
|
||||
if capability == "crawl":
|
||||
return bool(p.supports_crawl())
|
||||
return False
|
||||
|
||||
def _is_available_safe(p: WebSearchProvider) -> bool:
|
||||
"""Wrap ``is_available()`` so a buggy provider doesn't kill resolution."""
|
||||
try:
|
||||
return bool(p.is_available())
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug("provider %s.is_available() raised %s", p.name, exc)
|
||||
return False
|
||||
|
||||
# 1. Explicit config wins — return regardless of is_available() so the
|
||||
# user gets a precise downstream error message rather than a silent
|
||||
# backend switch. Matches _get_backend() in web_tools.py.
|
||||
if configured:
|
||||
provider = snapshot.get(configured)
|
||||
if provider is not None and _capable(provider):
|
||||
return provider
|
||||
if provider is None:
|
||||
logger.debug(
|
||||
"web backend '%s' configured but not registered; falling back",
|
||||
configured,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"web backend '%s' configured but does not support '%s'; falling back",
|
||||
configured, capability,
|
||||
)
|
||||
|
||||
# 2. + 3. Fallback path — filter by availability so we don't surface
|
||||
# a provider the user has no credentials for. Without this filter,
|
||||
# a registered-but-unconfigured provider could end up "active" on
|
||||
# a fresh install with no API keys at all.
|
||||
eligible = [
|
||||
p for p in snapshot.values()
|
||||
if _capable(p) and _is_available_safe(p)
|
||||
]
|
||||
if len(eligible) == 1:
|
||||
return eligible[0]
|
||||
|
||||
for legacy in _LEGACY_PREFERENCE:
|
||||
provider = snapshot.get(legacy)
|
||||
if (
|
||||
provider is not None
|
||||
and _capable(provider)
|
||||
and _is_available_safe(provider)
|
||||
):
|
||||
return provider
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_active_search_provider() -> Optional[WebSearchProvider]:
|
||||
"""Resolve the currently-active web search provider.
|
||||
|
||||
Reads ``web.search_backend`` (preferred) or ``web.backend`` (shared
|
||||
fallback) from config.yaml; falls back per the module docstring.
|
||||
"""
|
||||
explicit = _read_config_key("web", "search_backend") or _read_config_key("web", "backend")
|
||||
return _resolve(explicit, capability="search")
|
||||
|
||||
|
||||
def get_active_extract_provider() -> Optional[WebSearchProvider]:
|
||||
"""Resolve the currently-active web extract provider.
|
||||
|
||||
Reads ``web.extract_backend`` (preferred) or ``web.backend`` (shared
|
||||
fallback) from config.yaml; falls back per the module docstring.
|
||||
"""
|
||||
explicit = _read_config_key("web", "extract_backend") or _read_config_key("web", "backend")
|
||||
return _resolve(explicit, capability="extract")
|
||||
|
||||
|
||||
def get_active_crawl_provider() -> Optional[WebSearchProvider]:
|
||||
"""Resolve the currently-active web crawl provider.
|
||||
|
||||
Reads ``web.crawl_backend`` (preferred) or ``web.backend`` (shared
|
||||
fallback) from config.yaml; falls back per the module docstring.
|
||||
|
||||
Crawl is a niche capability — among built-in providers only Tavily and
|
||||
Firecrawl implement it. Callers should expect ``None`` and fall back to
|
||||
a different strategy (e.g. summarize-via-LLM) when neither is
|
||||
configured.
|
||||
"""
|
||||
explicit = _read_config_key("web", "crawl_backend") or _read_config_key("web", "backend")
|
||||
return _resolve(explicit, capability="crawl")
|
||||
|
||||
|
||||
def _reset_for_tests() -> None:
|
||||
"""Clear the registry. **Test-only.**"""
|
||||
with _lock:
|
||||
_providers.clear()
|
||||
@@ -364,6 +364,18 @@ compression:
|
||||
# compression of older turns.
|
||||
protect_last_n: 20
|
||||
|
||||
# Number of non-system messages to protect at the head of the transcript, in
|
||||
# ADDITION to the system prompt (which is always implicitly protected).
|
||||
# Head messages are NEVER summarized — they survive every compression
|
||||
# indefinitely. This gives stable early context for short/medium sessions,
|
||||
# but in long-running sessions that rely on rolling compaction the pinned
|
||||
# opening turns may not match how you want the session framed over time.
|
||||
# Set to 0 to preserve ONLY the system prompt (plus the rolling summary
|
||||
# and recent tail) — the cleanest configuration for long-running sessions.
|
||||
# Default 3 preserves the system prompt plus the first three non-system
|
||||
# head messages, matching the pre-feature behaviour.
|
||||
protect_first_n: 3
|
||||
|
||||
# To pin a specific model/provider for compression summaries, use the
|
||||
# auxiliary section below (auxiliary.compression.provider / model).
|
||||
|
||||
@@ -432,6 +444,10 @@ prompt_caching:
|
||||
# model: ""
|
||||
# timeout: 30
|
||||
# max_concurrency: 3 # Limit parallel summaries to reduce request-burst 429s
|
||||
# default_mode: "fast" # 'fast' | 'summary' — mode used when caller passes none.
|
||||
# # fast: FTS5 snippet hits, no LLM call. Default.
|
||||
# # summary: LLM-generated prose synthesis across hits.
|
||||
# # guided requires anchors and cannot be a default.
|
||||
# extra_body: {} # Provider-specific OpenAI-compatible request fields
|
||||
# # Example for providers that support request-body
|
||||
# # reasoning controls:
|
||||
@@ -445,7 +461,7 @@ prompt_caching:
|
||||
# Two stores: MEMORY.md (agent's notes) and USER.md (user profile).
|
||||
# Character limits keep the memory small and focused. The agent manages
|
||||
# pruning -- when at the limit, it must consolidate or replace entries.
|
||||
# Disabled by default in batch_runner and RL environments.
|
||||
# Disabled by default in batch_runner.
|
||||
#
|
||||
memory:
|
||||
# Agent's personal notes: environment facts, conventions, things learned
|
||||
@@ -669,6 +685,16 @@ platform_toolsets:
|
||||
# # allowed_chats: ["-1001234567890"]
|
||||
# extra:
|
||||
# disable_link_previews: false # Set true to suppress Telegram URL previews in bot messages
|
||||
#
|
||||
# Discord-specific settings (config.yaml top-level, not under platforms:):
|
||||
#
|
||||
# discord:
|
||||
# require_mention: true # Require @mention in server channels (default: true)
|
||||
# auto_thread: true # Auto-create thread on @mention (default: true)
|
||||
# free_response_channels: "" # Channel IDs where no mention is needed
|
||||
# reactions: true # Show processing reactions (default: true)
|
||||
# history_backfill: true # Recover missed channel messages on mention (default: true)
|
||||
# history_backfill_limit: 50 # Max messages to scan backwards (default: 50)
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Available toolsets (use these names in platform_toolsets or the toolsets list)
|
||||
@@ -693,10 +719,9 @@ platform_toolsets:
|
||||
# todo - todo (in-memory task planning, no deps)
|
||||
# tts - text_to_speech (Edge TTS free, or ELEVENLABS/OPENAI/MINIMAX/MISTRAL key)
|
||||
# cronjob - cronjob (create/list/update/pause/resume/run/remove scheduled tasks)
|
||||
# rl - rl_list_environments, rl_start_training, etc. (requires TINKER_API_KEY)
|
||||
#
|
||||
# PRESETS (curated bundles):
|
||||
# hermes-cli - All of the above except rl + send_message
|
||||
# hermes-cli - All of the above except send_message
|
||||
# hermes-telegram - terminal, file, web, vision, image_gen, tts, browser,
|
||||
# skills, todo, cronjob, send_message
|
||||
# hermes-discord - Same as hermes-telegram
|
||||
@@ -722,7 +747,6 @@ platform_toolsets:
|
||||
# session_search - Search and recall past conversations (FTS5 + Gemini Flash summarization)
|
||||
# tts - Text-to-speech (Edge TTS free, ElevenLabs, OpenAI, MiniMax, Mistral)
|
||||
# cronjob - Schedule and manage automated tasks (CLI-only)
|
||||
# rl - RL training tools (Tinker-Atropos)
|
||||
#
|
||||
# Composite toolsets:
|
||||
# debugging - terminal + web + file (for troubleshooting)
|
||||
|
||||
67
cron/jobs.py
67
cron/jobs.py
@@ -645,6 +645,44 @@ def get_job(job_id: str) -> Optional[Dict[str, Any]]:
|
||||
return None
|
||||
|
||||
|
||||
class AmbiguousJobReference(LookupError):
|
||||
"""Raised when a job name matches more than one job."""
|
||||
|
||||
def __init__(self, ref: str, matches: List[Dict[str, Any]]):
|
||||
self.ref = ref
|
||||
self.matches = matches
|
||||
ids = ", ".join(m["id"] for m in matches)
|
||||
super().__init__(
|
||||
f"Job name '{ref}' is ambiguous — matches {len(matches)} jobs: {ids}. "
|
||||
f"Use the job ID instead."
|
||||
)
|
||||
|
||||
|
||||
def resolve_job_ref(ref: str) -> Optional[Dict[str, Any]]:
|
||||
"""Resolve a job reference (ID or name) to a job record.
|
||||
|
||||
- Exact ID match wins (works even if a different job's name equals this ID).
|
||||
- Otherwise, case-insensitive name match.
|
||||
- If a name matches more than one job, raises AmbiguousJobReference so the
|
||||
caller can surface the matching IDs rather than silently picking one.
|
||||
"""
|
||||
if not ref:
|
||||
return None
|
||||
jobs = load_jobs()
|
||||
for job in jobs:
|
||||
if job["id"] == ref:
|
||||
return _normalize_job_record(job)
|
||||
ref_lower = ref.lower()
|
||||
name_matches = [j for j in jobs if (j.get("name") or "").lower() == ref_lower]
|
||||
if not name_matches:
|
||||
return None
|
||||
if len(name_matches) > 1:
|
||||
raise AmbiguousJobReference(
|
||||
ref, [_normalize_job_record(j) for j in name_matches]
|
||||
)
|
||||
return _normalize_job_record(name_matches[0])
|
||||
|
||||
|
||||
def list_jobs(include_disabled: bool = False) -> List[Dict[str, Any]]:
|
||||
"""List all jobs, optionally including disabled ones."""
|
||||
jobs = [_normalize_job_record(j) for j in load_jobs()]
|
||||
@@ -702,9 +740,12 @@ def update_job(job_id: str, updates: Dict[str, Any]) -> Optional[Dict[str, Any]]
|
||||
|
||||
|
||||
def pause_job(job_id: str, reason: Optional[str] = None) -> Optional[Dict[str, Any]]:
|
||||
"""Pause a job without deleting it."""
|
||||
"""Pause a job without deleting it. Accepts a job ID or name."""
|
||||
job = resolve_job_ref(job_id)
|
||||
if not job:
|
||||
return None
|
||||
return update_job(
|
||||
job_id,
|
||||
job["id"],
|
||||
{
|
||||
"enabled": False,
|
||||
"state": "paused",
|
||||
@@ -715,14 +756,14 @@ def pause_job(job_id: str, reason: Optional[str] = None) -> Optional[Dict[str, A
|
||||
|
||||
|
||||
def resume_job(job_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Resume a paused job and compute the next future run from now."""
|
||||
job = get_job(job_id)
|
||||
"""Resume a paused job and compute the next future run from now. Accepts a job ID or name."""
|
||||
job = resolve_job_ref(job_id)
|
||||
if not job:
|
||||
return None
|
||||
|
||||
next_run_at = compute_next_run(job["schedule"])
|
||||
return update_job(
|
||||
job_id,
|
||||
job["id"],
|
||||
{
|
||||
"enabled": True,
|
||||
"state": "scheduled",
|
||||
@@ -734,12 +775,12 @@ def resume_job(job_id: str) -> Optional[Dict[str, Any]]:
|
||||
|
||||
|
||||
def trigger_job(job_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Schedule a job to run on the next scheduler tick."""
|
||||
job = get_job(job_id)
|
||||
"""Schedule a job to run on the next scheduler tick. Accepts a job ID or name."""
|
||||
job = resolve_job_ref(job_id)
|
||||
if not job:
|
||||
return None
|
||||
return update_job(
|
||||
job_id,
|
||||
job["id"],
|
||||
{
|
||||
"enabled": True,
|
||||
"state": "scheduled",
|
||||
@@ -751,14 +792,18 @@ def trigger_job(job_id: str) -> Optional[Dict[str, Any]]:
|
||||
|
||||
|
||||
def remove_job(job_id: str) -> bool:
|
||||
"""Remove a job by ID."""
|
||||
"""Remove a job by ID or name."""
|
||||
job = resolve_job_ref(job_id)
|
||||
if not job:
|
||||
return False
|
||||
canonical_id = job["id"]
|
||||
jobs = load_jobs()
|
||||
original_len = len(jobs)
|
||||
jobs = [j for j in jobs if j["id"] != job_id]
|
||||
jobs = [j for j in jobs if j["id"] != canonical_id]
|
||||
if len(jobs) < original_len:
|
||||
save_jobs(jobs)
|
||||
# Clean up output directory to prevent orphaned dirs accumulating
|
||||
job_output_dir = OUTPUT_DIR / job_id
|
||||
job_output_dir = OUTPUT_DIR / canonical_id
|
||||
if job_output_dir.exists():
|
||||
shutil.rmtree(job_output_dir)
|
||||
return True
|
||||
|
||||
@@ -111,6 +111,7 @@ _HOME_TARGET_ENV_VARS = {
|
||||
"weixin": "WEIXIN_HOME_CHANNEL",
|
||||
"bluebubbles": "BLUEBUBBLES_HOME_CHANNEL",
|
||||
"qqbot": "QQBOT_HOME_CHANNEL",
|
||||
"whatsapp": "WHATSAPP_HOME_CHANNEL",
|
||||
}
|
||||
|
||||
# Legacy env var names kept for back-compat. Each entry is the current
|
||||
|
||||
@@ -39,6 +39,10 @@ if [ "$(id -u)" = "0" ]; then
|
||||
# by the mapped user on the host side.
|
||||
chown -R hermes:hermes "$HERMES_HOME" 2>/dev/null || \
|
||||
echo "Warning: chown failed (rootless container?) — continuing anyway"
|
||||
# The .venv must also be re-chowned when UID is remapped, otherwise
|
||||
# lazy_deps.py cannot install platform packages (discord.py, etc.).
|
||||
chown -R hermes:hermes "$INSTALL_DIR/.venv" 2>/dev/null || \
|
||||
echo "Warning: chown .venv failed (rootless container?) — continuing anyway"
|
||||
fi
|
||||
|
||||
# Ensure config.yaml is readable by the hermes runtime user even if it was
|
||||
|
||||
@@ -1,324 +0,0 @@
|
||||
# Hermes-Agent Atropos Environments
|
||||
|
||||
This directory contains the integration layer between **hermes-agent's** tool-calling capabilities and the **Atropos** RL training framework. It provides everything needed to run agentic LLMs through multi-turn tool-calling loops, score their output with arbitrary reward functions, and feed results into Atropos for training or evaluation.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
```
|
||||
Atropos Framework
|
||||
┌───────────────────────┐
|
||||
│ BaseEnv │ (atroposlib)
|
||||
│ - Server management │
|
||||
│ - Worker scheduling │
|
||||
│ - Wandb logging │
|
||||
│ - CLI (serve/process/ │
|
||||
│ evaluate) │
|
||||
└───────────┬───────────┘
|
||||
│ inherits
|
||||
┌───────────┴───────────┐
|
||||
│ HermesAgentBaseEnv │ hermes_base_env.py
|
||||
│ - Terminal backend │
|
||||
│ - Tool resolution │
|
||||
│ - Agent loop │
|
||||
│ - ToolContext │
|
||||
│ - Async patches │
|
||||
└───────────┬───────────┘
|
||||
│ inherits
|
||||
┌─────────────────┼─────────────────┐
|
||||
│ │ │
|
||||
TerminalTestEnv HermesSweEnv TerminalBench2EvalEnv
|
||||
(stack testing) (SWE training) (TB2 benchmark eval)
|
||||
```
|
||||
|
||||
### Inheritance Chain
|
||||
|
||||
**BaseEnv** (from `atroposlib`) is the Atropos base class. It provides:
|
||||
- Server management (OpenAI-compatible API servers, VLLM, SGLang)
|
||||
- Worker scheduling for parallel rollouts
|
||||
- Wandb integration for metrics and rollout logging
|
||||
- CLI interface with three subcommands: `serve`, `process`, `evaluate`
|
||||
- `evaluate_log()` for saving eval results to JSON + samples.jsonl
|
||||
|
||||
**HermesAgentBaseEnv** (`hermes_base_env.py`) extends BaseEnv with hermes-agent specifics:
|
||||
- Sets `os.environ["TERMINAL_ENV"]` to configure the terminal backend (local, docker, ssh, singularity, modal, daytona, vercel_sandbox)
|
||||
- Resolves hermes-agent toolsets via `_resolve_tools_for_group()` (calls `get_tool_definitions()` which queries `tools/registry.py`)
|
||||
- Implements `collect_trajectory()` which runs the full agent loop and computes rewards
|
||||
- Supports two-phase operation (Phase 1: OpenAI server, Phase 2: VLLM ManagedServer)
|
||||
- Applies monkey patches for async-safe tool operation at import time
|
||||
|
||||
Concrete environments inherit from `HermesAgentBaseEnv` and implement:
|
||||
- `setup()` -- Load dataset, initialize state
|
||||
- `get_next_item()` -- Return the next item for rollout
|
||||
- `format_prompt()` -- Convert a dataset item into the user message
|
||||
- `compute_reward()` -- Score the rollout using ToolContext
|
||||
- `evaluate()` -- Periodic evaluation logic
|
||||
|
||||
## Core Components
|
||||
|
||||
### Agent Loop (`agent_loop.py`)
|
||||
|
||||
`HermesAgentLoop` is the reusable multi-turn agent engine. It runs the same pattern as hermes-agent's `run_agent.py`:
|
||||
|
||||
1. Send messages + tools to the API via `server.chat_completion()`
|
||||
2. If the response contains `tool_calls`, execute each one via `handle_function_call()` (which delegates to `tools/registry.py`'s `dispatch()`)
|
||||
3. Append tool results to the conversation and go back to step 1
|
||||
4. If the response has no tool_calls, the agent is done
|
||||
|
||||
Tool calls are executed in a thread pool (`run_in_executor`) so backends that use `asyncio.run()` internally (Modal, Docker) don't deadlock inside Atropos's event loop.
|
||||
|
||||
Returns an `AgentResult` containing the full conversation history, turn count, reasoning content per turn, tool errors, and optional ManagedServer state (for Phase 2).
|
||||
|
||||
### Tool Context (`tool_context.py`)
|
||||
|
||||
`ToolContext` is a per-rollout handle that gives reward/verification functions direct access to **all** hermes-agent tools, scoped to the rollout's `task_id`. The same `task_id` means the terminal/browser session is the SAME one the model used during its rollout -- all state (files, processes, browser tabs) is preserved.
|
||||
|
||||
```python
|
||||
async def compute_reward(self, item, result, ctx: ToolContext):
|
||||
# Run tests in the model's terminal sandbox
|
||||
test = ctx.terminal("pytest -v")
|
||||
if test["exit_code"] == 0:
|
||||
return 1.0
|
||||
|
||||
# Check if a file was created
|
||||
content = ctx.read_file("/workspace/solution.py")
|
||||
if content.get("content"):
|
||||
return 0.5
|
||||
|
||||
# Download files locally for verification (binary-safe)
|
||||
ctx.download_file("/remote/output.bin", "/local/output.bin")
|
||||
|
||||
return 0.0
|
||||
```
|
||||
|
||||
Available methods:
|
||||
- **Terminal**: `terminal(command, timeout)` -- run shell commands
|
||||
- **Files**: `read_file(path)`, `write_file(path, content)`, `search(query, path)`
|
||||
- **Transfers**: `upload_file()`, `upload_dir()`, `download_file()`, `download_dir()` -- binary-safe file transfers between host and sandbox
|
||||
- **Web**: `web_search(query)`, `web_extract(urls)`
|
||||
- **Browser**: `browser_navigate(url)`, `browser_snapshot()`
|
||||
- **Generic**: `call_tool(name, args)` -- call any hermes-agent tool by name
|
||||
- **Cleanup**: `cleanup()` -- release all resources (called automatically after `compute_reward`)
|
||||
|
||||
### Patches (`patches.py`)
|
||||
|
||||
**Problem**: Some hermes-agent tools use `asyncio.run()` internally (e.g., the Modal backend). This crashes when called from inside Atropos's event loop because `asyncio.run()` cannot be nested.
|
||||
|
||||
**Solution**: `ModalEnvironment` uses a dedicated `_AsyncWorker` background thread with its own event loop. The calling code sees a sync interface, but internally all async Modal SDK calls happen on the worker thread so they don't conflict with Atropos's loop. This is built directly into `tools/environments/modal.py` — no monkey-patching required.
|
||||
|
||||
`patches.py` is now a no-op (kept for backward compatibility with imports).
|
||||
|
||||
### Tool Call Parsers (`tool_call_parsers/`)
|
||||
|
||||
Client-side parsers that extract structured `tool_calls` from raw model output text. Used in **Phase 2** (VLLM server type) where ManagedServer's `/generate` endpoint returns raw text without tool call parsing.
|
||||
|
||||
Each parser is a standalone reimplementation of the corresponding VLLM parser's `extract_tool_calls()` logic. No VLLM dependency -- only standard library (`re`, `json`, `uuid`) and `openai` types.
|
||||
|
||||
Available parsers:
|
||||
- `hermes` -- Hermes/ChatML `<tool_call>` XML format
|
||||
- `mistral` -- Mistral `[TOOL_CALLS]` format
|
||||
- `llama3_json` -- Llama 3 JSON tool calling
|
||||
- `qwen` -- Qwen tool calling format
|
||||
- `qwen3_coder` -- Qwen3 Coder format
|
||||
- `deepseek_v3` -- DeepSeek V3 format
|
||||
- `deepseek_v3_1` -- DeepSeek V3.1 format
|
||||
- `kimi_k2` -- Kimi K2 format
|
||||
- `longcat` -- Longcat format
|
||||
- `glm45` / `glm47` -- GLM model formats
|
||||
|
||||
Usage:
|
||||
```python
|
||||
from environments.tool_call_parsers import get_parser
|
||||
|
||||
parser = get_parser("hermes")
|
||||
content, tool_calls = parser.parse(raw_model_output)
|
||||
```
|
||||
|
||||
In Phase 1 (OpenAI server type), these parsers are not needed -- the server handles tool call parsing natively.
|
||||
|
||||
## Two-Phase Operation
|
||||
|
||||
### Phase 1: OpenAI Server (Evaluation / SFT Data Generation)
|
||||
|
||||
Uses `server.chat_completion()` with `tools=` parameter. The server (VLLM, SGLang, OpenRouter, OpenAI) handles tool call parsing natively. Returns `ChatCompletion` objects with structured `tool_calls`.
|
||||
|
||||
- Good for: evaluation, SFT data generation, testing
|
||||
- Run with: `serve` (with `run-api`), `process`, or `evaluate` subcommands
|
||||
- Placeholder tokens are created for the Atropos pipeline
|
||||
|
||||
### Phase 2: VLLM ManagedServer (Full RL Training)
|
||||
|
||||
Uses ManagedServer for exact token IDs + logprobs via `/generate`. Client-side tool call parser (from `tool_call_parsers/`) reconstructs structured `tool_calls` from raw output.
|
||||
|
||||
- Good for: full RL training with GRPO/PPO
|
||||
- Run with: `serve` subcommand
|
||||
- Real tokens, masks, and logprobs flow through the pipeline
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
environments/
|
||||
├── README.md # This file
|
||||
├── __init__.py # Package exports
|
||||
├── hermes_base_env.py # Abstract base (HermesAgentBaseEnv)
|
||||
├── agent_loop.py # Multi-turn agent engine (HermesAgentLoop)
|
||||
├── tool_context.py # Per-rollout tool access for reward functions
|
||||
├── patches.py # Async-safety patches for Modal backend
|
||||
│
|
||||
├── tool_call_parsers/ # Phase 2 client-side parsers
|
||||
│ ├── __init__.py # Registry + base class
|
||||
│ ├── hermes_parser.py
|
||||
│ ├── mistral_parser.py
|
||||
│ ├── llama_parser.py
|
||||
│ ├── qwen_parser.py
|
||||
│ ├── qwen3_coder_parser.py
|
||||
│ ├── deepseek_v3_parser.py
|
||||
│ ├── deepseek_v3_1_parser.py
|
||||
│ ├── kimi_k2_parser.py
|
||||
│ ├── longcat_parser.py
|
||||
│ ├── glm45_parser.py
|
||||
│ └── glm47_parser.py
|
||||
│
|
||||
├── terminal_test_env/ # Stack validation environment
|
||||
│ └── terminal_test_env.py
|
||||
│
|
||||
├── hermes_swe_env/ # SWE-bench style training environment
|
||||
│ └── hermes_swe_env.py
|
||||
│
|
||||
└── benchmarks/ # Evaluation benchmarks
|
||||
├── terminalbench_2/ # 89 terminal tasks, Modal sandboxes
|
||||
│ └── terminalbench2_env.py
|
||||
├── tblite/ # 100 calibrated tasks (fast TB2 proxy)
|
||||
│ └── tblite_env.py
|
||||
└── yc_bench/ # Long-horizon strategic benchmark
|
||||
└── yc_bench_env.py
|
||||
```
|
||||
|
||||
## Concrete Environments
|
||||
|
||||
### TerminalTestEnv (`terminal_test_env/`)
|
||||
|
||||
A self-contained environment with inline tasks (no external dataset needed) for validating the full stack end-to-end. Each task asks the model to create a file at a known path, and the verifier checks the content matches.
|
||||
|
||||
```bash
|
||||
# Serve mode (needs run-api)
|
||||
run-api
|
||||
python environments/terminal_test_env/terminal_test_env.py serve
|
||||
|
||||
# Process mode (no run-api, saves to JSONL)
|
||||
python environments/terminal_test_env/terminal_test_env.py process \
|
||||
--env.data_path_to_save_groups terminal_test_output.jsonl
|
||||
```
|
||||
|
||||
### HermesSweEnv (`hermes_swe_env/`)
|
||||
|
||||
SWE-bench style training environment. The model gets a coding task, uses terminal + file + web tools to solve it, and the reward function runs tests in the same Modal sandbox.
|
||||
|
||||
```bash
|
||||
python environments/hermes_swe_env/hermes_swe_env.py serve \
|
||||
--openai.model_name YourModel \
|
||||
--env.dataset_name bigcode/humanevalpack \
|
||||
--env.terminal_backend modal
|
||||
```
|
||||
|
||||
### TerminalBench2EvalEnv (`benchmarks/terminalbench_2/`)
|
||||
|
||||
**Eval-only** environment for the Terminal-Bench 2.0 benchmark (89 tasks). Each task gets a pre-built Docker Hub image, a natural language instruction, and a test suite. The agent uses terminal + file tools to solve the task, then the test suite verifies correctness.
|
||||
|
||||
Follows the standard Atropos eval pattern (like GPQA, MMLU, etc.):
|
||||
- Run via `evaluate` subcommand (no `run-api` needed)
|
||||
- `setup()` loads the dataset, `evaluate()` runs all tasks
|
||||
- `rollout_and_score_eval()` handles per-task agent loop + test verification
|
||||
- Downloads verifier output locally for reliable reward checking (Harbor pattern)
|
||||
|
||||
```bash
|
||||
# Run full benchmark
|
||||
python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
||||
--openai.model_name anthropic/claude-opus-4.6
|
||||
|
||||
# Run subset of tasks
|
||||
python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
||||
--openai.model_name anthropic/claude-opus-4.6 \
|
||||
--env.task_filter fix-git,git-multibranch
|
||||
|
||||
# Skip specific tasks
|
||||
python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
||||
--openai.model_name anthropic/claude-opus-4.6 \
|
||||
--env.skip_tasks heavy-task,slow-task
|
||||
```
|
||||
|
||||
## Creating a New Environment
|
||||
|
||||
### Training Environment
|
||||
|
||||
1. Create a new directory under `environments/`
|
||||
2. Create your env file inheriting from `HermesAgentBaseEnv`
|
||||
3. Implement the four abstract methods + `evaluate()`
|
||||
|
||||
```python
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
|
||||
class MyEnvConfig(HermesAgentEnvConfig):
|
||||
pass # Add custom fields as needed
|
||||
|
||||
class MyEnv(HermesAgentBaseEnv):
|
||||
name = "my-env"
|
||||
env_config_cls = MyEnvConfig
|
||||
|
||||
@classmethod
|
||||
def config_init(cls):
|
||||
env_config = MyEnvConfig(
|
||||
enabled_toolsets=["terminal", "file"],
|
||||
terminal_backend="modal",
|
||||
# ... other config
|
||||
)
|
||||
server_configs = [APIServerConfig(...)]
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup(self):
|
||||
self.dataset = load_dataset(...)
|
||||
self.iter = 0
|
||||
|
||||
async def get_next_item(self):
|
||||
item = self.dataset[self.iter % len(self.dataset)]
|
||||
self.iter += 1
|
||||
return item
|
||||
|
||||
def format_prompt(self, item):
|
||||
return item["instruction"]
|
||||
|
||||
async def compute_reward(self, item, result, ctx):
|
||||
# ctx gives you full tool access to the rollout's sandbox
|
||||
test = ctx.terminal("pytest -v")
|
||||
return 1.0 if test["exit_code"] == 0 else 0.0
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
# Periodic evaluation logic
|
||||
...
|
||||
|
||||
if __name__ == "__main__":
|
||||
MyEnv.cli()
|
||||
```
|
||||
|
||||
### Eval-Only Environment (Benchmark)
|
||||
|
||||
For eval benchmarks, follow the pattern in `terminalbench2_env.py`:
|
||||
1. Create under `environments/benchmarks/your-benchmark/`
|
||||
2. Inherit from `HermesAgentBaseEnv`
|
||||
3. Set eval-only config: `eval_handling=STOP_TRAIN`, `steps_per_eval=1`, `total_steps=1`
|
||||
4. Stub the training methods (`collect_trajectories`, `score`)
|
||||
5. Implement `rollout_and_score_eval()` and `evaluate()`
|
||||
6. Run with `evaluate` subcommand
|
||||
|
||||
## Key Config Fields
|
||||
|
||||
| Field | Description | Default |
|
||||
|-------|-------------|---------|
|
||||
| `enabled_toolsets` | Which hermes toolsets to enable | `None` (all) |
|
||||
| `disabled_toolsets` | Toolsets to disable | `None` |
|
||||
| `distribution` | Probabilistic toolset distribution name | `None` |
|
||||
| `max_agent_turns` | Max LLM calls per rollout | `30` |
|
||||
| `agent_temperature` | Sampling temperature | `1.0` |
|
||||
| `terminal_backend` | `local`, `docker`, `modal`, `daytona`, `ssh`, `singularity` | `local` |
|
||||
| `system_prompt` | System message for the agent | `None` |
|
||||
| `tool_call_parser` | Parser name for Phase 2 | `hermes` |
|
||||
| `eval_handling` | `STOP_TRAIN`, `LIMIT_TRAIN`, `NONE` | `STOP_TRAIN` |
|
||||
@@ -1,36 +0,0 @@
|
||||
"""
|
||||
Hermes-Agent Atropos Environments
|
||||
|
||||
Provides a layered integration between hermes-agent's tool-calling capabilities
|
||||
and the Atropos RL training framework.
|
||||
|
||||
Core layers:
|
||||
- agent_loop: Reusable multi-turn agent loop with standard OpenAI-spec tool calling
|
||||
- tool_context: Per-rollout tool access handle for reward/verification functions
|
||||
- hermes_base_env: Abstract base environment (BaseEnv subclass) for Atropos
|
||||
- tool_call_parsers: Client-side tool call parser registry for Phase 2 (VLLM /generate)
|
||||
|
||||
Concrete environments:
|
||||
- terminal_test_env/: Simple file-creation tasks for testing the stack
|
||||
- hermes_swe_env/: SWE-bench style tasks with Modal sandboxes
|
||||
|
||||
Benchmarks (eval-only):
|
||||
- benchmarks/terminalbench_2/: Terminal-Bench 2.0 evaluation
|
||||
"""
|
||||
|
||||
try:
|
||||
from environments.agent_loop import AgentResult, HermesAgentLoop
|
||||
from environments.tool_context import ToolContext
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
except ImportError:
|
||||
# atroposlib not installed — environments are unavailable but
|
||||
# submodules like tool_call_parsers can still be imported directly.
|
||||
pass
|
||||
|
||||
__all__ = [
|
||||
"AgentResult",
|
||||
"HermesAgentLoop",
|
||||
"ToolContext",
|
||||
"HermesAgentBaseEnv",
|
||||
"HermesAgentEnvConfig",
|
||||
]
|
||||
@@ -1,534 +0,0 @@
|
||||
"""
|
||||
HermesAgentLoop -- Reusable Multi-Turn Agent Engine
|
||||
|
||||
Runs the hermes-agent tool-calling loop using standard OpenAI-spec tool calling.
|
||||
Works with any server that returns ChatCompletion objects with tool_calls:
|
||||
- Phase 1: OpenAI server type (VLLM, SGLang, OpenRouter, OpenAI API)
|
||||
- Phase 2: ManagedServer with client-side tool call parser
|
||||
|
||||
The loop passes tools= and checks response.choices[0].message.tool_calls,
|
||||
identical to hermes-agent's run_agent.py. Tool execution is dispatched via
|
||||
handle_function_call() from model_tools.py.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from model_tools import handle_function_call
|
||||
from tools.terminal_tool import get_active_env
|
||||
from tools.tool_result_storage import maybe_persist_tool_result, enforce_turn_budget
|
||||
|
||||
# Thread pool for running sync tool calls that internally use asyncio.run()
|
||||
# (e.g., the Modal/Docker/Daytona terminal backends). Running them in a separate
|
||||
# thread gives them a clean event loop so they don't deadlock inside Atropos's loop.
|
||||
# Size must be large enough for concurrent eval tasks (e.g., 89 TB2 tasks all
|
||||
# making tool calls). Too small = thread pool starvation, tasks queue for minutes.
|
||||
# Resized at runtime by HermesAgentBaseEnv.__init__ via resize_tool_pool().
|
||||
_tool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=128)
|
||||
|
||||
|
||||
def resize_tool_pool(max_workers: int):
|
||||
"""
|
||||
Replace the global tool executor with a new one of the given size.
|
||||
|
||||
Called by HermesAgentBaseEnv.__init__ based on config.tool_pool_size.
|
||||
Safe to call before any tasks are submitted.
|
||||
"""
|
||||
global _tool_executor
|
||||
old_executor = _tool_executor
|
||||
_tool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
|
||||
old_executor.shutdown(wait=False)
|
||||
logger.info("Tool thread pool resized to %d workers", max_workers)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolError:
|
||||
"""Record of a tool execution error during the agent loop."""
|
||||
|
||||
turn: int # Which turn the error occurred on
|
||||
tool_name: str # Which tool was called
|
||||
arguments: str # The arguments passed (truncated)
|
||||
error: str # The error message
|
||||
tool_result: str # The raw result returned to the model
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentResult:
|
||||
"""Result of running the agent loop."""
|
||||
|
||||
# Full conversation history in OpenAI message format
|
||||
messages: List[Dict[str, Any]]
|
||||
# ManagedServer.get_state() if available (Phase 2), None otherwise
|
||||
managed_state: Optional[Dict[str, Any]] = None
|
||||
# How many LLM calls were made
|
||||
turns_used: int = 0
|
||||
# True if model stopped calling tools naturally (vs hitting max_turns)
|
||||
finished_naturally: bool = False
|
||||
# Extracted reasoning content per turn (from PR #297 helpers)
|
||||
reasoning_per_turn: List[Optional[str]] = field(default_factory=list)
|
||||
# Tool errors encountered during the loop
|
||||
tool_errors: List[ToolError] = field(default_factory=list)
|
||||
|
||||
|
||||
def _extract_reasoning_from_message(message) -> Optional[str]:
|
||||
"""
|
||||
Extract reasoning content from a ChatCompletion message.
|
||||
|
||||
Handles multiple provider formats:
|
||||
1. message.reasoning_content field (some providers)
|
||||
2. message.reasoning field (some providers)
|
||||
3. message.reasoning_details[].text (OpenRouter style)
|
||||
|
||||
Note: <think> block extraction from content is NOT done here -- that's
|
||||
handled by the response already in Phase 1 (server does it) or by
|
||||
ManagedServer's patch in Phase 2.
|
||||
|
||||
Args:
|
||||
message: The assistant message from ChatCompletion response
|
||||
|
||||
Returns:
|
||||
Extracted reasoning text, or None if not found
|
||||
"""
|
||||
# Check reasoning_content field (common across providers)
|
||||
if hasattr(message, "reasoning_content") and message.reasoning_content:
|
||||
return message.reasoning_content
|
||||
|
||||
# Check reasoning field
|
||||
if hasattr(message, "reasoning") and message.reasoning:
|
||||
return message.reasoning
|
||||
|
||||
# Check reasoning_details (OpenRouter style)
|
||||
if hasattr(message, "reasoning_details") and message.reasoning_details:
|
||||
for detail in message.reasoning_details:
|
||||
if hasattr(detail, "text") and detail.text:
|
||||
return detail.text
|
||||
if isinstance(detail, dict) and detail.get("text"):
|
||||
return detail["text"]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class HermesAgentLoop:
|
||||
"""
|
||||
Runs hermes-agent's tool-calling loop using standard OpenAI-spec tool calling.
|
||||
|
||||
Same pattern as run_agent.py:
|
||||
- Pass tools= to the API
|
||||
- Check response.choices[0].message.tool_calls
|
||||
- Dispatch via handle_function_call()
|
||||
|
||||
Works identically with any server type -- OpenAI, VLLM, SGLang, OpenRouter,
|
||||
or ManagedServer with a parser. The server determines how tool_calls get
|
||||
populated on the response.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server,
|
||||
tool_schemas: List[Dict[str, Any]],
|
||||
valid_tool_names: Set[str],
|
||||
max_turns: int = 30,
|
||||
task_id: Optional[str] = None,
|
||||
temperature: float = 1.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
budget_config: Optional["BudgetConfig"] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the agent loop.
|
||||
|
||||
Args:
|
||||
server: Server object with chat_completion() method (OpenAIServer,
|
||||
ManagedServer, ServerManager, etc.)
|
||||
tool_schemas: OpenAI-format tool definitions from get_tool_definitions()
|
||||
valid_tool_names: Set of tool names the model is allowed to call
|
||||
max_turns: Maximum number of LLM calls before stopping
|
||||
task_id: Unique ID for terminal/browser session isolation
|
||||
temperature: Sampling temperature for generation
|
||||
max_tokens: Max tokens per generation (None for server default)
|
||||
extra_body: Extra parameters passed to the OpenAI client's create() call.
|
||||
Used for OpenRouter provider preferences, transforms, etc.
|
||||
e.g. {"provider": {"ignore": ["DeepInfra"]}}
|
||||
budget_config: Tool result persistence budget. Controls per-tool
|
||||
thresholds, per-turn aggregate budget, and preview size.
|
||||
If None, uses DEFAULT_BUDGET (current hardcoded values).
|
||||
"""
|
||||
from tools.budget_config import DEFAULT_BUDGET
|
||||
self.server = server
|
||||
self.tool_schemas = tool_schemas
|
||||
self.valid_tool_names = valid_tool_names
|
||||
self.max_turns = max_turns
|
||||
self.task_id = task_id or str(uuid.uuid4())
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.extra_body = extra_body
|
||||
self.budget_config = budget_config or DEFAULT_BUDGET
|
||||
|
||||
async def run(self, messages: List[Dict[str, Any]]) -> AgentResult:
|
||||
"""
|
||||
Execute the full agent loop using standard OpenAI tool calling.
|
||||
|
||||
Args:
|
||||
messages: Initial conversation messages (system + user).
|
||||
Modified in-place as the conversation progresses.
|
||||
|
||||
Returns:
|
||||
AgentResult with full conversation history, managed state, and metadata
|
||||
"""
|
||||
reasoning_per_turn = []
|
||||
tool_errors: List[ToolError] = []
|
||||
|
||||
# Per-loop TodoStore for the todo tool (ephemeral, dies with the loop)
|
||||
from tools.todo_tool import TodoStore, todo_tool as _todo_tool
|
||||
_todo_store = TodoStore()
|
||||
|
||||
# Extract user task from first user message for browser_snapshot context
|
||||
_user_task = None
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, str) and content.strip():
|
||||
_user_task = content.strip()[:500] # Cap to avoid huge strings
|
||||
break
|
||||
|
||||
import time as _time
|
||||
|
||||
for turn in range(self.max_turns):
|
||||
turn_start = _time.monotonic()
|
||||
|
||||
# Build the chat_completion kwargs
|
||||
chat_kwargs = {
|
||||
"messages": messages,
|
||||
"n": 1,
|
||||
"temperature": self.temperature,
|
||||
}
|
||||
|
||||
# Only pass tools if we have them
|
||||
if self.tool_schemas:
|
||||
chat_kwargs["tools"] = self.tool_schemas
|
||||
|
||||
# Only pass max_tokens if explicitly set
|
||||
if self.max_tokens is not None:
|
||||
chat_kwargs["max_tokens"] = self.max_tokens
|
||||
|
||||
# Inject extra_body for provider-specific params (e.g., OpenRouter
|
||||
# provider preferences like banned/preferred providers, transforms)
|
||||
if self.extra_body:
|
||||
chat_kwargs["extra_body"] = self.extra_body
|
||||
|
||||
# Make the API call -- standard OpenAI spec
|
||||
api_start = _time.monotonic()
|
||||
try:
|
||||
response = await self.server.chat_completion(**chat_kwargs)
|
||||
except Exception as e:
|
||||
api_elapsed = _time.monotonic() - api_start
|
||||
logger.error("API call failed on turn %d (%.1fs): %s", turn + 1, api_elapsed, e)
|
||||
return AgentResult(
|
||||
messages=messages,
|
||||
managed_state=self._get_managed_state(),
|
||||
turns_used=turn + 1,
|
||||
finished_naturally=False,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
)
|
||||
|
||||
api_elapsed = _time.monotonic() - api_start
|
||||
|
||||
if not response or not response.choices:
|
||||
logger.warning("Empty response on turn %d (api=%.1fs)", turn + 1, api_elapsed)
|
||||
return AgentResult(
|
||||
messages=messages,
|
||||
managed_state=self._get_managed_state(),
|
||||
turns_used=turn + 1,
|
||||
finished_naturally=False,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
)
|
||||
|
||||
assistant_msg = response.choices[0].message
|
||||
|
||||
# Extract reasoning content from the response (all provider formats)
|
||||
reasoning = _extract_reasoning_from_message(assistant_msg)
|
||||
reasoning_per_turn.append(reasoning)
|
||||
|
||||
# Check for tool calls -- standard OpenAI spec.
|
||||
# Fallback: if response has no structured tool_calls but content
|
||||
# contains raw tool call tags (e.g. <tool_call>), parse them using
|
||||
# hermes-agent's standalone parsers. This handles the case where
|
||||
# ManagedServer's ToolCallTranslator couldn't parse because vLLM
|
||||
# isn't installed.
|
||||
if (
|
||||
not assistant_msg.tool_calls
|
||||
and assistant_msg.content
|
||||
and self.tool_schemas
|
||||
and "<tool_call>" in (assistant_msg.content or "")
|
||||
):
|
||||
try:
|
||||
from environments.tool_call_parsers import get_parser
|
||||
fallback_parser = get_parser("hermes")
|
||||
parsed_content, parsed_calls = fallback_parser.parse(
|
||||
assistant_msg.content
|
||||
)
|
||||
if parsed_calls:
|
||||
assistant_msg.tool_calls = parsed_calls
|
||||
if parsed_content is not None:
|
||||
assistant_msg.content = parsed_content
|
||||
logger.debug(
|
||||
"Fallback parser extracted %d tool calls from raw content",
|
||||
len(parsed_calls),
|
||||
)
|
||||
except Exception:
|
||||
pass # Fall through to no tool calls
|
||||
|
||||
if assistant_msg.tool_calls:
|
||||
# Normalize tool calls to dicts — they may come as objects
|
||||
# (OpenAI API) or dicts (vLLM ToolCallTranslator).
|
||||
def _tc_to_dict(tc):
|
||||
if isinstance(tc, dict):
|
||||
return {
|
||||
"id": tc.get("id", f"call_{uuid.uuid4().hex[:8]}"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.get("function", {}).get("name", tc.get("name", "")),
|
||||
"arguments": tc.get("function", {}).get("arguments", tc.get("arguments", "{}")),
|
||||
},
|
||||
}
|
||||
return {
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments,
|
||||
},
|
||||
}
|
||||
|
||||
# Build the assistant message dict for conversation history
|
||||
msg_dict: Dict[str, Any] = {
|
||||
"role": "assistant",
|
||||
"content": assistant_msg.content or "",
|
||||
"tool_calls": [_tc_to_dict(tc) for tc in assistant_msg.tool_calls],
|
||||
}
|
||||
|
||||
# Preserve reasoning_content for multi-turn chat template handling
|
||||
# (e.g., Kimi-K2's template renders <think> blocks differently
|
||||
# for history vs. the latest turn based on this field)
|
||||
if reasoning:
|
||||
msg_dict["reasoning_content"] = reasoning
|
||||
|
||||
messages.append(msg_dict)
|
||||
|
||||
# Execute each tool call via hermes-agent's dispatch
|
||||
for tc in assistant_msg.tool_calls:
|
||||
# Handle both object (OpenAI) and dict (vLLM) formats
|
||||
if isinstance(tc, dict):
|
||||
tool_name = tc.get("function", {}).get("name", tc.get("name", ""))
|
||||
tool_args_raw = tc.get("function", {}).get("arguments", tc.get("arguments", "{}"))
|
||||
else:
|
||||
tool_name = tc.function.name
|
||||
tool_args_raw = tc.function.arguments
|
||||
|
||||
# Validate tool name
|
||||
if tool_name not in self.valid_tool_names:
|
||||
tool_result = json.dumps(
|
||||
{
|
||||
"error": f"Unknown tool '{tool_name}'. "
|
||||
f"Available tools: {sorted(self.valid_tool_names)}"
|
||||
}
|
||||
)
|
||||
tool_errors.append(ToolError(
|
||||
turn=turn + 1, tool_name=tool_name,
|
||||
arguments=tool_args_raw[:200],
|
||||
error=f"Unknown tool '{tool_name}'",
|
||||
tool_result=tool_result,
|
||||
))
|
||||
logger.warning(
|
||||
"Model called unknown tool '%s' on turn %d",
|
||||
tool_name, turn + 1,
|
||||
)
|
||||
else:
|
||||
# Parse arguments
|
||||
try:
|
||||
args = json.loads(tool_args_raw)
|
||||
except json.JSONDecodeError as e:
|
||||
args = None
|
||||
tool_result = json.dumps(
|
||||
{"error": f"Invalid JSON in tool arguments: {e}. Please retry with valid JSON."}
|
||||
)
|
||||
tool_errors.append(ToolError(
|
||||
turn=turn + 1, tool_name=tool_name,
|
||||
arguments=tool_args_raw[:200],
|
||||
error=f"Invalid JSON: {e}",
|
||||
tool_result=tool_result,
|
||||
))
|
||||
logger.warning(
|
||||
"Invalid JSON in tool call arguments for '%s': %s",
|
||||
tool_name, tool_args_raw[:200],
|
||||
)
|
||||
|
||||
# Dispatch tool only if arguments parsed successfully
|
||||
if args is not None:
|
||||
try:
|
||||
if tool_name == "terminal":
|
||||
backend = os.getenv("TERMINAL_ENV", "local")
|
||||
cmd_preview = args.get("command", "")[:80]
|
||||
logger.info(
|
||||
"[%s] $ %s", self.task_id[:8], cmd_preview,
|
||||
)
|
||||
|
||||
tool_submit_time = _time.monotonic()
|
||||
|
||||
# Todo tool -- handle locally (needs per-loop TodoStore)
|
||||
if tool_name == "todo":
|
||||
tool_result = _todo_tool(
|
||||
todos=args.get("todos"),
|
||||
merge=args.get("merge", False),
|
||||
store=_todo_store,
|
||||
)
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
elif tool_name == "memory":
|
||||
tool_result = json.dumps({"error": "Memory is not available in RL environments."})
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
elif tool_name == "session_search":
|
||||
tool_result = json.dumps({"error": "Session search is not available in RL environments."})
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
else:
|
||||
# Run tool calls in a thread pool so backends that
|
||||
# use asyncio.run() internally (modal, docker, daytona) get
|
||||
# a clean event loop instead of deadlocking.
|
||||
loop = asyncio.get_running_loop()
|
||||
# Capture current tool_name/args for the lambda
|
||||
_tn, _ta, _tid = tool_name, args, self.task_id
|
||||
tool_result = await loop.run_in_executor(
|
||||
_tool_executor,
|
||||
lambda: handle_function_call(
|
||||
_tn, _ta, task_id=_tid,
|
||||
user_task=_user_task,
|
||||
),
|
||||
)
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
|
||||
# Log slow tools and thread pool stats for debugging
|
||||
pool_active = _tool_executor._work_queue.qsize()
|
||||
if tool_elapsed > 30:
|
||||
logger.warning(
|
||||
"[%s] turn %d: %s took %.1fs (pool queue=%d)",
|
||||
self.task_id[:8], turn + 1, tool_name,
|
||||
tool_elapsed, pool_active,
|
||||
)
|
||||
except Exception as e:
|
||||
tool_result = json.dumps(
|
||||
{"error": f"Tool execution failed: {type(e).__name__}: {str(e)}"}
|
||||
)
|
||||
tool_errors.append(ToolError(
|
||||
turn=turn + 1, tool_name=tool_name,
|
||||
arguments=tool_args_raw[:200],
|
||||
error=f"{type(e).__name__}: {str(e)}",
|
||||
tool_result=tool_result,
|
||||
))
|
||||
logger.error(
|
||||
"Tool '%s' execution failed on turn %d: %s",
|
||||
tool_name, turn + 1, e,
|
||||
)
|
||||
|
||||
# Also check if the tool returned an error in its JSON result
|
||||
try:
|
||||
result_data = json.loads(tool_result)
|
||||
if isinstance(result_data, dict):
|
||||
err = result_data.get("error")
|
||||
exit_code = result_data.get("exit_code")
|
||||
if err and exit_code and exit_code < 0:
|
||||
tool_errors.append(ToolError(
|
||||
turn=turn + 1, tool_name=tool_name,
|
||||
arguments=tool_args_raw[:200],
|
||||
error=str(err),
|
||||
tool_result=tool_result[:500],
|
||||
))
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
tc_id = tc.get("id", "") if isinstance(tc, dict) else tc.id
|
||||
tool_result = maybe_persist_tool_result(
|
||||
content=tool_result,
|
||||
tool_name=tool_name,
|
||||
tool_use_id=tc_id,
|
||||
env=get_active_env(self.task_id),
|
||||
config=self.budget_config,
|
||||
)
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tc_id,
|
||||
"content": tool_result,
|
||||
}
|
||||
)
|
||||
|
||||
num_tcs = len(assistant_msg.tool_calls)
|
||||
if num_tcs > 0:
|
||||
enforce_turn_budget(
|
||||
messages[-num_tcs:],
|
||||
env=get_active_env(self.task_id),
|
||||
config=self.budget_config,
|
||||
)
|
||||
|
||||
turn_elapsed = _time.monotonic() - turn_start
|
||||
logger.info(
|
||||
"[%s] turn %d: api=%.1fs, %d tools, turn_total=%.1fs",
|
||||
self.task_id[:8], turn + 1, api_elapsed,
|
||||
len(assistant_msg.tool_calls), turn_elapsed,
|
||||
)
|
||||
|
||||
else:
|
||||
# No tool calls -- model is done
|
||||
msg_dict = {
|
||||
"role": "assistant",
|
||||
"content": assistant_msg.content or "",
|
||||
}
|
||||
if reasoning:
|
||||
msg_dict["reasoning_content"] = reasoning
|
||||
messages.append(msg_dict)
|
||||
|
||||
turn_elapsed = _time.monotonic() - turn_start
|
||||
logger.info(
|
||||
"[%s] turn %d: api=%.1fs, no tools (finished), turn_total=%.1fs",
|
||||
self.task_id[:8], turn + 1, api_elapsed, turn_elapsed,
|
||||
)
|
||||
|
||||
return AgentResult(
|
||||
messages=messages,
|
||||
managed_state=self._get_managed_state(),
|
||||
turns_used=turn + 1,
|
||||
finished_naturally=True,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
)
|
||||
|
||||
# Hit max turns without the model stopping
|
||||
logger.info("Agent hit max_turns (%d) without finishing", self.max_turns)
|
||||
return AgentResult(
|
||||
messages=messages,
|
||||
managed_state=self._get_managed_state(),
|
||||
turns_used=self.max_turns,
|
||||
finished_naturally=False,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
)
|
||||
|
||||
def _get_managed_state(self) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get ManagedServer state if the server supports it.
|
||||
|
||||
Returns state dict with SequenceNodes containing tokens/logprobs/masks,
|
||||
or None if the server doesn't support get_state() (e.g., regular OpenAI server).
|
||||
"""
|
||||
if hasattr(self.server, "get_state"):
|
||||
return self.server.get_state()
|
||||
return None
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,73 +0,0 @@
|
||||
# OpenThoughts-TBLite Evaluation Environment
|
||||
|
||||
This environment evaluates terminal agents on the [OpenThoughts-TBLite](https://huggingface.co/datasets/open-thoughts/OpenThoughts-TBLite) benchmark, a difficulty-calibrated subset of [Terminal-Bench 2.0](https://www.tbench.ai/leaderboard/terminal-bench/2.0).
|
||||
|
||||
## Source
|
||||
|
||||
OpenThoughts-TBLite was created by the [OpenThoughts](https://www.openthoughts.ai/) Agent team in collaboration with [Snorkel AI](https://snorkel.ai/) and [Bespoke Labs](https://bespokelabs.ai/). The original dataset and documentation live at:
|
||||
|
||||
- **Dataset (source):** [open-thoughts/OpenThoughts-TBLite](https://huggingface.co/datasets/open-thoughts/OpenThoughts-TBLite)
|
||||
- **GitHub:** [open-thoughts/OpenThoughts-TBLite](https://github.com/open-thoughts/OpenThoughts-TBLite)
|
||||
- **Blog post:** [openthoughts.ai/blog/openthoughts-tblite](https://www.openthoughts.ai/blog/openthoughts-tblite)
|
||||
|
||||
## Our Dataset
|
||||
|
||||
We converted the source into the same schema used by our Terminal-Bench 2.0 environment (pre-built Docker Hub images, base64-encoded test tarballs, etc.) and published it as:
|
||||
|
||||
- **Dataset (ours):** [NousResearch/openthoughts-tblite](https://huggingface.co/datasets/NousResearch/openthoughts-tblite)
|
||||
- **Docker images:** `nousresearch/tblite-<task-name>:latest` on Docker Hub (100 images)
|
||||
|
||||
The conversion script is at `scripts/prepare_tblite_dataset.py`.
|
||||
|
||||
## Why TBLite?
|
||||
|
||||
Terminal-Bench 2.0 is one of the strongest frontier evaluations for terminal agents, but when a model scores near the floor (e.g., Qwen 3 8B at <1%), many changes look identical in aggregate score. TBLite addresses this by calibrating task difficulty using Claude Haiku 4.5 as a reference:
|
||||
|
||||
| Difficulty | Pass Rate Range | Tasks |
|
||||
|------------|----------------|-------|
|
||||
| Easy | >= 70% | 40 |
|
||||
| Medium | 40-69% | 26 |
|
||||
| Hard | 10-39% | 26 |
|
||||
| Extreme | < 10% | 8 |
|
||||
|
||||
This gives enough solvable tasks to detect small improvements quickly, while preserving enough hard tasks to avoid saturation. The correlation between TBLite and TB2 scores is **r = 0.911**.
|
||||
|
||||
TBLite also runs 2.6-8x faster than the full TB2, making it practical for iteration loops.
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# Run the full benchmark
|
||||
python environments/benchmarks/tblite/tblite_env.py evaluate
|
||||
|
||||
# Filter to specific tasks
|
||||
python environments/benchmarks/tblite/tblite_env.py evaluate \
|
||||
--env.task_filter "broken-python,pandas-etl"
|
||||
|
||||
# Use a different model
|
||||
python environments/benchmarks/tblite/tblite_env.py evaluate \
|
||||
--server.model_name "qwen/qwen3-30b"
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
`TBLiteEvalEnv` is a thin subclass of `TerminalBench2EvalEnv`. All evaluation logic (agent loop, Docker sandbox management, test verification, metrics) is inherited. Only the defaults differ:
|
||||
|
||||
| Setting | TB2 | TBLite |
|
||||
|----------------|----------------------------------|-----------------------------------------|
|
||||
| Dataset | `NousResearch/terminal-bench-2` | `NousResearch/openthoughts-tblite` |
|
||||
| Tasks | 89 | 100 |
|
||||
| Task timeout | 1800s (30 min) | 1200s (20 min) |
|
||||
| Wandb name | `terminal-bench-2` | `openthoughts-tblite` |
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@software{OpenThoughts-TBLite,
|
||||
author = {OpenThoughts-Agent team, Snorkel AI, Bespoke Labs},
|
||||
month = Feb,
|
||||
title = {{OpenThoughts-TBLite: A High-Signal Benchmark for Iterating on Terminal Agents}},
|
||||
howpublished = {https://www.openthoughts.ai/blog/openthoughts-tblite},
|
||||
year = {2026}
|
||||
}
|
||||
```
|
||||
@@ -1,39 +0,0 @@
|
||||
# OpenThoughts-TBLite Evaluation -- Default Configuration
|
||||
#
|
||||
# Eval-only environment for the TBLite benchmark (100 difficulty-calibrated
|
||||
# terminal tasks, a faster proxy for Terminal-Bench 2.0).
|
||||
# Uses Modal terminal backend for per-task cloud-isolated sandboxes
|
||||
# and OpenRouter for inference.
|
||||
#
|
||||
# Usage:
|
||||
# python environments/benchmarks/tblite/tblite_env.py evaluate \
|
||||
# --config environments/benchmarks/tblite/default.yaml
|
||||
#
|
||||
# # Override model:
|
||||
# python environments/benchmarks/tblite/tblite_env.py evaluate \
|
||||
# --config environments/benchmarks/tblite/default.yaml \
|
||||
# --openai.model_name anthropic/claude-sonnet-4
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal", "file"]
|
||||
max_agent_turns: 60
|
||||
max_token_length: 32000
|
||||
agent_temperature: 0.8
|
||||
terminal_backend: "modal"
|
||||
terminal_timeout: 300 # 5 min per command (builds, pip install)
|
||||
tool_pool_size: 128 # thread pool for 100 parallel tasks
|
||||
dataset_name: "NousResearch/openthoughts-tblite"
|
||||
test_timeout: 600
|
||||
task_timeout: 1200 # 20 min wall-clock per task (TBLite tasks are faster)
|
||||
tokenizer_name: "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
use_wandb: true
|
||||
wandb_name: "openthoughts-tblite"
|
||||
ensure_scores_are_not_same: false
|
||||
data_dir_to_save_evals: "environments/benchmarks/evals/openthoughts-tblite"
|
||||
|
||||
openai:
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
model_name: "anthropic/claude-opus-4.6"
|
||||
server_type: "openai"
|
||||
health_check: false
|
||||
# api_key loaded from OPENROUTER_API_KEY in .env
|
||||
@@ -1,38 +0,0 @@
|
||||
# OpenThoughts-TBLite Evaluation -- Docker Backend (Local Compute)
|
||||
#
|
||||
# Runs tasks in Docker containers on the local machine.
|
||||
# Sandboxed like Modal but no cloud costs. Good for dev/testing.
|
||||
#
|
||||
# Usage:
|
||||
# python environments/benchmarks/tblite/tblite_env.py evaluate \
|
||||
# --config environments/benchmarks/tblite/local.yaml
|
||||
#
|
||||
# # Override concurrency:
|
||||
# python environments/benchmarks/tblite/tblite_env.py evaluate \
|
||||
# --config environments/benchmarks/tblite/local.yaml \
|
||||
# --env.eval_concurrency 4
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal", "file"]
|
||||
max_agent_turns: 60
|
||||
max_token_length: 32000
|
||||
agent_temperature: 0.8
|
||||
terminal_backend: "docker"
|
||||
terminal_timeout: 300
|
||||
tool_pool_size: 16
|
||||
dataset_name: "NousResearch/openthoughts-tblite"
|
||||
test_timeout: 600
|
||||
task_timeout: 1200
|
||||
eval_concurrency: 8 # max 8 tasks at once
|
||||
tokenizer_name: "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
use_wandb: false
|
||||
wandb_name: "openthoughts-tblite-local"
|
||||
ensure_scores_are_not_same: false
|
||||
data_dir_to_save_evals: "environments/benchmarks/evals/openthoughts-tblite-local"
|
||||
|
||||
openai:
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
model_name: "anthropic/claude-sonnet-4"
|
||||
server_type: "openai"
|
||||
health_check: false
|
||||
# api_key loaded from OPENROUTER_API_KEY in .env
|
||||
@@ -1,40 +0,0 @@
|
||||
# OpenThoughts-TBLite Evaluation -- Local vLLM Backend
|
||||
#
|
||||
# Runs against a local vLLM server with Docker sandboxes.
|
||||
#
|
||||
# Start the vLLM server from the atropos directory:
|
||||
# python -m example_trainer.vllm_api_server \
|
||||
# --model Qwen/Qwen3-4B-Instruct-2507 \
|
||||
# --port 9001 \
|
||||
# --gpu-memory-utilization 0.8 \
|
||||
# --max-model-len=32000
|
||||
#
|
||||
# Then run:
|
||||
# python environments/benchmarks/tblite/tblite_env.py evaluate \
|
||||
# --config environments/benchmarks/tblite/local_vllm.yaml
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal", "file"]
|
||||
max_agent_turns: 60
|
||||
max_token_length: 16000
|
||||
agent_temperature: 0.6
|
||||
terminal_backend: "docker"
|
||||
terminal_timeout: 300
|
||||
tool_pool_size: 16
|
||||
dataset_name: "NousResearch/openthoughts-tblite"
|
||||
test_timeout: 600
|
||||
task_timeout: 1200
|
||||
eval_concurrency: 8
|
||||
tool_call_parser: "hermes"
|
||||
system_prompt: "You are an expert terminal agent. You MUST use the provided tools to complete tasks. Use the terminal tool to run shell commands, read_file to read files, write_file to write files, search_files to search, and patch to edit files. Do NOT write out solutions as text - execute them using the tools. Always start by exploring the environment with terminal commands."
|
||||
tokenizer_name: "Qwen/Qwen3-4B-Instruct-2507"
|
||||
use_wandb: false
|
||||
wandb_name: "tblite-qwen3-4b-instruct"
|
||||
ensure_scores_are_not_same: false
|
||||
data_dir_to_save_evals: "environments/benchmarks/evals/tblite-qwen3-4b-local"
|
||||
|
||||
openai:
|
||||
base_url: "http://localhost:9001"
|
||||
model_name: "Qwen/Qwen3-4B-Instruct-2507"
|
||||
server_type: "vllm"
|
||||
health_check: false
|
||||
@@ -1,42 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# OpenThoughts-TBLite Evaluation
|
||||
#
|
||||
# Run from repo root:
|
||||
# bash environments/benchmarks/tblite/run_eval.sh
|
||||
#
|
||||
# Override model:
|
||||
# bash environments/benchmarks/tblite/run_eval.sh \
|
||||
# --openai.model_name anthropic/claude-sonnet-4
|
||||
#
|
||||
# Run a subset:
|
||||
# bash environments/benchmarks/tblite/run_eval.sh \
|
||||
# --env.task_filter broken-python,pandas-etl
|
||||
#
|
||||
# All terminal settings (backend, timeout, lifetime, pool size) are
|
||||
# configured via env config fields -- no env vars needed.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
mkdir -p logs evals/openthoughts-tblite
|
||||
LOG_FILE="logs/tblite_$(date +%Y%m%d_%H%M%S).log"
|
||||
|
||||
echo "OpenThoughts-TBLite Evaluation"
|
||||
echo "Log file: $LOG_FILE"
|
||||
echo ""
|
||||
|
||||
# Unbuffered python output so logs are written in real-time
|
||||
export PYTHONUNBUFFERED=1
|
||||
|
||||
# Show INFO-level agent loop timing (api/tool durations per turn)
|
||||
# These go to the log file; tqdm + [START]/[PASS]/[FAIL] go to terminal
|
||||
export LOGLEVEL=INFO
|
||||
|
||||
python tblite_env.py evaluate \
|
||||
--config default.yaml \
|
||||
"$@" \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
echo ""
|
||||
echo "Log saved to: $LOG_FILE"
|
||||
echo "Eval results: evals/openthoughts-tblite/"
|
||||
@@ -1,119 +0,0 @@
|
||||
"""
|
||||
OpenThoughts-TBLite Evaluation Environment
|
||||
|
||||
A lighter, faster alternative to Terminal-Bench 2.0 for iterating on terminal
|
||||
agents. Uses the same evaluation logic as TerminalBench2EvalEnv but defaults
|
||||
to the NousResearch/openthoughts-tblite dataset (100 difficulty-calibrated
|
||||
tasks vs TB2's 89 harder tasks).
|
||||
|
||||
TBLite tasks are a curated subset of TB2 with a difficulty distribution
|
||||
designed to give meaningful signal even for smaller models:
|
||||
- Easy (40 tasks): >= 70% pass rate with Claude Haiku 4.5
|
||||
- Medium (26 tasks): 40-69% pass rate
|
||||
- Hard (26 tasks): 10-39% pass rate
|
||||
- Extreme (8 tasks): < 10% pass rate
|
||||
|
||||
Usage:
|
||||
python environments/benchmarks/tblite/tblite_env.py evaluate
|
||||
|
||||
# Filter to specific tasks:
|
||||
python environments/benchmarks/tblite/tblite_env.py evaluate \\
|
||||
--env.task_filter "broken-python,pandas-etl"
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
_repo_root = Path(__file__).resolve().parent.parent.parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import EvalHandlingEnum
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
|
||||
from environments.benchmarks.terminalbench_2.terminalbench2_env import (
|
||||
TerminalBench2EvalConfig,
|
||||
TerminalBench2EvalEnv,
|
||||
)
|
||||
|
||||
|
||||
class TBLiteEvalConfig(TerminalBench2EvalConfig):
|
||||
"""Configuration for the OpenThoughts-TBLite evaluation environment.
|
||||
|
||||
Inherits all TB2 config fields. Only the dataset default and task timeout
|
||||
differ -- TBLite tasks are calibrated to be faster.
|
||||
"""
|
||||
|
||||
dataset_name: str = Field(
|
||||
default="NousResearch/openthoughts-tblite",
|
||||
description="HuggingFace dataset containing TBLite tasks.",
|
||||
)
|
||||
|
||||
task_timeout: int = Field(
|
||||
default=1200,
|
||||
description="Maximum wall-clock seconds per task. TBLite tasks are "
|
||||
"generally faster than TB2, so 20 minutes is usually sufficient.",
|
||||
)
|
||||
|
||||
|
||||
class TBLiteEvalEnv(TerminalBench2EvalEnv):
|
||||
"""OpenThoughts-TBLite evaluation environment.
|
||||
|
||||
Inherits all evaluation logic from TerminalBench2EvalEnv (agent loop,
|
||||
test verification, Docker image resolution, metrics, wandb logging).
|
||||
Only the default configuration differs.
|
||||
"""
|
||||
|
||||
name = "openthoughts-tblite"
|
||||
env_config_cls = TBLiteEvalConfig
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[TBLiteEvalConfig, List[APIServerConfig]]:
|
||||
env_config = TBLiteEvalConfig(
|
||||
enabled_toolsets=["terminal", "file"],
|
||||
disabled_toolsets=None,
|
||||
distribution=None,
|
||||
|
||||
max_agent_turns=60,
|
||||
max_token_length=16000,
|
||||
agent_temperature=0.6,
|
||||
system_prompt=None,
|
||||
|
||||
terminal_backend="modal",
|
||||
terminal_timeout=300,
|
||||
|
||||
test_timeout=180,
|
||||
|
||||
# 100 tasks in parallel
|
||||
tool_pool_size=128,
|
||||
|
||||
eval_handling=EvalHandlingEnum.STOP_TRAIN,
|
||||
group_size=1,
|
||||
steps_per_eval=1,
|
||||
total_steps=1,
|
||||
|
||||
tokenizer_name="NousResearch/Hermes-3-Llama-3.1-8B",
|
||||
use_wandb=True,
|
||||
wandb_name="openthoughts-tblite",
|
||||
ensure_scores_are_not_same=False,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model_name="anthropic/claude-sonnet-4",
|
||||
server_type="openai",
|
||||
api_key=os.getenv("OPENROUTER_API_KEY", ""),
|
||||
health_check=False,
|
||||
)
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
TBLiteEvalEnv.cli()
|
||||
@@ -1,42 +0,0 @@
|
||||
# Terminal-Bench 2.0 Evaluation -- Default Configuration
|
||||
#
|
||||
# Eval-only environment for the TB2 benchmark (89 terminal tasks).
|
||||
# Uses Modal terminal backend for per-task cloud-isolated sandboxes
|
||||
# and OpenRouter for inference.
|
||||
#
|
||||
# Usage:
|
||||
# python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
||||
# --config environments/benchmarks/terminalbench_2/default.yaml
|
||||
#
|
||||
# # Override model:
|
||||
# python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
||||
# --config environments/benchmarks/terminalbench_2/default.yaml \
|
||||
# --openai.model_name anthropic/claude-sonnet-4
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal", "file"]
|
||||
max_agent_turns: 60
|
||||
max_token_length: 32000
|
||||
agent_temperature: 0.8
|
||||
terminal_backend: "modal"
|
||||
terminal_timeout: 300 # 5 min per command (builds, pip install)
|
||||
tool_pool_size: 128 # thread pool for 89 parallel tasks
|
||||
dataset_name: "NousResearch/terminal-bench-2"
|
||||
test_timeout: 600
|
||||
task_timeout: 1800 # 30 min wall-clock per task, auto-FAIL if exceeded
|
||||
tokenizer_name: "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
use_wandb: true
|
||||
wandb_name: "terminal-bench-2"
|
||||
ensure_scores_are_not_same: false
|
||||
data_dir_to_save_evals: "environments/benchmarks/evals/terminal-bench-2"
|
||||
# CRITICAL: Limit concurrent Modal sandbox creations to avoid deadlocks.
|
||||
# Modal's blocking calls (App.lookup, etc.) deadlock when too many sandboxes
|
||||
# are created simultaneously inside thread pool workers via asyncio.run().
|
||||
max_concurrent_tasks: 8
|
||||
|
||||
openai:
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
model_name: "anthropic/claude-opus-4.6"
|
||||
server_type: "openai"
|
||||
health_check: false
|
||||
# api_key loaded from OPENROUTER_API_KEY in .env
|
||||
@@ -1,42 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Terminal-Bench 2.0 Evaluation
|
||||
#
|
||||
# Run from repo root:
|
||||
# bash environments/benchmarks/terminalbench_2/run_eval.sh
|
||||
#
|
||||
# Override model:
|
||||
# bash environments/benchmarks/terminalbench_2/run_eval.sh \
|
||||
# --openai.model_name anthropic/claude-sonnet-4
|
||||
#
|
||||
# Run a subset:
|
||||
# bash environments/benchmarks/terminalbench_2/run_eval.sh \
|
||||
# --env.task_filter fix-git,git-multibranch
|
||||
#
|
||||
# All terminal settings (backend, timeout, lifetime, pool size) are
|
||||
# configured via env config fields -- no env vars needed.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
mkdir -p logs evals/terminal-bench-2
|
||||
LOG_FILE="logs/terminalbench2_$(date +%Y%m%d_%H%M%S).log"
|
||||
|
||||
echo "Terminal-Bench 2.0 Evaluation"
|
||||
echo "Log file: $LOG_FILE"
|
||||
echo ""
|
||||
|
||||
# Unbuffered python output so logs are written in real-time
|
||||
export PYTHONUNBUFFERED=1
|
||||
|
||||
# Show INFO-level agent loop timing (api/tool durations per turn)
|
||||
# These go to the log file; tqdm + [START]/[PASS]/[FAIL] go to terminal
|
||||
export LOGLEVEL=INFO
|
||||
|
||||
python terminalbench2_env.py evaluate \
|
||||
--config default.yaml \
|
||||
"$@" \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
echo ""
|
||||
echo "Log saved to: $LOG_FILE"
|
||||
echo "Eval results: evals/terminal-bench-2/"
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,115 +0,0 @@
|
||||
# YC-Bench: Long-Horizon Agent Benchmark
|
||||
|
||||
[YC-Bench](https://github.com/collinear-ai/yc-bench) by [Collinear AI](https://collinear.ai/) is a deterministic, long-horizon benchmark that tests LLM agents' ability to act as a tech startup CEO. The agent manages a simulated company over 1-3 years, making compounding decisions about resource allocation, cash flow, task management, and prestige specialisation across 4 skill domains.
|
||||
|
||||
Unlike TerminalBench2 (which evaluates per-task coding ability with binary pass/fail), YC-Bench measures **long-term strategic coherence** — whether an agent can maintain consistent strategy, manage compounding consequences, and adapt plans over hundreds of turns.
|
||||
|
||||
## Setup
|
||||
|
||||
```bash
|
||||
# Install yc-bench (optional dependency)
|
||||
pip install "hermes-agent[yc-bench]"
|
||||
|
||||
# Or install from source
|
||||
git clone https://github.com/collinear-ai/yc-bench
|
||||
cd yc-bench && pip install -e .
|
||||
|
||||
# Verify
|
||||
yc-bench --help
|
||||
```
|
||||
|
||||
## Running
|
||||
|
||||
```bash
|
||||
# From the repo root:
|
||||
bash environments/benchmarks/yc_bench/run_eval.sh
|
||||
|
||||
# Or directly:
|
||||
python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
||||
--config environments/benchmarks/yc_bench/default.yaml
|
||||
|
||||
# Override model:
|
||||
bash environments/benchmarks/yc_bench/run_eval.sh \
|
||||
--openai.model_name anthropic/claude-opus-4-20250514
|
||||
|
||||
# Quick single-preset test:
|
||||
bash environments/benchmarks/yc_bench/run_eval.sh \
|
||||
--env.presets '["fast_test"]' --env.seeds '[1]'
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
### Architecture
|
||||
|
||||
```
|
||||
HermesAgentLoop (our agent)
|
||||
-> terminal tool -> subprocess("yc-bench company status") -> JSON output
|
||||
-> terminal tool -> subprocess("yc-bench task accept --task-id X") -> JSON
|
||||
-> terminal tool -> subprocess("yc-bench sim resume") -> JSON (advance time)
|
||||
-> ... (100-500 turns per run)
|
||||
```
|
||||
|
||||
The environment initialises the simulation via `yc-bench sim init` (NOT `yc-bench run`, which would start yc-bench's own built-in agent loop). Our `HermesAgentLoop` then drives all interaction through CLI commands.
|
||||
|
||||
### Simulation Mechanics
|
||||
|
||||
- **4 skill domains**: research, inference, data_environment, training
|
||||
- **Prestige system** (1.0-10.0): Gates access to higher-paying tasks
|
||||
- **Employee management**: Junior/Mid/Senior with domain-specific skill rates
|
||||
- **Throughput splitting**: `effective_rate = base_rate / N` active tasks per employee
|
||||
- **Financial pressure**: Monthly payroll, bankruptcy = game over
|
||||
- **Deterministic**: SHA256-based RNG — same seed + preset = same world
|
||||
|
||||
### Difficulty Presets
|
||||
|
||||
| Preset | Employees | Tasks | Focus |
|
||||
|-----------|-----------|-------|-------|
|
||||
| tutorial | 3 | 50 | Basic loop mechanics |
|
||||
| easy | 5 | 100 | Throughput awareness |
|
||||
| **medium**| 5 | 150 | Prestige climbing + domain specialisation |
|
||||
| **hard** | 7 | 200 | Precise ETA reasoning |
|
||||
| nightmare | 8 | 300 | Sustained perfection under payroll pressure |
|
||||
| fast_test | (varies) | (varies) | Quick validation (~50 turns) |
|
||||
|
||||
Default eval runs **fast_test + medium + hard** × 3 seeds = 9 runs.
|
||||
|
||||
### Scoring
|
||||
|
||||
```
|
||||
composite = 0.5 × survival + 0.5 × normalised_funds
|
||||
```
|
||||
|
||||
- **Survival** (binary): Did the company avoid bankruptcy?
|
||||
- **Normalised funds** (0.0-1.0): Log-scale relative to initial $250K capital
|
||||
|
||||
## Configuration
|
||||
|
||||
Key fields in `default.yaml`:
|
||||
|
||||
| Field | Default | Description |
|
||||
|-------|---------|-------------|
|
||||
| `presets` | `["fast_test", "medium", "hard"]` | Which presets to evaluate |
|
||||
| `seeds` | `[1, 2, 3]` | RNG seeds per preset |
|
||||
| `max_agent_turns` | 200 | Max LLM calls per run |
|
||||
| `run_timeout` | 3600 | Wall-clock timeout per run (seconds) |
|
||||
| `survival_weight` | 0.5 | Weight of survival in composite score |
|
||||
| `funds_weight` | 0.5 | Weight of normalised funds in composite |
|
||||
| `horizon_years` | null | Override horizon (null = auto from preset) |
|
||||
|
||||
## Cost & Time Estimates
|
||||
|
||||
Each run is 100-500 LLM turns. Approximate costs per run at typical API rates:
|
||||
|
||||
| Preset | Turns | Time | Est. Cost |
|
||||
|--------|-------|------|-----------|
|
||||
| fast_test | ~50 | 5-10 min | $1-5 |
|
||||
| medium | ~200 | 20-40 min | $5-15 |
|
||||
| hard | ~300 | 30-60 min | $10-25 |
|
||||
|
||||
Full default eval (9 runs): ~3-6 hours, $50-200 depending on model.
|
||||
|
||||
## References
|
||||
|
||||
- [collinear-ai/yc-bench](https://github.com/collinear-ai/yc-bench) — Official repository
|
||||
- [Collinear AI](https://collinear.ai/) — Company behind yc-bench
|
||||
- [TerminalBench2](../terminalbench_2/) — Per-task coding benchmark (complementary)
|
||||
@@ -1,43 +0,0 @@
|
||||
# YC-Bench Evaluation -- Default Configuration
|
||||
#
|
||||
# Long-horizon agent benchmark: agent plays CEO of an AI startup over
|
||||
# a simulated 1-3 year run, interacting via yc-bench CLI subcommands.
|
||||
#
|
||||
# Requires: pip install "hermes-agent[yc-bench]"
|
||||
#
|
||||
# Usage:
|
||||
# python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
||||
# --config environments/benchmarks/yc_bench/default.yaml
|
||||
#
|
||||
# # Override model:
|
||||
# python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
||||
# --config environments/benchmarks/yc_bench/default.yaml \
|
||||
# --openai.model_name anthropic/claude-opus-4-20250514
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal"]
|
||||
max_agent_turns: 200
|
||||
max_token_length: 32000
|
||||
agent_temperature: 0.0
|
||||
terminal_backend: "local"
|
||||
terminal_timeout: 60
|
||||
presets: ["fast_test", "medium", "hard"]
|
||||
seeds: [1, 2, 3]
|
||||
run_timeout: 3600 # 60 min wall-clock per run, auto-FAIL if exceeded
|
||||
survival_weight: 0.5 # weight of binary survival in composite score
|
||||
funds_weight: 0.5 # weight of normalised final funds in composite score
|
||||
db_dir: "/tmp/yc_bench_dbs"
|
||||
company_name: "BenchCo"
|
||||
start_date: "01/01/2025" # MM/DD/YYYY (yc-bench convention)
|
||||
tokenizer_name: "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
use_wandb: true
|
||||
wandb_name: "yc-bench"
|
||||
ensure_scores_are_not_same: false
|
||||
data_dir_to_save_evals: "environments/benchmarks/evals/yc-bench"
|
||||
|
||||
openai:
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
model_name: "anthropic/claude-sonnet-4.6"
|
||||
server_type: "openai"
|
||||
health_check: false
|
||||
# api_key loaded from OPENROUTER_API_KEY in .env
|
||||
@@ -1,34 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# YC-Bench Evaluation
|
||||
#
|
||||
# Requires: pip install "hermes-agent[yc-bench]"
|
||||
#
|
||||
# Run from repo root:
|
||||
# bash environments/benchmarks/yc_bench/run_eval.sh
|
||||
#
|
||||
# Override model:
|
||||
# bash environments/benchmarks/yc_bench/run_eval.sh \
|
||||
# --openai.model_name anthropic/claude-opus-4-20250514
|
||||
#
|
||||
# Run a single preset:
|
||||
# bash environments/benchmarks/yc_bench/run_eval.sh \
|
||||
# --env.presets '["fast_test"]' --env.seeds '[1]'
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
mkdir -p logs evals/yc-bench
|
||||
LOG_FILE="logs/yc_bench_$(date +%Y%m%d_%H%M%S).log"
|
||||
|
||||
echo "YC-Bench Evaluation"
|
||||
echo "Log: $LOG_FILE"
|
||||
echo ""
|
||||
|
||||
PYTHONUNBUFFERED=1 LOGLEVEL="${LOGLEVEL:-INFO}" \
|
||||
python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
||||
--config environments/benchmarks/yc_bench/default.yaml \
|
||||
"$@" \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
echo ""
|
||||
echo "Log saved to: $LOG_FILE"
|
||||
@@ -1,848 +0,0 @@
|
||||
"""
|
||||
YCBenchEvalEnv -- YC-Bench Long-Horizon Agent Benchmark Environment
|
||||
|
||||
Evaluates agentic LLMs on YC-Bench: a deterministic, long-horizon benchmark
|
||||
where the agent acts as CEO of an AI startup over a simulated 1-3 year run.
|
||||
The agent manages cash flow, employees, tasks, and prestige across 4 domains,
|
||||
interacting exclusively via CLI subprocess calls against a SQLite-backed
|
||||
discrete-event simulation.
|
||||
|
||||
Unlike TerminalBench2 (per-task binary pass/fail), YC-Bench measures sustained
|
||||
multi-turn strategic coherence -- whether an agent can manage compounding
|
||||
decisions over hundreds of turns without going bankrupt.
|
||||
|
||||
This is an eval-only environment. Run via:
|
||||
|
||||
python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
||||
--config environments/benchmarks/yc_bench/default.yaml
|
||||
|
||||
The evaluate flow:
|
||||
1. setup() -- Verifies yc-bench installed, builds eval matrix (preset x seed)
|
||||
2. evaluate() -- Iterates over all runs sequentially through:
|
||||
a. rollout_and_score_eval() -- Per-run agent loop
|
||||
- Initialises a fresh yc-bench simulation via `sim init` (NOT `run`)
|
||||
- Runs HermesAgentLoop with terminal tool only
|
||||
- Reads final SQLite DB to extract score
|
||||
- Returns survival (0/1) + normalised funds score
|
||||
b. Aggregates per-preset and overall metrics
|
||||
c. Logs results via evaluate_log() and wandb
|
||||
|
||||
Key features:
|
||||
- CLI-only interface: agent calls yc-bench subcommands via terminal tool
|
||||
- Deterministic: same seed + preset = same world (SHA256-based RNG)
|
||||
- Multi-dimensional scoring: survival + normalised final funds
|
||||
- Per-preset difficulty breakdown in results
|
||||
- Isolated SQLite DB per run (no cross-run state leakage)
|
||||
|
||||
Requires: pip install hermes-agent[yc-bench]
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sqlite3
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
_repo_root = Path(__file__).resolve().parent.parent.parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import EvalHandlingEnum
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
|
||||
from environments.agent_loop import HermesAgentLoop
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# =============================================================================
|
||||
# System prompt
|
||||
# =============================================================================
|
||||
|
||||
YC_BENCH_SYSTEM_PROMPT = """\
|
||||
You are the autonomous CEO of an early-stage AI startup in a deterministic
|
||||
business simulation. You manage the company exclusively through the `yc-bench`
|
||||
CLI tool. Your primary goal is to **survive** until the simulation horizon ends
|
||||
without going bankrupt, while **maximising final funds**.
|
||||
|
||||
## Simulation Mechanics
|
||||
|
||||
- **Funds**: You start with $250,000 seed capital. Revenue comes from completing
|
||||
tasks. Rewards scale with your prestige: `base × (1 + scale × (prestige − 1))`.
|
||||
- **Domains**: There are 4 skill domains: **research**, **inference**,
|
||||
**data_environment**, and **training**. Each has its own prestige level
|
||||
(1.0-10.0). Higher prestige unlocks better-paying tasks.
|
||||
- **Employees**: You have employees (Junior/Mid/Senior) with domain-specific
|
||||
skill rates. **Throughput splits**: `effective_rate = base_rate / N` where N
|
||||
is the number of active tasks assigned to that employee. Focus beats breadth.
|
||||
- **Payroll**: Deducted automatically on the first business day of each month.
|
||||
Running out of funds = bankruptcy = game over.
|
||||
- **Time**: The simulation runs on business days (Mon-Fri), 09:00-18:00.
|
||||
Time only advances when you call `yc-bench sim resume`.
|
||||
|
||||
## Task Lifecycle
|
||||
|
||||
1. Browse market tasks with `market browse`
|
||||
2. Accept a task with `task accept` (this sets its deadline)
|
||||
3. Assign employees with `task assign`
|
||||
4. Dispatch with `task dispatch` to start work
|
||||
5. Call `sim resume` to advance time and let employees make progress
|
||||
6. Tasks complete when all domain requirements are fulfilled
|
||||
|
||||
**Penalties for failure vary by difficulty preset.** Completing a task on time
|
||||
earns full reward + prestige gain. Missing a deadline or cancelling a task
|
||||
incurs prestige penalties -- cancelling is always more costly than letting a
|
||||
task fail, so cancel only as a last resort.
|
||||
|
||||
## CLI Commands
|
||||
|
||||
### Observe
|
||||
- `yc-bench company status` -- funds, prestige, runway
|
||||
- `yc-bench employee list` -- skills, salary, active tasks
|
||||
- `yc-bench market browse [--domain D] [--required-prestige-lte N]` -- available tasks
|
||||
- `yc-bench task list [--status active|planned]` -- your tasks
|
||||
- `yc-bench task inspect --task-id UUID` -- progress, deadline, assignments
|
||||
- `yc-bench finance ledger [--category monthly_payroll|task_reward]` -- transaction history
|
||||
- `yc-bench report monthly` -- monthly P&L
|
||||
|
||||
### Act
|
||||
- `yc-bench task accept --task-id UUID` -- accept from market
|
||||
- `yc-bench task assign --task-id UUID --employee-id UUID` -- assign employee
|
||||
- `yc-bench task dispatch --task-id UUID` -- start work (needs >=1 assignment)
|
||||
- `yc-bench task cancel --task-id UUID --reason "text"` -- cancel (prestige penalty)
|
||||
- `yc-bench sim resume` -- advance simulation clock
|
||||
|
||||
### Memory (persists across context truncation)
|
||||
- `yc-bench scratchpad read` -- read your persistent notes
|
||||
- `yc-bench scratchpad write --content "text"` -- overwrite notes
|
||||
- `yc-bench scratchpad append --content "text"` -- append to notes
|
||||
- `yc-bench scratchpad clear` -- clear notes
|
||||
|
||||
## Strategy Guidelines
|
||||
|
||||
1. **Specialise in 2-3 domains** to climb the prestige ladder faster and unlock
|
||||
high-reward tasks. Don't spread thin across all 4 domains early on.
|
||||
2. **Focus employees** -- assigning one employee to many tasks halves their
|
||||
throughput per additional task. Keep assignments concentrated.
|
||||
3. **Use the scratchpad** to track your strategy, upcoming deadlines, and
|
||||
employee assignments. This persists even if conversation context is truncated.
|
||||
4. **Monitor runway** -- always know how many months of payroll you can cover.
|
||||
Accept high-reward tasks before payroll dates.
|
||||
5. **Don't over-accept** -- taking too many tasks and missing deadlines cascades
|
||||
into prestige loss, locking you out of profitable contracts.
|
||||
6. Use `finance ledger` and `report monthly` to track revenue trends.
|
||||
|
||||
## Your Turn
|
||||
|
||||
Each turn:
|
||||
1. Call `yc-bench company status` and `yc-bench task list` to orient yourself.
|
||||
2. Check for completed tasks and pending deadlines.
|
||||
3. Browse market for profitable tasks within your prestige level.
|
||||
4. Accept, assign, and dispatch tasks strategically.
|
||||
5. Call `yc-bench sim resume` to advance time.
|
||||
6. Repeat until the simulation ends.
|
||||
|
||||
Think step by step before acting."""
|
||||
|
||||
# Starting funds in cents ($250,000)
|
||||
INITIAL_FUNDS_CENTS = 25_000_000
|
||||
|
||||
# Default horizon per preset (years)
|
||||
_PRESET_HORIZONS = {
|
||||
"tutorial": 1,
|
||||
"easy": 1,
|
||||
"medium": 1,
|
||||
"hard": 1,
|
||||
"nightmare": 1,
|
||||
"fast_test": 1,
|
||||
"default": 3,
|
||||
"high_reward": 1,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Configuration
|
||||
# =============================================================================
|
||||
|
||||
class YCBenchEvalConfig(HermesAgentEnvConfig):
|
||||
"""
|
||||
Configuration for the YC-Bench evaluation environment.
|
||||
|
||||
Extends HermesAgentEnvConfig with YC-Bench-specific settings for
|
||||
preset selection, seed control, scoring, and simulation parameters.
|
||||
"""
|
||||
|
||||
presets: List[str] = Field(
|
||||
default=["fast_test", "medium", "hard"],
|
||||
description="YC-Bench preset names to evaluate.",
|
||||
)
|
||||
seeds: List[int] = Field(
|
||||
default=[1, 2, 3],
|
||||
description="Random seeds -- each preset x seed = one run.",
|
||||
)
|
||||
run_timeout: int = Field(
|
||||
default=3600,
|
||||
description="Maximum wall-clock seconds per run. Default 60 minutes.",
|
||||
)
|
||||
survival_weight: float = Field(
|
||||
default=0.5,
|
||||
description="Weight of survival (0/1) in composite score.",
|
||||
)
|
||||
funds_weight: float = Field(
|
||||
default=0.5,
|
||||
description="Weight of normalised final funds in composite score.",
|
||||
)
|
||||
db_dir: str = Field(
|
||||
default="/tmp/yc_bench_dbs",
|
||||
description="Directory for per-run SQLite databases.",
|
||||
)
|
||||
horizon_years: Optional[int] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Simulation horizon in years. If None (default), inferred from "
|
||||
"preset name (1 year for most, 3 for 'default')."
|
||||
),
|
||||
)
|
||||
company_name: str = Field(
|
||||
default="BenchCo",
|
||||
description="Name of the simulated company.",
|
||||
)
|
||||
start_date: str = Field(
|
||||
default="01/01/2025",
|
||||
description="Simulation start date in MM/DD/YYYY format (yc-bench convention).",
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Scoring helpers
|
||||
# =============================================================================
|
||||
|
||||
def _read_final_score(db_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Read final game state from a YC-Bench SQLite database.
|
||||
|
||||
Returns dict with final_funds_cents (int), survived (bool),
|
||||
terminal_reason (str).
|
||||
|
||||
Note: yc-bench table names are plural -- 'companies' not 'company',
|
||||
'sim_events' not 'simulation_log'.
|
||||
"""
|
||||
if not os.path.exists(db_path):
|
||||
logger.warning("DB not found at %s", db_path)
|
||||
return {
|
||||
"final_funds_cents": 0,
|
||||
"survived": False,
|
||||
"terminal_reason": "db_missing",
|
||||
}
|
||||
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cur = conn.cursor()
|
||||
|
||||
# Read final funds from the 'companies' table
|
||||
cur.execute("SELECT funds_cents FROM companies LIMIT 1")
|
||||
row = cur.fetchone()
|
||||
funds = row[0] if row else 0
|
||||
|
||||
# Determine terminal reason from 'sim_events' table
|
||||
terminal_reason = "unknown"
|
||||
try:
|
||||
cur.execute(
|
||||
"SELECT event_type FROM sim_events "
|
||||
"WHERE event_type IN ('bankruptcy', 'horizon_end') "
|
||||
"ORDER BY scheduled_at DESC LIMIT 1"
|
||||
)
|
||||
event_row = cur.fetchone()
|
||||
if event_row:
|
||||
terminal_reason = event_row[0]
|
||||
except sqlite3.OperationalError:
|
||||
# Table may not exist if simulation didn't progress
|
||||
pass
|
||||
|
||||
survived = funds >= 0 and terminal_reason != "bankruptcy"
|
||||
return {
|
||||
"final_funds_cents": funds,
|
||||
"survived": survived,
|
||||
"terminal_reason": terminal_reason,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to read DB %s: %s", db_path, e)
|
||||
return {
|
||||
"final_funds_cents": 0,
|
||||
"survived": False,
|
||||
"terminal_reason": f"db_error: {e}",
|
||||
}
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
|
||||
def _compute_composite_score(
|
||||
final_funds_cents: int,
|
||||
survived: bool,
|
||||
survival_weight: float = 0.5,
|
||||
funds_weight: float = 0.5,
|
||||
initial_funds_cents: int = INITIAL_FUNDS_CENTS,
|
||||
) -> float:
|
||||
"""
|
||||
Compute composite score from survival and final funds.
|
||||
|
||||
Score = survival_weight * survival_score
|
||||
+ funds_weight * normalised_funds_score
|
||||
|
||||
Normalised funds uses log-scale relative to initial capital:
|
||||
- funds <= 0: 0.0
|
||||
- funds == initial: ~0.15
|
||||
- funds == 10x: ~0.52
|
||||
- funds == 100x: 1.0
|
||||
"""
|
||||
survival_score = 1.0 if survived else 0.0
|
||||
|
||||
if final_funds_cents <= 0:
|
||||
funds_score = 0.0
|
||||
else:
|
||||
max_ratio = 100.0
|
||||
ratio = final_funds_cents / max(initial_funds_cents, 1)
|
||||
funds_score = min(math.log1p(ratio) / math.log1p(max_ratio), 1.0)
|
||||
|
||||
return survival_weight * survival_score + funds_weight * funds_score
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Main Environment
|
||||
# =============================================================================
|
||||
|
||||
class YCBenchEvalEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
YC-Bench long-horizon agent benchmark environment (eval-only).
|
||||
|
||||
Each eval item is a (preset, seed) pair. The environment initialises the
|
||||
simulation via ``yc-bench sim init`` (NOT ``yc-bench run`` which would start
|
||||
a competing built-in agent loop). The HermesAgentLoop then drives the
|
||||
interaction by calling individual yc-bench CLI commands via the terminal tool.
|
||||
|
||||
After the agent loop ends, the SQLite DB is read to extract the final score.
|
||||
|
||||
Scoring:
|
||||
composite = 0.5 * survival + 0.5 * normalised_funds
|
||||
"""
|
||||
|
||||
name = "yc-bench"
|
||||
env_config_cls = YCBenchEvalConfig
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[YCBenchEvalConfig, List[APIServerConfig]]:
|
||||
env_config = YCBenchEvalConfig(
|
||||
enabled_toolsets=["terminal"],
|
||||
disabled_toolsets=None,
|
||||
distribution=None,
|
||||
max_agent_turns=200,
|
||||
max_token_length=32000,
|
||||
agent_temperature=0.0,
|
||||
system_prompt=YC_BENCH_SYSTEM_PROMPT,
|
||||
terminal_backend="local",
|
||||
terminal_timeout=60,
|
||||
presets=["fast_test", "medium", "hard"],
|
||||
seeds=[1, 2, 3],
|
||||
run_timeout=3600,
|
||||
survival_weight=0.5,
|
||||
funds_weight=0.5,
|
||||
db_dir="/tmp/yc_bench_dbs",
|
||||
eval_handling=EvalHandlingEnum.STOP_TRAIN,
|
||||
group_size=1,
|
||||
steps_per_eval=1,
|
||||
total_steps=1,
|
||||
tokenizer_name="NousResearch/Hermes-3-Llama-3.1-8B",
|
||||
use_wandb=True,
|
||||
wandb_name="yc-bench",
|
||||
ensure_scores_are_not_same=False,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model_name="anthropic/claude-sonnet-4.6",
|
||||
server_type="openai",
|
||||
api_key=os.getenv("OPENROUTER_API_KEY", ""),
|
||||
health_check=False,
|
||||
)
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
# =========================================================================
|
||||
# Setup
|
||||
# =========================================================================
|
||||
|
||||
async def setup(self):
|
||||
"""Verify yc-bench is installed and build the eval matrix."""
|
||||
# Verify yc-bench CLI is available
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["yc-bench", "--help"], capture_output=True, text=True, timeout=10
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise FileNotFoundError
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||
raise RuntimeError(
|
||||
"yc-bench CLI not found. Install with:\n"
|
||||
' pip install "hermes-agent[yc-bench]"\n'
|
||||
"Or: git clone https://github.com/collinear-ai/yc-bench "
|
||||
"&& cd yc-bench && pip install -e ."
|
||||
)
|
||||
print("yc-bench CLI verified.")
|
||||
|
||||
# Build eval matrix: preset x seed
|
||||
self.all_eval_items = [
|
||||
{"preset": preset, "seed": seed}
|
||||
for preset in self.config.presets
|
||||
for seed in self.config.seeds
|
||||
]
|
||||
self.iter = 0
|
||||
|
||||
os.makedirs(self.config.db_dir, exist_ok=True)
|
||||
self.eval_metrics: List[Tuple[str, float]] = []
|
||||
|
||||
# Streaming JSONL log for crash-safe result persistence
|
||||
log_dir = os.path.join(os.path.dirname(__file__), "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
run_ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
self._streaming_path = os.path.join(log_dir, f"samples_{run_ts}.jsonl")
|
||||
self._streaming_file = open(self._streaming_path, "w", encoding="utf-8")
|
||||
self._streaming_lock = threading.Lock()
|
||||
|
||||
print(f"\nYC-Bench eval matrix: {len(self.all_eval_items)} runs")
|
||||
for item in self.all_eval_items:
|
||||
print(f" preset={item['preset']!r} seed={item['seed']}")
|
||||
print(f"Streaming results to: {self._streaming_path}\n")
|
||||
|
||||
def _save_result(self, result: Dict[str, Any]):
|
||||
"""Write a single run result to the streaming JSONL file immediately."""
|
||||
if not hasattr(self, "_streaming_file") or self._streaming_file.closed:
|
||||
return
|
||||
with self._streaming_lock:
|
||||
self._streaming_file.write(
|
||||
json.dumps(result, ensure_ascii=False, default=str) + "\n"
|
||||
)
|
||||
self._streaming_file.flush()
|
||||
|
||||
# =========================================================================
|
||||
# Training pipeline stubs (eval-only -- not used)
|
||||
# =========================================================================
|
||||
|
||||
async def get_next_item(self):
|
||||
item = self.all_eval_items[self.iter % len(self.all_eval_items)]
|
||||
self.iter += 1
|
||||
return item
|
||||
|
||||
def format_prompt(self, item: Dict[str, Any]) -> str:
|
||||
preset = item["preset"]
|
||||
seed = item["seed"]
|
||||
return (
|
||||
f"A new YC-Bench simulation has been initialized "
|
||||
f"(preset='{preset}', seed={seed}).\n"
|
||||
f"Your company '{self.config.company_name}' is ready.\n\n"
|
||||
"Begin by calling:\n"
|
||||
"1. `yc-bench company status` -- see your starting funds and prestige\n"
|
||||
"2. `yc-bench employee list` -- see your team and their skills\n"
|
||||
"3. `yc-bench market browse --required-prestige-lte 1` -- find tasks "
|
||||
"you can take\n\n"
|
||||
"Then accept 2-3 tasks, assign employees, dispatch them, and call "
|
||||
"`yc-bench sim resume` to advance time. Repeat this loop until the "
|
||||
"simulation ends (horizon reached or bankruptcy)."
|
||||
)
|
||||
|
||||
async def compute_reward(self, item, result, ctx) -> float:
|
||||
return 0.0
|
||||
|
||||
async def collect_trajectories(self, item):
|
||||
return None, []
|
||||
|
||||
async def score(self, rollout_group_data):
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# Per-run evaluation
|
||||
# =========================================================================
|
||||
|
||||
async def rollout_and_score_eval(self, eval_item: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Evaluate a single (preset, seed) run.
|
||||
|
||||
1. Sets DATABASE_URL and YC_BENCH_EXPERIMENT env vars
|
||||
2. Initialises the simulation via ``yc-bench sim init`` (NOT ``run``)
|
||||
3. Runs HermesAgentLoop with terminal tool
|
||||
4. Reads SQLite DB to compute final score
|
||||
5. Returns result dict with survival, funds, and composite score
|
||||
"""
|
||||
preset = eval_item["preset"]
|
||||
seed = eval_item["seed"]
|
||||
run_id = str(uuid.uuid4())[:8]
|
||||
run_key = f"{preset}_seed{seed}_{run_id}"
|
||||
|
||||
from tqdm import tqdm
|
||||
tqdm.write(f" [START] preset={preset!r} seed={seed} (run_id={run_id})")
|
||||
run_start = time.time()
|
||||
|
||||
# Isolated DB per run -- prevents cross-run state leakage
|
||||
db_path = os.path.join(self.config.db_dir, f"yc_bench_{run_key}.db")
|
||||
os.environ["DATABASE_URL"] = f"sqlite:///{db_path}"
|
||||
os.environ["YC_BENCH_EXPERIMENT"] = preset
|
||||
|
||||
# Determine horizon: explicit config override > preset lookup > default 1
|
||||
horizon = self.config.horizon_years or _PRESET_HORIZONS.get(preset, 1)
|
||||
|
||||
try:
|
||||
# ----------------------------------------------------------
|
||||
# Step 1: Initialise the simulation via CLI
|
||||
# IMPORTANT: We use `sim init`, NOT `yc-bench run`.
|
||||
# `yc-bench run` starts yc-bench's own LLM agent loop (via
|
||||
# LiteLLM), which would compete with our HermesAgentLoop.
|
||||
# `sim init` just sets up the world and returns.
|
||||
# ----------------------------------------------------------
|
||||
init_cmd = [
|
||||
"yc-bench", "sim", "init",
|
||||
"--seed", str(seed),
|
||||
"--start-date", self.config.start_date,
|
||||
"--company-name", self.config.company_name,
|
||||
"--horizon-years", str(horizon),
|
||||
]
|
||||
init_result = subprocess.run(
|
||||
init_cmd, capture_output=True, text=True, timeout=30,
|
||||
)
|
||||
if init_result.returncode != 0:
|
||||
error_msg = (init_result.stderr or init_result.stdout).strip()
|
||||
raise RuntimeError(f"yc-bench sim init failed: {error_msg}")
|
||||
|
||||
tqdm.write(f" Simulation initialized (horizon={horizon}yr)")
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# Step 2: Run the HermesAgentLoop
|
||||
# ----------------------------------------------------------
|
||||
tools, valid_names = self._resolve_tools_for_group()
|
||||
|
||||
messages: List[Dict[str, Any]] = [
|
||||
{"role": "system", "content": YC_BENCH_SYSTEM_PROMPT},
|
||||
{"role": "user", "content": self.format_prompt(eval_item)},
|
||||
]
|
||||
|
||||
agent = HermesAgentLoop(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=run_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
budget_config=self.config.build_budget_config(),
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# Step 3: Read final score from the simulation DB
|
||||
# ----------------------------------------------------------
|
||||
score_data = _read_final_score(db_path)
|
||||
final_funds = score_data["final_funds_cents"]
|
||||
survived = score_data["survived"]
|
||||
terminal_reason = score_data["terminal_reason"]
|
||||
|
||||
composite = _compute_composite_score(
|
||||
final_funds_cents=final_funds,
|
||||
survived=survived,
|
||||
survival_weight=self.config.survival_weight,
|
||||
funds_weight=self.config.funds_weight,
|
||||
)
|
||||
|
||||
elapsed = time.time() - run_start
|
||||
status = "SURVIVED" if survived else "BANKRUPT"
|
||||
if final_funds >= 0:
|
||||
funds_str = f"${final_funds / 100:,.0f}"
|
||||
else:
|
||||
funds_str = f"-${abs(final_funds) / 100:,.0f}"
|
||||
|
||||
tqdm.write(
|
||||
f" [{status}] preset={preset!r} seed={seed} "
|
||||
f"funds={funds_str} score={composite:.3f} "
|
||||
f"turns={result.turns_used} ({elapsed:.0f}s)"
|
||||
)
|
||||
|
||||
out = {
|
||||
"preset": preset,
|
||||
"seed": seed,
|
||||
"survived": survived,
|
||||
"final_funds_cents": final_funds,
|
||||
"final_funds_usd": final_funds / 100,
|
||||
"terminal_reason": terminal_reason,
|
||||
"composite_score": composite,
|
||||
"turns_used": result.turns_used,
|
||||
"finished_naturally": result.finished_naturally,
|
||||
"elapsed_seconds": elapsed,
|
||||
"db_path": db_path,
|
||||
"messages": result.messages,
|
||||
}
|
||||
self._save_result(out)
|
||||
return out
|
||||
|
||||
except Exception as e:
|
||||
elapsed = time.time() - run_start
|
||||
logger.error("Run %s failed: %s", run_key, e, exc_info=True)
|
||||
tqdm.write(
|
||||
f" [ERROR] preset={preset!r} seed={seed}: {e} ({elapsed:.0f}s)"
|
||||
)
|
||||
out = {
|
||||
"preset": preset,
|
||||
"seed": seed,
|
||||
"survived": False,
|
||||
"final_funds_cents": 0,
|
||||
"final_funds_usd": 0.0,
|
||||
"terminal_reason": f"error: {e}",
|
||||
"composite_score": 0.0,
|
||||
"turns_used": 0,
|
||||
"error": str(e),
|
||||
"elapsed_seconds": elapsed,
|
||||
}
|
||||
self._save_result(out)
|
||||
return out
|
||||
|
||||
# =========================================================================
|
||||
# Evaluate
|
||||
# =========================================================================
|
||||
|
||||
async def _run_with_timeout(self, item: Dict[str, Any]) -> Dict:
|
||||
"""Wrap a single rollout with a wall-clock timeout."""
|
||||
preset = item["preset"]
|
||||
seed = item["seed"]
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
self.rollout_and_score_eval(item),
|
||||
timeout=self.config.run_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
from tqdm import tqdm
|
||||
tqdm.write(
|
||||
f" [TIMEOUT] preset={preset!r} seed={seed} "
|
||||
f"(exceeded {self.config.run_timeout}s)"
|
||||
)
|
||||
out = {
|
||||
"preset": preset,
|
||||
"seed": seed,
|
||||
"survived": False,
|
||||
"final_funds_cents": 0,
|
||||
"final_funds_usd": 0.0,
|
||||
"terminal_reason": f"timeout ({self.config.run_timeout}s)",
|
||||
"composite_score": 0.0,
|
||||
"turns_used": 0,
|
||||
"error": "timeout",
|
||||
}
|
||||
self._save_result(out)
|
||||
return out
|
||||
|
||||
async def evaluate(self, *args, **kwargs) -> None:
|
||||
"""
|
||||
Run YC-Bench evaluation over all (preset, seed) combinations.
|
||||
|
||||
Runs sequentially -- each run is 100-500 turns, parallelising would
|
||||
be prohibitively expensive and cause env var conflicts.
|
||||
"""
|
||||
start_time = time.time()
|
||||
from tqdm import tqdm
|
||||
|
||||
# --- tqdm-compatible logging handler (TB2 pattern) ---
|
||||
class _TqdmHandler(logging.Handler):
|
||||
def emit(self, record):
|
||||
try:
|
||||
tqdm.write(self.format(record))
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
root = logging.getLogger()
|
||||
handler = _TqdmHandler()
|
||||
handler.setFormatter(
|
||||
logging.Formatter("%(levelname)s %(name)s: %(message)s")
|
||||
)
|
||||
root.handlers = [handler]
|
||||
for noisy in ("httpx", "openai"):
|
||||
logging.getLogger(noisy).setLevel(logging.WARNING)
|
||||
|
||||
# --- Print config summary ---
|
||||
print(f"\n{'='*60}")
|
||||
print("Starting YC-Bench Evaluation")
|
||||
print(f"{'='*60}")
|
||||
print(f" Presets: {self.config.presets}")
|
||||
print(f" Seeds: {self.config.seeds}")
|
||||
print(f" Total runs: {len(self.all_eval_items)}")
|
||||
print(f" Max turns/run: {self.config.max_agent_turns}")
|
||||
print(f" Run timeout: {self.config.run_timeout}s")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
results = []
|
||||
pbar = tqdm(
|
||||
total=len(self.all_eval_items), desc="YC-Bench", dynamic_ncols=True
|
||||
)
|
||||
|
||||
try:
|
||||
for item in self.all_eval_items:
|
||||
result = await self._run_with_timeout(item)
|
||||
results.append(result)
|
||||
survived_count = sum(1 for r in results if r.get("survived"))
|
||||
pbar.set_postfix_str(
|
||||
f"survived={survived_count}/{len(results)}"
|
||||
)
|
||||
pbar.update(1)
|
||||
|
||||
except (KeyboardInterrupt, asyncio.CancelledError):
|
||||
tqdm.write("\n[INTERRUPTED] Stopping evaluation...")
|
||||
pbar.close()
|
||||
try:
|
||||
from tools.terminal_tool import cleanup_all_environments
|
||||
cleanup_all_environments()
|
||||
except Exception:
|
||||
pass
|
||||
if hasattr(self, "_streaming_file") and not self._streaming_file.closed:
|
||||
self._streaming_file.close()
|
||||
return
|
||||
|
||||
pbar.close()
|
||||
end_time = time.time()
|
||||
|
||||
# --- Compute metrics ---
|
||||
valid = [r for r in results if r is not None]
|
||||
if not valid:
|
||||
print("Warning: No valid results.")
|
||||
return
|
||||
|
||||
total = len(valid)
|
||||
survived_total = sum(1 for r in valid if r.get("survived"))
|
||||
survival_rate = survived_total / total if total else 0.0
|
||||
avg_score = (
|
||||
sum(r.get("composite_score", 0) for r in valid) / total
|
||||
if total
|
||||
else 0.0
|
||||
)
|
||||
|
||||
preset_results: Dict[str, List[Dict]] = defaultdict(list)
|
||||
for r in valid:
|
||||
preset_results[r["preset"]].append(r)
|
||||
|
||||
eval_metrics = {
|
||||
"eval/survival_rate": survival_rate,
|
||||
"eval/avg_composite_score": avg_score,
|
||||
"eval/total_runs": total,
|
||||
"eval/survived_runs": survived_total,
|
||||
"eval/evaluation_time_seconds": end_time - start_time,
|
||||
}
|
||||
|
||||
for preset, items in sorted(preset_results.items()):
|
||||
ps = sum(1 for r in items if r.get("survived"))
|
||||
pt = len(items)
|
||||
pa = (
|
||||
sum(r.get("composite_score", 0) for r in items) / pt
|
||||
if pt
|
||||
else 0
|
||||
)
|
||||
key = preset.replace("-", "_")
|
||||
eval_metrics[f"eval/survival_rate_{key}"] = ps / pt if pt else 0
|
||||
eval_metrics[f"eval/avg_score_{key}"] = pa
|
||||
|
||||
self.eval_metrics = list(eval_metrics.items())
|
||||
|
||||
# --- Print summary ---
|
||||
print(f"\n{'='*60}")
|
||||
print("YC-Bench Evaluation Results")
|
||||
print(f"{'='*60}")
|
||||
print(
|
||||
f"Overall survival rate: {survival_rate:.1%} "
|
||||
f"({survived_total}/{total})"
|
||||
)
|
||||
print(f"Average composite score: {avg_score:.4f}")
|
||||
print(f"Evaluation time: {end_time - start_time:.1f}s")
|
||||
|
||||
print("\nPer-preset breakdown:")
|
||||
for preset, items in sorted(preset_results.items()):
|
||||
ps = sum(1 for r in items if r.get("survived"))
|
||||
pt = len(items)
|
||||
pa = (
|
||||
sum(r.get("composite_score", 0) for r in items) / pt
|
||||
if pt
|
||||
else 0
|
||||
)
|
||||
print(f" {preset}: {ps}/{pt} survived avg_score={pa:.4f}")
|
||||
for r in items:
|
||||
status = "SURVIVED" if r.get("survived") else "BANKRUPT"
|
||||
funds = r.get("final_funds_usd", 0)
|
||||
print(
|
||||
f" seed={r['seed']} [{status}] "
|
||||
f"${funds:,.0f} "
|
||||
f"score={r.get('composite_score', 0):.3f}"
|
||||
)
|
||||
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# --- Log results ---
|
||||
samples = [
|
||||
{k: v for k, v in r.items() if k != "messages"} for r in valid
|
||||
]
|
||||
|
||||
try:
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
samples=samples,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
generation_parameters={
|
||||
"temperature": self.config.agent_temperature,
|
||||
"max_tokens": self.config.max_token_length,
|
||||
"max_agent_turns": self.config.max_agent_turns,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error logging results: {e}")
|
||||
|
||||
# --- Cleanup (TB2 pattern) ---
|
||||
if hasattr(self, "_streaming_file") and not self._streaming_file.closed:
|
||||
self._streaming_file.close()
|
||||
print(f"Results saved to: {self._streaming_path}")
|
||||
|
||||
try:
|
||||
from tools.terminal_tool import cleanup_all_environments
|
||||
cleanup_all_environments()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
from environments.agent_loop import _tool_executor
|
||||
_tool_executor.shutdown(wait=False, cancel_futures=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# =========================================================================
|
||||
# Wandb logging
|
||||
# =========================================================================
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log YC-Bench-specific metrics to wandb."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
for k, v in self.eval_metrics:
|
||||
wandb_metrics[k] = v
|
||||
self.eval_metrics = []
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
YCBenchEvalEnv.cli()
|
||||
@@ -1,714 +0,0 @@
|
||||
"""
|
||||
HermesAgentBaseEnv -- Abstract Base Environment for Hermes-Agent + Atropos
|
||||
|
||||
Provides the Atropos integration plumbing that all hermes-agent environments share:
|
||||
- Two-mode operation (OpenAI server for Phase 1, VLLM ManagedServer for Phase 2)
|
||||
- Per-group toolset/distribution resolution
|
||||
- Agent loop orchestration via HermesAgentLoop
|
||||
- ToolContext creation for reward functions
|
||||
- ScoredDataGroup construction from ManagedServer state
|
||||
|
||||
Subclasses only need to implement:
|
||||
setup() -- Load dataset, initialize state
|
||||
get_next_item() -- Return the next item from the dataset
|
||||
format_prompt() -- Convert a dataset item into the user message
|
||||
compute_reward() -- Score the rollout (has full ToolContext access)
|
||||
evaluate() -- Periodic evaluation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
# Ensure the hermes-agent repo root is on sys.path so that imports like
|
||||
# `from model_tools import ...` and `from environments.X import ...` work
|
||||
# regardless of where the script is invoked from.
|
||||
_repo_root = Path(__file__).resolve().parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field
|
||||
|
||||
# Load API keys from hermes-agent/.env so all environments can access them
|
||||
_env_path = _repo_root / ".env"
|
||||
if _env_path.exists():
|
||||
load_dotenv(dotenv_path=_env_path)
|
||||
|
||||
# Apply monkey patches for async-safe tool operation inside Atropos's event loop.
|
||||
# This patches SwerexModalEnvironment to use a background thread instead of
|
||||
# asyncio.run(), which would deadlock inside Atropos. Safe for normal CLI too.
|
||||
from environments.patches import apply_patches
|
||||
apply_patches()
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
ScoredDataItem,
|
||||
)
|
||||
from atroposlib.envs.server_handling.server_manager import (
|
||||
APIServerConfig,
|
||||
ServerBaseline,
|
||||
ServerManager,
|
||||
)
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
from environments.agent_loop import AgentResult, HermesAgentLoop
|
||||
from environments.tool_context import ToolContext
|
||||
from tools.budget_config import (
|
||||
DEFAULT_RESULT_SIZE_CHARS,
|
||||
DEFAULT_TURN_BUDGET_CHARS,
|
||||
DEFAULT_PREVIEW_SIZE_CHARS,
|
||||
)
|
||||
|
||||
# Import hermes-agent toolset infrastructure
|
||||
from model_tools import get_tool_definitions
|
||||
from toolset_distributions import sample_toolsets_from_distribution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HermesAgentEnvConfig(BaseEnvConfig):
|
||||
"""
|
||||
Configuration for hermes-agent Atropos environments.
|
||||
|
||||
Extends BaseEnvConfig with agent-specific settings for toolsets,
|
||||
terminal backend, dataset loading, and tool call parsing.
|
||||
"""
|
||||
|
||||
# --- Toolset configuration ---
|
||||
# Mutually exclusive: use either enabled_toolsets OR distribution
|
||||
enabled_toolsets: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="Explicit list of hermes toolsets to enable (e.g., ['terminal', 'file', 'web']). "
|
||||
"If None and distribution is also None, all available toolsets are enabled.",
|
||||
)
|
||||
disabled_toolsets: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="Toolsets to disable. Applied as a filter on top of enabled_toolsets or distribution.",
|
||||
)
|
||||
distribution: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Name of a toolset distribution from toolset_distributions.py "
|
||||
"(e.g., 'development', 'terminal_tasks'). Sampled once per group. "
|
||||
"Mutually exclusive with enabled_toolsets.",
|
||||
)
|
||||
|
||||
# --- Agent loop configuration ---
|
||||
max_agent_turns: int = Field(
|
||||
default=30,
|
||||
description="Maximum number of LLM calls (tool-calling iterations) per rollout.",
|
||||
)
|
||||
system_prompt: Optional[str] = Field(
|
||||
default=None,
|
||||
description="System prompt for the agent. Tools are handled via the tools= parameter, "
|
||||
"not embedded in the prompt text.",
|
||||
)
|
||||
agent_temperature: float = Field(
|
||||
default=1.0,
|
||||
description="Sampling temperature for agent generation during rollouts.",
|
||||
)
|
||||
|
||||
# --- Terminal backend ---
|
||||
terminal_backend: str = Field(
|
||||
default="local",
|
||||
description="Terminal backend: 'local', 'docker', 'modal', 'daytona', 'ssh', 'singularity'. "
|
||||
"Modal or Daytona recommended for production RL (cloud isolation per rollout).",
|
||||
)
|
||||
terminal_timeout: int = Field(
|
||||
default=120,
|
||||
description="Per-command timeout in seconds for terminal tool calls. "
|
||||
"Commands exceeding this are killed. Increase for tasks with long-running "
|
||||
"commands (compilation, pip install, etc.).",
|
||||
)
|
||||
terminal_lifetime: int = Field(
|
||||
default=3600,
|
||||
description="Sandbox inactivity lifetime in seconds. The cleanup thread kills "
|
||||
"sandboxes that have been idle longer than this. Must be longer than "
|
||||
"the longest gap between tool calls (e.g., waiting for LLM response).",
|
||||
)
|
||||
|
||||
# --- Dataset ---
|
||||
dataset_name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="HuggingFace dataset name. Optional if tasks are defined inline.",
|
||||
)
|
||||
dataset_split: str = Field(
|
||||
default="train",
|
||||
description="Dataset split to use.",
|
||||
)
|
||||
prompt_field: str = Field(
|
||||
default="prompt",
|
||||
description="Which field in the dataset contains the prompt.",
|
||||
)
|
||||
|
||||
# --- Thread pool ---
|
||||
tool_pool_size: int = Field(
|
||||
default=128,
|
||||
description="Thread pool size for tool execution. Each concurrent task needs a "
|
||||
"thread for tool calls. Must be large enough for parallel evaluation. "
|
||||
"Too small = thread pool starvation.",
|
||||
)
|
||||
|
||||
# --- Phase 2: Tool call parsing ---
|
||||
tool_call_parser: str = Field(
|
||||
default="hermes",
|
||||
description="Tool call parser name for Phase 2 (VLLM server type). "
|
||||
"Ignored in Phase 1 (OpenAI server type where VLLM parses natively). "
|
||||
"Options: hermes, mistral, llama3_json, qwen, deepseek_v3, etc.",
|
||||
)
|
||||
|
||||
# --- Tool result budget ---
|
||||
# Defaults imported from tools.budget_config (single source of truth).
|
||||
default_result_size_chars: int = Field(
|
||||
default=DEFAULT_RESULT_SIZE_CHARS,
|
||||
description="Default per-tool threshold (chars) for persisting large results "
|
||||
"to sandbox. Results exceeding this are written to /tmp/hermes-results/ "
|
||||
"and replaced with a preview. Per-tool registry values take precedence "
|
||||
"unless overridden via tool_result_overrides.",
|
||||
)
|
||||
turn_budget_chars: int = Field(
|
||||
default=DEFAULT_TURN_BUDGET_CHARS,
|
||||
description="Aggregate char budget per assistant turn. If all tool results "
|
||||
"in a single turn exceed this, the largest are persisted to disk first.",
|
||||
)
|
||||
preview_size_chars: int = Field(
|
||||
default=DEFAULT_PREVIEW_SIZE_CHARS,
|
||||
description="Size of the inline preview shown after a tool result is persisted.",
|
||||
)
|
||||
tool_result_overrides: Optional[Dict[str, int]] = Field(
|
||||
default=None,
|
||||
description="Per-tool threshold overrides (chars). Keys are tool names, "
|
||||
"values are char thresholds. Overrides both the default and registry "
|
||||
"per-tool values. Example: {'terminal': 10000, 'search_files': 5000}. "
|
||||
"Note: read_file is pinned to infinity and cannot be overridden.",
|
||||
)
|
||||
|
||||
# --- Provider-specific parameters ---
|
||||
# Passed as extra_body to the OpenAI client's chat.completions.create() call.
|
||||
# Useful for OpenRouter provider preferences, transforms, route settings, etc.
|
||||
# Example YAML:
|
||||
# extra_body:
|
||||
# provider:
|
||||
# ignore: ["DeepInfra", "Fireworks"]
|
||||
# order: ["Together"]
|
||||
# transforms: ["middle-out"]
|
||||
extra_body: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="Extra body parameters passed to the OpenAI client's "
|
||||
"chat.completions.create(). Used for OpenRouter provider preferences, "
|
||||
"transforms, and other provider-specific settings.",
|
||||
)
|
||||
|
||||
def build_budget_config(self):
|
||||
"""Build a BudgetConfig from env config fields."""
|
||||
from tools.budget_config import BudgetConfig
|
||||
return BudgetConfig(
|
||||
default_result_size=self.default_result_size_chars,
|
||||
turn_budget=self.turn_budget_chars,
|
||||
preview_size=self.preview_size_chars,
|
||||
tool_overrides=dict(self.tool_result_overrides) if self.tool_result_overrides else {},
|
||||
)
|
||||
|
||||
|
||||
class HermesAgentBaseEnv(BaseEnv):
|
||||
"""
|
||||
Abstract base environment for hermes-agent Atropos integration.
|
||||
|
||||
Handles two modes of operation:
|
||||
- Phase 1 (OpenAI server type): Uses server.chat_completion() directly.
|
||||
The server (VLLM, SGLang, OpenRouter, OpenAI) handles tool call parsing
|
||||
and reasoning extraction natively. DummyManagedServer provides placeholder
|
||||
tokens. Good for SFT data gen, verifier testing, evaluation.
|
||||
|
||||
- Phase 2 (VLLM server type): Uses ManagedServer for exact token IDs + logprobs
|
||||
via /generate. Client-side tool call parser reconstructs structured tool_calls
|
||||
from raw output. Full RL training capability.
|
||||
|
||||
Subclasses must implement:
|
||||
setup() -- Load dataset, initialize state
|
||||
get_next_item() -- Return the next item to roll out
|
||||
format_prompt() -- Convert a dataset item into the user message string
|
||||
compute_reward() -- Score the rollout using ToolContext
|
||||
evaluate() -- Periodic evaluation
|
||||
"""
|
||||
|
||||
name: Optional[str] = "hermes-agent"
|
||||
env_config_cls = HermesAgentEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: HermesAgentEnvConfig,
|
||||
server_configs: Union[ServerBaseline, List[APIServerConfig]],
|
||||
slurm=False,
|
||||
testing=False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
|
||||
# Set terminal environment variables so hermes tools pick them up.
|
||||
# These can all be overridden per-environment via config fields instead
|
||||
# of requiring users to set shell env vars.
|
||||
if config.terminal_backend:
|
||||
os.environ["TERMINAL_ENV"] = config.terminal_backend
|
||||
os.environ["TERMINAL_TIMEOUT"] = str(config.terminal_timeout)
|
||||
os.environ["TERMINAL_LIFETIME_SECONDS"] = str(config.terminal_lifetime)
|
||||
print(
|
||||
f"🖥️ Terminal: backend={config.terminal_backend}, "
|
||||
f"timeout={config.terminal_timeout}s, lifetime={config.terminal_lifetime}s"
|
||||
)
|
||||
|
||||
# Resize the agent loop's thread pool for tool execution.
|
||||
# This must be large enough for the number of concurrent tasks
|
||||
# (e.g., 89 parallel TB2 eval tasks each need a thread for tool calls).
|
||||
from environments.agent_loop import resize_tool_pool
|
||||
resize_tool_pool(config.tool_pool_size)
|
||||
|
||||
# Set tool_parser on the ServerManager so ManagedServer uses it
|
||||
# for bidirectional tool call translation (raw text ↔ OpenAI tool_calls).
|
||||
if hasattr(self.server, 'tool_parser'):
|
||||
self.server.tool_parser = config.tool_call_parser
|
||||
print(f"🔧 Tool parser: {config.tool_call_parser}")
|
||||
|
||||
# Current group's resolved tools (set in collect_trajectories)
|
||||
self._current_group_tools: Optional[Tuple[List[Dict], Set[str]]] = None
|
||||
|
||||
# Tool error tracking for wandb logging
|
||||
self._tool_error_buffer: List[Dict[str, Any]] = []
|
||||
|
||||
# =========================================================================
|
||||
# Toolset resolution (per-group)
|
||||
# =========================================================================
|
||||
|
||||
def _resolve_tools_for_group(self) -> Tuple[List[Dict[str, Any]], Set[str]]:
|
||||
"""
|
||||
Resolve toolsets for a group. Called once in collect_trajectories(),
|
||||
then shared by all collect_trajectory() calls in the group.
|
||||
|
||||
If distribution is set, samples probabilistically.
|
||||
If enabled_toolsets is set, uses that explicit list.
|
||||
disabled_toolsets is applied as a filter on top.
|
||||
|
||||
Returns:
|
||||
(tool_schemas, valid_tool_names) tuple
|
||||
"""
|
||||
config = self.config
|
||||
|
||||
if config.distribution:
|
||||
group_toolsets = sample_toolsets_from_distribution(config.distribution)
|
||||
logger.info("Sampled toolsets from '%s': %s", config.distribution, group_toolsets)
|
||||
else:
|
||||
group_toolsets = config.enabled_toolsets # None means "all available"
|
||||
if group_toolsets is None:
|
||||
logger.warning(
|
||||
"enabled_toolsets is None -- loading ALL tools including messaging. "
|
||||
"Set explicit enabled_toolsets for RL training."
|
||||
)
|
||||
|
||||
tools = get_tool_definitions(
|
||||
enabled_toolsets=group_toolsets,
|
||||
disabled_toolsets=config.disabled_toolsets,
|
||||
quiet_mode=True,
|
||||
)
|
||||
|
||||
valid_names = {t["function"]["name"] for t in tools} if tools else set()
|
||||
logger.info("Resolved %d tools for group: %s", len(valid_names), sorted(valid_names))
|
||||
return tools, valid_names
|
||||
|
||||
# =========================================================================
|
||||
# Server mode detection
|
||||
# =========================================================================
|
||||
|
||||
def _use_managed_server(self) -> bool:
|
||||
"""
|
||||
Determine if we should use ManagedServer (Phase 2) or direct server (Phase 1).
|
||||
|
||||
Phase 2 (ManagedServer) is used when the server type is 'vllm' or 'sglang',
|
||||
which go through the /generate endpoint for exact token tracking.
|
||||
|
||||
Phase 1 (direct server) is used for 'openai' server type, which uses
|
||||
/v1/chat/completions with native tool call parsing.
|
||||
"""
|
||||
if not self.server.servers:
|
||||
return False
|
||||
|
||||
server = self.server.servers[0]
|
||||
# If the server is an OpenAI server (not VLLM/SGLang), use direct mode
|
||||
from atroposlib.envs.server_handling.openai_server import OpenAIServer
|
||||
return not isinstance(server, OpenAIServer)
|
||||
|
||||
# =========================================================================
|
||||
# Core Atropos integration
|
||||
# =========================================================================
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item: Item
|
||||
) -> Tuple[
|
||||
Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]],
|
||||
List[Item],
|
||||
]:
|
||||
"""
|
||||
Override collect_trajectories to resolve toolsets once per group,
|
||||
then delegate to the standard group-level collection.
|
||||
|
||||
The default BaseEnv.collect_trajectories() calls collect_trajectory()
|
||||
group_size times in parallel. We resolve tools once here and store
|
||||
them for all those calls to use.
|
||||
"""
|
||||
# Resolve toolsets for this group (shared by all rollouts in the group)
|
||||
self._current_group_tools = self._resolve_tools_for_group()
|
||||
|
||||
# Delegate to the default implementation which calls collect_trajectory()
|
||||
# group_size times via asyncio.gather
|
||||
return await super().collect_trajectories(item)
|
||||
|
||||
# =========================================================================
|
||||
# Wandb rollout display -- format trajectories nicely
|
||||
# =========================================================================
|
||||
|
||||
@staticmethod
|
||||
def _format_trajectory_for_display(messages: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Format a conversation's messages into a readable trajectory string
|
||||
for wandb rollout tables. Shows tool calls, tool results, and reasoning
|
||||
in a structured way instead of raw token decoding.
|
||||
"""
|
||||
parts = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content", "")
|
||||
|
||||
if role == "system":
|
||||
parts.append(f"[SYSTEM]\n{content}")
|
||||
|
||||
elif role == "user":
|
||||
parts.append(f"[USER]\n{content}")
|
||||
|
||||
elif role == "assistant":
|
||||
# Show reasoning if present
|
||||
reasoning = msg.get("reasoning_content", "")
|
||||
if reasoning:
|
||||
# Truncate long reasoning for display
|
||||
if len(reasoning) > 300:
|
||||
reasoning = reasoning[:300] + "..."
|
||||
parts.append(f"[ASSISTANT thinking]\n{reasoning}")
|
||||
|
||||
# Show content
|
||||
if content:
|
||||
parts.append(f"[ASSISTANT]\n{content}")
|
||||
|
||||
# Show tool calls
|
||||
tool_calls = msg.get("tool_calls", [])
|
||||
for tc in tool_calls:
|
||||
func = tc.get("function", {})
|
||||
name = func.get("name", "?")
|
||||
args = func.get("arguments", "{}")
|
||||
# Truncate long arguments for display
|
||||
if len(args) > 200:
|
||||
args = args[:200] + "..."
|
||||
parts.append(f"[TOOL CALL] {name}({args})")
|
||||
|
||||
elif role == "tool":
|
||||
tool_id = msg.get("tool_call_id", "")
|
||||
result = content
|
||||
# Truncate long tool results for display
|
||||
if len(result) > 500:
|
||||
result = result[:500] + "..."
|
||||
parts.append(f"[TOOL RESULT] {result}")
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
async def add_rollouts_for_wandb(
|
||||
self,
|
||||
scored_data,
|
||||
item=None,
|
||||
):
|
||||
"""
|
||||
Override to show formatted trajectories with tool calls visible,
|
||||
instead of raw token decoding which loses all structure.
|
||||
"""
|
||||
num_keep = self.config.num_rollouts_per_group_for_logging
|
||||
if num_keep == -1:
|
||||
num_keep = self.config.group_size
|
||||
|
||||
group = []
|
||||
for i in range(min(num_keep, len(scored_data.get("scores", [])))):
|
||||
score = scored_data["scores"][i]
|
||||
|
||||
# Use messages if available for rich display
|
||||
messages = None
|
||||
if scored_data.get("messages") and i < len(scored_data["messages"]):
|
||||
messages = scored_data["messages"][i]
|
||||
|
||||
if messages:
|
||||
text = self._format_trajectory_for_display(messages)
|
||||
elif scored_data.get("tokens") and i < len(scored_data["tokens"]):
|
||||
text = self.tokenizer.decode(scored_data["tokens"][i])
|
||||
else:
|
||||
text = "(no data)"
|
||||
|
||||
group.append((text, score))
|
||||
|
||||
self.rollouts_for_wandb.append(group)
|
||||
if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep:
|
||||
self.rollouts_for_wandb.pop(0)
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log base metrics including tool errors to wandb."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
# Log tool error stats
|
||||
if self._tool_error_buffer:
|
||||
wandb_metrics["train/tool_errors_count"] = len(self._tool_error_buffer)
|
||||
|
||||
# Log error details as a summary string (tables can crash wandb on tmp cleanup)
|
||||
error_summaries = []
|
||||
for err in self._tool_error_buffer:
|
||||
error_summaries.append(
|
||||
f"[turn {err['turn']}] {err['tool']}({err['args'][:80]}) -> {err['error'][:150]}"
|
||||
)
|
||||
wandb_metrics["train/tool_error_details"] = "\n".join(error_summaries)
|
||||
|
||||
# Also print to stdout for immediate visibility
|
||||
for summary in error_summaries:
|
||||
print(f" Tool Error: {summary}")
|
||||
|
||||
self._tool_error_buffer = []
|
||||
else:
|
||||
wandb_metrics["train/tool_errors_count"] = 0
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
async def collect_trajectory(
|
||||
self, item: Item
|
||||
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
|
||||
"""
|
||||
Run a single rollout: agent loop + reward computation.
|
||||
|
||||
This is called group_size times in parallel by collect_trajectories().
|
||||
Each call gets its own task_id for terminal/browser session isolation.
|
||||
"""
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# Get group-level tools (resolved once in collect_trajectories)
|
||||
if self._current_group_tools is None:
|
||||
# Fallback: resolve per-trajectory if called outside collect_trajectories
|
||||
tools, valid_names = self._resolve_tools_for_group()
|
||||
else:
|
||||
tools, valid_names = self._current_group_tools
|
||||
|
||||
# Build initial messages
|
||||
messages: List[Dict[str, Any]] = []
|
||||
if self.config.system_prompt:
|
||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
messages.append({"role": "user", "content": self.format_prompt(item)})
|
||||
|
||||
# Run the agent loop
|
||||
result: AgentResult
|
||||
if self._use_managed_server():
|
||||
# Phase 2: ManagedServer with ToolCallTranslator -- exact tokens + logprobs
|
||||
# tool_parser is set on ServerManager in __init__ and passed through
|
||||
# to ManagedServer, which uses ToolCallTranslator for bidirectional
|
||||
# translation between raw text and OpenAI tool_calls.
|
||||
try:
|
||||
async with self.server.managed_server(
|
||||
tokenizer=self.tokenizer,
|
||||
preserve_think_blocks=bool(self.config.thinking_mode),
|
||||
) as managed:
|
||||
agent = HermesAgentLoop(
|
||||
server=managed,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
budget_config=self.config.build_budget_config(),
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
except NotImplementedError:
|
||||
# DummyManagedServer not allowed -- fall back to Phase 1
|
||||
logger.warning(
|
||||
"ManagedServer not available (OpenAI server?). "
|
||||
"Falling back to direct server mode."
|
||||
)
|
||||
agent = HermesAgentLoop(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
budget_config=self.config.build_budget_config(),
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
else:
|
||||
# Phase 1: OpenAI server -- native tool_calls, placeholder tokens
|
||||
agent = HermesAgentLoop(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
budget_config=self.config.build_budget_config(),
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
# Skip reward computation if the agent loop produced no meaningful work
|
||||
# (e.g., API call failed on turn 1). No point spinning up a Modal sandbox
|
||||
# just to verify files that were never created.
|
||||
only_system_and_user = all(
|
||||
msg.get("role") in {"system", "user"} for msg in result.messages
|
||||
)
|
||||
if result.turns_used == 0 or only_system_and_user:
|
||||
logger.warning(
|
||||
"Agent loop produced no output (turns=%d, msgs=%d). Skipping reward.",
|
||||
result.turns_used, len(result.messages),
|
||||
)
|
||||
reward = 0.0
|
||||
else:
|
||||
# Compute reward using ToolContext (gives verifier full tool access)
|
||||
ctx = ToolContext(task_id)
|
||||
try:
|
||||
reward = await self.compute_reward(item, result, ctx)
|
||||
except Exception as e:
|
||||
logger.error("compute_reward failed: %s", e)
|
||||
reward = 0.0
|
||||
finally:
|
||||
ctx.cleanup()
|
||||
|
||||
# Track tool errors for wandb logging
|
||||
if result.tool_errors:
|
||||
for err in result.tool_errors:
|
||||
self._tool_error_buffer.append({
|
||||
"turn": err.turn,
|
||||
"tool": err.tool_name,
|
||||
"args": err.arguments[:150],
|
||||
"error": err.error[:300],
|
||||
"result": err.tool_result[:300],
|
||||
})
|
||||
|
||||
# Build ScoredDataItem from ManagedServer state
|
||||
# Phase 2: real tokens/masks/logprobs from SequenceNodes
|
||||
# Phase 1: placeholder tokens (still need a valid ScoredDataItem for the pipeline)
|
||||
nodes = (result.managed_state or {}).get("nodes", [])
|
||||
|
||||
if nodes:
|
||||
# Phase 2 (or DummyManagedServer): use actual node data
|
||||
node = nodes[-1] # Final sequence node = full trajectory
|
||||
scored_item: Dict[str, Any] = {
|
||||
"tokens": node.tokens,
|
||||
"masks": node.masked_tokens,
|
||||
"scores": reward,
|
||||
}
|
||||
|
||||
# Include logprobs if available (Phase 2)
|
||||
if hasattr(node, "logprobs") and node.logprobs:
|
||||
scored_item["advantages"] = None # Computed by trainer
|
||||
scored_item["ref_logprobs"] = None
|
||||
else:
|
||||
# Phase 1 with no managed state: create placeholder tokens
|
||||
# so the data pipeline doesn't break. These are NOT suitable
|
||||
# for training but allow process mode (SFT data gen) to work.
|
||||
# Tokenize the full conversation to get approximate tokens.
|
||||
full_text = "\n".join(
|
||||
msg.get("content", "") for msg in result.messages if msg.get("content")
|
||||
)
|
||||
if self.tokenizer:
|
||||
tokens = self.tokenizer.encode(full_text, add_special_tokens=True)
|
||||
else:
|
||||
tokens = list(range(min(len(full_text) // 4, 128)))
|
||||
|
||||
scored_item = {
|
||||
"tokens": tokens,
|
||||
"masks": [-100] + tokens[1:], # Mask first token as prompt
|
||||
"scores": reward,
|
||||
}
|
||||
|
||||
# Always include messages for wandb rollout display and data logging
|
||||
scored_item["messages"] = result.messages
|
||||
|
||||
return scored_item, []
|
||||
|
||||
# =========================================================================
|
||||
# Abstract methods -- subclasses must implement
|
||||
# =========================================================================
|
||||
|
||||
@abstractmethod
|
||||
async def setup(self):
|
||||
"""
|
||||
Load dataset, initialize state.
|
||||
|
||||
Called once when the environment starts. Typical implementation:
|
||||
self.dataset = load_dataset(self.config.dataset_name, split=self.config.dataset_split)
|
||||
self.iter = 0
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_next_item(self) -> Item:
|
||||
"""
|
||||
Return the next item from the dataset for rollout.
|
||||
|
||||
Called by the base env's main loop to get items for workers.
|
||||
Should cycle through the dataset.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def format_prompt(self, item: Item) -> str:
|
||||
"""
|
||||
Convert a dataset item into the user message for the agent.
|
||||
|
||||
Args:
|
||||
item: Dataset item (dict, tuple, etc.)
|
||||
|
||||
Returns:
|
||||
The prompt string to send to the agent
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def compute_reward(
|
||||
self, item: Item, result: AgentResult, ctx: ToolContext
|
||||
) -> float:
|
||||
"""
|
||||
Score the rollout. Has full access to:
|
||||
- item: the original dataset item (ground truth, test commands, etc.)
|
||||
- result: AgentResult with full messages, turn count, reasoning, etc.
|
||||
- ctx: ToolContext -- call ANY hermes-agent tool (terminal, file, web,
|
||||
browser, vision...) scoped to this rollout's sandbox. Nothing
|
||||
is off-limits.
|
||||
|
||||
Args:
|
||||
item: The dataset item that was rolled out
|
||||
result: The agent's rollout result
|
||||
ctx: ToolContext with full tool access for verification
|
||||
|
||||
Returns:
|
||||
Reward float (typically 0.0 to 1.0, but any float is valid)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""
|
||||
Periodic evaluation. Called every steps_per_eval steps.
|
||||
|
||||
Typical implementation runs the agent on a held-out eval set
|
||||
and logs metrics via wandb/evaluate_log.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -1,34 +0,0 @@
|
||||
# SWE Environment -- Default Configuration
|
||||
#
|
||||
# SWE-bench style tasks with Modal sandboxes for cloud isolation.
|
||||
# Uses terminal + file + web toolsets.
|
||||
#
|
||||
# Usage:
|
||||
# python environments/hermes_swe_env/hermes_swe_env.py serve \
|
||||
# --config environments/hermes_swe_env/default.yaml
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal", "file", "web"]
|
||||
max_agent_turns: 30
|
||||
max_token_length: 4096
|
||||
group_size: 4
|
||||
terminal_backend: "modal"
|
||||
tool_call_parser: "hermes"
|
||||
tokenizer_name: "NousResearch/DeepHermes-3-Llama-3-3B-Preview"
|
||||
dataset_name: "bigcode/humanevalpack"
|
||||
dataset_split: "test"
|
||||
prompt_field: "prompt"
|
||||
steps_per_eval: 50
|
||||
total_steps: 500
|
||||
use_wandb: true
|
||||
wandb_name: "hermes-swe"
|
||||
system_prompt: >
|
||||
You are a skilled software engineer. You have access to a terminal,
|
||||
file tools, and web search. Use these tools to complete the coding task.
|
||||
Write clean, working code and verify it runs correctly before finishing.
|
||||
|
||||
openai:
|
||||
base_url: "http://localhost:8000/v1"
|
||||
model_name: "NousResearch/DeepHermes-3-Llama-3-3B-Preview"
|
||||
server_type: "openai"
|
||||
api_key: ""
|
||||
@@ -1,229 +0,0 @@
|
||||
"""
|
||||
HermesSweEnv -- SWE-Bench Style Environment with Modal Sandboxes
|
||||
|
||||
A concrete environment for software engineering tasks where the model writes code
|
||||
and the reward function runs tests to verify correctness. Uses Modal terminal
|
||||
backend for cloud-isolated sandboxes per rollout.
|
||||
|
||||
The reward function uses ToolContext.terminal() to run test commands in the same
|
||||
Modal sandbox the model used during its agentic loop. All filesystem state from
|
||||
the model's tool calls is preserved for verification.
|
||||
|
||||
Usage:
|
||||
# Phase 1: OpenAI server type
|
||||
vllm serve YourModel --tool-parser hermes
|
||||
run-api
|
||||
python environments/hermes_swe_env.py serve \\
|
||||
--openai.base_url http://localhost:8000/v1 \\
|
||||
--openai.model_name YourModel \\
|
||||
--openai.server_type openai \\
|
||||
--env.dataset_name bigcode/humanevalpack \\
|
||||
--env.terminal_backend modal
|
||||
|
||||
# Phase 2: VLLM server type (full RL training)
|
||||
python environments/hermes_swe_env.py serve \\
|
||||
--openai.base_url http://localhost:8000/v1 \\
|
||||
--openai.model_name YourModel \\
|
||||
--openai.server_type vllm \\
|
||||
--env.tool_call_parser hermes \\
|
||||
--env.terminal_backend modal
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
# Ensure repo root is on sys.path for imports
|
||||
_repo_root = Path(__file__).resolve().parent.parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from atroposlib.envs.base import ScoredDataGroup
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
from environments.agent_loop import AgentResult
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
from environments.tool_context import ToolContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HermesSweEnvConfig(HermesAgentEnvConfig):
|
||||
"""Config with defaults for SWE-bench style tasks."""
|
||||
|
||||
pass # Inherits all fields, overrides defaults in config_init
|
||||
|
||||
|
||||
class HermesSweEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
SWE-bench style environment using Modal terminal backend.
|
||||
|
||||
The model gets a coding task, uses terminal + file + web tools to solve it,
|
||||
and the reward function runs tests in the same Modal sandbox to verify.
|
||||
|
||||
Subclass this for specific SWE datasets (HumanEval, SWE-bench, etc.)
|
||||
and customize format_prompt() and compute_reward() as needed.
|
||||
"""
|
||||
|
||||
name = "hermes-swe"
|
||||
env_config_cls = HermesSweEnvConfig
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[HermesSweEnvConfig, List[APIServerConfig]]:
|
||||
"""
|
||||
Default configuration for the SWE environment.
|
||||
|
||||
Uses Modal terminal backend for cloud isolation and terminal + file + web toolsets.
|
||||
"""
|
||||
env_config = HermesSweEnvConfig(
|
||||
# Toolsets: terminal for running code, file for reading/writing, web for docs
|
||||
enabled_toolsets=["terminal", "file", "web"],
|
||||
disabled_toolsets=None,
|
||||
distribution=None,
|
||||
# Agent settings -- SWE tasks need more turns
|
||||
max_agent_turns=30,
|
||||
max_token_length=4096,
|
||||
agent_temperature=1.0,
|
||||
system_prompt=(
|
||||
"You are a skilled software engineer. You have access to a terminal, "
|
||||
"file tools, and web search. Use these tools to complete the coding task. "
|
||||
"Write clean, working code and verify it runs correctly before finishing."
|
||||
),
|
||||
# Modal backend for cloud-isolated sandboxes
|
||||
terminal_backend="modal",
|
||||
# Dataset -- override via CLI for your specific SWE dataset
|
||||
dataset_name="bigcode/humanevalpack",
|
||||
dataset_split="test",
|
||||
prompt_field="prompt",
|
||||
# Atropos settings
|
||||
group_size=4,
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
tool_call_parser="hermes",
|
||||
steps_per_eval=50,
|
||||
total_steps=500,
|
||||
use_wandb=True,
|
||||
wandb_name="hermes-swe",
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
base_url="http://localhost:8000/v1",
|
||||
model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
server_type="openai", # Phase 1; switch to "vllm" for Phase 2
|
||||
api_key="",
|
||||
)
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup(self):
|
||||
"""Load the SWE dataset."""
|
||||
if self.config.dataset_name:
|
||||
self.dataset = load_dataset(
|
||||
self.config.dataset_name, split=self.config.dataset_split
|
||||
)
|
||||
else:
|
||||
# Placeholder if no dataset specified
|
||||
self.dataset = []
|
||||
self.iter = 0
|
||||
self.reward_buffer: List[float] = []
|
||||
|
||||
async def get_next_item(self) -> Dict[str, Any]:
|
||||
"""Cycle through the SWE dataset."""
|
||||
if not self.dataset:
|
||||
raise ValueError("No dataset loaded. Set dataset_name in config.")
|
||||
item = self.dataset[self.iter % len(self.dataset)]
|
||||
self.iter += 1
|
||||
return item
|
||||
|
||||
def format_prompt(self, item: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Format the SWE task prompt.
|
||||
|
||||
Override this in subclasses for different dataset formats.
|
||||
Default assumes the dataset has a 'prompt' field and optionally a 'test' field.
|
||||
"""
|
||||
prompt = item.get(self.config.prompt_field, "")
|
||||
|
||||
# If the dataset has test information, include it in the prompt
|
||||
test_info = item.get("test", item.get("test_code", item.get("tests", "")))
|
||||
if test_info:
|
||||
prompt += f"\n\nTests to pass:\n{test_info}"
|
||||
|
||||
return prompt
|
||||
|
||||
async def compute_reward(
|
||||
self, item: Dict[str, Any], result: AgentResult, ctx: ToolContext
|
||||
) -> float:
|
||||
"""
|
||||
Score by running tests in the model's Modal sandbox.
|
||||
|
||||
Default implementation:
|
||||
- If the dataset item has a 'test' or 'test_code' field, run it
|
||||
- Check exit code: 0 = pass, non-zero = fail
|
||||
- Partial credit for file creation
|
||||
|
||||
Override this in subclasses for more sophisticated reward logic.
|
||||
"""
|
||||
# Find the test command from the dataset item
|
||||
test_code = item.get("test", item.get("test_code", item.get("tests", "")))
|
||||
|
||||
if test_code:
|
||||
# Run the test in the model's sandbox
|
||||
test_result = ctx.terminal(
|
||||
f'cd /workspace && python3 -c "{test_code}"', timeout=60
|
||||
)
|
||||
|
||||
if test_result["exit_code"] == 0:
|
||||
self.reward_buffer.append(1.0)
|
||||
return 1.0
|
||||
|
||||
# Partial credit: check if the model created any Python files
|
||||
file_check = ctx.terminal("find /workspace -name '*.py' -newer /tmp/.start_marker 2>/dev/null | head -5")
|
||||
if file_check["exit_code"] == 0 and file_check.get("output", "").strip():
|
||||
self.reward_buffer.append(0.1)
|
||||
return 0.1
|
||||
|
||||
self.reward_buffer.append(0.0)
|
||||
return 0.0
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""
|
||||
Run evaluation on a held-out set.
|
||||
|
||||
Override for dataset-specific evaluation logic.
|
||||
"""
|
||||
start_time = time.time()
|
||||
end_time = time.time()
|
||||
|
||||
eval_metrics = {"eval/placeholder": 0.0}
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log SWE-specific metrics."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
if self.reward_buffer:
|
||||
wandb_metrics["train/avg_reward"] = sum(self.reward_buffer) / len(
|
||||
self.reward_buffer
|
||||
)
|
||||
wandb_metrics["train/pass_rate"] = sum(
|
||||
1 for r in self.reward_buffer if r == 1.0
|
||||
) / len(self.reward_buffer)
|
||||
self.reward_buffer = []
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
HermesSweEnv.cli()
|
||||
@@ -1,35 +0,0 @@
|
||||
"""
|
||||
Monkey patches for making hermes-agent tools work inside async frameworks (Atropos).
|
||||
|
||||
Problem:
|
||||
Some tools use asyncio.run() internally (e.g., Modal backend via SWE-ReX,
|
||||
web_extract). This crashes when called from inside Atropos's event loop because
|
||||
asyncio.run() can't be nested.
|
||||
|
||||
Solution:
|
||||
The Modal environment (tools/environments/modal.py) now uses a dedicated
|
||||
_AsyncWorker thread internally, making it safe for both CLI and Atropos use.
|
||||
No monkey-patching is required.
|
||||
|
||||
This module is kept for backward compatibility. apply_patches() is a no-op.
|
||||
|
||||
Usage:
|
||||
Call apply_patches() once at import time (done automatically by hermes_base_env.py).
|
||||
This is idempotent and safe to call multiple times.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_patches_applied = False
|
||||
|
||||
|
||||
def apply_patches():
|
||||
"""Apply all monkey patches needed for Atropos compatibility."""
|
||||
global _patches_applied
|
||||
if _patches_applied:
|
||||
return
|
||||
|
||||
logger.debug("apply_patches() called; no patches needed (async safety is built-in)")
|
||||
_patches_applied = True
|
||||
@@ -1,34 +0,0 @@
|
||||
# Terminal Test Environment -- Default Configuration
|
||||
#
|
||||
# Simple file-creation tasks for validating the full Atropos + hermes-agent stack.
|
||||
# Uses Modal terminal backend and OpenRouter (Claude) for inference.
|
||||
# API keys loaded from ~/hermes-agent/.env
|
||||
#
|
||||
# Usage:
|
||||
# run-api
|
||||
# python environments/terminal_test_env/terminal_test_env.py serve \
|
||||
# --config environments/terminal_test_env/default.yaml
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal", "file"]
|
||||
max_agent_turns: 10
|
||||
max_token_length: 2048
|
||||
group_size: 3
|
||||
total_steps: 3
|
||||
steps_per_eval: 3
|
||||
terminal_backend: "modal"
|
||||
tool_call_parser: "hermes"
|
||||
tokenizer_name: "NousResearch/DeepHermes-3-Llama-3-3B-Preview"
|
||||
ensure_scores_are_not_same: false
|
||||
use_wandb: false
|
||||
system_prompt: >
|
||||
You are a helpful assistant with access to a terminal and file tools.
|
||||
Complete the user's request by using the available tools.
|
||||
Be precise and follow instructions exactly.
|
||||
|
||||
openai:
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
model_name: "anthropic/claude-opus-4.6"
|
||||
server_type: "openai"
|
||||
health_check: false
|
||||
# api_key loaded from OPENROUTER_API_KEY in .env
|
||||
@@ -1,292 +0,0 @@
|
||||
"""
|
||||
TerminalTestEnv -- Simple Test Environment for Validating the Stack
|
||||
|
||||
A self-contained environment with inline tasks (no external dataset needed).
|
||||
Each task asks the model to create a file at a known path with specific content.
|
||||
The reward verifier cats the file and checks if the content matches.
|
||||
|
||||
Enables only terminal + file toolsets. Uses Modal terminal backend with
|
||||
OpenRouter (Claude) by default.
|
||||
|
||||
Training tasks (3):
|
||||
1. Create ~/greeting.txt with "Hello from Hermes Agent"
|
||||
2. Create ~/count.txt with numbers 1-5, one per line
|
||||
3. Create ~/answer.txt with the result of 123 + 456
|
||||
|
||||
Eval task (1):
|
||||
1. Create ~/result.txt with the result of 6 * 7
|
||||
|
||||
Usage:
|
||||
# Start Atropos API server
|
||||
run-api
|
||||
|
||||
# Run environment (uses OpenRouter + Modal by default)
|
||||
python environments/terminal_test_env.py serve
|
||||
|
||||
# Process mode (no run-api needed, saves to JSONL)
|
||||
python environments/terminal_test_env.py process \\
|
||||
--env.data_path_to_save_groups terminal_test_output.jsonl
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
# Ensure repo root is on sys.path for imports
|
||||
_repo_root = Path(__file__).resolve().parent.parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from atroposlib.envs.base import ScoredDataGroup
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
from environments.agent_loop import AgentResult
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
from environments.tool_context import ToolContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Inline task definitions -- no external dataset needed
|
||||
# =============================================================================
|
||||
|
||||
TRAIN_TASKS = [
|
||||
{
|
||||
"prompt": "Create a file at ~/greeting.txt containing exactly the text: Hello from Hermes Agent",
|
||||
"verify_path": "~/greeting.txt",
|
||||
"expected_content": "Hello from Hermes Agent",
|
||||
},
|
||||
{
|
||||
"prompt": "Create a file at ~/count.txt containing the numbers 1 through 5, one per line",
|
||||
"verify_path": "~/count.txt",
|
||||
"expected_content": "1\n2\n3\n4\n5",
|
||||
},
|
||||
{
|
||||
"prompt": "Create a file at ~/answer.txt containing the result of 123 + 456",
|
||||
"verify_path": "~/answer.txt",
|
||||
"expected_content": "579",
|
||||
},
|
||||
]
|
||||
|
||||
EVAL_TASKS = [
|
||||
{
|
||||
"prompt": "Create a file at ~/result.txt containing the result of 6 * 7",
|
||||
"verify_path": "~/result.txt",
|
||||
"expected_content": "42",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class TerminalTestEnvConfig(HermesAgentEnvConfig):
|
||||
"""Config with defaults suitable for terminal testing."""
|
||||
|
||||
pass # Inherits all fields, overrides defaults in config_init
|
||||
|
||||
|
||||
class TerminalTestEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
Simple test environment with inline file-creation tasks.
|
||||
|
||||
All tasks follow the same pattern: "create a file at ~/X.txt with content Y".
|
||||
The verifier runs `cat ~/X.txt` in the rollout's terminal and checks the output
|
||||
against the expected string. Same verifier logic for all tasks.
|
||||
|
||||
This environment is designed to validate the full stack end-to-end:
|
||||
- Agent loop executes tool calls (terminal/file)
|
||||
- ToolContext provides terminal access to the reward function
|
||||
- Reward function verifies file content via cat
|
||||
- Scored data flows through the Atropos pipeline
|
||||
"""
|
||||
|
||||
name = "terminal-test"
|
||||
env_config_cls = TerminalTestEnvConfig
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[TerminalTestEnvConfig, List[APIServerConfig]]:
|
||||
"""
|
||||
Default configuration for the terminal test environment.
|
||||
|
||||
Uses Modal terminal backend for cloud isolation and OpenRouter with
|
||||
Claude for inference. API keys loaded from ~/hermes-agent/.env.
|
||||
"""
|
||||
env_config = TerminalTestEnvConfig(
|
||||
# Terminal + file tools only
|
||||
enabled_toolsets=["terminal", "file"],
|
||||
disabled_toolsets=None,
|
||||
distribution=None,
|
||||
# Agent settings
|
||||
max_agent_turns=10, # Simple tasks, don't need many turns
|
||||
max_token_length=16000,
|
||||
agent_temperature=1.0,
|
||||
system_prompt=(
|
||||
"You are a helpful assistant with access to a terminal and file tools. "
|
||||
"Complete the user's request by using the available tools. "
|
||||
"Be precise and follow instructions exactly."
|
||||
),
|
||||
# Modal terminal backend for cloud-isolated sandboxes per rollout
|
||||
terminal_backend="modal",
|
||||
# Atropos settings
|
||||
group_size=3, # 3 rollouts per group
|
||||
tokenizer_name="NousResearch/q-30b-t-h45-e1",
|
||||
tool_call_parser="hermes",
|
||||
steps_per_eval=3, # Eval after all 3 steps
|
||||
total_steps=3, # 3 groups total (1 group per step)
|
||||
use_wandb=True,
|
||||
wandb_name="terminal-test",
|
||||
ensure_scores_are_not_same=False, # Allow all-same scores for simple tasks
|
||||
# No external dataset
|
||||
dataset_name=None,
|
||||
)
|
||||
|
||||
# OpenRouter with Claude -- API key loaded from .env (OPENROUTER_API_KEY)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model_name="anthropic/claude-opus-4.6",
|
||||
server_type="openai",
|
||||
api_key=os.getenv("OPENROUTER_API_KEY", ""),
|
||||
health_check=False, # OpenRouter doesn't have a /health endpoint
|
||||
)
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup(self):
|
||||
"""Initialize inline task lists."""
|
||||
self.train_tasks = list(TRAIN_TASKS)
|
||||
self.eval_tasks = list(EVAL_TASKS)
|
||||
self.iter = 0
|
||||
# Track reward stats for wandb logging
|
||||
self.reward_buffer: List[float] = []
|
||||
|
||||
async def get_next_item(self) -> Dict[str, str]:
|
||||
"""Cycle through training tasks."""
|
||||
item = self.train_tasks[self.iter % len(self.train_tasks)]
|
||||
self.iter += 1
|
||||
return item
|
||||
|
||||
def format_prompt(self, item: Dict[str, str]) -> str:
|
||||
"""The prompt is directly in the task item."""
|
||||
return item["prompt"]
|
||||
|
||||
async def compute_reward(
|
||||
self, item: Dict[str, str], result: AgentResult, ctx: ToolContext
|
||||
) -> float:
|
||||
"""
|
||||
Verify by cat-ing the expected file path and checking content matches.
|
||||
Same verifier for all tasks -- they all write a file at a known path.
|
||||
|
||||
Scoring:
|
||||
1.0 = exact match
|
||||
0.5 = expected content is present but has extra stuff
|
||||
0.0 = file doesn't exist or content doesn't match
|
||||
"""
|
||||
verify_result = ctx.terminal(f"cat {item['verify_path']}")
|
||||
|
||||
# File doesn't exist or can't be read
|
||||
if verify_result["exit_code"] != 0:
|
||||
self.reward_buffer.append(0.0)
|
||||
return 0.0
|
||||
|
||||
actual = verify_result.get("output", "").strip()
|
||||
expected = item["expected_content"].strip()
|
||||
|
||||
# Exact match
|
||||
if actual == expected:
|
||||
self.reward_buffer.append(1.0)
|
||||
return 1.0
|
||||
|
||||
# Partial credit: expected content is present but has extra stuff
|
||||
if expected in actual:
|
||||
self.reward_buffer.append(0.5)
|
||||
return 0.5
|
||||
|
||||
self.reward_buffer.append(0.0)
|
||||
return 0.0
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""
|
||||
Run eval tasks using the agent loop and verify results.
|
||||
Logs accuracy metrics.
|
||||
"""
|
||||
start_time = time.time()
|
||||
correct = 0
|
||||
total = len(self.eval_tasks)
|
||||
samples = []
|
||||
|
||||
for eval_item in self.eval_tasks:
|
||||
try:
|
||||
# For eval, we do a simple single-turn completion (not full agent loop)
|
||||
# to keep eval fast. The agent loop is tested via training.
|
||||
completion = await self.server.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": self.config.system_prompt or ""},
|
||||
{"role": "user", "content": eval_item["prompt"]},
|
||||
],
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.0,
|
||||
split="eval",
|
||||
)
|
||||
|
||||
response_content = (
|
||||
completion.choices[0].message.content if completion.choices else ""
|
||||
)
|
||||
|
||||
samples.append(
|
||||
{
|
||||
"prompt": eval_item["prompt"],
|
||||
"response": response_content,
|
||||
"expected": eval_item["expected_content"],
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Eval failed for item: %s", e)
|
||||
samples.append(
|
||||
{
|
||||
"prompt": eval_item["prompt"],
|
||||
"response": f"ERROR: {e}",
|
||||
"expected": eval_item["expected_content"],
|
||||
}
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
eval_metrics = {
|
||||
"eval/num_samples": total,
|
||||
}
|
||||
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
samples=samples,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log training metrics including reward stats and accuracy."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
if self.reward_buffer:
|
||||
total = len(self.reward_buffer)
|
||||
correct = sum(1 for r in self.reward_buffer if r == 1.0)
|
||||
partial = sum(1 for r in self.reward_buffer if r == 0.5)
|
||||
|
||||
wandb_metrics["train/avg_reward"] = sum(self.reward_buffer) / total
|
||||
wandb_metrics["train/accuracy"] = correct / total
|
||||
wandb_metrics["train/partial_match_rate"] = partial / total
|
||||
wandb_metrics["train/total_rollouts"] = total
|
||||
self.reward_buffer = []
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
TerminalTestEnv.cli()
|
||||
@@ -1,120 +0,0 @@
|
||||
"""
|
||||
Tool Call Parser Registry
|
||||
|
||||
Client-side parsers that extract structured tool_calls from raw model output text.
|
||||
Used in Phase 2 (VLLM server type) where ManagedServer's /generate endpoint returns
|
||||
raw text without tool call parsing.
|
||||
|
||||
Each parser is a standalone reimplementation of the corresponding VLLM parser's
|
||||
non-streaming extract_tool_calls() logic. No VLLM dependency -- only standard library
|
||||
(re, json, uuid) and openai types.
|
||||
|
||||
Usage:
|
||||
from environments.tool_call_parsers import get_parser
|
||||
|
||||
parser = get_parser("hermes")
|
||||
content, tool_calls = parser.parse(raw_model_output)
|
||||
# content = text with tool call markup stripped
|
||||
# tool_calls = list of ChatCompletionMessageToolCall objects, or None
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Type alias for parser return value
|
||||
ParseResult = Tuple[Optional[str], Optional[List[ChatCompletionMessageToolCall]]]
|
||||
|
||||
|
||||
class ToolCallParser(ABC):
|
||||
"""
|
||||
Base class for tool call parsers.
|
||||
|
||||
Each parser knows how to extract structured tool_calls from a specific
|
||||
model family's raw output text format.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
"""
|
||||
Parse raw model output text for tool calls.
|
||||
|
||||
Args:
|
||||
text: Raw decoded text from the model's completion
|
||||
|
||||
Returns:
|
||||
Tuple of (content, tool_calls) where:
|
||||
- content: text with tool call markup stripped (the message 'content' field),
|
||||
or None if the entire output was tool calls
|
||||
- tool_calls: list of ChatCompletionMessageToolCall objects,
|
||||
or None if no tool calls were found
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Global parser registry: name -> parser class
|
||||
PARSER_REGISTRY: Dict[str, Type[ToolCallParser]] = {}
|
||||
|
||||
|
||||
def register_parser(name: str):
|
||||
"""
|
||||
Decorator to register a parser class under a given name.
|
||||
|
||||
Usage:
|
||||
@register_parser("hermes")
|
||||
class HermesToolCallParser(ToolCallParser):
|
||||
...
|
||||
"""
|
||||
|
||||
def decorator(cls: Type[ToolCallParser]) -> Type[ToolCallParser]:
|
||||
PARSER_REGISTRY[name] = cls
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def get_parser(name: str) -> ToolCallParser:
|
||||
"""
|
||||
Get a parser instance by name.
|
||||
|
||||
Args:
|
||||
name: Parser name (e.g., "hermes", "mistral", "llama3_json")
|
||||
|
||||
Returns:
|
||||
Instantiated parser
|
||||
|
||||
Raises:
|
||||
KeyError: If parser name is not found in registry
|
||||
"""
|
||||
if name not in PARSER_REGISTRY:
|
||||
available = sorted(PARSER_REGISTRY.keys())
|
||||
raise KeyError(
|
||||
f"Tool call parser '{name}' not found. Available parsers: {available}"
|
||||
)
|
||||
return PARSER_REGISTRY[name]()
|
||||
|
||||
|
||||
def list_parsers() -> List[str]:
|
||||
"""Return sorted list of registered parser names."""
|
||||
return sorted(PARSER_REGISTRY.keys())
|
||||
|
||||
|
||||
# Import all parser modules to trigger registration via @register_parser decorators
|
||||
# Each module registers itself when imported
|
||||
from environments.tool_call_parsers.hermes_parser import HermesToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.longcat_parser import LongcatToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.mistral_parser import MistralToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.llama_parser import LlamaToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.qwen_parser import QwenToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.deepseek_v3_parser import DeepSeekV3ToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.deepseek_v3_1_parser import DeepSeekV31ToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.kimi_k2_parser import KimiK2ToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.glm45_parser import Glm45ToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.glm47_parser import Glm47ToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.qwen3_coder_parser import Qwen3CoderToolCallParser # noqa: E402, F401
|
||||
@@ -1,72 +0,0 @@
|
||||
"""
|
||||
DeepSeek V3.1 tool call parser.
|
||||
|
||||
Similar to V3 but with a slightly different format:
|
||||
<|tool▁call▁begin|>function_name<|tool▁sep|>arguments<|tool▁call▁end|>
|
||||
|
||||
Note: V3 has type+name before the separator, V3.1 has name before and args after.
|
||||
|
||||
Based on VLLM's DeepSeekV31ToolParser.extract_tool_calls()
|
||||
"""
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
|
||||
@register_parser("deepseek_v3_1")
|
||||
@register_parser("deepseek_v31")
|
||||
class DeepSeekV31ToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for DeepSeek V3.1 tool calls.
|
||||
|
||||
Slightly different regex than V3: function_name comes before the separator,
|
||||
arguments come after (no type field, no json code block wrapper).
|
||||
"""
|
||||
|
||||
START_TOKEN = "<|tool▁calls▁begin|>"
|
||||
|
||||
# Regex captures: function_name, function_arguments
|
||||
PATTERN = re.compile(
|
||||
r"<|tool▁call▁begin|>(?P<function_name>.*?)<|tool▁sep|>(?P<function_arguments>.*?)<|tool▁call▁end|>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
if self.START_TOKEN not in text:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
matches = self.PATTERN.findall(text)
|
||||
if not matches:
|
||||
return text, None
|
||||
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
for match in matches:
|
||||
func_name, func_args = match
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
||||
type="function",
|
||||
function=Function(
|
||||
name=func_name.strip(),
|
||||
arguments=func_args.strip(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
content = text[: text.find(self.START_TOKEN)].strip()
|
||||
return content if content else None, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
@@ -1,89 +0,0 @@
|
||||
"""
|
||||
DeepSeek V3 tool call parser.
|
||||
|
||||
Format uses special unicode tokens:
|
||||
<|tool▁calls▁begin|>
|
||||
<|tool▁call▁begin|>type<|tool▁sep|>function_name
|
||||
```json
|
||||
{"arg": "value"}
|
||||
```
|
||||
<|tool▁call▁end|>
|
||||
<|tool▁calls▁end|>
|
||||
|
||||
Fixes Issue #989: Support for multiple simultaneous tool calls.
|
||||
"""
|
||||
|
||||
import re
|
||||
import uuid
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@register_parser("deepseek_v3")
|
||||
class DeepSeekV3ToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for DeepSeek V3 tool calls.
|
||||
|
||||
Uses special unicode tokens with fullwidth angle brackets and block elements.
|
||||
Extracts type, function name, and JSON arguments from the structured format.
|
||||
Ensures all tool calls are captured when the model executes multiple actions.
|
||||
"""
|
||||
|
||||
START_TOKEN = "<|tool▁calls▁begin|>"
|
||||
|
||||
# Updated PATTERN: Using \s* instead of literal \n for increased robustness
|
||||
# against variations in model formatting (Issue #989).
|
||||
PATTERN = re.compile(
|
||||
r"<|tool▁call▁begin|>(?P<type>.*?)<|tool▁sep|>(?P<function_name>.*?)\s*```json\s*(?P<function_arguments>.*?)\s*```\s*<|tool▁call▁end|>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
"""
|
||||
Parses the input text and extracts all available tool calls.
|
||||
"""
|
||||
if self.START_TOKEN not in text:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
# Using finditer to capture ALL tool calls in the sequence
|
||||
matches = list(self.PATTERN.finditer(text))
|
||||
if not matches:
|
||||
return text, None
|
||||
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
|
||||
for match in matches:
|
||||
func_name = match.group("function_name").strip()
|
||||
func_args = match.group("function_arguments").strip()
|
||||
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
||||
type="function",
|
||||
function=Function(
|
||||
name=func_name,
|
||||
arguments=func_args,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if tool_calls:
|
||||
# Content is text before the first tool call block
|
||||
content_index = text.find(self.START_TOKEN)
|
||||
content = text[:content_index].strip()
|
||||
return content if content else None, tool_calls
|
||||
|
||||
return text, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing DeepSeek V3 tool calls: {e}")
|
||||
return text, None
|
||||
@@ -1,109 +0,0 @@
|
||||
"""
|
||||
GLM 4.5 (GLM-4-MoE) tool call parser.
|
||||
|
||||
Format uses custom arg_key/arg_value tags rather than standard JSON:
|
||||
<tool_call>function_name
|
||||
<arg_key>param1</arg_key><arg_value>value1</arg_value>
|
||||
<arg_key>param2</arg_key><arg_value>value2</arg_value>
|
||||
</tool_call>
|
||||
|
||||
Values are deserialized using json.loads -> ast.literal_eval -> raw string fallback.
|
||||
|
||||
Based on VLLM's Glm4MoeModelToolParser.extract_tool_calls()
|
||||
"""
|
||||
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
|
||||
def _deserialize_value(value: str) -> Any:
|
||||
"""
|
||||
Try to deserialize a string value to its native Python type.
|
||||
Attempts json.loads, then ast.literal_eval, then returns raw string.
|
||||
"""
|
||||
try:
|
||||
return json.loads(value)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
try:
|
||||
return ast.literal_eval(value)
|
||||
except (ValueError, SyntaxError, TypeError):
|
||||
pass
|
||||
|
||||
return value
|
||||
|
||||
|
||||
@register_parser("glm45")
|
||||
class Glm45ToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for GLM 4.5 (GLM-4-MoE) tool calls.
|
||||
|
||||
Uses <tool_call>...</tool_call> tags with <arg_key>/<arg_value> pairs
|
||||
instead of standard JSON arguments.
|
||||
"""
|
||||
|
||||
FUNC_CALL_REGEX = re.compile(r"<tool_call>.*?</tool_call>", re.DOTALL)
|
||||
FUNC_DETAIL_REGEX = re.compile(r"<tool_call>([^\n]*)\n(.*)</tool_call>", re.DOTALL)
|
||||
FUNC_ARG_REGEX = re.compile(
|
||||
r"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>", re.DOTALL
|
||||
)
|
||||
|
||||
START_TOKEN = "<tool_call>"
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
if self.START_TOKEN not in text:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
matched_calls = self.FUNC_CALL_REGEX.findall(text)
|
||||
if not matched_calls:
|
||||
return text, None
|
||||
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
|
||||
for match in matched_calls:
|
||||
detail = self.FUNC_DETAIL_REGEX.search(match)
|
||||
if not detail:
|
||||
continue
|
||||
|
||||
func_name = detail.group(1).strip()
|
||||
func_args_raw = detail.group(2)
|
||||
|
||||
# Parse arg_key/arg_value pairs
|
||||
pairs = self.FUNC_ARG_REGEX.findall(func_args_raw) if func_args_raw else []
|
||||
arg_dict: Dict[str, Any] = {}
|
||||
for key, value in pairs:
|
||||
arg_key = key.strip()
|
||||
arg_val = _deserialize_value(value.strip())
|
||||
arg_dict[arg_key] = arg_val
|
||||
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
||||
type="function",
|
||||
function=Function(
|
||||
name=func_name,
|
||||
arguments=json.dumps(arg_dict, ensure_ascii=False),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
content = text[: text.find(self.START_TOKEN)].strip()
|
||||
return content if content else None, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
@@ -1,35 +0,0 @@
|
||||
"""
|
||||
GLM 4.7 tool call parser.
|
||||
|
||||
Same as GLM 4.5 but with slightly different regex patterns.
|
||||
The tool_call tags may wrap differently and arg parsing handles
|
||||
newlines between key/value pairs.
|
||||
|
||||
Based on VLLM's Glm47MoeModelToolParser (extends Glm4MoeModelToolParser).
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, register_parser
|
||||
from environments.tool_call_parsers.glm45_parser import Glm45ToolCallParser
|
||||
|
||||
|
||||
@register_parser("glm47")
|
||||
class Glm47ToolCallParser(Glm45ToolCallParser):
|
||||
"""
|
||||
Parser for GLM 4.7 tool calls.
|
||||
Extends GLM 4.5 with updated regex patterns.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# GLM 4.7 uses a slightly different detail regex that includes
|
||||
# the <tool_call> wrapper and optional arg_key content
|
||||
self.FUNC_DETAIL_REGEX = re.compile(
|
||||
r"<tool_call>(.*?)(<arg_key>.*?)?</tool_call>", re.DOTALL
|
||||
)
|
||||
# GLM 4.7 handles newlines between arg_key and arg_value tags
|
||||
self.FUNC_ARG_REGEX = re.compile(
|
||||
r"<arg_key>(.*?)</arg_key>(?:\\n|\s)*<arg_value>(.*?)</arg_value>",
|
||||
re.DOTALL,
|
||||
)
|
||||
@@ -1,75 +0,0 @@
|
||||
"""
|
||||
Hermes tool call parser.
|
||||
|
||||
Format: <tool_call>{"name": "func", "arguments": {...}}</tool_call>
|
||||
Based on VLLM's Hermes2ProToolParser.extract_tool_calls()
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
|
||||
@register_parser("hermes")
|
||||
class HermesToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for Hermes-format tool calls.
|
||||
|
||||
Matches <tool_call>...</tool_call> tags containing JSON with "name" and "arguments".
|
||||
Also handles unclosed <tool_call> at end-of-string (truncated generation).
|
||||
"""
|
||||
|
||||
# Matches both closed and unclosed tool_call tags
|
||||
PATTERN = re.compile(
|
||||
r"<tool_call>\s*(.*?)\s*</tool_call>|<tool_call>\s*(.*)", re.DOTALL
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
if "<tool_call>" not in text:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
matches = self.PATTERN.findall(text)
|
||||
if not matches:
|
||||
return text, None
|
||||
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
for match in matches:
|
||||
# match is a tuple: (closed_content, unclosed_content)
|
||||
raw_json = match[0] if match[0] else match[1]
|
||||
if not raw_json.strip():
|
||||
continue
|
||||
|
||||
tc_data = json.loads(raw_json)
|
||||
if "name" not in tc_data:
|
||||
continue
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
||||
type="function",
|
||||
function=Function(
|
||||
name=tc_data["name"],
|
||||
arguments=json.dumps(
|
||||
tc_data.get("arguments", {}), ensure_ascii=False
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
# Content is everything before the first <tool_call> tag
|
||||
content = text[: text.find("<tool_call>")].strip()
|
||||
return content if content else None, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
@@ -1,93 +0,0 @@
|
||||
"""
|
||||
Kimi K2 tool call parser.
|
||||
|
||||
Format:
|
||||
<|tool_calls_section_begin|>
|
||||
<|tool_call_begin|>function_id:0<|tool_call_argument_begin|>{"arg": "val"}<|tool_call_end|>
|
||||
<|tool_calls_section_end|>
|
||||
|
||||
The function_id format is typically "functions.func_name:index" or "func_name:index".
|
||||
|
||||
Based on VLLM's KimiK2ToolParser.extract_tool_calls()
|
||||
"""
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
|
||||
@register_parser("kimi_k2")
|
||||
class KimiK2ToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for Kimi K2 tool calls.
|
||||
|
||||
Uses section begin/end tokens wrapping individual tool call begin/end tokens.
|
||||
The tool_call_id contains the function name (after last dot, before colon).
|
||||
"""
|
||||
|
||||
# Support both singular and plural variants
|
||||
START_TOKENS = [
|
||||
"<|tool_calls_section_begin|>",
|
||||
"<|tool_call_section_begin|>",
|
||||
]
|
||||
|
||||
# Regex captures: tool_call_id (e.g., "functions.get_weather:0"), function_arguments
|
||||
PATTERN = re.compile(
|
||||
r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[^<]+:\d+)\s*"
|
||||
r"<\|tool_call_argument_begin\|>\s*"
|
||||
r"(?P<function_arguments>(?:(?!<\|tool_call_begin\|>).)*?)\s*"
|
||||
r"<\|tool_call_end\|>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
# Check for any variant of the start token
|
||||
has_start = any(token in text for token in self.START_TOKENS)
|
||||
if not has_start:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
matches = self.PATTERN.findall(text)
|
||||
if not matches:
|
||||
return text, None
|
||||
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
for match in matches:
|
||||
function_id, function_args = match
|
||||
|
||||
# Extract function name from ID format: "functions.get_weather:0" -> "get_weather"
|
||||
function_name = function_id.split(":")[0].split(".")[-1]
|
||||
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=function_id, # Preserve the original ID format
|
||||
type="function",
|
||||
function=Function(
|
||||
name=function_name,
|
||||
arguments=function_args.strip(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
# Content is everything before the tool calls section
|
||||
earliest_start = len(text)
|
||||
for token in self.START_TOKENS:
|
||||
idx = text.find(token)
|
||||
if idx >= 0 and idx < earliest_start:
|
||||
earliest_start = idx
|
||||
|
||||
content = text[:earliest_start].strip()
|
||||
return content if content else None, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
@@ -1,96 +0,0 @@
|
||||
"""
|
||||
Llama 3.x / 4 tool call parser.
|
||||
|
||||
Format: The model outputs JSON objects with "name" and "arguments" (or "parameters") keys.
|
||||
May be preceded by <|python_tag|> token. Supports multiple JSON objects separated
|
||||
by content or semicolons.
|
||||
|
||||
Based on VLLM's Llama3JsonToolParser.extract_tool_calls()
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
|
||||
@register_parser("llama3_json")
|
||||
@register_parser("llama4_json")
|
||||
class LlamaToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for Llama 3.x and 4 JSON-format tool calls.
|
||||
|
||||
Finds JSON objects containing "name" + ("arguments" or "parameters") keys.
|
||||
Uses Python's json.JSONDecoder.raw_decode for robust extraction of
|
||||
JSON objects from mixed text.
|
||||
"""
|
||||
|
||||
BOT_TOKEN = "<|python_tag|>"
|
||||
|
||||
# Regex to find the start of potential JSON objects
|
||||
JSON_START = re.compile(r"\{")
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
# Quick check: need either the bot token or a JSON brace
|
||||
if self.BOT_TOKEN not in text and "{" not in text:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
decoder = json.JSONDecoder()
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
end_index = -1 # Track where the last parsed JSON ended
|
||||
|
||||
for match in self.JSON_START.finditer(text):
|
||||
start = match.start()
|
||||
# Skip if this brace is inside a previously parsed JSON object
|
||||
if start <= end_index:
|
||||
continue
|
||||
|
||||
try:
|
||||
obj, json_end = decoder.raw_decode(text[start:])
|
||||
end_index = start + json_end
|
||||
|
||||
# Must have "name" and either "arguments" or "parameters"
|
||||
name = obj.get("name")
|
||||
args = obj.get("arguments", obj.get("parameters"))
|
||||
|
||||
if not name or args is None:
|
||||
continue
|
||||
|
||||
# Normalize arguments to JSON string
|
||||
if isinstance(args, dict):
|
||||
args = json.dumps(args, ensure_ascii=False)
|
||||
elif not isinstance(args, str):
|
||||
args = json.dumps(args, ensure_ascii=False)
|
||||
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
||||
type="function",
|
||||
function=Function(name=name, arguments=args),
|
||||
)
|
||||
)
|
||||
except (json.JSONDecodeError, KeyError, ValueError):
|
||||
continue
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
# Content is everything before the first tool call JSON
|
||||
# Find where the first tool call starts in the text
|
||||
first_tc_start = text.find("{")
|
||||
if self.BOT_TOKEN in text:
|
||||
first_tc_start = text.find(self.BOT_TOKEN)
|
||||
content = text[:first_tc_start].strip() if first_tc_start > 0 else None
|
||||
|
||||
return content, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
@@ -1,69 +0,0 @@
|
||||
"""
|
||||
Longcat Flash Chat tool call parser.
|
||||
|
||||
Same as Hermes but uses <longcat_tool_call> tags instead of <tool_call>.
|
||||
Based on VLLM's LongcatFlashToolParser (extends Hermes2ProToolParser).
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
|
||||
@register_parser("longcat")
|
||||
class LongcatToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for Longcat Flash Chat tool calls.
|
||||
Identical logic to Hermes, just different tag names.
|
||||
"""
|
||||
|
||||
PATTERN = re.compile(
|
||||
r"<longcat_tool_call>\s*(.*?)\s*</longcat_tool_call>|<longcat_tool_call>\s*(.*)",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
if "<longcat_tool_call>" not in text:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
matches = self.PATTERN.findall(text)
|
||||
if not matches:
|
||||
return text, None
|
||||
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
for match in matches:
|
||||
raw_json = match[0] if match[0] else match[1]
|
||||
if not raw_json.strip():
|
||||
continue
|
||||
|
||||
tc_data = json.loads(raw_json)
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
||||
type="function",
|
||||
function=Function(
|
||||
name=tc_data["name"],
|
||||
arguments=json.dumps(
|
||||
tc_data.get("arguments", {}), ensure_ascii=False
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
content = text[: text.find("<longcat_tool_call>")].strip()
|
||||
return content if content else None, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
@@ -1,137 +0,0 @@
|
||||
"""
|
||||
Mistral tool call parser.
|
||||
|
||||
Supports two formats depending on tokenizer version:
|
||||
- Pre-v11: content[TOOL_CALLS] [{"name": ..., "arguments": {...}}, ...]
|
||||
- v11+: content[TOOL_CALLS]tool_name1{"arg": "val"}[TOOL_CALLS]tool_name2{"arg": "val"}
|
||||
|
||||
Based on VLLM's MistralToolParser.extract_tool_calls()
|
||||
The [TOOL_CALLS] token is the bot_token used by Mistral models.
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
|
||||
def _generate_mistral_id() -> str:
|
||||
"""Mistral tool call IDs are 9-char alphanumeric strings."""
|
||||
import random
|
||||
import string
|
||||
|
||||
return "".join(random.choices(string.ascii_letters + string.digits, k=9))
|
||||
|
||||
|
||||
@register_parser("mistral")
|
||||
class MistralToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for Mistral-format tool calls.
|
||||
|
||||
Detects format by checking if the content after [TOOL_CALLS] starts with '['
|
||||
(pre-v11 JSON array) or with a tool name (v11+ format).
|
||||
"""
|
||||
|
||||
# The [TOOL_CALLS] token -- may appear as different strings depending on tokenizer
|
||||
BOT_TOKEN = "[TOOL_CALLS]"
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
if self.BOT_TOKEN not in text:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
parts = text.split(self.BOT_TOKEN)
|
||||
content = parts[0].strip()
|
||||
raw_tool_calls = parts[1:]
|
||||
|
||||
# Detect format: if the first raw part starts with '[', it's pre-v11
|
||||
first_raw = raw_tool_calls[0].strip() if raw_tool_calls else ""
|
||||
is_pre_v11 = first_raw.startswith("[") or first_raw.startswith("{")
|
||||
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
|
||||
if not is_pre_v11:
|
||||
# v11+ format: [TOOL_CALLS]tool_name{args}[TOOL_CALLS]tool_name2{args2}
|
||||
for raw in raw_tool_calls:
|
||||
raw = raw.strip()
|
||||
if not raw or "{" not in raw:
|
||||
continue
|
||||
|
||||
brace_idx = raw.find("{")
|
||||
tool_name = raw[:brace_idx].strip()
|
||||
args_str = raw[brace_idx:]
|
||||
|
||||
# Validate and clean the JSON arguments
|
||||
try:
|
||||
parsed_args = json.loads(args_str)
|
||||
args_str = json.dumps(parsed_args, ensure_ascii=False)
|
||||
except json.JSONDecodeError:
|
||||
pass # Keep raw if parsing fails
|
||||
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=_generate_mistral_id(),
|
||||
type="function",
|
||||
function=Function(name=tool_name, arguments=args_str),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Pre-v11 format: [TOOL_CALLS] [{"name": ..., "arguments": {...}}]
|
||||
try:
|
||||
parsed = json.loads(first_raw)
|
||||
if isinstance(parsed, dict):
|
||||
parsed = [parsed]
|
||||
|
||||
for tc in parsed:
|
||||
if "name" not in tc:
|
||||
continue
|
||||
args = tc.get("arguments", {})
|
||||
if isinstance(args, dict):
|
||||
args = json.dumps(args, ensure_ascii=False)
|
||||
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=_generate_mistral_id(),
|
||||
type="function",
|
||||
function=Function(
|
||||
name=tc["name"], arguments=args
|
||||
),
|
||||
)
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# Fallback: extract JSON objects using raw_decode
|
||||
decoder = json.JSONDecoder()
|
||||
idx = 0
|
||||
while idx < len(first_raw):
|
||||
try:
|
||||
obj, end_idx = decoder.raw_decode(first_raw, idx)
|
||||
if isinstance(obj, dict) and "name" in obj:
|
||||
args = obj.get("arguments", {})
|
||||
if isinstance(args, dict):
|
||||
args = json.dumps(args, ensure_ascii=False)
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=_generate_mistral_id(),
|
||||
type="function",
|
||||
function=Function(
|
||||
name=obj["name"], arguments=args
|
||||
),
|
||||
)
|
||||
)
|
||||
idx = end_idx
|
||||
except json.JSONDecodeError:
|
||||
idx += 1
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
return content if content else None, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
@@ -1,163 +0,0 @@
|
||||
"""
|
||||
Qwen3-Coder tool call parser.
|
||||
|
||||
Format uses XML-style nested tags:
|
||||
<tool_call>
|
||||
<function=function_name>
|
||||
<parameter=param_name>value</parameter>
|
||||
<parameter=param_name2>value2</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
|
||||
Parameters are extracted from <parameter=name>value</parameter> tags and
|
||||
type-converted using the schema if available, otherwise treated as strings.
|
||||
|
||||
Based on VLLM's Qwen3CoderToolParser.extract_tool_calls()
|
||||
"""
|
||||
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
|
||||
def _try_convert_value(value: str) -> Any:
|
||||
"""
|
||||
Try to convert a parameter value string to a native Python type.
|
||||
Handles null, numbers, booleans, JSON objects/arrays, and falls back to string.
|
||||
"""
|
||||
stripped = value.strip()
|
||||
|
||||
# Handle null
|
||||
if stripped.lower() == "null":
|
||||
return None
|
||||
|
||||
# Try JSON first (handles objects, arrays, strings, numbers, booleans)
|
||||
try:
|
||||
return json.loads(stripped)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# Try Python literal eval (handles tuples, etc.)
|
||||
try:
|
||||
return ast.literal_eval(stripped)
|
||||
except (ValueError, SyntaxError, TypeError):
|
||||
pass
|
||||
|
||||
# Return as string
|
||||
return stripped
|
||||
|
||||
|
||||
@register_parser("qwen3_coder")
|
||||
class Qwen3CoderToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for Qwen3-Coder XML-format tool calls.
|
||||
|
||||
Uses nested XML tags: <tool_call><function=name><parameter=key>val</parameter></function></tool_call>
|
||||
"""
|
||||
|
||||
START_TOKEN = "<tool_call>"
|
||||
FUNCTION_PREFIX = "<function="
|
||||
|
||||
# Find complete tool_call blocks (or unclosed at end)
|
||||
TOOL_CALL_REGEX = re.compile(
|
||||
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*?)$", re.DOTALL
|
||||
)
|
||||
|
||||
# Find function blocks within a tool_call
|
||||
FUNCTION_REGEX = re.compile(
|
||||
r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL
|
||||
)
|
||||
|
||||
# Find parameter blocks within a function
|
||||
PARAMETER_REGEX = re.compile(
|
||||
r"<parameter=(.*?)(?:</parameter>|(?=<parameter=)|(?=</function>)|$)",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
def _parse_function_call(self, function_str: str) -> Optional[ChatCompletionMessageToolCall]:
|
||||
"""Parse a single <function=name>...</function> block into a ToolCall."""
|
||||
try:
|
||||
# Extract function name: everything before the first '>'
|
||||
gt_idx = function_str.index(">")
|
||||
func_name = function_str[:gt_idx].strip()
|
||||
params_str = function_str[gt_idx + 1:]
|
||||
|
||||
# Extract parameters
|
||||
param_dict: Dict[str, Any] = {}
|
||||
for match_text in self.PARAMETER_REGEX.findall(params_str):
|
||||
if ">" not in match_text:
|
||||
continue
|
||||
eq_idx = match_text.index(">")
|
||||
param_name = match_text[:eq_idx].strip()
|
||||
param_value = match_text[eq_idx + 1:]
|
||||
|
||||
# Clean up whitespace
|
||||
if param_value.startswith("\n"):
|
||||
param_value = param_value[1:]
|
||||
if param_value.endswith("\n"):
|
||||
param_value = param_value[:-1]
|
||||
|
||||
param_dict[param_name] = _try_convert_value(param_value)
|
||||
|
||||
return ChatCompletionMessageToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:24]}",
|
||||
type="function",
|
||||
function=Function(
|
||||
name=func_name,
|
||||
arguments=json.dumps(param_dict, ensure_ascii=False),
|
||||
),
|
||||
)
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
if self.FUNCTION_PREFIX not in text:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
# Find all tool_call blocks
|
||||
tc_matches = self.TOOL_CALL_REGEX.findall(text)
|
||||
raw_blocks = [m[0] if m[0] else m[1] for m in tc_matches]
|
||||
|
||||
# Fallback: if no tool_call tags, try the whole text
|
||||
if not raw_blocks:
|
||||
raw_blocks = [text]
|
||||
|
||||
# Find function blocks within each tool_call
|
||||
function_strs: List[str] = []
|
||||
for block in raw_blocks:
|
||||
func_matches = self.FUNCTION_REGEX.findall(block)
|
||||
function_strs.extend(m[0] if m[0] else m[1] for m in func_matches)
|
||||
|
||||
if not function_strs:
|
||||
return text, None
|
||||
|
||||
# Parse each function call
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
for func_str in function_strs:
|
||||
tc = self._parse_function_call(func_str)
|
||||
if tc is not None:
|
||||
tool_calls.append(tc)
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
# Content before tool calls
|
||||
first_tc = text.find(self.START_TOKEN)
|
||||
if first_tc < 0:
|
||||
first_tc = text.find(self.FUNCTION_PREFIX)
|
||||
content = text[:first_tc].strip() if first_tc > 0 else None
|
||||
|
||||
return content, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
@@ -1,19 +0,0 @@
|
||||
"""
|
||||
Qwen 2.5 tool call parser.
|
||||
|
||||
Uses the same <tool_call> format as Hermes.
|
||||
Registered as a separate parser name for clarity when using --tool-parser=qwen.
|
||||
"""
|
||||
|
||||
from environments.tool_call_parsers import register_parser
|
||||
from environments.tool_call_parsers.hermes_parser import HermesToolCallParser
|
||||
|
||||
|
||||
@register_parser("qwen")
|
||||
class QwenToolCallParser(HermesToolCallParser):
|
||||
"""
|
||||
Parser for Qwen 2.5 tool calls.
|
||||
Same <tool_call>{"name": ..., "arguments": ...}</tool_call> format as Hermes.
|
||||
"""
|
||||
|
||||
pass # Identical format -- inherits everything from Hermes
|
||||
@@ -1,473 +0,0 @@
|
||||
"""
|
||||
ToolContext -- Unrestricted Tool Access for Reward Functions
|
||||
|
||||
A per-rollout handle that gives reward/verification functions direct access to
|
||||
ALL hermes-agent tools, scoped to the rollout's task_id. The same task_id means
|
||||
the terminal/browser session is the SAME one the model used during its rollout --
|
||||
all state (files, processes, browser tabs) is preserved.
|
||||
|
||||
The verifier author decides which tools to use. Nothing is hardcoded or gated.
|
||||
|
||||
Example usage in a compute_reward():
|
||||
async def compute_reward(self, item, result, ctx):
|
||||
# Run tests in the model's terminal sandbox
|
||||
test = ctx.terminal("pytest -v")
|
||||
if test["exit_code"] == 0:
|
||||
return 1.0
|
||||
|
||||
# Check if a file was created
|
||||
content = ctx.read_file("/workspace/solution.py")
|
||||
if content.get("content"):
|
||||
return 0.5
|
||||
|
||||
return 0.0
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
|
||||
from model_tools import handle_function_call
|
||||
from tools.terminal_tool import cleanup_vm
|
||||
from tools.browser_tool import cleanup_browser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Thread pool for running sync tool calls that internally use asyncio.run()
|
||||
_tool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
def _run_tool_in_thread(tool_name: str, arguments: Dict[str, Any], task_id: str) -> str:
|
||||
"""
|
||||
Run a tool call in a thread pool executor so backends that use asyncio.run()
|
||||
internally (modal, docker, daytona) get a clean event loop.
|
||||
|
||||
If we're already in an async context, executes handle_function_call() in a
|
||||
disposable worker thread and blocks for the result.
|
||||
If not (e.g., called from sync code), runs directly.
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
# We're in an async context -- need to run in thread
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
future = pool.submit(
|
||||
handle_function_call, tool_name, arguments, task_id
|
||||
)
|
||||
return future.result(timeout=300)
|
||||
except RuntimeError:
|
||||
# No running event loop -- safe to call directly
|
||||
return handle_function_call(tool_name, arguments, task_id)
|
||||
|
||||
|
||||
class ToolContext:
|
||||
"""
|
||||
Open-ended access to all hermes-agent tools for a specific rollout.
|
||||
|
||||
Passed to compute_reward() so verifiers can use any tool they need:
|
||||
terminal commands, file reads/writes, web searches, browser automation, etc.
|
||||
All calls share the rollout's task_id for session isolation.
|
||||
"""
|
||||
|
||||
def __init__(self, task_id: str):
|
||||
self.task_id = task_id
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Terminal tools
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def terminal(self, command: str, timeout: int = 180) -> Dict[str, Any]:
|
||||
"""
|
||||
Run a command in the rollout's terminal session.
|
||||
|
||||
Args:
|
||||
command: Shell command to execute
|
||||
timeout: Command timeout in seconds
|
||||
|
||||
Returns:
|
||||
Dict with 'exit_code' (int) and 'output' (str)
|
||||
"""
|
||||
import os
|
||||
backend = os.getenv("TERMINAL_ENV", "local")
|
||||
logger.debug("ToolContext.terminal [%s backend] task=%s: %s", backend, self.task_id[:8], command[:100])
|
||||
|
||||
# Run via thread helper so modal/docker/daytona backends' asyncio.run() doesn't deadlock
|
||||
result = _run_tool_in_thread(
|
||||
"terminal",
|
||||
{"command": command, "timeout": timeout},
|
||||
self.task_id,
|
||||
)
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
return {"exit_code": -1, "output": result}
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# File tools
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def read_file(self, path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Read a file from the rollout's filesystem.
|
||||
|
||||
Args:
|
||||
path: File path to read
|
||||
|
||||
Returns:
|
||||
Dict with file content or error
|
||||
"""
|
||||
result = handle_function_call(
|
||||
"read_file", {"path": path}, task_id=self.task_id
|
||||
)
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
return {"error": result}
|
||||
|
||||
def write_file(self, path: str, content: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Write a TEXT file in the rollout's filesystem.
|
||||
|
||||
Uses a shell heredoc under the hood, so this is only safe for text content.
|
||||
For binary files (images, compiled artifacts, etc.), use upload_file() instead.
|
||||
|
||||
Args:
|
||||
path: File path to write
|
||||
content: Text content to write
|
||||
|
||||
Returns:
|
||||
Dict with success status or error
|
||||
"""
|
||||
result = handle_function_call(
|
||||
"write_file", {"path": path, "content": content}, task_id=self.task_id
|
||||
)
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
return {"error": result}
|
||||
|
||||
def upload_file(self, local_path: str, remote_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Upload a local file to the rollout's sandbox (binary-safe).
|
||||
|
||||
Unlike write_file() which passes content through a shell heredoc (text-only),
|
||||
this method base64-encodes the file and decodes it inside the sandbox.
|
||||
Safe for any file type: binaries, images, archives, etc.
|
||||
|
||||
For large files (>1MB), the content is split into chunks to avoid
|
||||
hitting shell command-length limits.
|
||||
|
||||
Args:
|
||||
local_path: Path to a local file on the host
|
||||
remote_path: Destination path inside the sandbox
|
||||
|
||||
Returns:
|
||||
Dict with 'exit_code' and 'output'
|
||||
"""
|
||||
import base64
|
||||
from pathlib import Path as _Path
|
||||
|
||||
local = _Path(local_path)
|
||||
if not local.exists():
|
||||
return {"exit_code": -1, "output": f"Local file not found: {local_path}"}
|
||||
|
||||
raw = local.read_bytes()
|
||||
b64 = base64.b64encode(raw).decode("ascii")
|
||||
|
||||
# Ensure parent directory exists in the sandbox
|
||||
parent = str(_Path(remote_path).parent)
|
||||
if parent not in {".", "/"}:
|
||||
self.terminal(f"mkdir -p {parent}", timeout=10)
|
||||
|
||||
# For small files, single command is fine
|
||||
chunk_size = 60_000 # ~60KB per chunk (well within shell limits)
|
||||
if len(b64) <= chunk_size:
|
||||
result = self.terminal(
|
||||
f"printf '%s' '{b64}' | base64 -d > {remote_path}",
|
||||
timeout=30,
|
||||
)
|
||||
else:
|
||||
# For larger files, write base64 in chunks then decode
|
||||
tmp_b64 = "/tmp/_hermes_upload.b64"
|
||||
self.terminal(f": > {tmp_b64}", timeout=5) # truncate
|
||||
for i in range(0, len(b64), chunk_size):
|
||||
chunk = b64[i : i + chunk_size]
|
||||
self.terminal(f"printf '%s' '{chunk}' >> {tmp_b64}", timeout=15)
|
||||
result = self.terminal(
|
||||
f"base64 -d {tmp_b64} > {remote_path} && rm -f {tmp_b64}",
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def upload_dir(self, local_dir: str, remote_dir: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Upload an entire local directory to the rollout's sandbox (binary-safe).
|
||||
|
||||
Recursively uploads all files, preserving directory structure.
|
||||
|
||||
Args:
|
||||
local_dir: Path to a local directory on the host
|
||||
remote_dir: Destination directory inside the sandbox
|
||||
|
||||
Returns:
|
||||
List of results, one per file uploaded
|
||||
"""
|
||||
from pathlib import Path as _Path
|
||||
|
||||
local = _Path(local_dir)
|
||||
if not local.exists() or not local.is_dir():
|
||||
return [{"exit_code": -1, "output": f"Local directory not found: {local_dir}"}]
|
||||
|
||||
results = []
|
||||
for file_path in sorted(local.rglob("*")):
|
||||
if file_path.is_file():
|
||||
relative = file_path.relative_to(local)
|
||||
target = f"{remote_dir}/{relative}"
|
||||
results.append(self.upload_file(str(file_path), target))
|
||||
return results
|
||||
|
||||
def download_file(self, remote_path: str, local_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Download a file from the rollout's sandbox to the host (binary-safe).
|
||||
|
||||
The inverse of upload_file(). Base64-encodes the file inside the sandbox,
|
||||
reads the encoded data through the terminal, and decodes it locally.
|
||||
Safe for any file type.
|
||||
|
||||
Args:
|
||||
remote_path: Path to the file inside the sandbox
|
||||
local_path: Destination path on the host
|
||||
|
||||
Returns:
|
||||
Dict with 'success' (bool) and 'bytes' (int) or 'error' (str)
|
||||
"""
|
||||
import base64
|
||||
from pathlib import Path as _Path
|
||||
|
||||
# Base64-encode the file inside the sandbox and capture output
|
||||
result = self.terminal(
|
||||
f"base64 {remote_path} 2>/dev/null",
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
if result.get("exit_code", -1) != 0:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Failed to read remote file: {result.get('output', '')}",
|
||||
}
|
||||
|
||||
b64_data = result.get("output", "").strip()
|
||||
if not b64_data:
|
||||
return {"success": False, "error": f"Remote file is empty or missing: {remote_path}"}
|
||||
|
||||
try:
|
||||
raw = base64.b64decode(b64_data)
|
||||
except Exception as e:
|
||||
return {"success": False, "error": f"Base64 decode failed: {e}"}
|
||||
|
||||
# Write to local host filesystem
|
||||
local = _Path(local_path)
|
||||
local.parent.mkdir(parents=True, exist_ok=True)
|
||||
local.write_bytes(raw)
|
||||
|
||||
return {"success": True, "bytes": len(raw)}
|
||||
|
||||
def download_dir(self, remote_dir: str, local_dir: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Download a directory from the rollout's sandbox to the host (binary-safe).
|
||||
|
||||
Lists all files in the remote directory, then downloads each one.
|
||||
Preserves directory structure.
|
||||
|
||||
Args:
|
||||
remote_dir: Path to the directory inside the sandbox
|
||||
local_dir: Destination directory on the host
|
||||
|
||||
Returns:
|
||||
List of results, one per file downloaded
|
||||
"""
|
||||
from pathlib import Path as _Path
|
||||
|
||||
# List files in the remote directory
|
||||
ls_result = self.terminal(
|
||||
f"find {remote_dir} -type f 2>/dev/null",
|
||||
timeout=15,
|
||||
)
|
||||
|
||||
if ls_result.get("exit_code", -1) != 0:
|
||||
return [{"success": False, "error": f"Failed to list remote dir: {remote_dir}"}]
|
||||
|
||||
file_list = ls_result.get("output", "").strip()
|
||||
if not file_list:
|
||||
return [{"success": False, "error": f"Remote directory is empty or missing: {remote_dir}"}]
|
||||
|
||||
results = []
|
||||
for remote_file in file_list.splitlines():
|
||||
remote_file = remote_file.strip()
|
||||
if not remote_file:
|
||||
continue
|
||||
# Compute the relative path to preserve directory structure
|
||||
if remote_file.startswith(remote_dir):
|
||||
relative = remote_file[len(remote_dir):].lstrip("/")
|
||||
else:
|
||||
relative = _Path(remote_file).name
|
||||
local_file = str(_Path(local_dir) / relative)
|
||||
results.append(self.download_file(remote_file, local_file))
|
||||
|
||||
return results
|
||||
|
||||
def search(self, query: str, path: str = ".") -> Dict[str, Any]:
|
||||
"""
|
||||
Search for text in the rollout's filesystem.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
path: Directory to search in
|
||||
|
||||
Returns:
|
||||
Dict with search results
|
||||
"""
|
||||
result = handle_function_call(
|
||||
"search_files", {"pattern": query, "path": path}, task_id=self.task_id
|
||||
)
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
return {"error": result}
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Web tools
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def web_search(self, query: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Search the web.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
Dict with search results
|
||||
"""
|
||||
result = handle_function_call("web_search", {"query": query})
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
return {"error": result}
|
||||
|
||||
def web_extract(self, urls: List[str]) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract content from URLs.
|
||||
|
||||
Args:
|
||||
urls: List of URLs to extract content from
|
||||
|
||||
Returns:
|
||||
Dict with extracted content
|
||||
"""
|
||||
result = handle_function_call("web_extract", {"urls": urls})
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
return {"error": result}
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Browser tools
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def browser_navigate(self, url: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Navigate the rollout's browser session to a URL.
|
||||
|
||||
Args:
|
||||
url: URL to navigate to
|
||||
|
||||
Returns:
|
||||
Dict with page snapshot or error
|
||||
"""
|
||||
result = handle_function_call(
|
||||
"browser_navigate", {"url": url}, task_id=self.task_id
|
||||
)
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
return {"error": result}
|
||||
|
||||
def browser_snapshot(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Take a snapshot of the current browser page.
|
||||
|
||||
Returns:
|
||||
Dict with page content/accessibility snapshot
|
||||
"""
|
||||
result = handle_function_call(
|
||||
"browser_snapshot", {}, task_id=self.task_id
|
||||
)
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
return {"error": result}
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Generic tool access
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Call any hermes-agent tool by name.
|
||||
|
||||
This is the generic escape hatch -- if a tool doesn't have a convenience
|
||||
wrapper above, you can call it directly here.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool (e.g., "vision_analyze", "skills_list")
|
||||
arguments: Dict of arguments for the tool
|
||||
|
||||
Returns:
|
||||
Raw JSON string result from the tool
|
||||
"""
|
||||
return _run_tool_in_thread(tool_name, arguments, self.task_id)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Cleanup
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def cleanup(self):
|
||||
"""
|
||||
Release all resources (terminal VMs, browser sessions, background processes)
|
||||
for this rollout.
|
||||
|
||||
Called automatically by the base environment via try/finally after
|
||||
compute_reward() completes. You generally don't need to call this yourself.
|
||||
"""
|
||||
# Kill any background processes from this rollout (safety net)
|
||||
try:
|
||||
from tools.process_registry import process_registry
|
||||
killed = process_registry.kill_all(task_id=self.task_id)
|
||||
if killed:
|
||||
logger.debug("Process cleanup for task %s: killed %d process(es)", self.task_id, killed)
|
||||
except Exception as e:
|
||||
logger.debug("Process cleanup for task %s: %s", self.task_id, e)
|
||||
|
||||
try:
|
||||
cleanup_vm(self.task_id)
|
||||
except Exception as e:
|
||||
logger.debug("VM cleanup for task %s: %s", self.task_id, e)
|
||||
|
||||
# Suppress browser_tool's noisy debug prints during cleanup.
|
||||
# The cleanup still runs (safe), it just doesn't spam the console.
|
||||
_prev_quiet = os.environ.get("HERMES_QUIET")
|
||||
os.environ["HERMES_QUIET"] = "1"
|
||||
try:
|
||||
cleanup_browser(self.task_id)
|
||||
except Exception as e:
|
||||
logger.debug("Browser cleanup for task %s: %s", self.task_id, e)
|
||||
finally:
|
||||
if _prev_quiet is None:
|
||||
os.environ.pop("HERMES_QUIET", None)
|
||||
else:
|
||||
os.environ["HERMES_QUIET"] = _prev_quiet
|
||||
@@ -1,719 +0,0 @@
|
||||
"""
|
||||
WebResearchEnv — RL Environment for Multi-Step Web Research
|
||||
============================================================
|
||||
|
||||
Trains models to do accurate, efficient, multi-source web research.
|
||||
|
||||
Reward signals:
|
||||
- Answer correctness (LLM judge, 0.0–1.0)
|
||||
- Source diversity (used ≥2 distinct domains)
|
||||
- Efficiency (penalizes excessive tool calls)
|
||||
- Tool usage (bonus for actually using web tools)
|
||||
|
||||
Dataset: FRAMES benchmark (Google, 2024) — multi-hop factual questions
|
||||
HuggingFace: google/frames-benchmark
|
||||
Fallback: built-in sample questions (no HF token needed)
|
||||
|
||||
Usage:
|
||||
# Phase 1 (OpenAI-compatible server)
|
||||
python environments/web_research_env.py serve \\
|
||||
--openai.base_url http://localhost:8000/v1 \\
|
||||
--openai.model_name YourModel \\
|
||||
--openai.server_type openai
|
||||
|
||||
# Process mode (offline data generation)
|
||||
python environments/web_research_env.py process \\
|
||||
--env.data_path_to_save_groups data/web_research.jsonl
|
||||
|
||||
# Standalone eval
|
||||
python environments/web_research_env.py evaluate \\
|
||||
--openai.base_url http://localhost:8000/v1 \\
|
||||
--openai.model_name YourModel
|
||||
|
||||
Built by: github.com/jackx707
|
||||
Inspired by: GroceryMind — production Hermes agent doing live web research
|
||||
across German grocery stores (firecrawl + hermes-agent)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
# Ensure hermes-agent root is on path
|
||||
_repo_root = Path(__file__).resolve().parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Optional HuggingFace datasets import
|
||||
# ---------------------------------------------------------------------------
|
||||
try:
|
||||
from datasets import load_dataset
|
||||
HF_AVAILABLE = True
|
||||
except ImportError:
|
||||
HF_AVAILABLE = False
|
||||
|
||||
from atroposlib.envs.base import ScoredDataGroup
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
from environments.agent_loop import AgentResult
|
||||
from environments.tool_context import ToolContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fallback sample dataset (used when HuggingFace is unavailable)
|
||||
# Multi-hop questions requiring real web search to answer.
|
||||
# ---------------------------------------------------------------------------
|
||||
SAMPLE_QUESTIONS = [
|
||||
{
|
||||
"question": "What is the current population of the capital city of the country that won the 2022 FIFA World Cup?",
|
||||
"answer": "Buenos Aires has approximately 3 million people in the city proper, or around 15 million in the greater metro area.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "Who is the CEO of the company that makes the most widely used open-source container orchestration platform?",
|
||||
"answer": "The Linux Foundation oversees Kubernetes. CNCF (Cloud Native Computing Foundation) is the specific body — it does not have a traditional CEO but has an executive director.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "What programming language was used to write the original version of the web framework used by Instagram?",
|
||||
"answer": "Django, which Instagram was built on, is written in Python.",
|
||||
"difficulty": "easy",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "In what year was the university founded where the inventor of the World Wide Web currently holds a professorship?",
|
||||
"answer": "Tim Berners-Lee holds a professorship at MIT (founded 1861) and the University of Southampton (founded 1952).",
|
||||
"difficulty": "hard",
|
||||
"hops": 3,
|
||||
},
|
||||
{
|
||||
"question": "What is the latest stable version of the programming language that ranks #1 on the TIOBE index as of this year?",
|
||||
"answer": "Python is currently #1 on TIOBE. The latest stable version should be verified via the official python.org site.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "How many employees does the parent company of Instagram have?",
|
||||
"answer": "Meta Platforms (parent of Instagram) employs approximately 70,000+ people as of recent reports.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "What is the current interest rate set by the central bank of the country where the Eiffel Tower is located?",
|
||||
"answer": "The European Central Bank sets rates for France/eurozone. The current rate should be verified — it has changed frequently in 2023-2025.",
|
||||
"difficulty": "hard",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "Which company acquired the startup founded by the creator of Oculus VR?",
|
||||
"answer": "Palmer Luckey founded Oculus VR, which was acquired by Facebook (now Meta). He later founded Anduril Industries.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "What is the market cap of the company that owns the most popular search engine in Russia?",
|
||||
"answer": "Yandex (now split into separate entities after 2024 restructuring). Current market cap should be verified via financial sources.",
|
||||
"difficulty": "hard",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "What was the GDP growth rate of the country that hosted the most recent Summer Olympics?",
|
||||
"answer": "Paris, France hosted the 2024 Summer Olympics. France's recent GDP growth should be verified via World Bank or IMF data.",
|
||||
"difficulty": "hard",
|
||||
"hops": 2,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class WebResearchEnvConfig(HermesAgentEnvConfig):
|
||||
"""Configuration for the web research RL environment."""
|
||||
|
||||
# Reward weights
|
||||
correctness_weight: float = Field(
|
||||
default=0.6,
|
||||
description="Weight for answer correctness in reward (LLM judge score).",
|
||||
)
|
||||
tool_usage_weight: float = Field(
|
||||
default=0.2,
|
||||
description="Weight for tool usage signal (did the model actually use web tools?).",
|
||||
)
|
||||
efficiency_weight: float = Field(
|
||||
default=0.2,
|
||||
description="Weight for efficiency signal (penalizes excessive tool calls).",
|
||||
)
|
||||
diversity_bonus: float = Field(
|
||||
default=0.1,
|
||||
description="Bonus reward for citing ≥2 distinct domains.",
|
||||
)
|
||||
|
||||
# Efficiency thresholds
|
||||
efficient_max_calls: int = Field(
|
||||
default=5,
|
||||
description="Maximum tool calls before efficiency penalty begins.",
|
||||
)
|
||||
heavy_penalty_calls: int = Field(
|
||||
default=10,
|
||||
description="Tool call count where efficiency penalty steepens.",
|
||||
)
|
||||
|
||||
# Eval
|
||||
eval_size: int = Field(
|
||||
default=20,
|
||||
description="Number of held-out items for evaluation.",
|
||||
)
|
||||
eval_split_ratio: float = Field(
|
||||
default=0.1,
|
||||
description="Fraction of dataset to hold out for evaluation (0.0–1.0).",
|
||||
)
|
||||
|
||||
# Dataset
|
||||
dataset_name: str = Field(
|
||||
default="google/frames-benchmark",
|
||||
description="HuggingFace dataset name for research questions.",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Environment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class WebResearchEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
RL environment for training multi-step web research skills.
|
||||
|
||||
The model is given a factual question requiring 2-3 hops of web research
|
||||
and must use web_search / web_extract tools to find and synthesize the answer.
|
||||
|
||||
Reward is multi-signal:
|
||||
60% — answer correctness (LLM judge)
|
||||
20% — tool usage (did the model actually search the web?)
|
||||
20% — efficiency (penalizes >5 tool calls)
|
||||
|
||||
Bonus +0.1 for source diversity (≥2 distinct domains cited).
|
||||
"""
|
||||
|
||||
name = "web-research"
|
||||
env_config_cls = WebResearchEnvConfig
|
||||
|
||||
# Default toolsets for this environment — web + file for saving notes
|
||||
default_toolsets = ["web", "file"]
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[WebResearchEnvConfig, List[APIServerConfig]]:
|
||||
"""Default configuration for the web research environment."""
|
||||
env_config = WebResearchEnvConfig(
|
||||
enabled_toolsets=["web", "file"],
|
||||
max_agent_turns=15,
|
||||
agent_temperature=1.0,
|
||||
system_prompt=(
|
||||
"You are a highly capable research agent. When asked a factual question, "
|
||||
"always use web_search to find current, accurate information before answering. "
|
||||
"Cite at least 2 sources. Be concise and accurate."
|
||||
),
|
||||
group_size=4,
|
||||
total_steps=1000,
|
||||
steps_per_eval=100,
|
||||
use_wandb=True,
|
||||
wandb_name="web-research",
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model_name="anthropic/claude-sonnet-4.5",
|
||||
server_type="openai",
|
||||
api_key=os.getenv("OPENROUTER_API_KEY", ""),
|
||||
health_check=False,
|
||||
)
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._items: list[dict] = []
|
||||
self._eval_items: list[dict] = []
|
||||
self._index: int = 0
|
||||
|
||||
# Metrics tracking for wandb
|
||||
self._reward_buffer: list[float] = []
|
||||
self._correctness_buffer: list[float] = []
|
||||
self._tool_usage_buffer: list[float] = []
|
||||
self._efficiency_buffer: list[float] = []
|
||||
self._diversity_buffer: list[float] = []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 1. Setup — load dataset
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def setup(self) -> None:
|
||||
"""Load the FRAMES benchmark or fall back to built-in samples."""
|
||||
if HF_AVAILABLE:
|
||||
try:
|
||||
logger.info("Loading FRAMES benchmark from HuggingFace...")
|
||||
ds = load_dataset(self.config.dataset_name, split="test")
|
||||
self._items = [
|
||||
{
|
||||
"question": row["Prompt"],
|
||||
"answer": row["Answer"],
|
||||
"difficulty": row.get("reasoning_types", "unknown"),
|
||||
"hops": 2,
|
||||
}
|
||||
for row in ds
|
||||
]
|
||||
# Hold out for eval
|
||||
eval_size = max(
|
||||
self.config.eval_size,
|
||||
int(len(self._items) * self.config.eval_split_ratio),
|
||||
)
|
||||
random.shuffle(self._items)
|
||||
self._eval_items = self._items[:eval_size]
|
||||
self._items = self._items[eval_size:]
|
||||
logger.info(
|
||||
f"Loaded {len(self._items)} train / {len(self._eval_items)} eval items "
|
||||
f"from FRAMES benchmark."
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load FRAMES from HuggingFace: {e}. Using built-in samples.")
|
||||
|
||||
# Fallback
|
||||
random.shuffle(SAMPLE_QUESTIONS)
|
||||
split = max(1, len(SAMPLE_QUESTIONS) * 8 // 10)
|
||||
self._items = SAMPLE_QUESTIONS[:split]
|
||||
self._eval_items = SAMPLE_QUESTIONS[split:]
|
||||
logger.info(
|
||||
f"Using built-in sample dataset: {len(self._items)} train / "
|
||||
f"{len(self._eval_items)} eval items."
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 2. get_next_item — return the next question
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def get_next_item(self) -> dict:
|
||||
"""Return the next item, cycling through the dataset."""
|
||||
if not self._items:
|
||||
raise RuntimeError("Dataset is empty. Did you call setup()?")
|
||||
item = self._items[self._index % len(self._items)]
|
||||
self._index += 1
|
||||
return item
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 3. format_prompt — build the user-facing prompt
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def format_prompt(self, item: dict) -> str:
|
||||
"""Format the research question as a task prompt."""
|
||||
return (
|
||||
f"Research the following question thoroughly using web search. "
|
||||
f"You MUST search the web to find current, accurate information — "
|
||||
f"do not rely solely on your training data.\n\n"
|
||||
f"Question: {item['question']}\n\n"
|
||||
f"Requirements:\n"
|
||||
f"- Use web_search and/or web_extract tools to find information\n"
|
||||
f"- Search at least 2 different sources\n"
|
||||
f"- Provide a concise, accurate answer (2-4 sentences)\n"
|
||||
f"- Cite the sources you used"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 4. compute_reward — multi-signal scoring
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def compute_reward(
|
||||
self,
|
||||
item: dict,
|
||||
result: AgentResult,
|
||||
ctx: ToolContext,
|
||||
) -> float:
|
||||
"""
|
||||
Multi-signal reward function:
|
||||
|
||||
correctness_weight * correctness — LLM judge comparing answer to ground truth
|
||||
tool_usage_weight * tool_used — binary: did the model use web tools?
|
||||
efficiency_weight * efficiency — penalizes wasteful tool usage
|
||||
+ diversity_bonus — source diversity (≥2 distinct domains)
|
||||
"""
|
||||
# Extract final response from messages (last assistant message with content)
|
||||
final_response = ""
|
||||
tools_used: list[str] = []
|
||||
for msg in reversed(result.messages):
|
||||
if msg.get("role") == "assistant" and msg.get("content") and not final_response:
|
||||
final_response = msg["content"]
|
||||
# Collect tool names from tool call messages
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
for tc in msg["tool_calls"]:
|
||||
fn = tc.get("function", {}) if isinstance(tc, dict) else {}
|
||||
name = fn.get("name", "")
|
||||
if name:
|
||||
tools_used.append(name)
|
||||
tool_call_count: int = result.turns_used or len(tools_used)
|
||||
|
||||
cfg = self.config
|
||||
|
||||
# ---- Signal 1: Answer correctness (LLM judge) ----------------
|
||||
correctness = await self._llm_judge(
|
||||
question=item["question"],
|
||||
expected=item["answer"],
|
||||
model_answer=final_response,
|
||||
)
|
||||
|
||||
# ---- Signal 2: Web tool usage --------------------------------
|
||||
web_tools = {"web_search", "web_extract", "search", "firecrawl"}
|
||||
tool_used = 1.0 if any(t in web_tools for t in tools_used) else 0.0
|
||||
|
||||
# ---- Signal 3: Efficiency ------------------------------------
|
||||
if tool_call_count <= cfg.efficient_max_calls:
|
||||
efficiency = 1.0
|
||||
elif tool_call_count <= cfg.heavy_penalty_calls:
|
||||
efficiency = 1.0 - (tool_call_count - cfg.efficient_max_calls) * 0.08
|
||||
else:
|
||||
efficiency = max(0.0, 1.0 - (tool_call_count - cfg.efficient_max_calls) * 0.12)
|
||||
|
||||
# ---- Bonus: Source diversity ---------------------------------
|
||||
domains = self._extract_domains(final_response)
|
||||
diversity = cfg.diversity_bonus if len(domains) >= 2 else 0.0
|
||||
|
||||
# ---- Combine ------------------------------------------------
|
||||
reward = (
|
||||
cfg.correctness_weight * correctness
|
||||
+ cfg.tool_usage_weight * tool_used
|
||||
+ cfg.efficiency_weight * efficiency
|
||||
+ diversity
|
||||
)
|
||||
reward = min(1.0, max(0.0, reward)) # clamp to [0, 1]
|
||||
|
||||
# Track for wandb
|
||||
self._reward_buffer.append(reward)
|
||||
self._correctness_buffer.append(correctness)
|
||||
self._tool_usage_buffer.append(tool_used)
|
||||
self._efficiency_buffer.append(efficiency)
|
||||
self._diversity_buffer.append(diversity)
|
||||
|
||||
logger.debug(
|
||||
f"Reward breakdown — correctness={correctness:.2f}, "
|
||||
f"tool_used={tool_used:.1f}, efficiency={efficiency:.2f}, "
|
||||
f"diversity={diversity:.1f} → total={reward:.3f}"
|
||||
)
|
||||
|
||||
return reward
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 5. evaluate — run on held-out eval split
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def evaluate(self, *args, **kwargs) -> None:
|
||||
"""Run evaluation on the held-out split using the full agent loop with tools.
|
||||
|
||||
Each eval item runs through the same agent loop as training —
|
||||
the model can use web_search, web_extract, etc. to research answers.
|
||||
This measures actual agentic research capability, not just knowledge.
|
||||
"""
|
||||
import time
|
||||
import uuid
|
||||
from environments.agent_loop import HermesAgentLoop
|
||||
from environments.tool_context import ToolContext
|
||||
|
||||
items = self._eval_items
|
||||
if not items:
|
||||
logger.warning("No eval items available.")
|
||||
return
|
||||
|
||||
eval_size = min(self.config.eval_size, len(items))
|
||||
eval_items = items[:eval_size]
|
||||
|
||||
logger.info(f"Running eval on {len(eval_items)} questions (with agent loop + tools)...")
|
||||
start_time = time.time()
|
||||
samples = []
|
||||
|
||||
# Resolve tools once for all eval items
|
||||
tools, valid_names = self._resolve_tools_for_group()
|
||||
|
||||
for i, item in enumerate(eval_items):
|
||||
task_id = str(uuid.uuid4())
|
||||
logger.info(f"Eval [{i+1}/{len(eval_items)}]: {item['question'][:80]}...")
|
||||
|
||||
try:
|
||||
# Build messages
|
||||
messages: List[Dict[str, Any]] = []
|
||||
if self.config.system_prompt:
|
||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
messages.append({"role": "user", "content": self.format_prompt(item)})
|
||||
|
||||
# Run the full agent loop with tools
|
||||
agent = HermesAgentLoop(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=task_id,
|
||||
temperature=0.0, # Deterministic for eval
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
budget_config=self.config.build_budget_config(),
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
# Extract final response and tool usage from messages
|
||||
final_response = ""
|
||||
tool_call_count = 0
|
||||
for msg in reversed(result.messages):
|
||||
if msg.get("role") == "assistant" and msg.get("content") and not final_response:
|
||||
final_response = msg["content"]
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
tool_call_count += len(msg["tool_calls"])
|
||||
|
||||
# Compute reward (includes LLM judge for correctness)
|
||||
# Temporarily save buffer lengths so we can extract the
|
||||
# correctness score without calling judge twice, and avoid
|
||||
# polluting training metric buffers with eval data.
|
||||
buf_len = len(self._correctness_buffer)
|
||||
ctx = ToolContext(task_id)
|
||||
try:
|
||||
reward = await self.compute_reward(item, result, ctx)
|
||||
finally:
|
||||
ctx.cleanup()
|
||||
|
||||
# Extract correctness from the buffer (compute_reward appended it)
|
||||
# then remove eval entries from training buffers
|
||||
correctness = (
|
||||
self._correctness_buffer[buf_len]
|
||||
if len(self._correctness_buffer) > buf_len
|
||||
else 0.0
|
||||
)
|
||||
# Roll back buffers to avoid polluting training metrics
|
||||
for buf in (
|
||||
self._reward_buffer, self._correctness_buffer,
|
||||
self._tool_usage_buffer, self._efficiency_buffer,
|
||||
self._diversity_buffer,
|
||||
):
|
||||
if len(buf) > buf_len:
|
||||
buf.pop()
|
||||
|
||||
samples.append({
|
||||
"prompt": item["question"],
|
||||
"response": final_response[:500],
|
||||
"expected": item["answer"],
|
||||
"correctness": correctness,
|
||||
"reward": reward,
|
||||
"tool_calls": tool_call_count,
|
||||
"turns": result.turns_used,
|
||||
})
|
||||
|
||||
logger.info(
|
||||
f" → correctness={correctness:.2f}, reward={reward:.3f}, "
|
||||
f"tools={tool_call_count}, turns={result.turns_used}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Eval error on item: {e}")
|
||||
samples.append({
|
||||
"prompt": item["question"],
|
||||
"response": f"ERROR: {e}",
|
||||
"expected": item["answer"],
|
||||
"correctness": 0.0,
|
||||
"reward": 0.0,
|
||||
"tool_calls": 0,
|
||||
"turns": 0,
|
||||
})
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
# Compute aggregate metrics
|
||||
correctness_scores = [s["correctness"] for s in samples]
|
||||
rewards = [s["reward"] for s in samples]
|
||||
tool_counts = [s["tool_calls"] for s in samples]
|
||||
n = len(samples)
|
||||
|
||||
eval_metrics = {
|
||||
"eval/mean_correctness": sum(correctness_scores) / n if n else 0.0,
|
||||
"eval/mean_reward": sum(rewards) / n if n else 0.0,
|
||||
"eval/mean_tool_calls": sum(tool_counts) / n if n else 0.0,
|
||||
"eval/tool_usage_rate": sum(1 for t in tool_counts if t > 0) / n if n else 0.0,
|
||||
"eval/n_items": n,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Eval complete — correctness={eval_metrics['eval/mean_correctness']:.3f}, "
|
||||
f"reward={eval_metrics['eval/mean_reward']:.3f}, "
|
||||
f"tool_usage={eval_metrics['eval/tool_usage_rate']:.0%}"
|
||||
)
|
||||
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
samples=samples,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 6. wandb_log — custom metrics
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None) -> None:
|
||||
"""Log reward breakdown metrics to wandb."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
if self._reward_buffer:
|
||||
n = len(self._reward_buffer)
|
||||
wandb_metrics["train/mean_reward"] = sum(self._reward_buffer) / n
|
||||
wandb_metrics["train/mean_correctness"] = sum(self._correctness_buffer) / n
|
||||
wandb_metrics["train/mean_tool_usage"] = sum(self._tool_usage_buffer) / n
|
||||
wandb_metrics["train/mean_efficiency"] = sum(self._efficiency_buffer) / n
|
||||
wandb_metrics["train/mean_diversity"] = sum(self._diversity_buffer) / n
|
||||
wandb_metrics["train/total_rollouts"] = n
|
||||
|
||||
# Accuracy buckets
|
||||
wandb_metrics["train/correct_rate"] = (
|
||||
sum(1 for c in self._correctness_buffer if c >= 0.7) / n
|
||||
)
|
||||
wandb_metrics["train/tool_usage_rate"] = (
|
||||
sum(1 for t in self._tool_usage_buffer if t > 0) / n
|
||||
)
|
||||
|
||||
# Clear buffers
|
||||
self._reward_buffer.clear()
|
||||
self._correctness_buffer.clear()
|
||||
self._tool_usage_buffer.clear()
|
||||
self._efficiency_buffer.clear()
|
||||
self._diversity_buffer.clear()
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _llm_judge(
|
||||
self,
|
||||
question: str,
|
||||
expected: str,
|
||||
model_answer: str,
|
||||
) -> float:
|
||||
"""
|
||||
Use the server's LLM to judge answer correctness.
|
||||
Falls back to keyword heuristic if LLM call fails.
|
||||
"""
|
||||
if not model_answer or not model_answer.strip():
|
||||
return 0.0
|
||||
|
||||
judge_prompt = (
|
||||
"You are an impartial judge evaluating the quality of an AI research answer.\n\n"
|
||||
f"Question: {question}\n\n"
|
||||
f"Reference answer: {expected}\n\n"
|
||||
f"Model answer: {model_answer}\n\n"
|
||||
"Score the model answer on a scale from 0.0 to 1.0 where:\n"
|
||||
" 1.0 = fully correct and complete\n"
|
||||
" 0.7 = mostly correct with minor gaps\n"
|
||||
" 0.4 = partially correct\n"
|
||||
" 0.1 = mentions relevant topic but wrong or very incomplete\n"
|
||||
" 0.0 = completely wrong or no answer\n\n"
|
||||
"Consider: factual accuracy, completeness, and relevance.\n"
|
||||
'Respond with ONLY a JSON object: {"score": <float>, "reason": "<one sentence>"}'
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.server.chat_completion(
|
||||
messages=[{"role": "user", "content": judge_prompt}],
|
||||
n=1,
|
||||
max_tokens=150,
|
||||
temperature=0.0,
|
||||
split="eval",
|
||||
)
|
||||
text = response.choices[0].message.content if response.choices else ""
|
||||
parsed = self._parse_judge_json(text)
|
||||
if parsed is not None:
|
||||
return float(parsed)
|
||||
except Exception as e:
|
||||
logger.debug(f"LLM judge failed: {e}. Using heuristic.")
|
||||
|
||||
return self._heuristic_score(expected, model_answer)
|
||||
|
||||
@staticmethod
|
||||
def _parse_judge_json(text: str) -> Optional[float]:
|
||||
"""Extract the score float from LLM judge JSON response."""
|
||||
try:
|
||||
clean = re.sub(r"```(?:json)?|```", "", text).strip()
|
||||
data = json.loads(clean)
|
||||
score = float(data.get("score", -1))
|
||||
if 0.0 <= score <= 1.0:
|
||||
return score
|
||||
except Exception:
|
||||
match = re.search(r'"score"\s*:\s*([0-9.]+)', text)
|
||||
if match:
|
||||
score = float(match.group(1))
|
||||
if 0.0 <= score <= 1.0:
|
||||
return score
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _heuristic_score(expected: str, model_answer: str) -> float:
|
||||
"""Lightweight keyword overlap score as fallback."""
|
||||
stopwords = {
|
||||
"the", "a", "an", "is", "are", "was", "were", "of", "in", "on",
|
||||
"at", "to", "for", "with", "and", "or", "but", "it", "its",
|
||||
"this", "that", "as", "by", "from", "be", "has", "have", "had",
|
||||
}
|
||||
|
||||
def tokenize(text: str) -> set:
|
||||
tokens = re.findall(r'\b\w+\b', text.lower())
|
||||
return {t for t in tokens if t not in stopwords and len(t) > 2}
|
||||
|
||||
expected_tokens = tokenize(expected)
|
||||
answer_tokens = tokenize(model_answer)
|
||||
|
||||
if not expected_tokens:
|
||||
return 0.5
|
||||
|
||||
overlap = len(expected_tokens & answer_tokens)
|
||||
union = len(expected_tokens | answer_tokens)
|
||||
|
||||
jaccard = overlap / union if union > 0 else 0.0
|
||||
recall = overlap / len(expected_tokens)
|
||||
return min(1.0, 0.4 * jaccard + 0.6 * recall)
|
||||
|
||||
@staticmethod
|
||||
def _extract_domains(text: str) -> set:
|
||||
"""Extract unique domains from URLs cited in the response."""
|
||||
urls = re.findall(r'https?://[^\s\)>\]"\']+', text)
|
||||
domains = set()
|
||||
for url in urls:
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
domain = parsed.netloc.lower().lstrip("www.")
|
||||
if domain:
|
||||
domains.add(domain)
|
||||
except Exception:
|
||||
pass
|
||||
return domains
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
WebResearchEnv.cli()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user