Compare commits
98 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ce0f4838b0 | |||
| 2ecad49113 | |||
| 8245173d61 | |||
| 327e577acf | |||
| b5996b6451 | |||
| ef10d2e7c9 | |||
| d5416284f1 | |||
| af1ea1f4ed | |||
| db84a78e61 | |||
| f199cd9f84 | |||
| 77276070f5 | |||
| 274217316e | |||
| 13c72fb486 | |||
| 6af9942327 | |||
| 8373956850 | |||
| 94bdc63ff5 | |||
| eacb398f75 | |||
| 5301cc212b | |||
| c4a21d7831 | |||
| 59c7cc64f0 | |||
| 55f3262e78 | |||
| 5360b54244 | |||
| 647cc0bb0d | |||
| 4f8aaf1046 | |||
| b6e07417c5 | |||
| 47614dbfca | |||
| 09d9724a09 | |||
| 85782a4ed7 | |||
| 9f57f2286d | |||
| 6682f91b80 | |||
| 05d9f641c0 | |||
| 9329e06696 | |||
| 04b1fdaecf | |||
| 681778a0b7 | |||
| 0161d4bb6c | |||
| 814c60092b | |||
| 23ac522d37 | |||
| e0e7397c32 | |||
| e0e4856d46 | |||
| 0086cdaf93 | |||
| fc2754dbdf | |||
| 3df26b925c | |||
| 80efe664ce | |||
| d57a4b3eb5 | |||
| 6bdad1f3b2 | |||
| f9ad7400e3 | |||
| 965ae7fa97 | |||
| cbd1f8e4be | |||
| f8745f59c2 | |||
| bcca5ed34d | |||
| c8c6ce1731 | |||
| 5af672c753 | |||
| d364132114 | |||
| 4c94396206 | |||
| e8b9f5ff9a | |||
| d3d5916089 | |||
| eabd8c1fd1 | |||
| 4695d2716f | |||
| 8ed2ef6f46 | |||
| 1702a94c88 | |||
| 55622b5525 | |||
| 74e47c081f | |||
| d6c488f2dc | |||
| 09d970160b | |||
| db82c453b9 | |||
| 38ea2a57a5 | |||
| 0854640537 | |||
| 19071529f6 | |||
| ed84637d11 | |||
| 4abfb6bc24 | |||
| e84fe483bc | |||
| ccb5aae0d2 | |||
| 34fc94d1f4 | |||
| 4813aaf0ba | |||
| 5ce0067c08 | |||
| 29575b3712 | |||
| 71558e753d | |||
| 4f7e64c845 | |||
| 2cbf0631a5 | |||
| 659af123c3 | |||
| f4c43f0886 | |||
| b54b246071 | |||
| 1a00d730eb | |||
| 76f40e6449 | |||
| 2bed2124a4 | |||
| 8709e1ebec | |||
| 54d817f882 | |||
| 74fdfe6b50 | |||
| 02a54e01ce | |||
| 8a31985e4f | |||
| 41c13ba71d | |||
| 36c5b188b5 | |||
| 1e29fa8865 | |||
| e74a682b0f | |||
| 2b606d20e2 | |||
| 3ac750ec07 | |||
| aa2d3e2ee1 | |||
| 7d628eaa3d |
+7
-18
@@ -281,6 +281,13 @@ BROWSER_SESSION_TIMEOUT=300
|
||||
# Browser sessions are automatically closed after this period of no activity
|
||||
BROWSER_INACTIVITY_TIMEOUT=120
|
||||
|
||||
# Extra Chromium launch flags passed to agent-browser, comma- or newline-separated.
|
||||
# Hermes auto-injects "--no-sandbox,--disable-dev-shm-usage" when it detects root
|
||||
# or AppArmor-restricted unprivileged user namespaces (Ubuntu 23.10+, DGX Spark,
|
||||
# many container images), so leave this unset unless you need extra flags.
|
||||
# Setting this disables the auto-injection.
|
||||
# AGENT_BROWSER_ARGS=--no-sandbox
|
||||
|
||||
# Camofox local anti-detection browser (Camoufox-based Firefox).
|
||||
# Set CAMOFOX_URL to route the browser tools through a local Camofox server
|
||||
# instead of agent-browser/Browserbase. See docs/user-guide/features/browser.md.
|
||||
@@ -387,24 +394,6 @@ IMAGE_TOOLS_DEBUG=false
|
||||
# CONTEXT_COMPRESSION_THRESHOLD=0.85 # Compress at 85% of context limit
|
||||
# Model is set via compression.summary_model in config.yaml (default: google/gemini-3-flash-preview)
|
||||
|
||||
# =============================================================================
|
||||
# RL TRAINING (Tinker + Atropos)
|
||||
# =============================================================================
|
||||
# Run reinforcement learning training on language models using the Tinker API.
|
||||
# Requires the rl-server to be running (from tinker-atropos package).
|
||||
|
||||
# Tinker API Key - RL training service
|
||||
# Get at: https://tinker-console.thinkingmachines.ai/keys
|
||||
# TINKER_API_KEY=
|
||||
|
||||
# Weights & Biases API Key - Experiment tracking and metrics
|
||||
# Get at: https://wandb.ai/authorize
|
||||
# WANDB_API_KEY=
|
||||
|
||||
# RL API Server URL (default: http://localhost:8080)
|
||||
# Change if running the rl-server on a different host/port
|
||||
# RL_API_URL=http://localhost:8080
|
||||
|
||||
# =============================================================================
|
||||
# SKILLS HUB (GitHub integration for skill search/install/publish)
|
||||
# =============================================================================
|
||||
|
||||
@@ -11,6 +11,7 @@ on:
|
||||
- '**/sitecustomize.py'
|
||||
- '**/usercustomize.py'
|
||||
- '**/__init__.pth'
|
||||
- 'pyproject.toml'
|
||||
|
||||
permissions:
|
||||
pull-requests: write
|
||||
@@ -137,3 +138,68 @@ jobs:
|
||||
run: |
|
||||
echo "::error::CRITICAL supply chain risk patterns detected in this PR. See the PR comment for details."
|
||||
exit 1
|
||||
|
||||
dep-bounds:
|
||||
name: Check PyPI dependency upper bounds
|
||||
runs-on: ubuntu-latest
|
||||
if: contains(github.event.pull_request.changed_files_url, 'pyproject.toml') || true
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Check for unbounded PyPI deps
|
||||
id: bounds
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
BASE="${{ github.event.pull_request.base.sha }}"
|
||||
HEAD="${{ github.event.pull_request.head.sha }}"
|
||||
|
||||
# Only check added lines in pyproject.toml
|
||||
ADDED=$(git diff "$BASE".."$HEAD" -- pyproject.toml | grep '^+' | grep -v '^+++' || true)
|
||||
|
||||
if [ -z "$ADDED" ]; then
|
||||
echo "found=false" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Match PyPI dep specs that have >= but no < ceiling.
|
||||
# Pattern: "package>=version" without a following ",<" bound.
|
||||
# Excludes git+ URLs (which use commit SHAs) and comments.
|
||||
UNBOUNDED=$(echo "$ADDED" | grep -oE '"[a-zA-Z0-9_-]+(\[[^\]]*\])?>=[ 0-9.]+"' | grep -v ',<' || true)
|
||||
|
||||
if [ -n "$UNBOUNDED" ]; then
|
||||
echo "found=true" >> "$GITHUB_OUTPUT"
|
||||
echo "$UNBOUNDED" > /tmp/unbounded.txt
|
||||
else
|
||||
echo "found=false" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: Post unbounded dep warning
|
||||
if: steps.bounds.outputs.found == 'true'
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
BODY="## ⚠️ Unbounded PyPI Dependency Detected
|
||||
|
||||
This PR adds PyPI dependencies without a \`<next_major\` upper bound. Per our [supply chain policy](../blob/main/CONTRIBUTING.md#dependency-pinning-policy-supply-chain-hardening), all PyPI deps must be pinned as \`>=floor,<next_major\`.
|
||||
|
||||
**Unbounded specs found:**
|
||||
\`\`\`
|
||||
$(cat /tmp/unbounded.txt)
|
||||
\`\`\`
|
||||
|
||||
**Fix:** Add an upper bound, e.g. \`\"package>=1.2.0,<2\"\`
|
||||
|
||||
---
|
||||
*See PR #2810 and CONTRIBUTING.md for the full policy rationale.*"
|
||||
|
||||
gh pr comment "${{ github.event.pull_request.number }}" --body "$BODY" || echo "::warning::Could not post PR comment (expected for fork PRs)"
|
||||
|
||||
- name: Fail on unbounded deps
|
||||
if: steps.bounds.outputs.found == 'true'
|
||||
run: |
|
||||
echo "::error::PyPI dependencies without upper bounds detected. Add <next_major ceiling per CONTRIBUTING.md policy."
|
||||
exit 1
|
||||
|
||||
@@ -0,0 +1,137 @@
|
||||
name: Publish to PyPI
|
||||
|
||||
# Triggered by CalVer tag pushes from scripts/release.py (e.g. v2026.5.15)
|
||||
# Can also be triggered manually from the Actions tab as an escape hatch.
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v20*' # CalVer tags: v2026.5.15, v2026.5.15.2, etc.
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
confirm_tag:
|
||||
description: 'Tag to publish (e.g. v2026.5.15). Must already exist.'
|
||||
required: true
|
||||
type: string
|
||||
|
||||
# Restrict default token to read-only; each job escalates as needed.
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
# Prevent overlapping publishes (e.g. two same-day tags pushed quickly).
|
||||
concurrency:
|
||||
group: pypi-publish
|
||||
cancel-in-progress: false
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Build distribution 📦
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
# On workflow_dispatch, check out the confirmed tag.
|
||||
ref: ${{ inputs.confirm_tag || github.ref }}
|
||||
fetch-tags: true
|
||||
|
||||
- name: Validate tag exists
|
||||
if: github.event_name == 'workflow_dispatch'
|
||||
run: |
|
||||
if ! git tag -l "${{ inputs.confirm_tag }}" | grep -q .; then
|
||||
echo "::error::Tag '${{ inputs.confirm_tag }}' does not exist in the repo"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
|
||||
with:
|
||||
python-version: '3.13'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@d0cc045d04ccac9d8b7881df0226f9e82c39688e # v6
|
||||
|
||||
- name: Build wheel and sdist
|
||||
run: uv build --sdist --wheel
|
||||
|
||||
- name: Upload distribution artifacts
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4
|
||||
with:
|
||||
name: python-package-distributions
|
||||
path: dist/
|
||||
|
||||
publish:
|
||||
name: Publish to PyPI
|
||||
needs: build
|
||||
runs-on: ubuntu-latest
|
||||
environment:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/hermes-agent
|
||||
permissions:
|
||||
id-token: write # OIDC trusted publishing
|
||||
|
||||
steps:
|
||||
- name: Download distribution artifacts
|
||||
uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4
|
||||
with:
|
||||
name: python-package-distributions
|
||||
path: dist/
|
||||
|
||||
- name: Publish to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@cef221092ed1bacb1cc03d23a2d87d1d172e277b # v1.14.0
|
||||
with:
|
||||
skip-existing: true
|
||||
|
||||
sign:
|
||||
name: Sign and attach to GitHub Release
|
||||
# Only runs on tag pushes — release.py creates the GitHub Release,
|
||||
# and workflow_dispatch won't have a matching release to attach to.
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
needs: publish
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write # attach assets to the existing release
|
||||
id-token: write # sigstore signing
|
||||
|
||||
steps:
|
||||
- name: Download distribution artifacts
|
||||
uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4
|
||||
with:
|
||||
name: python-package-distributions
|
||||
path: dist/
|
||||
|
||||
- name: Wait for GitHub Release to exist
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ github.token }}
|
||||
# release.py creates the GitHub Release after pushing the tag,
|
||||
# but this workflow starts from the tag push — wait for it.
|
||||
run: |
|
||||
for i in $(seq 1 30); do
|
||||
if gh release view "$GITHUB_REF_NAME" --repo "$GITHUB_REPOSITORY" >/dev/null 2>&1; then
|
||||
echo "Release $GITHUB_REF_NAME found"
|
||||
exit 0
|
||||
fi
|
||||
echo "Waiting for release... ($i/30)"
|
||||
sleep 10
|
||||
done
|
||||
echo "::warning::Release $GITHUB_REF_NAME not found after 5 minutes — skipping signature upload"
|
||||
echo "skip_sign=true" >> "$GITHUB_ENV"
|
||||
|
||||
- name: Sign with Sigstore
|
||||
if: env.skip_sign != 'true'
|
||||
uses: sigstore/gh-action-sigstore-python@f514d46b907ebcd5bedc05145c03b69c1edd8b46 # v3.0.0
|
||||
with:
|
||||
inputs: >-
|
||||
./dist/*.tar.gz
|
||||
./dist/*.whl
|
||||
|
||||
- name: Attach signed artifacts to GitHub Release
|
||||
if: env.skip_sign != 'true'
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ github.token }}
|
||||
# release.py already created the GitHub Release — just upload
|
||||
# the Sigstore signatures alongside the existing assets.
|
||||
run: >-
|
||||
gh release upload
|
||||
"$GITHUB_REF_NAME" dist/*.sigstore.json
|
||||
--repo "$GITHUB_REPOSITORY"
|
||||
--clobber
|
||||
@@ -1,3 +0,0 @@
|
||||
[submodule "tinker-atropos"]
|
||||
path = tinker-atropos
|
||||
url = https://github.com/nousresearch/tinker-atropos
|
||||
@@ -56,7 +56,6 @@ hermes-agent/
|
||||
├── tui_gateway/ # Python JSON-RPC backend for the TUI
|
||||
├── acp_adapter/ # ACP server (VS Code / Zed / JetBrains integration)
|
||||
├── cron/ # Scheduler — jobs.py, scheduler.py
|
||||
├── environments/ # RL training environments (Atropos)
|
||||
├── scripts/ # run_tests.sh, release.py, auxiliary scripts
|
||||
├── website/ # Docusaurus docs site
|
||||
└── tests/ # Pytest suite (~17k tests across ~900 files as of May 2026)
|
||||
@@ -309,6 +308,29 @@ The registry handles schema collection, dispatch, availability checking, and err
|
||||
|
||||
---
|
||||
|
||||
## Dependency Pinning Policy
|
||||
|
||||
All dependencies must have upper bounds to limit supply-chain attack surface.
|
||||
This policy was established after the litellm compromise (PR #2796, #2810) and
|
||||
reinforced after the Mini Shai-Hulud worm campaign (May 2026).
|
||||
|
||||
| Source type | Treatment | Example |
|
||||
|---|---|---|
|
||||
| PyPI package | `>=floor,<next_major` | `"httpx>=0.28.1,<1"` |
|
||||
| Git URL | Commit SHA | `git+https://...@<40-char-sha>` |
|
||||
| GitHub Actions | Commit SHA + comment | `uses: actions/checkout@<sha> # v4` |
|
||||
| CI-only pip | `==exact` | `pyyaml==6.0.2` |
|
||||
|
||||
**When adding a new dependency to `pyproject.toml`:**
|
||||
1. Pin to `>=current_version,<next_major` for post-1.0 (e.g. `>=1.5.0,<2`).
|
||||
2. For pre-1.0 packages, use `<0.(current_minor + 2)` (e.g. `>=0.29,<0.32`).
|
||||
3. Never commit a bare `>=X.Y.Z` without a ceiling — CI and reviewers will reject it.
|
||||
4. Run `uv lock` to regenerate `uv.lock` with hashes.
|
||||
|
||||
Reference: #2810 (bounds pass), #9801 (SHA pinning + audit CI).
|
||||
|
||||
---
|
||||
|
||||
## Adding Configuration
|
||||
|
||||
### config.yaml options:
|
||||
|
||||
+41
-4
@@ -91,9 +91,6 @@ export VIRTUAL_ENV="$(pwd)/venv"
|
||||
# Install with all extras (messaging, cron, CLI menus, dev tools)
|
||||
uv pip install -e ".[all,dev]"
|
||||
|
||||
# Optional: RL training submodule
|
||||
# git submodule update --init tinker-atropos && uv pip install -e "./tinker-atropos"
|
||||
|
||||
# Optional: browser tools
|
||||
npm install
|
||||
```
|
||||
@@ -196,7 +193,6 @@ hermes-agent/
|
||||
│
|
||||
├── skills/ # Bundled skills (copied to ~/.hermes/skills/ on install)
|
||||
├── optional-skills/ # Official optional skills (discoverable via hub, not activated by default)
|
||||
├── environments/ # RL training environments (Atropos integration)
|
||||
├── tests/ # Test suite
|
||||
├── website/ # Documentation site (hermes-agent.nousresearch.com)
|
||||
│
|
||||
@@ -804,6 +800,47 @@ Hermes has terminal access. Security matters.
|
||||
|
||||
If your PR affects security, note it explicitly in the description.
|
||||
|
||||
### Dependency pinning policy (supply chain hardening)
|
||||
|
||||
After the [litellm supply chain compromise](https://github.com/BerriAI/litellm/issues/24512) in March 2026 and the [Mini Shai-Hulud worm campaign](https://socket.dev/blog/tanstack-npm-packages-compromised-mini-shai-hulud-supply-chain-attack) in May 2026, all dependencies must follow these rules:
|
||||
|
||||
| Source type | Required treatment | Rationale |
|
||||
|---|---|---|
|
||||
| **PyPI package** | `>=floor,<next_major` | PyPI versions are immutable once published, but new versions can be pushed into your range. A `<next_major` ceiling stops a 1.x install from upgrading to a malicious 2.0.0. |
|
||||
| **Git URL** (atroposlib, tinker, yc-bench, Baileys) | Full commit SHA | Branches and tags are mutable refs; SHA is content-addressed. |
|
||||
| **GitHub Actions** | Full commit SHA + version comment | Action tags are mutable refs (e.g. tj-actions/changed-files March 2025). Pin as `uses: owner/action@<sha> # vX.Y.Z` |
|
||||
| **CI-only pip installs** | `==exact` | Hermetic CI builds; churn is acceptable. |
|
||||
|
||||
**Every new PyPI dependency in a PR must have a `<next_major` upper bound.** PRs adding unbounded `>=X.Y.Z` specs will be rejected by reviewers. The `supply-chain-audit.yml` CI workflow also flags dependency manifest changes for manual review.
|
||||
|
||||
**How to determine the ceiling:**
|
||||
- If the package is at version `1.x.y`, use `<2`.
|
||||
- If the package is at version `0.x.y` (pre-1.0), use `<0.(current_minor + 2)` — e.g. if current is `0.29.x`, use `<0.32`. This gives ~2 minor versions of headroom while keeping the window small enough that a hostile takeover version is unlikely to land inside it.
|
||||
- Exception: packages with very stable APIs (e.g. `aiohttp-socks`) can use `<1` at reviewer discretion.
|
||||
|
||||
**Examples:**
|
||||
```toml
|
||||
# ✅ Correct — post-1.0
|
||||
"openai>=2.21.0,<3"
|
||||
"pydantic>=2.12.5,<3"
|
||||
|
||||
# ✅ Correct — pre-1.0 (tight minor window)
|
||||
"asyncpg>=0.29,<0.32"
|
||||
"aiosqlite>=0.20,<0.23"
|
||||
"hindsight-client>=0.4.22,<0.5"
|
||||
|
||||
# ❌ Rejected — no upper bound
|
||||
"some-package>=1.2.3"
|
||||
|
||||
# ❌ Rejected — too tight (blocks legitimate patches)
|
||||
"some-package==1.2.3"
|
||||
|
||||
# ❌ Rejected — too loose for pre-1.0 (allows 80 minor versions)
|
||||
"some-package>=0.20,<1"
|
||||
```
|
||||
|
||||
**Reference PRs:** #2796 (litellm removal), #2810 (upper bounds pass), #9801 (SHA pinning + supply-chain-audit CI).
|
||||
|
||||
---
|
||||
|
||||
## Pull Request Process
|
||||
|
||||
@@ -23,7 +23,7 @@ Use any model you want — [Nous Portal](https://portal.nousresearch.com), [Open
|
||||
<tr><td><b>Scheduled automations</b></td><td>Built-in cron scheduler with delivery to any platform. Daily reports, nightly backups, weekly audits — all in natural language, running unattended.</td></tr>
|
||||
<tr><td><b>Delegates and parallelizes</b></td><td>Spawn isolated subagents for parallel workstreams. Write Python scripts that call tools via RPC, collapsing multi-step pipelines into zero-context-cost turns.</td></tr>
|
||||
<tr><td><b>Runs anywhere, not just your laptop</b></td><td>Seven terminal backends — local, Docker, SSH, Singularity, Modal, Daytona, and Vercel Sandbox. Daytona and Modal offer serverless persistence — your agent's environment hibernates when idle and wakes on demand, costing nearly nothing between sessions. Run it on a $5 VPS or a GPU cluster.</td></tr>
|
||||
<tr><td><b>Research-ready</b></td><td>Batch trajectory generation, Atropos RL environments, trajectory compression for training the next generation of tool-calling models.</td></tr>
|
||||
<tr><td><b>Research-ready</b></td><td>Batch trajectory generation, trajectory compression for training the next generation of tool-calling models.</td></tr>
|
||||
</table>
|
||||
|
||||
---
|
||||
@@ -175,8 +175,6 @@ uv pip install -e ".[all,dev]"
|
||||
scripts/run_tests.sh
|
||||
```
|
||||
|
||||
> **RL Training (optional):** The RL/Atropos integration (`environments/`) — see [`CONTRIBUTING.md`](https://github.com/NousResearch/hermes-agent/blob/main/CONTRIBUTING.md#development-setup) for the full setup.
|
||||
|
||||
---
|
||||
|
||||
## Community
|
||||
|
||||
+1
-7
@@ -23,7 +23,7 @@
|
||||
<tr><td><b>定时自动化</b></td><td>内置 cron 调度器,支持向任何平台投递。日报、夜间备份、周审计——全部用自然语言描述,无人值守运行。</td></tr>
|
||||
<tr><td><b>委派与并行</b></td><td>生成隔离子代理处理并行工作流。编写 Python 脚本通过 RPC 调用工具,将多步管道压缩为零上下文开销的轮次。</td></tr>
|
||||
<tr><td><b>随处运行</b></td><td>六种终端后端——本地、Docker、SSH、Daytona、Singularity 和 Modal。Daytona 和 Modal 提供 Serverless 持久化——代理环境空闲时休眠、按需唤醒,空闲期间几乎零成本。$5 VPS 或 GPU 集群都能跑。</td></tr>
|
||||
<tr><td><b>研究就绪</b></td><td>批量轨迹生成、Atropos RL 环境、轨迹压缩——用于训练下一代工具调用模型。</td></tr>
|
||||
<tr><td><b>研究就绪</b></td><td>批量轨迹生成、轨迹压缩——用于训练下一代工具调用模型。</td></tr>
|
||||
</table>
|
||||
|
||||
---
|
||||
@@ -161,12 +161,6 @@ uv pip install -e ".[all,dev]"
|
||||
python -m pytest tests/ -q
|
||||
```
|
||||
|
||||
> **RL 训练(可选):** 如需参与 RL/Tinker-Atropos 集成开发:
|
||||
> ```bash
|
||||
> git submodule update --init tinker-atropos
|
||||
> uv pip install -e "./tinker-atropos"
|
||||
> ```
|
||||
|
||||
---
|
||||
|
||||
## 社区
|
||||
|
||||
+46
-2
@@ -1,8 +1,11 @@
|
||||
"""ACP auth helpers — detect the currently configured Hermes provider."""
|
||||
"""ACP auth helpers — detect and advertise Hermes authentication methods."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
TERMINAL_SETUP_AUTH_METHOD_ID = "hermes-setup"
|
||||
|
||||
|
||||
def detect_provider() -> Optional[str]:
|
||||
@@ -22,3 +25,44 @@ def detect_provider() -> Optional[str]:
|
||||
def has_provider() -> bool:
|
||||
"""Return True if Hermes can resolve any runtime provider credentials."""
|
||||
return detect_provider() is not None
|
||||
|
||||
|
||||
def build_auth_methods() -> list[Any]:
|
||||
"""Return registry-compatible ACP auth methods for Hermes.
|
||||
|
||||
The official ACP registry validates that agents advertise at least one
|
||||
usable auth method during the initial handshake. A fresh Zed install may
|
||||
not have Hermes provider credentials configured yet, so Hermes always
|
||||
advertises a terminal setup method. When credentials are already present,
|
||||
it also advertises the resolved provider as the default agent-managed
|
||||
runtime credential method.
|
||||
"""
|
||||
from acp.schema import AuthMethodAgent, TerminalAuthMethod
|
||||
|
||||
methods: list[Any] = []
|
||||
provider = detect_provider()
|
||||
if provider:
|
||||
methods.append(
|
||||
AuthMethodAgent(
|
||||
id=provider,
|
||||
name=f"{provider} runtime credentials",
|
||||
description=(
|
||||
"Authenticate Hermes using the currently configured "
|
||||
f"{provider} runtime credentials."
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
methods.append(
|
||||
TerminalAuthMethod(
|
||||
id=TERMINAL_SETUP_AUTH_METHOD_ID,
|
||||
name="Configure Hermes provider",
|
||||
description=(
|
||||
"Open Hermes' interactive model/provider setup in a terminal. "
|
||||
"Use this when Hermes has not been configured on this machine yet."
|
||||
),
|
||||
type="terminal",
|
||||
args=["--setup"],
|
||||
)
|
||||
)
|
||||
return methods
|
||||
|
||||
@@ -0,0 +1,288 @@
|
||||
# bootstrap_browser_tools.ps1 — install agent-browser + Playwright Chromium
|
||||
# into ~/.hermes/node/ for use by Hermes Agent's browser tools on Windows.
|
||||
#
|
||||
# Targets the registry-install path: users who got Hermes via
|
||||
# `uvx --from 'hermes-agent[acp]==X' hermes-acp` don't have a repo clone,
|
||||
# so the install.ps1 `npm install`-in-repo flow doesn't apply. This script
|
||||
# is a self-contained, idempotent slice of install.ps1's browser block.
|
||||
#
|
||||
# Usage:
|
||||
# .\bootstrap_browser_tools.ps1 # use defaults
|
||||
# .\bootstrap_browser_tools.ps1 -Yes # accept Chromium download
|
||||
# .\bootstrap_browser_tools.ps1 -SkipChromium # Node + agent-browser only
|
||||
#
|
||||
# Idempotent: re-running this is safe and fast.
|
||||
|
||||
[CmdletBinding()]
|
||||
param(
|
||||
[switch]$Yes,
|
||||
[switch]$SkipChromium
|
||||
)
|
||||
|
||||
$ErrorActionPreference = "Stop"
|
||||
$NodeVersion = "22"
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# Logging
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
function Write-Info { param([string]$msg) Write-Host "[*] $msg" -ForegroundColor Cyan }
|
||||
function Write-Success { param([string]$msg) Write-Host "[+] $msg" -ForegroundColor Green }
|
||||
function Write-Warn { param([string]$msg) Write-Host "[!] $msg" -ForegroundColor Yellow }
|
||||
function Write-Err { param([string]$msg) Write-Host "[x] $msg" -ForegroundColor Red }
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# Paths
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
$HermesHome = $env:HERMES_HOME
|
||||
if (-not $HermesHome) {
|
||||
$HermesHome = Join-Path $env:USERPROFILE ".hermes"
|
||||
}
|
||||
$NodePrefix = Join-Path $HermesHome "node"
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# Step 1: Node.js
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
function Resolve-NpmExe {
|
||||
# Same gotcha as install.ps1: prefer npm.cmd over npm.ps1 so the
|
||||
# PowerShell execution policy doesn't block us.
|
||||
$cmd = Get-Command npm -ErrorAction SilentlyContinue
|
||||
if (-not $cmd) { return $null }
|
||||
$npmExe = $cmd.Source
|
||||
if ($npmExe -like "*.ps1") {
|
||||
$sibling = Join-Path (Split-Path $npmExe -Parent) "npm.cmd"
|
||||
if (Test-Path $sibling) { return $sibling }
|
||||
}
|
||||
return $npmExe
|
||||
}
|
||||
|
||||
function Resolve-NpxExe {
|
||||
$cmd = Get-Command npx -ErrorAction SilentlyContinue
|
||||
if (-not $cmd) { return $null }
|
||||
$npxExe = $cmd.Source
|
||||
if ($npxExe -like "*.ps1") {
|
||||
$sibling = Join-Path (Split-Path $npxExe -Parent) "npx.cmd"
|
||||
if (Test-Path $sibling) { return $sibling }
|
||||
}
|
||||
return $npxExe
|
||||
}
|
||||
|
||||
function Ensure-Node {
|
||||
# System Node on PATH?
|
||||
$sysNode = Get-Command node -ErrorAction SilentlyContinue
|
||||
if ($sysNode) {
|
||||
try {
|
||||
$v = & $sysNode.Source --version
|
||||
$major = [int]($v -replace '^v(\d+).*', '$1')
|
||||
if ($major -ge 20) {
|
||||
Write-Success "Node.js $v found on PATH"
|
||||
return
|
||||
}
|
||||
Write-Warn "Node.js $v is older than v20 — installing managed Node."
|
||||
} catch {
|
||||
Write-Warn "Failed to query Node version: $_"
|
||||
}
|
||||
}
|
||||
|
||||
# Hermes-managed Node?
|
||||
$managedNode = Join-Path $NodePrefix "node.exe"
|
||||
if (Test-Path $managedNode) {
|
||||
$v = & $managedNode --version
|
||||
Write-Success "Node.js $v found (Hermes-managed at $NodePrefix)"
|
||||
# Prepend to current-process PATH so subsequent npm/npx calls find it.
|
||||
$env:PATH = "$NodePrefix;$env:PATH"
|
||||
return
|
||||
}
|
||||
|
||||
Write-Info "Installing Node.js $NodeVersion LTS into $NodePrefix ..."
|
||||
|
||||
$arch = if ([Environment]::Is64BitOperatingSystem) { "x64" } else { "x86" }
|
||||
$indexUrl = "https://nodejs.org/dist/latest-v${NodeVersion}.x/"
|
||||
|
||||
try {
|
||||
$indexPage = Invoke-WebRequest -Uri $indexUrl -UseBasicParsing
|
||||
$matches = [regex]::Matches($indexPage.Content, "node-v${NodeVersion}\.\d+\.\d+-win-${arch}\.zip")
|
||||
if ($matches.Count -eq 0) {
|
||||
Write-Err "Could not locate Node.js $NodeVersion zip for win-$arch"
|
||||
throw "no tarball"
|
||||
}
|
||||
$zipName = $matches[0].Value
|
||||
$zipUrl = "$indexUrl$zipName"
|
||||
|
||||
$tmpDir = Join-Path $env:TEMP "hermes-node-$([guid]::NewGuid().ToString('N'))"
|
||||
New-Item -ItemType Directory -Force -Path $tmpDir | Out-Null
|
||||
$zipPath = Join-Path $tmpDir $zipName
|
||||
|
||||
Write-Info "Downloading $zipName ..."
|
||||
Invoke-WebRequest -Uri $zipUrl -OutFile $zipPath -UseBasicParsing
|
||||
|
||||
Expand-Archive -Path $zipPath -DestinationPath $tmpDir -Force
|
||||
$extracted = Get-ChildItem -Path $tmpDir -Directory | Where-Object { $_.Name -like "node-v*" } | Select-Object -First 1
|
||||
|
||||
if (-not $extracted) { Write-Err "Node.js extraction failed"; throw "extract" }
|
||||
|
||||
if (Test-Path $NodePrefix) { Remove-Item -Recurse -Force $NodePrefix }
|
||||
New-Item -ItemType Directory -Force -Path $HermesHome | Out-Null
|
||||
Move-Item -Path $extracted.FullName -Destination $NodePrefix
|
||||
|
||||
Remove-Item -Recurse -Force $tmpDir -ErrorAction SilentlyContinue
|
||||
|
||||
$env:PATH = "$NodePrefix;$env:PATH"
|
||||
$v = & "$NodePrefix\node.exe" --version
|
||||
Write-Success "Node.js $v installed to $NodePrefix"
|
||||
} catch {
|
||||
Write-Err "Node.js install failed: $_"
|
||||
Write-Info "Install Node 20+ manually from https://nodejs.org/en/download/ and re-run."
|
||||
throw
|
||||
}
|
||||
}
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# Step 2: agent-browser
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
function Ensure-AgentBrowser {
|
||||
$npmExe = Resolve-NpmExe
|
||||
if (-not $npmExe) {
|
||||
Write-Err "npm not on PATH after Node install — aborting"
|
||||
throw "npm missing"
|
||||
}
|
||||
|
||||
# Already installed?
|
||||
$existing = Get-Command agent-browser -ErrorAction SilentlyContinue
|
||||
if ($existing) {
|
||||
Write-Success "agent-browser already installed at $($existing.Source)"
|
||||
return
|
||||
}
|
||||
|
||||
# When the user has system Node (winget / installer-based), `npm install
|
||||
# -g` writes to a directory that may require admin rights. Force the
|
||||
# prefix to the user-writable Hermes-managed Node directory so we never
|
||||
# need elevation and the agent can always find the result. Mirrors the
|
||||
# bash bootstrap's `--prefix $NODE_PREFIX` strategy.
|
||||
New-Item -ItemType Directory -Force -Path $NodePrefix | Out-Null
|
||||
|
||||
Write-Info "Installing agent-browser (npm, prefix=$NodePrefix)..."
|
||||
& $npmExe install -g --prefix $NodePrefix --silent `
|
||||
"agent-browser@^0.26.0" "@askjo/camofox-browser@^1.5.2"
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
Write-Err "npm install -g agent-browser failed (exit $LASTEXITCODE)"
|
||||
throw "npm install"
|
||||
}
|
||||
|
||||
# Windows npm global installs drop shims at $NodePrefix\ root (not bin/).
|
||||
# Prepend to PATH so any subsequent npx call resolves them.
|
||||
$env:PATH = "$NodePrefix;$env:PATH"
|
||||
|
||||
Write-Success "agent-browser installed to $NodePrefix"
|
||||
}
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# Step 3: Playwright Chromium
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
function Find-SystemBrowser {
|
||||
$candidates = @(
|
||||
"C:\Program Files\Google\Chrome\Application\chrome.exe",
|
||||
"C:\Program Files (x86)\Google\Chrome\Application\chrome.exe",
|
||||
"C:\Program Files\Chromium\Application\chromium.exe",
|
||||
"${env:LOCALAPPDATA}\Google\Chrome\Application\chrome.exe",
|
||||
"${env:LOCALAPPDATA}\Chromium\Application\chromium.exe"
|
||||
)
|
||||
foreach ($p in $candidates) {
|
||||
if (Test-Path $p) { return $p }
|
||||
}
|
||||
# Edge — Chromium-based, agent-browser can use it
|
||||
foreach ($p in @(
|
||||
"C:\Program Files (x86)\Microsoft\Edge\Application\msedge.exe",
|
||||
"C:\Program Files\Microsoft\Edge\Application\msedge.exe"
|
||||
)) {
|
||||
if (Test-Path $p) { return $p }
|
||||
}
|
||||
return $null
|
||||
}
|
||||
|
||||
function Write-BrowserEnv {
|
||||
param([string]$BrowserPath)
|
||||
$envFile = Join-Path $HermesHome ".env"
|
||||
New-Item -ItemType Directory -Force -Path $HermesHome | Out-Null
|
||||
if (Test-Path $envFile) {
|
||||
$existing = Get-Content $envFile -Raw -ErrorAction SilentlyContinue
|
||||
if ($existing -and ($existing -match "(?m)^AGENT_BROWSER_EXECUTABLE_PATH=")) {
|
||||
return
|
||||
}
|
||||
}
|
||||
Add-Content -Path $envFile -Value ""
|
||||
Add-Content -Path $envFile -Value "# Hermes Agent browser tools — use the system Chrome/Chromium/Edge binary."
|
||||
Add-Content -Path $envFile -Value "AGENT_BROWSER_EXECUTABLE_PATH=$BrowserPath"
|
||||
Write-Success "Configured browser tools to use $BrowserPath"
|
||||
}
|
||||
|
||||
function Confirm-ChromiumDownload {
|
||||
if ($Yes) { return $true }
|
||||
if (-not [Environment]::UserInteractive) {
|
||||
Write-Warn "Non-interactive shell — skipping Chromium prompt."
|
||||
Write-Info "Re-run with -Yes to install Chromium (~400 MB download)."
|
||||
return $false
|
||||
}
|
||||
$reply = Read-Host "Install Playwright Chromium (~400 MB download)? [y/N]"
|
||||
return ($reply -match "^(y|yes)$")
|
||||
}
|
||||
|
||||
function Ensure-Chromium {
|
||||
if ($SkipChromium) {
|
||||
Write-Info "Skipping Chromium install (-SkipChromium)"
|
||||
return
|
||||
}
|
||||
|
||||
# agent-browser on Windows expects a Playwright-managed Chromium under
|
||||
# %LOCALAPPDATA%\ms-playwright. The system-browser shortcut from the
|
||||
# Linux/macOS path doesn't apply the same way on Windows — Playwright's
|
||||
# default launch path won't pick up a stock Chrome install without an
|
||||
# explicit AGENT_BROWSER_EXECUTABLE_PATH. We still offer it as a
|
||||
# fallback when the user doesn't want the download.
|
||||
|
||||
if (-not (Confirm-ChromiumDownload)) {
|
||||
$sys = Find-SystemBrowser
|
||||
if ($sys) {
|
||||
Write-Info "Using system browser at $sys (Chromium download skipped)."
|
||||
Write-BrowserEnv -BrowserPath $sys
|
||||
} else {
|
||||
Write-Info "Chromium install skipped. Browser tools won't launch until"
|
||||
Write-Info "Chromium is installed or AGENT_BROWSER_EXECUTABLE_PATH is set."
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
$npxExe = Resolve-NpxExe
|
||||
if (-not $npxExe) {
|
||||
Write-Err "npx not on PATH — cannot install Playwright Chromium"
|
||||
throw "npx missing"
|
||||
}
|
||||
|
||||
Write-Info "Installing Playwright Chromium (~400 MB) ..."
|
||||
& $npxExe --yes playwright install chromium
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
Write-Err "Playwright Chromium install failed (exit $LASTEXITCODE)"
|
||||
Write-Info "Try again later: npx --yes playwright install chromium"
|
||||
throw "playwright"
|
||||
}
|
||||
Write-Success "Playwright Chromium installed"
|
||||
}
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# Main
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
Write-Info "Hermes Agent: bootstrapping browser tools"
|
||||
Write-Info " HERMES_HOME = $HermesHome"
|
||||
Write-Info " OS = Windows"
|
||||
|
||||
Ensure-Node
|
||||
Ensure-AgentBrowser
|
||||
Ensure-Chromium
|
||||
|
||||
Write-Success "Browser tools setup complete."
|
||||
Write-Info "Hermes Agent will pick up agent-browser from $NodePrefix on next launch."
|
||||
+399
@@ -0,0 +1,399 @@
|
||||
#!/usr/bin/env bash
|
||||
#
|
||||
# bootstrap_browser_tools.sh — install agent-browser + Playwright Chromium
|
||||
# into ~/.hermes/node/ for use by Hermes Agent's browser tools.
|
||||
#
|
||||
# Targets the registry-install path: users who got Hermes via
|
||||
# `uvx --from 'hermes-agent[acp]==X' hermes-acp` don't have a repo clone,
|
||||
# so the install.sh `npm install`-in-repo flow doesn't apply. This script
|
||||
# is a self-contained, idempotent slice of install.sh's browser block —
|
||||
# safe to run from `hermes-acp --setup-browser`, from a fresh terminal,
|
||||
# or from install.sh itself (it's a no-op when everything is already in place).
|
||||
#
|
||||
# Usage:
|
||||
# bootstrap_browser_tools.sh # use defaults
|
||||
# bootstrap_browser_tools.sh --yes # accept the ~400MB Chromium download
|
||||
# bootstrap_browser_tools.sh --skip-chromium # only install Node + agent-browser
|
||||
# HERMES_HOME=/custom/path bootstrap_browser_tools.sh
|
||||
#
|
||||
# Idempotent: re-running this is safe and fast. Each step checks whether
|
||||
# the work is already done.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# Config
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
NODE_VERSION="22"
|
||||
HERMES_HOME="${HERMES_HOME:-$HOME/.hermes}"
|
||||
NODE_PREFIX="$HERMES_HOME/node"
|
||||
|
||||
SKIP_CHROMIUM=false
|
||||
ASSUME_YES=false
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# Logging
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
if [ -t 1 ]; then
|
||||
C_GREEN='\033[0;32m'
|
||||
C_YELLOW='\033[0;33m'
|
||||
C_BLUE='\033[0;34m'
|
||||
C_RED='\033[0;31m'
|
||||
C_RESET='\033[0m'
|
||||
else
|
||||
C_GREEN='' ; C_YELLOW='' ; C_BLUE='' ; C_RED='' ; C_RESET=''
|
||||
fi
|
||||
|
||||
log_info() { printf "${C_BLUE}[*]${C_RESET} %s\n" "$*"; }
|
||||
log_success() { printf "${C_GREEN}[✓]${C_RESET} %s\n" "$*"; }
|
||||
log_warn() { printf "${C_YELLOW}[!]${C_RESET} %s\n" "$*" >&2; }
|
||||
log_error() { printf "${C_RED}[✗]${C_RESET} %s\n" "$*" >&2; }
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# Arg parsing
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
while [ $# -gt 0 ]; do
|
||||
case "$1" in
|
||||
--skip-chromium) SKIP_CHROMIUM=true ;;
|
||||
--yes|-y) ASSUME_YES=true ;;
|
||||
-h|--help)
|
||||
cat <<EOF
|
||||
Bootstrap Hermes Agent browser tools.
|
||||
|
||||
Installs Node.js (into ~/.hermes/node/), the agent-browser npm package,
|
||||
and the Playwright Chromium browser engine.
|
||||
|
||||
Options:
|
||||
--skip-chromium Install Node + agent-browser but skip Chromium download
|
||||
--yes, -y Accept the ~400 MB Chromium download without prompting
|
||||
-h, --help Show this help
|
||||
|
||||
Environment:
|
||||
HERMES_HOME Override Hermes data dir (default: \$HOME/.hermes)
|
||||
EOF
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
log_error "Unknown option: $1"
|
||||
exit 2
|
||||
;;
|
||||
esac
|
||||
shift
|
||||
done
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# OS / arch detection
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
OS="unknown"
|
||||
case "$(uname -s)" in
|
||||
Linux*) OS="linux" ;;
|
||||
Darwin*) OS="macos" ;;
|
||||
*)
|
||||
log_error "Unsupported OS: $(uname -s)"
|
||||
log_info "Windows users: run scripts/bootstrap_browser_tools.ps1 in PowerShell."
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
NODE_ARCH=""
|
||||
case "$(uname -m)" in
|
||||
x86_64) NODE_ARCH="x64" ;;
|
||||
aarch64|arm64) NODE_ARCH="arm64" ;;
|
||||
armv7l) NODE_ARCH="armv7l" ;;
|
||||
*)
|
||||
log_error "Unsupported architecture: $(uname -m)"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
NODE_OS=""
|
||||
case "$OS" in
|
||||
linux) NODE_OS="linux" ;;
|
||||
macos) NODE_OS="darwin" ;;
|
||||
esac
|
||||
|
||||
DISTRO=""
|
||||
if [ -f /etc/os-release ]; then
|
||||
# shellcheck disable=SC1091
|
||||
. /etc/os-release
|
||||
DISTRO="${ID:-}"
|
||||
fi
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# Step 1: Node.js
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
ensure_node() {
|
||||
# Already on PATH and recent enough?
|
||||
if command -v node >/dev/null 2>&1; then
|
||||
local found_ver major
|
||||
found_ver=$(node --version 2>/dev/null)
|
||||
major=$(echo "$found_ver" | sed -E 's/^v([0-9]+).*/\1/')
|
||||
if [ -n "$major" ] && [ "$major" -ge 20 ]; then
|
||||
log_success "Node.js $found_ver found on PATH"
|
||||
return 0
|
||||
fi
|
||||
log_warn "Node.js $found_ver is older than v20 — installing managed Node."
|
||||
fi
|
||||
|
||||
if [ -x "$NODE_PREFIX/bin/node" ]; then
|
||||
local found_ver
|
||||
found_ver=$("$NODE_PREFIX/bin/node" --version 2>/dev/null || echo "?")
|
||||
export PATH="$NODE_PREFIX/bin:$PATH"
|
||||
log_success "Node.js $found_ver found (Hermes-managed at $NODE_PREFIX)"
|
||||
return 0
|
||||
fi
|
||||
|
||||
log_info "Installing Node.js $NODE_VERSION LTS into $NODE_PREFIX ..."
|
||||
|
||||
local index_url="https://nodejs.org/dist/latest-v${NODE_VERSION}.x/"
|
||||
local tarball_name
|
||||
tarball_name=$(curl -fsSL "$index_url" \
|
||||
| grep -oE "node-v${NODE_VERSION}\.[0-9]+\.[0-9]+-${NODE_OS}-${NODE_ARCH}\.tar\.xz" \
|
||||
| head -1)
|
||||
|
||||
if [ -z "$tarball_name" ]; then
|
||||
tarball_name=$(curl -fsSL "$index_url" \
|
||||
| grep -oE "node-v${NODE_VERSION}\.[0-9]+\.[0-9]+-${NODE_OS}-${NODE_ARCH}\.tar\.gz" \
|
||||
| head -1)
|
||||
fi
|
||||
|
||||
if [ -z "$tarball_name" ]; then
|
||||
log_error "Could not locate Node.js $NODE_VERSION tarball for $NODE_OS-$NODE_ARCH"
|
||||
log_info "Install Node 20+ manually: https://nodejs.org/en/download/"
|
||||
return 1
|
||||
fi
|
||||
|
||||
local tmp_dir
|
||||
tmp_dir=$(mktemp -d)
|
||||
trap 'rm -rf "$tmp_dir"' RETURN
|
||||
|
||||
log_info "Downloading $tarball_name ..."
|
||||
if ! curl -fsSL "${index_url}${tarball_name}" -o "$tmp_dir/$tarball_name"; then
|
||||
log_error "Node.js download failed"
|
||||
return 1
|
||||
fi
|
||||
|
||||
if [[ "$tarball_name" == *.tar.xz ]]; then
|
||||
tar xf "$tmp_dir/$tarball_name" -C "$tmp_dir"
|
||||
else
|
||||
tar xzf "$tmp_dir/$tarball_name" -C "$tmp_dir"
|
||||
fi
|
||||
|
||||
local extracted_dir
|
||||
extracted_dir=$(ls -d "$tmp_dir"/node-v* 2>/dev/null | head -1)
|
||||
if [ ! -d "$extracted_dir" ]; then
|
||||
log_error "Node.js extraction failed"
|
||||
return 1
|
||||
fi
|
||||
|
||||
mkdir -p "$HERMES_HOME"
|
||||
rm -rf "$NODE_PREFIX"
|
||||
mv "$extracted_dir" "$NODE_PREFIX"
|
||||
|
||||
export PATH="$NODE_PREFIX/bin:$PATH"
|
||||
|
||||
local installed_ver
|
||||
installed_ver=$("$NODE_PREFIX/bin/node" --version 2>/dev/null || echo "?")
|
||||
log_success "Node.js $installed_ver installed to $NODE_PREFIX"
|
||||
}
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# Step 2: agent-browser + @askjo/camofox-browser via global npm install
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
ensure_agent_browser() {
|
||||
if ! command -v npm >/dev/null 2>&1; then
|
||||
log_error "npm not on PATH after Node install — aborting"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# _find_agent_browser() in tools/browser_tool.py walks ~/.hermes/node/bin
|
||||
# plus a few standard prefixes, so installing globally into the managed
|
||||
# Node prefix is enough — no PATH manipulation needed from the agent side.
|
||||
if [ -x "$NODE_PREFIX/bin/agent-browser" ] || command -v agent-browser >/dev/null 2>&1; then
|
||||
log_success "agent-browser already installed"
|
||||
return 0
|
||||
fi
|
||||
|
||||
# When the system's `npm` resolves to a root-owned prefix (e.g.
|
||||
# /usr/lib/node_modules), `npm install -g` fails with EACCES without
|
||||
# sudo. Force the prefix to the user-writable Hermes-managed Node
|
||||
# directory so we never need sudo and the agent can always find the
|
||||
# result. If we installed Node ourselves above, this is a no-op
|
||||
# (managed Node already uses $NODE_PREFIX). If the user has system
|
||||
# Node, we still drop agent-browser under $NODE_PREFIX/bin/ — which
|
||||
# is exactly where _browser_candidate_path_dirs() looks first.
|
||||
mkdir -p "$NODE_PREFIX"
|
||||
|
||||
log_info "Installing agent-browser (npm, prefix=$NODE_PREFIX)..."
|
||||
if ! npm install -g --prefix "$NODE_PREFIX" --silent \
|
||||
agent-browser@^0.26.0 \
|
||||
"@askjo/camofox-browser@^1.5.2"; then
|
||||
log_error "npm install -g agent-browser failed"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# macOS/Linux global installs place the shim into $NODE_PREFIX/bin/.
|
||||
# Add it to PATH for any subsequent steps (npx playwright).
|
||||
export PATH="$NODE_PREFIX/bin:$PATH"
|
||||
|
||||
log_success "agent-browser installed to $NODE_PREFIX/bin/"
|
||||
}
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# Step 3: Playwright Chromium
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
confirm_chromium_download() {
|
||||
if [ "$ASSUME_YES" = true ]; then return 0; fi
|
||||
if [ ! -t 0 ]; then
|
||||
log_warn "Non-interactive shell — skipping Chromium prompt."
|
||||
log_info "Re-run with --yes to install Chromium (~400 MB download)."
|
||||
return 1
|
||||
fi
|
||||
printf "Install Playwright Chromium (~400 MB download)? [y/N] "
|
||||
local reply=""
|
||||
read -r reply || reply=""
|
||||
case "$reply" in
|
||||
y|Y|yes|YES) return 0 ;;
|
||||
*) return 1 ;;
|
||||
esac
|
||||
}
|
||||
|
||||
# Detect a usable system Chrome/Chromium. agent-browser's Chrome engine can
|
||||
# use it instead of downloading Playwright's bundled Chromium, saving the
|
||||
# download cost. Returns the path or empty string.
|
||||
find_system_browser() {
|
||||
local candidate
|
||||
for candidate in google-chrome google-chrome-stable chromium chromium-browser chrome; do
|
||||
if command -v "$candidate" >/dev/null 2>&1; then
|
||||
command -v "$candidate"
|
||||
return 0
|
||||
fi
|
||||
done
|
||||
# macOS app-bundle locations
|
||||
if [ "$OS" = "macos" ]; then
|
||||
for candidate in \
|
||||
"/Applications/Google Chrome.app/Contents/MacOS/Google Chrome" \
|
||||
"/Applications/Chromium.app/Contents/MacOS/Chromium" ; do
|
||||
if [ -x "$candidate" ]; then
|
||||
echo "$candidate"
|
||||
return 0
|
||||
fi
|
||||
done
|
||||
fi
|
||||
return 1
|
||||
}
|
||||
|
||||
write_browser_env() {
|
||||
local browser_path="$1"
|
||||
local env_file="$HERMES_HOME/.env"
|
||||
mkdir -p "$HERMES_HOME"
|
||||
if [ -f "$env_file" ] && grep -q "^AGENT_BROWSER_EXECUTABLE_PATH=" "$env_file"; then
|
||||
return 0
|
||||
fi
|
||||
{
|
||||
echo ""
|
||||
echo "# Hermes Agent browser tools — use the system Chrome/Chromium binary."
|
||||
echo "AGENT_BROWSER_EXECUTABLE_PATH=$browser_path"
|
||||
} >> "$env_file"
|
||||
log_success "Configured browser tools to use $browser_path"
|
||||
}
|
||||
|
||||
ensure_chromium() {
|
||||
if [ "$SKIP_CHROMIUM" = true ]; then
|
||||
log_info "Skipping Chromium install (--skip-chromium)"
|
||||
return 0
|
||||
fi
|
||||
|
||||
local system_browser
|
||||
system_browser="$(find_system_browser 2>/dev/null || true)"
|
||||
if [ -n "$system_browser" ]; then
|
||||
log_success "Found system browser: $system_browser"
|
||||
log_info "Skipping Playwright Chromium download; agent-browser will use it."
|
||||
write_browser_env "$system_browser"
|
||||
return 0
|
||||
fi
|
||||
|
||||
if ! confirm_chromium_download; then
|
||||
log_info "Chromium install skipped. Browser tools will only work if you"
|
||||
log_info "set AGENT_BROWSER_EXECUTABLE_PATH or install Chromium later."
|
||||
return 0
|
||||
fi
|
||||
|
||||
if ! command -v npx >/dev/null 2>&1; then
|
||||
log_error "npx not on PATH — cannot install Playwright Chromium"
|
||||
return 1
|
||||
fi
|
||||
|
||||
log_info "Installing Playwright Chromium (~400 MB) ..."
|
||||
|
||||
# On apt-based distros, --with-deps requires sudo. Try non-interactively
|
||||
# only — never prompt — and fall back to the bare browser-only install.
|
||||
local installed=false
|
||||
if [ "$OS" = "linux" ]; then
|
||||
case "$DISTRO" in
|
||||
ubuntu|debian|raspbian|pop|linuxmint|elementary|zorin|kali|parrot)
|
||||
if [ "$(id -u)" -eq 0 ] || (command -v sudo >/dev/null 2>&1 && sudo -n true 2>/dev/null); then
|
||||
log_info "Installing system deps with --with-deps (sudo available)"
|
||||
if npx --yes playwright install --with-deps chromium; then
|
||||
installed=true
|
||||
fi
|
||||
else
|
||||
log_warn "sudo not available non-interactively — installing Chromium without system deps."
|
||||
log_info "If browser tools fail to launch, an administrator should run:"
|
||||
log_info " sudo npx playwright install-deps chromium"
|
||||
fi
|
||||
;;
|
||||
arch|manjaro|cachyos|endeavouros|garuda)
|
||||
log_info "Arch-family system dependencies are not auto-installed."
|
||||
log_info "If launch fails, run: sudo pacman -S nss atk at-spi2-core cups libdrm libxkbcommon mesa pango cairo alsa-lib"
|
||||
;;
|
||||
fedora|rhel|centos|rocky|alma)
|
||||
log_info "Fedora/RHEL system dependencies are not auto-installed."
|
||||
log_info "If launch fails, run: sudo dnf install nss atk at-spi2-core cups-libs libdrm libxkbcommon mesa-libgbm pango cairo alsa-lib"
|
||||
;;
|
||||
opensuse*|sles)
|
||||
log_info "openSUSE system dependencies are not auto-installed."
|
||||
;;
|
||||
esac
|
||||
fi
|
||||
|
||||
if [ "$installed" = false ]; then
|
||||
if npx --yes playwright install chromium; then
|
||||
installed=true
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ "$installed" = true ]; then
|
||||
log_success "Playwright Chromium installed"
|
||||
else
|
||||
log_error "Playwright Chromium install failed"
|
||||
log_info "Try again later: npx --yes playwright install chromium"
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# Main
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
main() {
|
||||
log_info "Hermes Agent: bootstrapping browser tools"
|
||||
log_info " HERMES_HOME = $HERMES_HOME"
|
||||
log_info " OS / arch = $NODE_OS-$NODE_ARCH ${DISTRO:+($DISTRO)}"
|
||||
|
||||
ensure_node
|
||||
ensure_agent_browser
|
||||
ensure_chromium
|
||||
|
||||
log_success "Browser tools setup complete."
|
||||
log_info "Hermes Agent will pick up agent-browser from $NODE_PREFIX/bin/ on next launch."
|
||||
}
|
||||
|
||||
main
|
||||
+144
-1
@@ -24,6 +24,7 @@ except ModuleNotFoundError:
|
||||
# means UTF-8 stdio setup is skipped on Windows; POSIX is unaffected.
|
||||
pass
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
@@ -107,8 +108,150 @@ def _load_env() -> None:
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="hermes-acp",
|
||||
description="Run Hermes Agent as an ACP stdio server.",
|
||||
)
|
||||
parser.add_argument("--version", action="store_true", help="Print Hermes version and exit")
|
||||
parser.add_argument(
|
||||
"--check",
|
||||
action="store_true",
|
||||
help="Verify ACP dependencies and adapter imports, then exit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--setup",
|
||||
action="store_true",
|
||||
help="Run interactive Hermes provider/model setup for ACP terminal auth",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--setup-browser",
|
||||
action="store_true",
|
||||
help="Install agent-browser + Playwright Chromium into ~/.hermes/node/ "
|
||||
"for browser tool support. Idempotent.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--yes",
|
||||
"-y",
|
||||
action="store_true",
|
||||
dest="assume_yes",
|
||||
help="Accept all prompts (currently used by --setup-browser to skip the "
|
||||
"~400 MB Chromium download confirmation).",
|
||||
)
|
||||
return parser.parse_args(argv)
|
||||
|
||||
|
||||
def _print_version() -> None:
|
||||
from hermes_cli import __version__ as hermes_version
|
||||
|
||||
print(hermes_version)
|
||||
|
||||
|
||||
def _run_check() -> None:
|
||||
import acp # noqa: F401
|
||||
from acp_adapter.server import HermesACPAgent # noqa: F401
|
||||
|
||||
print("Hermes ACP check OK")
|
||||
|
||||
|
||||
def _run_setup() -> None:
|
||||
from hermes_cli.main import main as hermes_main
|
||||
|
||||
old_argv = sys.argv[:]
|
||||
try:
|
||||
sys.argv = [old_argv[0] if old_argv else "hermes", "model"]
|
||||
hermes_main()
|
||||
finally:
|
||||
sys.argv = old_argv
|
||||
|
||||
# Offer browser-tools install as a follow-up. The terminal auth method
|
||||
# is the one supported first-run UX for registry installs, so this is
|
||||
# the natural moment to ask. Skip silently if stdin isn't a TTY (the
|
||||
# answer can't be collected anyway).
|
||||
if not sys.stdin.isatty():
|
||||
return
|
||||
try:
|
||||
reply = input(
|
||||
"\nInstall browser tools? Downloads agent-browser (npm) and "
|
||||
"optionally Playwright Chromium (~400 MB). [y/N] "
|
||||
).strip().lower()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
return
|
||||
if reply in {"y", "yes"}:
|
||||
_run_setup_browser(assume_yes=False)
|
||||
|
||||
|
||||
def _run_setup_browser(assume_yes: bool = False) -> int:
|
||||
"""Bootstrap agent-browser + Playwright Chromium for the registry-install path.
|
||||
|
||||
Shells out to the bundled platform-specific bootstrap script
|
||||
(acp_adapter/bootstrap/bootstrap_browser_tools.{sh,ps1}) so the install
|
||||
logic lives in one place — readable, debuggable, and shareable with
|
||||
install.sh / install.ps1 if we ever want to call it from there too.
|
||||
|
||||
Returns the script's exit code (0 on success).
|
||||
"""
|
||||
import platform
|
||||
import subprocess
|
||||
|
||||
bootstrap_dir = Path(__file__).resolve().parent / "bootstrap"
|
||||
|
||||
if platform.system() == "Windows":
|
||||
script = bootstrap_dir / "bootstrap_browser_tools.ps1"
|
||||
if not script.is_file():
|
||||
print(
|
||||
f"Bootstrap script not found at {script} — wheel may be incomplete.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
cmd = [
|
||||
"powershell.exe",
|
||||
"-NoProfile",
|
||||
"-ExecutionPolicy", "Bypass",
|
||||
"-File", str(script),
|
||||
]
|
||||
if assume_yes:
|
||||
cmd.append("-Yes")
|
||||
else:
|
||||
script = bootstrap_dir / "bootstrap_browser_tools.sh"
|
||||
if not script.is_file():
|
||||
print(
|
||||
f"Bootstrap script not found at {script} — wheel may be incomplete.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
cmd = ["bash", str(script)]
|
||||
if assume_yes:
|
||||
cmd.append("--yes")
|
||||
|
||||
# stdio is inherited so the user sees the bootstrap's progress live.
|
||||
try:
|
||||
result = subprocess.run(cmd, check=False)
|
||||
except FileNotFoundError as exc:
|
||||
# bash / powershell.exe not on PATH
|
||||
print(f"Could not launch browser bootstrap: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
return result.returncode
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> None:
|
||||
"""Entry point: load env, configure logging, run the ACP agent."""
|
||||
args = _parse_args(argv)
|
||||
if args.version:
|
||||
_print_version()
|
||||
return
|
||||
if args.check:
|
||||
_run_check()
|
||||
return
|
||||
if args.setup:
|
||||
_run_setup()
|
||||
return
|
||||
if args.setup_browser:
|
||||
rc = _run_setup_browser(assume_yes=args.assume_yes)
|
||||
if rc != 0:
|
||||
sys.exit(rc)
|
||||
return
|
||||
|
||||
_setup_logging()
|
||||
_load_env()
|
||||
|
||||
|
||||
+13
-20
@@ -57,13 +57,7 @@ from acp.schema import (
|
||||
UserMessageChunk,
|
||||
)
|
||||
|
||||
# AuthMethodAgent was renamed from AuthMethod in agent-client-protocol 0.9.0
|
||||
try:
|
||||
from acp.schema import AuthMethodAgent
|
||||
except ImportError:
|
||||
from acp.schema import AuthMethod as AuthMethodAgent # type: ignore[attr-defined]
|
||||
|
||||
from acp_adapter.auth import detect_provider
|
||||
from acp_adapter.auth import TERMINAL_SETUP_AUTH_METHOD_ID, build_auth_methods, detect_provider
|
||||
from acp_adapter.events import (
|
||||
make_message_cb,
|
||||
make_step_cb,
|
||||
@@ -744,16 +738,7 @@ class HermesACPAgent(acp.Agent):
|
||||
resolved_protocol_version = (
|
||||
protocol_version if isinstance(protocol_version, int) else acp.PROTOCOL_VERSION
|
||||
)
|
||||
provider = detect_provider()
|
||||
auth_methods = None
|
||||
if provider:
|
||||
auth_methods = [
|
||||
AuthMethodAgent(
|
||||
id=provider,
|
||||
name=f"{provider} runtime credentials",
|
||||
description=f"Authenticate Hermes using the currently configured {provider} runtime credentials.",
|
||||
)
|
||||
]
|
||||
auth_methods = build_auth_methods()
|
||||
|
||||
client_name = client_info.name if client_info else "unknown"
|
||||
logger.info(
|
||||
@@ -784,10 +769,18 @@ class HermesACPAgent(acp.Agent):
|
||||
# server has provider credentials configured — harmless under
|
||||
# Hermes' threat model (ACP is stdio-only, local-trust), but poor
|
||||
# API hygiene and confusing if ACP ever grows multi-method auth.
|
||||
provider = detect_provider()
|
||||
if not provider:
|
||||
if not isinstance(method_id, str):
|
||||
return None
|
||||
if not isinstance(method_id, str) or method_id.strip().lower() != provider:
|
||||
normalized_method = method_id.strip().lower()
|
||||
provider = detect_provider()
|
||||
|
||||
if normalized_method == TERMINAL_SETUP_AUTH_METHOD_ID:
|
||||
# Terminal auth launches Hermes setup/model selection out-of-band.
|
||||
# Only report success once that flow has produced usable runtime
|
||||
# credentials for the normal ACP session.
|
||||
return AuthenticateResponse() if provider else None
|
||||
|
||||
if not provider or normalized_method != provider:
|
||||
return None
|
||||
return AuthenticateResponse()
|
||||
|
||||
|
||||
+12
-8
@@ -1,12 +1,16 @@
|
||||
{
|
||||
"schema_version": 1,
|
||||
"name": "hermes-agent",
|
||||
"display_name": "Hermes Agent",
|
||||
"description": "AI agent by Nous Research with 90+ tools, persistent memory, and multi-platform support",
|
||||
"icon": "icon.svg",
|
||||
"id": "hermes-agent",
|
||||
"name": "Hermes Agent",
|
||||
"version": "0.13.0",
|
||||
"description": "Self-improving open-source AI agent by Nous Research with ACP editor integration, persistent memory, skills, and rich tool support.",
|
||||
"repository": "https://github.com/NousResearch/hermes-agent",
|
||||
"website": "https://hermes-agent.nousresearch.com/docs/user-guide/features/acp",
|
||||
"authors": ["Nous Research"],
|
||||
"license": "MIT",
|
||||
"distribution": {
|
||||
"type": "command",
|
||||
"command": "hermes",
|
||||
"args": ["acp"]
|
||||
"uvx": {
|
||||
"package": "hermes-agent[acp]==0.13.0",
|
||||
"args": ["hermes-acp"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+7
-24
@@ -1,25 +1,8 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 64 64" width="64" height="64">
|
||||
<defs>
|
||||
<linearGradient id="gold" x1="0%" y1="0%" x2="0%" y2="100%">
|
||||
<stop offset="0%" style="stop-color:#F5C542;stop-opacity:1" />
|
||||
<stop offset="100%" style="stop-color:#D4961C;stop-opacity:1" />
|
||||
</linearGradient>
|
||||
</defs>
|
||||
<!-- Staff -->
|
||||
<rect x="30" y="10" width="4" height="46" rx="2" fill="url(#gold)" />
|
||||
<!-- Wings (left) -->
|
||||
<path d="M30 18 C24 14, 14 14, 10 18 C14 16, 22 16, 28 20" fill="#F5C542" opacity="0.9" />
|
||||
<path d="M30 22 C26 19, 18 19, 14 22 C18 20, 24 20, 28 24" fill="#D4961C" opacity="0.8" />
|
||||
<!-- Wings (right) -->
|
||||
<path d="M34 18 C40 14, 50 14, 54 18 C50 16, 42 16, 36 20" fill="#F5C542" opacity="0.9" />
|
||||
<path d="M34 22 C38 19, 46 19, 50 22 C46 20, 40 20, 36 24" fill="#D4961C" opacity="0.8" />
|
||||
<!-- Left serpent -->
|
||||
<path d="M32 48 C22 44, 20 38, 26 34 C20 36, 18 42, 24 46 C18 40, 22 30, 30 28 C24 32, 22 38, 28 42"
|
||||
fill="none" stroke="#F5C542" stroke-width="2.5" stroke-linecap="round" />
|
||||
<!-- Right serpent -->
|
||||
<path d="M32 48 C42 44, 44 38, 38 34 C44 36, 46 42, 40 46 C46 40, 42 30, 34 28 C40 32, 42 38, 36 42"
|
||||
fill="none" stroke="#D4961C" stroke-width="2.5" stroke-linecap="round" />
|
||||
<!-- Orb at top -->
|
||||
<circle cx="32" cy="10" r="4" fill="#F5C542" />
|
||||
<circle cx="32" cy="10" r="2" fill="#FFF8E1" opacity="0.7" />
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16" width="16" height="16" fill="none">
|
||||
<path d="M8 1.5v13" stroke="currentColor" stroke-width="1.5" stroke-linecap="round"/>
|
||||
<path d="M8 3.25c-2.35-1.4-4.7-.95-6.25.35 1.85-.2 3.8.2 5.55 1.55" stroke="currentColor" stroke-width="1.1" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M8 3.25c2.35-1.4 4.7-.95 6.25.35-1.85-.2-3.8.2-5.55 1.55" stroke="currentColor" stroke-width="1.1" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M8 13.25c-2.3-1-3.05-2.65-1.35-4.15-2 .8-2.35 2.95-.35 4" stroke="currentColor" stroke-width="1.1" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M8 13.25c2.3-1 3.05-2.65 1.35-4.15 2 .8 2.35 2.95.35 4" stroke="currentColor" stroke-width="1.1" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<circle cx="8" cy="1.8" r="1.1" fill="currentColor"/>
|
||||
</svg>
|
||||
|
||||
|
Before Width: | Height: | Size: 1.4 KiB After Width: | Height: | Size: 882 B |
@@ -1456,8 +1456,21 @@ def _try_nous(vision: bool = False) -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
nous = _read_nous_auth()
|
||||
runtime = _resolve_nous_runtime_api(force_refresh=False)
|
||||
if runtime is None and not nous:
|
||||
logger.warning(
|
||||
"Auxiliary Nous client unavailable: no Nous authentication found "
|
||||
"(run: hermes auth)."
|
||||
)
|
||||
_mark_provider_unhealthy("nous", ttl=60)
|
||||
return None, None
|
||||
if runtime is None and nous:
|
||||
# Runtime credential mint failed but stored Nous auth is still present.
|
||||
# Falls back to the raw stored token below; surface a debug line so
|
||||
# operators investigating expired/invalid sessions have a breadcrumb,
|
||||
# without blocking the fallback path the rest of this function relies on.
|
||||
logger.debug(
|
||||
"Auxiliary Nous: runtime credential mint failed; falling back to "
|
||||
"stored auth.json token."
|
||||
)
|
||||
global auxiliary_is_nous
|
||||
auxiliary_is_nous = True
|
||||
logger.debug("Auxiliary client: Nous Portal")
|
||||
|
||||
@@ -1429,15 +1429,23 @@ The user has requested that this compaction PRIORITISE preserving all informatio
|
||||
return messages
|
||||
|
||||
turns_to_summarize = messages[compress_start:compress_end]
|
||||
# A persisted handoff summary can sit in the protected head after a
|
||||
# resume (commonly immediately after the system prompt). Search from
|
||||
# the first non-system message through the compression window so we can
|
||||
# rehydrate iterative-summary state without serializing that handoff as
|
||||
# a new turn. Protected messages after the handoff remain live context,
|
||||
# so only summarize messages that are both after the handoff and inside
|
||||
# the current compression window.
|
||||
summary_search_start = 1 if messages and messages[0].get("role") == "system" else 0
|
||||
summary_idx, summary_body = self._find_latest_context_summary(
|
||||
messages,
|
||||
compress_start,
|
||||
summary_search_start,
|
||||
compress_end,
|
||||
)
|
||||
if summary_idx is not None:
|
||||
if summary_body and not self._previous_summary:
|
||||
self._previous_summary = summary_body
|
||||
turns_to_summarize = messages[summary_idx + 1:compress_end]
|
||||
turns_to_summarize = messages[max(compress_start, summary_idx + 1):compress_end]
|
||||
|
||||
if not self.quiet_mode:
|
||||
logger.info(
|
||||
|
||||
@@ -240,21 +240,6 @@ def build_tool_preview(tool_name: str, args: dict, max_len: int | None = None) -
|
||||
msg = msg[:17] + "..."
|
||||
return f"to {target}: \"{msg}\""
|
||||
|
||||
if tool_name.startswith("rl_"):
|
||||
rl_previews = {
|
||||
"rl_list_environments": "listing envs",
|
||||
"rl_select_environment": args.get("name", ""),
|
||||
"rl_get_current_config": "reading config",
|
||||
"rl_edit_config": f"{args.get('field', '')}={args.get('value', '')}",
|
||||
"rl_start_training": "starting",
|
||||
"rl_check_status": args.get("run_id", "")[:16],
|
||||
"rl_stop_training": f"stopping {args.get('run_id', '')[:16]}",
|
||||
"rl_get_results": args.get("run_id", "")[:16],
|
||||
"rl_list_runs": "listing runs",
|
||||
"rl_test_inference": f"{args.get('num_steps', 3)} steps",
|
||||
}
|
||||
return rl_previews.get(tool_name)
|
||||
|
||||
key = primary_args.get(tool_name)
|
||||
if not key:
|
||||
for fallback_key in ("query", "text", "command", "path", "name", "prompt", "code", "goal"):
|
||||
@@ -981,15 +966,6 @@ def get_cute_tool_message(
|
||||
if action == "list":
|
||||
return _wrap(f"┊ ⏰ cron listing {dur}")
|
||||
return _wrap(f"┊ ⏰ cron {action} {args.get('job_id', '')} {dur}")
|
||||
if tool_name.startswith("rl_"):
|
||||
rl = {
|
||||
"rl_list_environments": "list envs", "rl_select_environment": f"select {args.get('name', '')}",
|
||||
"rl_get_current_config": "get config", "rl_edit_config": f"set {args.get('field', '?')}",
|
||||
"rl_start_training": "start training", "rl_check_status": f"status {args.get('run_id', '?')[:12]}",
|
||||
"rl_stop_training": f"stop {args.get('run_id', '?')[:12]}", "rl_get_results": f"results {args.get('run_id', '?')[:12]}",
|
||||
"rl_list_runs": "list runs", "rl_test_inference": "test inference",
|
||||
}
|
||||
return _wrap(f"┊ 🧪 rl {rl.get(tool_name, tool_name.replace('rl_', ''))} {dur}")
|
||||
if tool_name == "execute_code":
|
||||
code = args.get("code", "")
|
||||
first_line = code.strip().split("\n")[0] if code.strip() else ""
|
||||
|
||||
+35
-3
@@ -40,7 +40,7 @@ import os
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import Future as ConcurrentFuture
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from agent.lsp import eventlog
|
||||
from agent.lsp.client import (
|
||||
@@ -305,6 +305,7 @@ class LSPService:
|
||||
*,
|
||||
delta: bool = True,
|
||||
timeout: Optional[float] = None,
|
||||
line_shift: Optional[Callable[[int], Optional[int]]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Synchronously open ``file_path`` in the right server, wait for
|
||||
diagnostics, return them.
|
||||
@@ -314,6 +315,18 @@ class LSPService:
|
||||
Diagnostics present in the baseline are removed so the caller
|
||||
only sees errors introduced by the current edit.
|
||||
|
||||
When ``line_shift`` is provided, baseline diagnostics are
|
||||
remapped through it before the set-difference. This handles
|
||||
the case where the edit deleted or inserted lines, causing
|
||||
pre-existing diagnostics below the edit point to surface at
|
||||
different line numbers in the post-edit snapshot — without
|
||||
the shift, they'd all look "introduced by this edit". Pass
|
||||
a callable built by
|
||||
:func:`agent.lsp.range_shift.build_line_shift` (pre_text,
|
||||
post_text). Omit when pre/post content isn't available;
|
||||
the unshifted comparison still catches diagnostics that
|
||||
didn't move.
|
||||
|
||||
Returns an empty list when LSP is disabled, when no workspace
|
||||
can be detected, when no server matches, or when the server
|
||||
can't be spawned. Never raises.
|
||||
@@ -344,6 +357,14 @@ class LSPService:
|
||||
if delta:
|
||||
baseline = self._delta_baseline.get(abs_path) or []
|
||||
if baseline:
|
||||
if line_shift is not None:
|
||||
# Remap baseline diagnostics into post-edit
|
||||
# coordinates so shifted-but-otherwise-identical
|
||||
# entries hash equal under _diag_key. Entries
|
||||
# that mapped into a deleted region drop out
|
||||
# silently — they no longer apply.
|
||||
from agent.lsp.range_shift import shift_baseline
|
||||
baseline = shift_baseline(baseline, line_shift)
|
||||
seen = {_diag_key(d) for d in baseline}
|
||||
diags = [d for d in diags if _diag_key(d) not in seen]
|
||||
# Roll baseline forward — next call returns deltas relative
|
||||
@@ -585,8 +606,19 @@ class LSPService:
|
||||
|
||||
|
||||
def _diag_key(d: Dict[str, Any]) -> str:
|
||||
"""Content equality key used for delta filtering. Mirrors
|
||||
:func:`agent.lsp.client._diagnostic_key`."""
|
||||
"""Content equality key used for cross-edit delta filtering.
|
||||
|
||||
Includes the diagnostic's position range — when used together
|
||||
with :func:`agent.lsp.range_shift.shift_baseline`, the baseline
|
||||
is line-shifted into post-edit coordinates BEFORE this key is
|
||||
computed, so identical-but-shifted diagnostics hash equal. Two
|
||||
genuinely distinct diagnostics at different lines (e.g. the same
|
||||
error class introduced at a second site) hash differently and
|
||||
are surfaced as new.
|
||||
|
||||
Mirrors :func:`agent.lsp.client._diagnostic_key`; intentionally
|
||||
identical so the two layers agree on diagnostic identity.
|
||||
"""
|
||||
rng = d.get("range") or {}
|
||||
start = rng.get("start") or {}
|
||||
end = rng.get("end") or {}
|
||||
|
||||
@@ -0,0 +1,149 @@
|
||||
"""Diff-aware line-shift map for cross-edit LSP delta filtering.
|
||||
|
||||
When an edit deletes or inserts lines in the middle of a file, every
|
||||
diagnostic below the edit point shifts to a new line number. The
|
||||
LSPService delta filter subtracts the pre-edit baseline from the
|
||||
post-edit diagnostics keyed on ``(severity, code, source, message,
|
||||
range)`` — without an adjustment, the shifted-but-otherwise-identical
|
||||
diagnostics look brand-new and the agent gets flooded with noise.
|
||||
|
||||
The fix used here is the same trick git's blame and unified diff use:
|
||||
build a piecewise-linear map from pre-edit line numbers to post-edit
|
||||
line numbers, then apply that map to baseline diagnostics before the
|
||||
set-difference. Diagnostics whose pre-edit line is in a region the
|
||||
edit deleted return ``None`` and are dropped from the baseline (they
|
||||
genuinely no longer apply).
|
||||
|
||||
Trade-off vs. dropping range from the key entirely (the previous
|
||||
fix): preserves the "new instance of an identical error at a
|
||||
different line" signal — if the model introduces a second instance
|
||||
of the same error class at a different location, that one will be
|
||||
surfaced as new instead of swallowed by content-only dedup.
|
||||
|
||||
The map is derived from ``difflib.SequenceMatcher.get_opcodes()`` and
|
||||
exposed as a single callable so callers don't have to reason about
|
||||
diff regions.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import difflib
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
|
||||
def build_line_shift(pre_text: str, post_text: str) -> Callable[[int], Optional[int]]:
|
||||
"""Build a function mapping pre-edit line numbers to post-edit line numbers.
|
||||
|
||||
Lines are 0-indexed to match the LSP wire format
|
||||
(``range.start.line`` is 0-indexed).
|
||||
|
||||
The returned callable takes a pre-edit 0-indexed line number and
|
||||
returns the corresponding post-edit 0-indexed line number, or
|
||||
``None`` if that line was deleted by the edit (no post-edit
|
||||
counterpart exists).
|
||||
|
||||
Cost: one ``SequenceMatcher.get_opcodes()`` call up front; the
|
||||
returned closure is O(log n) per call (binary search over opcode
|
||||
regions). Cheap enough to call once per write/patch and apply to
|
||||
every baseline diagnostic.
|
||||
"""
|
||||
pre_lines = pre_text.splitlines() if pre_text else []
|
||||
post_lines = post_text.splitlines() if post_text else []
|
||||
|
||||
# Trivial case: identical content or no content — identity map.
|
||||
if pre_lines == post_lines:
|
||||
return lambda line: line
|
||||
|
||||
# SequenceMatcher.get_opcodes() returns a list of
|
||||
# (tag, i1, i2, j1, j2) where tag is 'equal', 'replace', 'delete',
|
||||
# or 'insert'. i1:i2 is the range in pre, j1:j2 is the range in
|
||||
# post. We build a list of (i1, i2, j1, j2, tag) tuples and
|
||||
# binary-search by i for each lookup.
|
||||
sm = difflib.SequenceMatcher(a=pre_lines, b=post_lines, autojunk=False)
|
||||
opcodes = sm.get_opcodes()
|
||||
|
||||
def shift(line: int) -> Optional[int]:
|
||||
# Find the opcode region whose i1 <= line < i2.
|
||||
# Linear scan is fine — typical opcode count is small (single
|
||||
# digits for a typical patch-tool edit).
|
||||
for tag, i1, i2, j1, j2 in opcodes:
|
||||
if i1 <= line < i2:
|
||||
if tag == "equal":
|
||||
# Pre-line N → post-line (N - i1 + j1).
|
||||
return line - i1 + j1
|
||||
if tag == "delete":
|
||||
# Pre-line is in a deleted region — no post counterpart.
|
||||
return None
|
||||
if tag == "replace":
|
||||
# Replace == delete + insert; the pre-line has no
|
||||
# post counterpart in any meaningful sense. Drop.
|
||||
return None
|
||||
# 'insert' has i1 == i2 so line < i2 can't be hit.
|
||||
if line < i1:
|
||||
# Past the relevant region — handled in earlier iteration.
|
||||
break
|
||||
# Past the last opcode region (line >= len(pre_lines)).
|
||||
# Anchor at end of post.
|
||||
return max(0, len(post_lines) - 1) if post_lines else None
|
||||
|
||||
return shift
|
||||
|
||||
|
||||
def shift_diagnostic_range(diag: Dict[str, Any],
|
||||
shift: Callable[[int], Optional[int]]) -> Optional[Dict[str, Any]]:
|
||||
"""Return a copy of ``diag`` with its line range remapped through ``shift``.
|
||||
|
||||
Returns ``None`` if the diagnostic's start line maps to ``None``
|
||||
(the line was deleted by the edit) — caller drops it from the
|
||||
baseline since the diagnostic no longer applies.
|
||||
|
||||
Both ``start.line`` and ``end.line`` are remapped independently;
|
||||
when only the end maps to ``None`` (rare, multi-line diagnostic
|
||||
straddling the edit boundary) we collapse to a single-line range
|
||||
at the shifted start to keep the diagnostic in the baseline.
|
||||
|
||||
The original ``diag`` is not mutated.
|
||||
"""
|
||||
rng = diag.get("range") or {}
|
||||
start = rng.get("start") or {}
|
||||
end = rng.get("end") or {}
|
||||
|
||||
pre_start_line = int(start.get("line", 0))
|
||||
pre_end_line = int(end.get("line", pre_start_line))
|
||||
|
||||
new_start_line = shift(pre_start_line)
|
||||
if new_start_line is None:
|
||||
return None
|
||||
|
||||
new_end_line = shift(pre_end_line)
|
||||
if new_end_line is None:
|
||||
# Diagnostic straddled the deletion — collapse to start.
|
||||
new_end_line = new_start_line
|
||||
|
||||
shifted = dict(diag)
|
||||
shifted["range"] = {
|
||||
"start": {
|
||||
"line": new_start_line,
|
||||
"character": int(start.get("character", 0)),
|
||||
},
|
||||
"end": {
|
||||
"line": new_end_line,
|
||||
"character": int(end.get("character", 0)),
|
||||
},
|
||||
}
|
||||
return shifted
|
||||
|
||||
|
||||
def shift_baseline(baseline: List[Dict[str, Any]],
|
||||
shift: Callable[[int], Optional[int]]) -> List[Dict[str, Any]]:
|
||||
"""Apply ``shift`` to every diagnostic in ``baseline``, dropping deleted entries."""
|
||||
out: List[Dict[str, Any]] = []
|
||||
for d in baseline:
|
||||
if not isinstance(d, dict):
|
||||
continue
|
||||
shifted = shift_diagnostic_range(d, shift)
|
||||
if shifted is not None:
|
||||
out.append(shifted)
|
||||
return out
|
||||
|
||||
|
||||
__all__ = ["build_line_shift", "shift_diagnostic_range", "shift_baseline"]
|
||||
+16
-4
@@ -444,6 +444,10 @@ prompt_caching:
|
||||
# model: ""
|
||||
# timeout: 30
|
||||
# max_concurrency: 3 # Limit parallel summaries to reduce request-burst 429s
|
||||
# default_mode: "fast" # 'fast' | 'summary' — mode used when caller passes none.
|
||||
# # fast: FTS5 snippet hits, no LLM call. Default.
|
||||
# # summary: LLM-generated prose synthesis across hits.
|
||||
# # guided requires anchors and cannot be a default.
|
||||
# extra_body: {} # Provider-specific OpenAI-compatible request fields
|
||||
# # Example for providers that support request-body
|
||||
# # reasoning controls:
|
||||
@@ -457,7 +461,7 @@ prompt_caching:
|
||||
# Two stores: MEMORY.md (agent's notes) and USER.md (user profile).
|
||||
# Character limits keep the memory small and focused. The agent manages
|
||||
# pruning -- when at the limit, it must consolidate or replace entries.
|
||||
# Disabled by default in batch_runner and RL environments.
|
||||
# Disabled by default in batch_runner.
|
||||
#
|
||||
memory:
|
||||
# Agent's personal notes: environment facts, conventions, things learned
|
||||
@@ -681,6 +685,16 @@ platform_toolsets:
|
||||
# # allowed_chats: ["-1001234567890"]
|
||||
# extra:
|
||||
# disable_link_previews: false # Set true to suppress Telegram URL previews in bot messages
|
||||
#
|
||||
# Discord-specific settings (config.yaml top-level, not under platforms:):
|
||||
#
|
||||
# discord:
|
||||
# require_mention: true # Require @mention in server channels (default: true)
|
||||
# auto_thread: true # Auto-create thread on @mention (default: true)
|
||||
# free_response_channels: "" # Channel IDs where no mention is needed
|
||||
# reactions: true # Show processing reactions (default: true)
|
||||
# history_backfill: true # Recover missed channel messages on mention (default: true)
|
||||
# history_backfill_limit: 50 # Max messages to scan backwards (default: 50)
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Available toolsets (use these names in platform_toolsets or the toolsets list)
|
||||
@@ -705,10 +719,9 @@ platform_toolsets:
|
||||
# todo - todo (in-memory task planning, no deps)
|
||||
# tts - text_to_speech (Edge TTS free, or ELEVENLABS/OPENAI/MINIMAX/MISTRAL key)
|
||||
# cronjob - cronjob (create/list/update/pause/resume/run/remove scheduled tasks)
|
||||
# rl - rl_list_environments, rl_start_training, etc. (requires TINKER_API_KEY)
|
||||
#
|
||||
# PRESETS (curated bundles):
|
||||
# hermes-cli - All of the above except rl + send_message
|
||||
# hermes-cli - All of the above except send_message
|
||||
# hermes-telegram - terminal, file, web, vision, image_gen, tts, browser,
|
||||
# skills, todo, cronjob, send_message
|
||||
# hermes-discord - Same as hermes-telegram
|
||||
@@ -734,7 +747,6 @@ platform_toolsets:
|
||||
# session_search - Search and recall past conversations (FTS5 + Gemini Flash summarization)
|
||||
# tts - Text-to-speech (Edge TTS free, ElevenLabs, OpenAI, MiniMax, Mistral)
|
||||
# cronjob - Schedule and manage automated tasks (CLI-only)
|
||||
# rl - RL training tools (Tinker-Atropos)
|
||||
#
|
||||
# Composite toolsets:
|
||||
# debugging - terminal + web + file (for troubleshooting)
|
||||
|
||||
@@ -1242,7 +1242,13 @@ _STREAM_PAD = " " # 4-space indent for streamed response text (matches Panel
|
||||
|
||||
|
||||
def _hex_to_ansi(hex_color: str, *, bold: bool = False) -> str:
|
||||
"""Convert a hex color like '#268bd2' to a true-color ANSI escape."""
|
||||
"""Convert a hex color like '#268bd2' to a true-color ANSI escape.
|
||||
|
||||
Auto-remaps known dark-mode-tuned colors to readable light-mode
|
||||
equivalents when running on a light terminal (see
|
||||
_maybe_remap_for_light_mode + _LIGHT_MODE_REMAP).
|
||||
"""
|
||||
hex_color = _maybe_remap_for_light_mode(hex_color)
|
||||
try:
|
||||
r = int(hex_color[1:3], 16)
|
||||
g = int(hex_color[3:5], 16)
|
||||
@@ -1253,6 +1259,250 @@ def _hex_to_ansi(hex_color: str, *, bold: bool = False) -> str:
|
||||
return _ACCENT_ANSI_DEFAULT if bold else "\033[38;2;184;134;11m"
|
||||
|
||||
|
||||
# ────────────────────────────────────────────────────────────────────────
|
||||
# Light/dark terminal mode detection.
|
||||
#
|
||||
# Mirrors ui-tui/src/theme.ts detectLightMode(). Used to decide whether
|
||||
# to remap "near-white" skin colors (e.g. #FFF8DC banner_text, #B8860B
|
||||
# banner_dim) to darker equivalents that are readable on a light
|
||||
# Terminal.app / iTerm2 background.
|
||||
#
|
||||
# Detection priority:
|
||||
# 1. HERMES_LIGHT / HERMES_TUI_LIGHT env (true/false) — explicit override
|
||||
# 2. HERMES_TUI_THEME=light|dark — explicit theme
|
||||
# 3. HERMES_TUI_BACKGROUND=#RRGGBB — explicit bg hint
|
||||
# 4. COLORFGBG env (set by xterm/Konsole/urxvt) — bg slot 7/15 = light
|
||||
# 5. OSC 11 query (\x1b]11;?\x1b\\) — ask the terminal directly
|
||||
# 6. Default: assume dark (matches the legacy Hermes assumption)
|
||||
#
|
||||
# Cached after first call so we don't query the terminal repeatedly.
|
||||
_LIGHT_MODE_CACHE: bool | None = None
|
||||
_TRUE_RE = re.compile(r"^(1|true|on|yes|y)$")
|
||||
_FALSE_RE = re.compile(r"^(0|false|off|no|n)$")
|
||||
_LIGHT_DEFAULT_TERM_PROGRAMS = frozenset() # Apple_Terminal doesn't reliably indicate; require explicit
|
||||
|
||||
|
||||
def _luminance_from_hex(hex_str: str) -> float | None:
|
||||
s = (hex_str or "").strip().lstrip("#")
|
||||
if len(s) == 3:
|
||||
s = "".join(c * 2 for c in s)
|
||||
if len(s) != 6 or not all(c in "0123456789abcdefABCDEF" for c in s):
|
||||
return None
|
||||
try:
|
||||
r, g, b = int(s[0:2], 16), int(s[2:4], 16), int(s[4:6], 16)
|
||||
except ValueError:
|
||||
return None
|
||||
# Rec.709 luma
|
||||
return (0.2126 * r + 0.7152 * g + 0.0722 * b) / 255.0
|
||||
|
||||
|
||||
def _query_osc11_background() -> str | None:
|
||||
"""Ask the terminal for its background color via OSC 11.
|
||||
|
||||
Most modern terminals reply with \x1b]11;rgb:RRRR/GGGG/BBBB\x1b\\
|
||||
within a few ms. We wait up to 100ms total before giving up.
|
||||
Returns "#RRGGBB" or None on timeout / non-tty.
|
||||
"""
|
||||
if not sys.stdin.isatty() or not sys.stdout.isatty():
|
||||
return None
|
||||
try:
|
||||
import termios
|
||||
import tty
|
||||
fd = sys.stdin.fileno()
|
||||
old = termios.tcgetattr(fd)
|
||||
except Exception:
|
||||
return None
|
||||
try:
|
||||
try:
|
||||
tty.setcbreak(fd)
|
||||
except Exception:
|
||||
return None
|
||||
try:
|
||||
sys.stdout.write("\x1b]11;?\x1b\\")
|
||||
sys.stdout.flush()
|
||||
except Exception:
|
||||
return None
|
||||
# Read up to ~50ms for the response
|
||||
import select
|
||||
deadline = time.monotonic() + 0.1
|
||||
buf = b""
|
||||
while time.monotonic() < deadline:
|
||||
r, _, _ = select.select([fd], [], [], deadline - time.monotonic())
|
||||
if not r:
|
||||
continue
|
||||
try:
|
||||
chunk = os.read(fd, 64)
|
||||
except OSError:
|
||||
break
|
||||
if not chunk:
|
||||
break
|
||||
buf += chunk
|
||||
if b"\x1b\\" in buf or b"\x07" in buf:
|
||||
break
|
||||
# Parse: \x1b]11;rgb:RRRR/GGGG/BBBB\x1b\\
|
||||
m = re.search(rb"rgb:([0-9a-fA-F]+)/([0-9a-fA-F]+)/([0-9a-fA-F]+)", buf)
|
||||
if not m:
|
||||
return None
|
||||
# Each component is 1-4 hex digits — normalize to 8-bit
|
||||
def norm(h: bytes) -> int:
|
||||
v = int(h, 16)
|
||||
# Scale to 0-255 based on hex length
|
||||
bits = len(h) * 4
|
||||
return (v * 255) // ((1 << bits) - 1) if bits else 0
|
||||
r, g, b = norm(m.group(1)), norm(m.group(2)), norm(m.group(3))
|
||||
return f"#{r:02X}{g:02X}{b:02X}"
|
||||
finally:
|
||||
try:
|
||||
termios.tcsetattr(fd, termios.TCSANOW, old)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _detect_light_mode() -> bool:
|
||||
global _LIGHT_MODE_CACHE
|
||||
if _LIGHT_MODE_CACHE is not None:
|
||||
return _LIGHT_MODE_CACHE
|
||||
result = False
|
||||
try:
|
||||
# 1. Explicit env override
|
||||
for var in ("HERMES_LIGHT", "HERMES_TUI_LIGHT"):
|
||||
v = (os.environ.get(var) or "").strip().lower()
|
||||
if _TRUE_RE.match(v):
|
||||
result = True
|
||||
_LIGHT_MODE_CACHE = result
|
||||
return result
|
||||
if _FALSE_RE.match(v):
|
||||
_LIGHT_MODE_CACHE = result
|
||||
return result
|
||||
# 2. Theme hint
|
||||
theme = (os.environ.get("HERMES_TUI_THEME") or "").strip().lower()
|
||||
if theme == "light":
|
||||
result = True
|
||||
_LIGHT_MODE_CACHE = result
|
||||
return result
|
||||
if theme == "dark":
|
||||
_LIGHT_MODE_CACHE = result
|
||||
return result
|
||||
# 3. Explicit bg hex
|
||||
bg_hint = os.environ.get("HERMES_TUI_BACKGROUND") or ""
|
||||
bg_lum = _luminance_from_hex(bg_hint)
|
||||
if bg_lum is not None:
|
||||
result = bg_lum >= 0.5
|
||||
_LIGHT_MODE_CACHE = result
|
||||
return result
|
||||
# 4. COLORFGBG (xterm/Konsole/urxvt)
|
||||
cfgbg = (os.environ.get("COLORFGBG") or "").strip()
|
||||
if cfgbg:
|
||||
last = cfgbg.split(";")[-1] if ";" in cfgbg else cfgbg
|
||||
if last.isdigit():
|
||||
bg = int(last)
|
||||
if bg in (7, 15):
|
||||
result = True
|
||||
_LIGHT_MODE_CACHE = result
|
||||
return result
|
||||
if 0 <= bg < 16:
|
||||
_LIGHT_MODE_CACHE = result
|
||||
return result
|
||||
# 5. OSC 11 query (best-effort, only when stdin/stdout are TTY)
|
||||
bg_color = _query_osc11_background()
|
||||
if bg_color:
|
||||
lum = _luminance_from_hex(bg_color)
|
||||
if lum is not None:
|
||||
result = lum >= 0.5
|
||||
_LIGHT_MODE_CACHE = result
|
||||
return result
|
||||
# 6. TERM_PROGRAM allow-list (currently empty)
|
||||
tp = (os.environ.get("TERM_PROGRAM") or "").strip()
|
||||
if tp in _LIGHT_DEFAULT_TERM_PROGRAMS:
|
||||
result = True
|
||||
except Exception:
|
||||
result = False
|
||||
_LIGHT_MODE_CACHE = result
|
||||
return result
|
||||
|
||||
|
||||
# Light-mode equivalents of skin colors that are unreadable on cream
|
||||
# Terminal.app backgrounds. Used by _SkinAwareAnsi to remap colors
|
||||
# at resolution time when light mode is detected.
|
||||
#
|
||||
# IMPORTANT: only remap colors that are used as STANDALONE foregrounds
|
||||
# on the terminal's background. Don't remap colors that are paired
|
||||
# with a dark bg (e.g. status bar text on bg:#1a1a2e) — those would
|
||||
# become invisible the OTHER direction (dark gray on dark navy).
|
||||
_LIGHT_MODE_REMAP: dict[str, str] = {
|
||||
# Original (dark-mode) -> Light-mode replacement (darker, readable)
|
||||
"#FFF8DC": "#1A1A1A", # cornsilk -> near-black
|
||||
"#FFD700": "#9A6B00", # gold -> dark goldenrod (readable on cream)
|
||||
"#FFBF00": "#8A5A00", # amber -> dark amber
|
||||
"#B8860B": "#5C4500", # dark goldenrod -> deeper brown (more contrast)
|
||||
"#DAA520": "#6B4F00", # goldenrod -> dark olive
|
||||
"#F1E6CF": "#1A1A1A", # cream -> near-black
|
||||
"#c9d1d9": "#24292F", # github-light fg
|
||||
"#EAF7FF": "#0F1B26", # ice
|
||||
"#F5F5F5": "#1A1A1A",
|
||||
"#FFF0D4": "#1A1A1A",
|
||||
"#CD7F32": "#8A4F1A", # bronze -> darker bronze
|
||||
"#FFEFB5": "#3A2A00",
|
||||
# NOTE: skipping #C0C0C0/#888888/#555555/#8B8682 — those are
|
||||
# status-bar foregrounds paired with dark navy bg, where dark
|
||||
# remap values would become invisible.
|
||||
}
|
||||
|
||||
|
||||
def _maybe_remap_for_light_mode(hex_color: str) -> str:
|
||||
"""If we're in light mode, remap a dark-mode-tuned color to a
|
||||
higher-contrast equivalent. No-op in dark mode."""
|
||||
if not _detect_light_mode():
|
||||
return hex_color
|
||||
if not hex_color or not hex_color.startswith("#"):
|
||||
return hex_color
|
||||
# Case-insensitive lookup
|
||||
upper = hex_color.upper()
|
||||
if upper in _LIGHT_MODE_REMAP_UPPER:
|
||||
return _LIGHT_MODE_REMAP_UPPER[upper]
|
||||
return hex_color
|
||||
|
||||
|
||||
# Pre-uppercased lookup table for case-insensitive remapping
|
||||
_LIGHT_MODE_REMAP_UPPER = {k.upper(): v for k, v in _LIGHT_MODE_REMAP.items()}
|
||||
|
||||
|
||||
def _install_skin_light_mode_hook() -> None:
|
||||
"""Wrap SkinConfig.get_color at import time so EVERY skin color read goes
|
||||
through the light-mode remap. Idempotent."""
|
||||
try:
|
||||
from hermes_cli.skin_engine import SkinConfig # type: ignore[import]
|
||||
except Exception:
|
||||
return
|
||||
if getattr(SkinConfig, "_hermes_light_mode_hook_installed", False):
|
||||
return
|
||||
_orig_get_color = SkinConfig.get_color
|
||||
|
||||
def _wrapped_get_color(self, key, fallback=""):
|
||||
value = _orig_get_color(self, key, fallback)
|
||||
try:
|
||||
return _maybe_remap_for_light_mode(value)
|
||||
except Exception:
|
||||
return value
|
||||
|
||||
SkinConfig.get_color = _wrapped_get_color # type: ignore[method-assign]
|
||||
SkinConfig._hermes_light_mode_hook_installed = True # type: ignore[attr-defined]
|
||||
|
||||
|
||||
_install_skin_light_mode_hook()
|
||||
|
||||
|
||||
# Prime the light-mode detection cache early (at module load) when
|
||||
# we're running interactively so OSC 11 happens before pt grabs the
|
||||
# tty. Skip for non-tty contexts (subagents, gateway, tests).
|
||||
try:
|
||||
if sys.stdin.isatty() and sys.stdout.isatty():
|
||||
_detect_light_mode()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
class _SkinAwareAnsi:
|
||||
"""Lazy ANSI escape that resolves from the skin engine on first use.
|
||||
|
||||
@@ -1290,7 +1540,12 @@ class _SkinAwareAnsi:
|
||||
|
||||
|
||||
_ACCENT = _SkinAwareAnsi("response_border", "#FFD700", bold=True)
|
||||
_DIM = _SkinAwareAnsi("banner_dim", "#B8860B")
|
||||
# Use ANSI dim+italic attributes (\x1b[2;3m) instead of a hardcoded
|
||||
# hex color so dim/thinking text inherits the terminal's default
|
||||
# foreground color and stays readable in both light and dark
|
||||
# Terminal.app modes. Hardcoded skin colors like #B8860B
|
||||
# (dark goldenrod) become invisible against light cream backgrounds.
|
||||
_DIM = "\x1b[2;3m"
|
||||
|
||||
|
||||
def _accent_hex() -> str:
|
||||
@@ -1710,43 +1965,7 @@ def _resolve_attachment_path(raw_path: str) -> Path | None:
|
||||
return resolved
|
||||
|
||||
|
||||
def _format_process_notification(evt: dict) -> "str | None":
|
||||
"""Format a process notification event into a [IMPORTANT: ...] message.
|
||||
|
||||
Handles both completion events (notify_on_complete) and watch pattern
|
||||
match events from the unified completion_queue.
|
||||
"""
|
||||
evt_type = evt.get("type", "completion")
|
||||
_sid = evt.get("session_id", "unknown")
|
||||
_cmd = evt.get("command", "unknown")
|
||||
|
||||
if evt_type == "watch_disabled":
|
||||
return f"[IMPORTANT: {evt.get('message', '')}]"
|
||||
|
||||
if evt_type == "watch_match":
|
||||
_pat = evt.get("pattern", "?")
|
||||
_out = evt.get("output", "")
|
||||
_sup = evt.get("suppressed", 0)
|
||||
text = (
|
||||
f"[IMPORTANT: Background process {_sid} matched "
|
||||
f"watch pattern \"{_pat}\".\n"
|
||||
f"Command: {_cmd}\n"
|
||||
f"Matched output:\n{_out}"
|
||||
)
|
||||
if _sup:
|
||||
text += f"\n({_sup} earlier matches were suppressed by rate limit)"
|
||||
text += "]"
|
||||
return text
|
||||
|
||||
# Default: completion event
|
||||
_exit = evt.get("exit_code", "?")
|
||||
_out = evt.get("output", "")
|
||||
return (
|
||||
f"[IMPORTANT: Background process {_sid} completed "
|
||||
f"(exit code {_exit}).\n"
|
||||
f"Command: {_cmd}\n"
|
||||
f"Output:\n{_out}]"
|
||||
)
|
||||
|
||||
|
||||
def _detect_file_drop(user_input: str) -> "dict | None":
|
||||
@@ -2980,25 +3199,27 @@ class HermesCLI:
|
||||
|
||||
@staticmethod
|
||||
def _scrollback_box_width(width: Optional[int] = None) -> int:
|
||||
"""Return a resize-safe width for printed scrollback box rules.
|
||||
"""Return the full viewport width for printed scrollback box rules.
|
||||
|
||||
Lines already printed to terminal scrollback are reflowed by the
|
||||
terminal emulator when the column count shrinks. A full-width response
|
||||
border drawn at, say, 200 columns will wrap into two or three rows of
|
||||
dashes after the user resizes to 80 columns, looking like duplicated
|
||||
separator lines (the family of bugs tracked by #18449, #19280, #22976).
|
||||
Previously this clamped to ``max(32, min(width, 56))`` as a defense
|
||||
against terminal-emulator reflow on column-shrink (#25975, salvaging
|
||||
#24403). That clamp made response/reasoning borders look stubby on
|
||||
any modern wide terminal. We now trust the prompt_toolkit
|
||||
``_output_screen_diff`` monkey-patch landed in #26137 (salvaging
|
||||
#25981) to keep chrome out of scrollback in the first place, and
|
||||
accept that an aggressive column-shrink may visually reflow already
|
||||
printed Panel borders — that's a cosmetic artifact of stamped
|
||||
scrollback history, not a live-render bug.
|
||||
|
||||
Keep decorative scrollback boxes intentionally narrower than the
|
||||
viewport so a moderate resize never triggers reflow. The live TUI
|
||||
footer (status bar, input rule) still uses the full width — only
|
||||
content that is *stamped into scrollback* needs this clamp.
|
||||
A small floor (32 cols) is kept so the box still renders on tiny
|
||||
terminals without negative ``'─' * (w - 2)`` math.
|
||||
"""
|
||||
if width is None:
|
||||
try:
|
||||
width = shutil.get_terminal_size((80, 24)).columns
|
||||
except Exception:
|
||||
width = 80
|
||||
return max(32, min(int(width or 80), 56))
|
||||
return max(32, int(width or 80))
|
||||
|
||||
def _tui_input_rule_height(self, position: str, width: Optional[int] = None) -> int:
|
||||
"""Return the visible height for the top/bottom input separator rules."""
|
||||
@@ -3113,8 +3334,11 @@ class HermesCLI:
|
||||
percent_label = f"{percent}%" if percent is not None else "--"
|
||||
duration_label = snapshot["duration"]
|
||||
|
||||
yolo_active = bool(os.getenv("HERMES_YOLO_MODE"))
|
||||
if width < 52:
|
||||
text = f"⚕ {snapshot['model_short']} · {duration_label}"
|
||||
if yolo_active:
|
||||
text += " · ⚠ YOLO"
|
||||
return self._trim_status_bar_text(text, width)
|
||||
if width < 76:
|
||||
parts = [f"⚕ {snapshot['model_short']}", percent_label]
|
||||
@@ -3122,6 +3346,8 @@ class HermesCLI:
|
||||
if compressions:
|
||||
parts.append(f"🗜️ {compressions}")
|
||||
parts.append(duration_label)
|
||||
if yolo_active:
|
||||
parts.append("⚠ YOLO")
|
||||
return self._trim_status_bar_text(" · ".join(parts), width)
|
||||
|
||||
if snapshot["context_length"]:
|
||||
@@ -3139,6 +3365,8 @@ class HermesCLI:
|
||||
prompt_elapsed = snapshot.get("prompt_elapsed")
|
||||
if prompt_elapsed:
|
||||
parts.append(prompt_elapsed)
|
||||
if yolo_active:
|
||||
parts.append("⚠ YOLO")
|
||||
return self._trim_status_bar_text(" │ ".join(parts), width)
|
||||
except Exception:
|
||||
return f"⚕ {self.model if getattr(self, 'model', None) else 'Hermes'}"
|
||||
@@ -3155,6 +3383,7 @@ class HermesCLI:
|
||||
# line and produce duplicated status bar rows over long sessions.
|
||||
width = self._get_tui_terminal_width()
|
||||
duration_label = snapshot["duration"]
|
||||
yolo_active = bool(os.getenv("HERMES_YOLO_MODE"))
|
||||
|
||||
if width < 52:
|
||||
frags = [
|
||||
@@ -3162,8 +3391,11 @@ class HermesCLI:
|
||||
("class:status-bar-strong", snapshot["model_short"]),
|
||||
("class:status-bar-dim", " · "),
|
||||
("class:status-bar-dim", duration_label),
|
||||
("class:status-bar", " "),
|
||||
]
|
||||
if yolo_active:
|
||||
frags.append(("class:status-bar-dim", " · "))
|
||||
frags.append(("class:status-bar-yolo", "⚠ YOLO"))
|
||||
frags.append(("class:status-bar", " "))
|
||||
else:
|
||||
percent = snapshot["context_percent"]
|
||||
percent_label = f"{percent}%" if percent is not None else "--"
|
||||
@@ -3181,8 +3413,11 @@ class HermesCLI:
|
||||
frags.extend([
|
||||
("class:status-bar-dim", " · "),
|
||||
("class:status-bar-dim", duration_label),
|
||||
("class:status-bar", " "),
|
||||
])
|
||||
if yolo_active:
|
||||
frags.append(("class:status-bar-dim", " · "))
|
||||
frags.append(("class:status-bar-yolo", "⚠ YOLO"))
|
||||
frags.append(("class:status-bar", " "))
|
||||
else:
|
||||
if snapshot["context_length"]:
|
||||
ctx_total = _format_context_length(snapshot["context_length"])
|
||||
@@ -3215,6 +3450,9 @@ class HermesCLI:
|
||||
if prompt_elapsed:
|
||||
frags.append(("class:status-bar-dim", " │ "))
|
||||
frags.append(("class:status-bar-dim", prompt_elapsed))
|
||||
if yolo_active:
|
||||
frags.append(("class:status-bar-dim", " │ "))
|
||||
frags.append(("class:status-bar-yolo", "⚠ YOLO"))
|
||||
frags.append(("class:status-bar", " "))
|
||||
|
||||
total_width = sum(self._status_bar_display_width(text) for _, text in frags)
|
||||
@@ -5961,6 +6199,38 @@ class HermesCLI:
|
||||
else:
|
||||
_cprint(f" ↻ Resumed session {target_id}{title_part} — no messages, starting fresh.")
|
||||
|
||||
def _handle_sessions_command(self, cmd_original: str) -> None:
|
||||
"""Handle /sessions [list|<id_or_title>] — browse or resume previous sessions.
|
||||
|
||||
Without arguments, prints the same recent-sessions table that /resume
|
||||
shows when called without a target, and tells the user how to resume.
|
||||
With an explicit subcommand or target, delegates to the resume flow so
|
||||
``/sessions <id>`` and ``/resume <id>`` behave identically.
|
||||
|
||||
The TUI ships an interactive picker overlay for this command; the
|
||||
classic CLI prints an inline list because there is no equivalent
|
||||
overlay primitive here. Without this handler the canonical name
|
||||
``sessions`` falls through ``process_command``'s elif chain and
|
||||
prints ``Unknown command: sessions`` even though the command is
|
||||
registered in the central COMMAND_REGISTRY.
|
||||
"""
|
||||
parts = cmd_original.split(None, 1)
|
||||
arg = parts[1].strip() if len(parts) > 1 else ""
|
||||
sub = arg.lower()
|
||||
|
||||
# Bare /sessions or /sessions list — show recent sessions inline.
|
||||
if not arg or sub in {"list", "ls", "browse"}:
|
||||
if not self._session_db:
|
||||
from hermes_state import format_session_db_unavailable
|
||||
_cprint(f" {format_session_db_unavailable()}")
|
||||
return
|
||||
if not self._show_recent_sessions(reason="sessions"):
|
||||
_cprint(" (._.) No previous sessions yet.")
|
||||
return
|
||||
|
||||
# /sessions <id_or_title> behaves the same as /resume <id_or_title>.
|
||||
self._handle_resume_command(f"/resume {arg}")
|
||||
|
||||
def _handle_branch_command(self, cmd_original: str) -> None:
|
||||
"""Handle /branch [name] — fork the current session into a new independent copy.
|
||||
|
||||
@@ -7540,6 +7810,8 @@ class HermesCLI:
|
||||
self.new_session(title=title)
|
||||
elif canonical == "resume":
|
||||
self._handle_resume_command(cmd_original)
|
||||
elif canonical == "sessions":
|
||||
self._handle_sessions_command(cmd_original)
|
||||
elif canonical == "model":
|
||||
self._handle_model_switch(cmd_original)
|
||||
elif canonical == "codex-runtime":
|
||||
@@ -7913,8 +8185,8 @@ class HermesCLI:
|
||||
from hermes_cli.skin_engine import get_active_skin
|
||||
_skin = get_active_skin()
|
||||
label = _skin.get_branding("response_label", "⚕ Hermes")
|
||||
_resp_color = _skin.get_color("response_border", "#CD7F32")
|
||||
_resp_text = _skin.get_color("banner_text", "#FFF8DC")
|
||||
_resp_color = _maybe_remap_for_light_mode(_skin.get_color("response_border", "#CD7F32"))
|
||||
_resp_text = _maybe_remap_for_light_mode(_skin.get_color("banner_text", "#FFF8DC"))
|
||||
except Exception:
|
||||
label = "⚕ Hermes"
|
||||
_resp_color = "#CD7F32"
|
||||
@@ -8515,7 +8787,8 @@ class HermesCLI:
|
||||
|
||||
set_active_skin(new_skin)
|
||||
_ACCENT.reset() # Re-resolve ANSI color for the new skin
|
||||
_DIM.reset() # Re-resolve dim/secondary ANSI color for the new skin
|
||||
# _DIM is now a fixed dim+italic ANSI escape (terminal-default fg)
|
||||
# so it doesn't need re-resolving on skin switch.
|
||||
if save_config_value("display.skin", new_skin):
|
||||
print(f" Skin set to: {new_skin} (saved)")
|
||||
else:
|
||||
@@ -10894,12 +11167,12 @@ class HermesCLI:
|
||||
from hermes_cli.skin_engine import get_active_skin
|
||||
_skin = get_active_skin()
|
||||
label = _skin.get_branding("response_label", "⚕ Hermes")
|
||||
_resp_color = _skin.get_color("response_border", "#CD7F32")
|
||||
_resp_text = _skin.get_color("banner_text", "#FFF8DC")
|
||||
_resp_color = _maybe_remap_for_light_mode(_skin.get_color("response_border", "#CD7F32"))
|
||||
_resp_text = _maybe_remap_for_light_mode(_skin.get_color("banner_text", "#FFF8DC"))
|
||||
except Exception:
|
||||
label = "⚕ Hermes"
|
||||
_resp_color = "#CD7F32"
|
||||
_resp_text = "#FFF8DC"
|
||||
_resp_color = _maybe_remap_for_light_mode("#CD7F32")
|
||||
_resp_text = _maybe_remap_for_light_mode("#FFF8DC")
|
||||
|
||||
is_error_response = result and (result.get("failed") or result.get("partial"))
|
||||
already_streamed = self._stream_started and self._stream_box_opened and not is_error_response
|
||||
@@ -11138,13 +11411,48 @@ class HermesCLI:
|
||||
return "".join(text for _, text in self._get_tui_prompt_fragments())
|
||||
|
||||
def _build_tui_style_dict(self) -> dict[str, str]:
|
||||
"""Layer the active skin's prompt_toolkit colors over the base TUI style."""
|
||||
"""Layer the active skin's prompt_toolkit colors over the base TUI style.
|
||||
|
||||
Also rewrites any hex-color tokens in the resulting style strings
|
||||
to their light-mode equivalents (via _LIGHT_MODE_REMAP) when the
|
||||
terminal is detected as light. This makes the chrome readable
|
||||
on cream Terminal.app backgrounds without per-skin overrides.
|
||||
"""
|
||||
style_dict = dict(getattr(self, "_tui_style_base", {}) or {})
|
||||
try:
|
||||
from hermes_cli.skin_engine import get_prompt_toolkit_style_overrides
|
||||
style_dict.update(get_prompt_toolkit_style_overrides())
|
||||
except Exception:
|
||||
pass
|
||||
# Light-mode remap on the style strings. Each value is a pt
|
||||
# style string like "bg:#1a1a2e #C0C0C0 bold" — split on space,
|
||||
# rewrite any "#XXX" tokens (including "bg:#XXX") through the
|
||||
# light-mode remap, rejoin.
|
||||
#
|
||||
# CRITICAL: skip the remap entirely when a style string already
|
||||
# specifies its own bg (e.g. status-bar / completion-menu styles
|
||||
# with `bg:#1a1a2e ...`). Those colors were tuned for that
|
||||
# specific dark bg and remapping the FG to a dark equivalent
|
||||
# would produce dark-on-dark (invisible). The terminal's BG
|
||||
# mode is irrelevant — what matters is the bg the style itself
|
||||
# paints.
|
||||
try:
|
||||
if _detect_light_mode():
|
||||
def _remap_value(v: str) -> str:
|
||||
if not v:
|
||||
return v
|
||||
tokens = v.split()
|
||||
has_explicit_bg = any(t.startswith("bg:") for t in tokens)
|
||||
if has_explicit_bg:
|
||||
# The style paints its own bg — leave its fg alone.
|
||||
return v
|
||||
return " ".join(
|
||||
_maybe_remap_for_light_mode(t) if t.startswith("#") else t
|
||||
for t in tokens
|
||||
)
|
||||
style_dict = {k: _remap_value(v or "") for k, v in style_dict.items()}
|
||||
except Exception:
|
||||
pass
|
||||
return style_dict
|
||||
|
||||
def _apply_tui_skin_style(self) -> bool:
|
||||
@@ -11230,6 +11538,13 @@ class HermesCLI:
|
||||
|
||||
def run(self):
|
||||
"""Run the interactive CLI loop with persistent input at bottom."""
|
||||
# Detect light/dark terminal mode now (before pt grabs the tty).
|
||||
# Caches the result so subsequent _hex_to_ansi / style calls
|
||||
# don't risk re-querying mid-render.
|
||||
try:
|
||||
_detect_light_mode()
|
||||
except Exception:
|
||||
pass
|
||||
# Push the entire TUI to the bottom of the terminal so the banner,
|
||||
# responses, and prompt all appear pinned to the bottom — empty
|
||||
# space stays above, not below. This prints enough blank lines to
|
||||
@@ -12993,11 +13308,16 @@ class HermesCLI:
|
||||
|
||||
# Style for the application
|
||||
self._tui_style_base = {
|
||||
'input-area': '#FFF8DC',
|
||||
'placeholder': '#555555 italic',
|
||||
'prompt': '#FFF8DC',
|
||||
# Input area / prompt: empty style strings inherit the
|
||||
# terminal's default foreground/background, so the typed
|
||||
# text is readable in both light and dark Terminal.app
|
||||
# color schemes. (Hardcoding a near-white #FFF8DC made
|
||||
# input invisible on light backgrounds.)
|
||||
'input-area': '',
|
||||
'placeholder': '#888888 italic',
|
||||
'prompt': '',
|
||||
'prompt-working': '#888888 italic',
|
||||
'hint': '#555555 italic',
|
||||
'hint': '#888888 italic',
|
||||
'status-bar': 'bg:#1a1a2e #C0C0C0',
|
||||
'status-bar-strong': 'bg:#1a1a2e #FFD700 bold',
|
||||
'status-bar-dim': 'bg:#1a1a2e #8B8682',
|
||||
@@ -13005,6 +13325,7 @@ class HermesCLI:
|
||||
'status-bar-warn': 'bg:#1a1a2e #FFD700 bold',
|
||||
'status-bar-bad': 'bg:#1a1a2e #FF8C00 bold',
|
||||
'status-bar-critical': 'bg:#1a1a2e #FF6B6B bold',
|
||||
'status-bar-yolo': 'bg:#1a1a2e #FF4444 bold',
|
||||
# Bronze horizontal rules around the input area
|
||||
'input-rule': '#CD7F32',
|
||||
# Clipboard image attachment badges
|
||||
@@ -13056,19 +13377,70 @@ class HermesCLI:
|
||||
self._app = app # Store reference for clarify_callback
|
||||
|
||||
# ── Fix ghost status-bar lines on terminal resize ──────────────
|
||||
# When the terminal shrinks (e.g. un-maximize), the emulator reflows
|
||||
# the previously-rendered full-width rows (status bar, input rules)
|
||||
# into multiple narrower rows. prompt_toolkit's _on_resize handler
|
||||
# only cursor_up()s by the stored layout height, missing the extra
|
||||
# rows created by reflow — leaving ghost duplicates visible.
|
||||
# Resize handling: monkey-patch prompt_toolkit's _output_screen_diff
|
||||
# to suppress the deliberate "reserve vertical space" scroll-up.
|
||||
#
|
||||
# It's not just column-shrink: widening, row-shrinking, and
|
||||
# multiplexer-driven SIGWINCH-less redraws (cmux / tmux tab switch)
|
||||
# all produce the same class of drift, where the renderer's tracked
|
||||
# _cursor_pos.y no longer matches terminal reality. The only reliable
|
||||
# recovery is a full screen-clear (\x1b[2J\x1b[H) before the next
|
||||
# redraw, so we force one on every resize rather than trying to
|
||||
# compute the exact drift.
|
||||
# Background: prompt_toolkit's renderer (renderer.py L232-242)
|
||||
# explicitly moves the cursor to the bottom of the canvas after
|
||||
# painting "to make sure the terminal scrolls up, even when the
|
||||
# lower lines of the canvas just contain whitespace". In
|
||||
# non-fullscreen mode this scrolls chrome content (status bar,
|
||||
# input rules) into terminal scrollback on every render. When
|
||||
# the terminal column-shrinks, the emulator reflows the previously
|
||||
# rendered full-width rows into multiple narrower rows that get
|
||||
# pushed up — leaving ghost duplicates AND polluting scrollback.
|
||||
# Same issue as pt #29 (open since 2014), #1675, #1933.
|
||||
#
|
||||
# Surgical fix: wrap _output_screen_diff so that when its internal
|
||||
# `if current_height > previous_screen.height` branch fires (the
|
||||
# one that does the bottom-cursor-move), we make it fall through
|
||||
# by inflating previous_screen.height first.
|
||||
try:
|
||||
import prompt_toolkit.renderer as _pt_renderer
|
||||
from prompt_toolkit.renderer import _output_screen_diff as _orig_osd
|
||||
|
||||
if not getattr(_pt_renderer, "_hermes_osd_patched", False):
|
||||
def _patched_output_screen_diff(
|
||||
app, output, screen, current_pos, color_depth,
|
||||
previous_screen, last_style, is_done, full_screen,
|
||||
attrs_for_style_string, style_string_has_style,
|
||||
size, previous_width,
|
||||
):
|
||||
"""Wraps pt's _output_screen_diff to suppress the
|
||||
reserve-vertical-space scroll (renderer.py L232-242).
|
||||
|
||||
Strategy: ONLY when previous_screen is non-None and
|
||||
its current height is genuinely smaller than the new
|
||||
screen's height, inflate it to match. This prevents
|
||||
the bottom-cursor-move at L242 without changing any
|
||||
other code path's behavior.
|
||||
|
||||
Critical: do NOT replace a None previous_screen with
|
||||
a fresh Screen() — that would skip the proper
|
||||
reset_attributes()+erase_down() at L178-185 which
|
||||
fires when previous_screen is None (first-paint /
|
||||
width-change). Without that reset, ANSI styles
|
||||
leak between renders.
|
||||
"""
|
||||
try:
|
||||
if previous_screen is not None and hasattr(previous_screen, "height"):
|
||||
if previous_screen.height < screen.height:
|
||||
previous_screen.height = screen.height
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return _orig_osd(
|
||||
app, output, screen, current_pos, color_depth,
|
||||
previous_screen, last_style, is_done, full_screen,
|
||||
attrs_for_style_string, style_string_has_style,
|
||||
size, previous_width,
|
||||
)
|
||||
|
||||
_pt_renderer._output_screen_diff = _patched_output_screen_diff
|
||||
_pt_renderer._hermes_osd_patched = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
_original_on_resize = app._on_resize
|
||||
|
||||
def _resize_clear_ghosts():
|
||||
@@ -13110,16 +13482,8 @@ class HermesCLI:
|
||||
# and watch pattern matches) while agent is idle.
|
||||
try:
|
||||
from tools.process_registry import process_registry
|
||||
if not process_registry.completion_queue.empty():
|
||||
evt = process_registry.completion_queue.get_nowait()
|
||||
# Skip if the agent already consumed this via wait/poll/log
|
||||
_evt_sid = evt.get("session_id", "")
|
||||
if evt.get("type") == "completion" and process_registry.is_completion_consumed(_evt_sid):
|
||||
pass # already delivered via tool result
|
||||
else:
|
||||
_synth = _format_process_notification(evt)
|
||||
if _synth:
|
||||
self._pending_input.put(_synth)
|
||||
for _evt, _synth in process_registry.drain_notifications():
|
||||
self._pending_input.put(_synth)
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
@@ -13227,15 +13591,8 @@ class HermesCLI:
|
||||
# that arrived while the agent was running.
|
||||
try:
|
||||
from tools.process_registry import process_registry
|
||||
while not process_registry.completion_queue.empty():
|
||||
evt = process_registry.completion_queue.get_nowait()
|
||||
# Skip if the agent already consumed this via wait/poll/log
|
||||
_evt_sid = evt.get("session_id", "")
|
||||
if evt.get("type") == "completion" and process_registry.is_completion_consumed(_evt_sid):
|
||||
continue # already delivered via tool result
|
||||
_synth = _format_process_notification(evt)
|
||||
if _synth:
|
||||
self._pending_input.put(_synth)
|
||||
for _evt, _synth in process_registry.drain_notifications():
|
||||
self._pending_input.put(_synth)
|
||||
except Exception:
|
||||
pass # Non-fatal — don't break the main loop
|
||||
|
||||
@@ -13367,6 +13724,30 @@ class HermesCLI:
|
||||
self._print_exit_summary()
|
||||
return
|
||||
|
||||
# On macOS with uv-managed Python, kqueue's selector cannot register
|
||||
# fd 0, raising OSError(EINVAL) from kqueue.control() when prompt_toolkit
|
||||
# calls loop.add_reader (#6393). Probe kqueue and, if it can't watch
|
||||
# stdin, switch to a SelectSelector-backed event loop policy.
|
||||
if sys.platform == "darwin":
|
||||
try:
|
||||
import selectors as _selectors
|
||||
if hasattr(_selectors, "KqueueSelector"):
|
||||
_kq = _selectors.KqueueSelector()
|
||||
try:
|
||||
_kq.register(0, _selectors.EVENT_READ)
|
||||
_kq.unregister(0)
|
||||
finally:
|
||||
_kq.close()
|
||||
except (OSError, ValueError, KeyError):
|
||||
import asyncio as _aio_probe
|
||||
import selectors as _selectors
|
||||
|
||||
class _SelectEventLoopPolicy(_aio_probe.DefaultEventLoopPolicy):
|
||||
def new_event_loop(self):
|
||||
return _aio_probe.SelectorEventLoop(_selectors.SelectSelector())
|
||||
|
||||
_aio_probe.set_event_loop_policy(_SelectEventLoopPolicy())
|
||||
|
||||
# Run the application with patch_stdout for proper output handling
|
||||
try:
|
||||
with patch_stdout():
|
||||
@@ -13387,12 +13768,20 @@ class HermesCLI:
|
||||
except (KeyError, OSError) as _stdin_err:
|
||||
# Catch selector registration failures from broken stdin (#6393)
|
||||
# and I/O errors from broken stdout during interrupt (#13710).
|
||||
if isinstance(_stdin_err, OSError) and getattr(_stdin_err, "errno", None) == errno.EIO:
|
||||
_errno = getattr(_stdin_err, "errno", None) if isinstance(_stdin_err, OSError) else None
|
||||
_msg = str(_stdin_err)
|
||||
if _errno == errno.EIO:
|
||||
pass # suppress broken-stdout I/O errors on interrupt (#13710)
|
||||
elif "is not registered" in str(_stdin_err) or "Bad file descriptor" in str(_stdin_err):
|
||||
elif (
|
||||
_errno in (errno.EINVAL, errno.EBADF)
|
||||
or "is not registered" in _msg
|
||||
or "Bad file descriptor" in _msg
|
||||
or "Invalid argument" in _msg
|
||||
):
|
||||
print(
|
||||
f"\nError: stdin is not usable ({_stdin_err}).\n"
|
||||
"This can happen with certain Python installations (e.g. uv-managed cPython on macOS).\n"
|
||||
"This can happen with certain Python installations (e.g. uv-managed cPython on macOS)\n"
|
||||
"where kqueue cannot register fd 0.\n"
|
||||
"Try reinstalling Python via pyenv or Homebrew, then re-run: hermes setup"
|
||||
)
|
||||
else:
|
||||
|
||||
+56
-11
@@ -645,6 +645,44 @@ def get_job(job_id: str) -> Optional[Dict[str, Any]]:
|
||||
return None
|
||||
|
||||
|
||||
class AmbiguousJobReference(LookupError):
|
||||
"""Raised when a job name matches more than one job."""
|
||||
|
||||
def __init__(self, ref: str, matches: List[Dict[str, Any]]):
|
||||
self.ref = ref
|
||||
self.matches = matches
|
||||
ids = ", ".join(m["id"] for m in matches)
|
||||
super().__init__(
|
||||
f"Job name '{ref}' is ambiguous — matches {len(matches)} jobs: {ids}. "
|
||||
f"Use the job ID instead."
|
||||
)
|
||||
|
||||
|
||||
def resolve_job_ref(ref: str) -> Optional[Dict[str, Any]]:
|
||||
"""Resolve a job reference (ID or name) to a job record.
|
||||
|
||||
- Exact ID match wins (works even if a different job's name equals this ID).
|
||||
- Otherwise, case-insensitive name match.
|
||||
- If a name matches more than one job, raises AmbiguousJobReference so the
|
||||
caller can surface the matching IDs rather than silently picking one.
|
||||
"""
|
||||
if not ref:
|
||||
return None
|
||||
jobs = load_jobs()
|
||||
for job in jobs:
|
||||
if job["id"] == ref:
|
||||
return _normalize_job_record(job)
|
||||
ref_lower = ref.lower()
|
||||
name_matches = [j for j in jobs if (j.get("name") or "").lower() == ref_lower]
|
||||
if not name_matches:
|
||||
return None
|
||||
if len(name_matches) > 1:
|
||||
raise AmbiguousJobReference(
|
||||
ref, [_normalize_job_record(j) for j in name_matches]
|
||||
)
|
||||
return _normalize_job_record(name_matches[0])
|
||||
|
||||
|
||||
def list_jobs(include_disabled: bool = False) -> List[Dict[str, Any]]:
|
||||
"""List all jobs, optionally including disabled ones."""
|
||||
jobs = [_normalize_job_record(j) for j in load_jobs()]
|
||||
@@ -702,9 +740,12 @@ def update_job(job_id: str, updates: Dict[str, Any]) -> Optional[Dict[str, Any]]
|
||||
|
||||
|
||||
def pause_job(job_id: str, reason: Optional[str] = None) -> Optional[Dict[str, Any]]:
|
||||
"""Pause a job without deleting it."""
|
||||
"""Pause a job without deleting it. Accepts a job ID or name."""
|
||||
job = resolve_job_ref(job_id)
|
||||
if not job:
|
||||
return None
|
||||
return update_job(
|
||||
job_id,
|
||||
job["id"],
|
||||
{
|
||||
"enabled": False,
|
||||
"state": "paused",
|
||||
@@ -715,14 +756,14 @@ def pause_job(job_id: str, reason: Optional[str] = None) -> Optional[Dict[str, A
|
||||
|
||||
|
||||
def resume_job(job_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Resume a paused job and compute the next future run from now."""
|
||||
job = get_job(job_id)
|
||||
"""Resume a paused job and compute the next future run from now. Accepts a job ID or name."""
|
||||
job = resolve_job_ref(job_id)
|
||||
if not job:
|
||||
return None
|
||||
|
||||
next_run_at = compute_next_run(job["schedule"])
|
||||
return update_job(
|
||||
job_id,
|
||||
job["id"],
|
||||
{
|
||||
"enabled": True,
|
||||
"state": "scheduled",
|
||||
@@ -734,12 +775,12 @@ def resume_job(job_id: str) -> Optional[Dict[str, Any]]:
|
||||
|
||||
|
||||
def trigger_job(job_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Schedule a job to run on the next scheduler tick."""
|
||||
job = get_job(job_id)
|
||||
"""Schedule a job to run on the next scheduler tick. Accepts a job ID or name."""
|
||||
job = resolve_job_ref(job_id)
|
||||
if not job:
|
||||
return None
|
||||
return update_job(
|
||||
job_id,
|
||||
job["id"],
|
||||
{
|
||||
"enabled": True,
|
||||
"state": "scheduled",
|
||||
@@ -751,14 +792,18 @@ def trigger_job(job_id: str) -> Optional[Dict[str, Any]]:
|
||||
|
||||
|
||||
def remove_job(job_id: str) -> bool:
|
||||
"""Remove a job by ID."""
|
||||
"""Remove a job by ID or name."""
|
||||
job = resolve_job_ref(job_id)
|
||||
if not job:
|
||||
return False
|
||||
canonical_id = job["id"]
|
||||
jobs = load_jobs()
|
||||
original_len = len(jobs)
|
||||
jobs = [j for j in jobs if j["id"] != job_id]
|
||||
jobs = [j for j in jobs if j["id"] != canonical_id]
|
||||
if len(jobs) < original_len:
|
||||
save_jobs(jobs)
|
||||
# Clean up output directory to prevent orphaned dirs accumulating
|
||||
job_output_dir = OUTPUT_DIR / job_id
|
||||
job_output_dir = OUTPUT_DIR / canonical_id
|
||||
if job_output_dir.exists():
|
||||
shutil.rmtree(job_output_dir)
|
||||
return True
|
||||
|
||||
@@ -1,324 +0,0 @@
|
||||
# Hermes-Agent Atropos Environments
|
||||
|
||||
This directory contains the integration layer between **hermes-agent's** tool-calling capabilities and the **Atropos** RL training framework. It provides everything needed to run agentic LLMs through multi-turn tool-calling loops, score their output with arbitrary reward functions, and feed results into Atropos for training or evaluation.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
```
|
||||
Atropos Framework
|
||||
┌───────────────────────┐
|
||||
│ BaseEnv │ (atroposlib)
|
||||
│ - Server management │
|
||||
│ - Worker scheduling │
|
||||
│ - Wandb logging │
|
||||
│ - CLI (serve/process/ │
|
||||
│ evaluate) │
|
||||
└───────────┬───────────┘
|
||||
│ inherits
|
||||
┌───────────┴───────────┐
|
||||
│ HermesAgentBaseEnv │ hermes_base_env.py
|
||||
│ - Terminal backend │
|
||||
│ - Tool resolution │
|
||||
│ - Agent loop │
|
||||
│ - ToolContext │
|
||||
│ - Async patches │
|
||||
└───────────┬───────────┘
|
||||
│ inherits
|
||||
┌─────────────────┼─────────────────┐
|
||||
│ │ │
|
||||
TerminalTestEnv HermesSweEnv TerminalBench2EvalEnv
|
||||
(stack testing) (SWE training) (TB2 benchmark eval)
|
||||
```
|
||||
|
||||
### Inheritance Chain
|
||||
|
||||
**BaseEnv** (from `atroposlib`) is the Atropos base class. It provides:
|
||||
- Server management (OpenAI-compatible API servers, VLLM, SGLang)
|
||||
- Worker scheduling for parallel rollouts
|
||||
- Wandb integration for metrics and rollout logging
|
||||
- CLI interface with three subcommands: `serve`, `process`, `evaluate`
|
||||
- `evaluate_log()` for saving eval results to JSON + samples.jsonl
|
||||
|
||||
**HermesAgentBaseEnv** (`hermes_base_env.py`) extends BaseEnv with hermes-agent specifics:
|
||||
- Sets `os.environ["TERMINAL_ENV"]` to configure the terminal backend (local, docker, ssh, singularity, modal, daytona, vercel_sandbox)
|
||||
- Resolves hermes-agent toolsets via `_resolve_tools_for_group()` (calls `get_tool_definitions()` which queries `tools/registry.py`)
|
||||
- Implements `collect_trajectory()` which runs the full agent loop and computes rewards
|
||||
- Supports two-phase operation (Phase 1: OpenAI server, Phase 2: VLLM ManagedServer)
|
||||
- Applies monkey patches for async-safe tool operation at import time
|
||||
|
||||
Concrete environments inherit from `HermesAgentBaseEnv` and implement:
|
||||
- `setup()` -- Load dataset, initialize state
|
||||
- `get_next_item()` -- Return the next item for rollout
|
||||
- `format_prompt()` -- Convert a dataset item into the user message
|
||||
- `compute_reward()` -- Score the rollout using ToolContext
|
||||
- `evaluate()` -- Periodic evaluation logic
|
||||
|
||||
## Core Components
|
||||
|
||||
### Agent Loop (`agent_loop.py`)
|
||||
|
||||
`HermesAgentLoop` is the reusable multi-turn agent engine. It runs the same pattern as hermes-agent's `run_agent.py`:
|
||||
|
||||
1. Send messages + tools to the API via `server.chat_completion()`
|
||||
2. If the response contains `tool_calls`, execute each one via `handle_function_call()` (which delegates to `tools/registry.py`'s `dispatch()`)
|
||||
3. Append tool results to the conversation and go back to step 1
|
||||
4. If the response has no tool_calls, the agent is done
|
||||
|
||||
Tool calls are executed in a thread pool (`run_in_executor`) so backends that use `asyncio.run()` internally (Modal, Docker) don't deadlock inside Atropos's event loop.
|
||||
|
||||
Returns an `AgentResult` containing the full conversation history, turn count, reasoning content per turn, tool errors, and optional ManagedServer state (for Phase 2).
|
||||
|
||||
### Tool Context (`tool_context.py`)
|
||||
|
||||
`ToolContext` is a per-rollout handle that gives reward/verification functions direct access to **all** hermes-agent tools, scoped to the rollout's `task_id`. The same `task_id` means the terminal/browser session is the SAME one the model used during its rollout -- all state (files, processes, browser tabs) is preserved.
|
||||
|
||||
```python
|
||||
async def compute_reward(self, item, result, ctx: ToolContext):
|
||||
# Run tests in the model's terminal sandbox
|
||||
test = ctx.terminal("pytest -v")
|
||||
if test["exit_code"] == 0:
|
||||
return 1.0
|
||||
|
||||
# Check if a file was created
|
||||
content = ctx.read_file("/workspace/solution.py")
|
||||
if content.get("content"):
|
||||
return 0.5
|
||||
|
||||
# Download files locally for verification (binary-safe)
|
||||
ctx.download_file("/remote/output.bin", "/local/output.bin")
|
||||
|
||||
return 0.0
|
||||
```
|
||||
|
||||
Available methods:
|
||||
- **Terminal**: `terminal(command, timeout)` -- run shell commands
|
||||
- **Files**: `read_file(path)`, `write_file(path, content)`, `search(query, path)`
|
||||
- **Transfers**: `upload_file()`, `upload_dir()`, `download_file()`, `download_dir()` -- binary-safe file transfers between host and sandbox
|
||||
- **Web**: `web_search(query)`, `web_extract(urls)`
|
||||
- **Browser**: `browser_navigate(url)`, `browser_snapshot()`
|
||||
- **Generic**: `call_tool(name, args)` -- call any hermes-agent tool by name
|
||||
- **Cleanup**: `cleanup()` -- release all resources (called automatically after `compute_reward`)
|
||||
|
||||
### Patches (`patches.py`)
|
||||
|
||||
**Problem**: Some hermes-agent tools use `asyncio.run()` internally (e.g., the Modal backend). This crashes when called from inside Atropos's event loop because `asyncio.run()` cannot be nested.
|
||||
|
||||
**Solution**: `ModalEnvironment` uses a dedicated `_AsyncWorker` background thread with its own event loop. The calling code sees a sync interface, but internally all async Modal SDK calls happen on the worker thread so they don't conflict with Atropos's loop. This is built directly into `tools/environments/modal.py` — no monkey-patching required.
|
||||
|
||||
`patches.py` is now a no-op (kept for backward compatibility with imports).
|
||||
|
||||
### Tool Call Parsers (`tool_call_parsers/`)
|
||||
|
||||
Client-side parsers that extract structured `tool_calls` from raw model output text. Used in **Phase 2** (VLLM server type) where ManagedServer's `/generate` endpoint returns raw text without tool call parsing.
|
||||
|
||||
Each parser is a standalone reimplementation of the corresponding VLLM parser's `extract_tool_calls()` logic. No VLLM dependency -- only standard library (`re`, `json`, `uuid`) and `openai` types.
|
||||
|
||||
Available parsers:
|
||||
- `hermes` -- Hermes/ChatML `<tool_call>` XML format
|
||||
- `mistral` -- Mistral `[TOOL_CALLS]` format
|
||||
- `llama3_json` -- Llama 3 JSON tool calling
|
||||
- `qwen` -- Qwen tool calling format
|
||||
- `qwen3_coder` -- Qwen3 Coder format
|
||||
- `deepseek_v3` -- DeepSeek V3 format
|
||||
- `deepseek_v3_1` -- DeepSeek V3.1 format
|
||||
- `kimi_k2` -- Kimi K2 format
|
||||
- `longcat` -- Longcat format
|
||||
- `glm45` / `glm47` -- GLM model formats
|
||||
|
||||
Usage:
|
||||
```python
|
||||
from environments.tool_call_parsers import get_parser
|
||||
|
||||
parser = get_parser("hermes")
|
||||
content, tool_calls = parser.parse(raw_model_output)
|
||||
```
|
||||
|
||||
In Phase 1 (OpenAI server type), these parsers are not needed -- the server handles tool call parsing natively.
|
||||
|
||||
## Two-Phase Operation
|
||||
|
||||
### Phase 1: OpenAI Server (Evaluation / SFT Data Generation)
|
||||
|
||||
Uses `server.chat_completion()` with `tools=` parameter. The server (VLLM, SGLang, OpenRouter, OpenAI) handles tool call parsing natively. Returns `ChatCompletion` objects with structured `tool_calls`.
|
||||
|
||||
- Good for: evaluation, SFT data generation, testing
|
||||
- Run with: `serve` (with `run-api`), `process`, or `evaluate` subcommands
|
||||
- Placeholder tokens are created for the Atropos pipeline
|
||||
|
||||
### Phase 2: VLLM ManagedServer (Full RL Training)
|
||||
|
||||
Uses ManagedServer for exact token IDs + logprobs via `/generate`. Client-side tool call parser (from `tool_call_parsers/`) reconstructs structured `tool_calls` from raw output.
|
||||
|
||||
- Good for: full RL training with GRPO/PPO
|
||||
- Run with: `serve` subcommand
|
||||
- Real tokens, masks, and logprobs flow through the pipeline
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
environments/
|
||||
├── README.md # This file
|
||||
├── __init__.py # Package exports
|
||||
├── hermes_base_env.py # Abstract base (HermesAgentBaseEnv)
|
||||
├── agent_loop.py # Multi-turn agent engine (HermesAgentLoop)
|
||||
├── tool_context.py # Per-rollout tool access for reward functions
|
||||
├── patches.py # Async-safety patches for Modal backend
|
||||
│
|
||||
├── tool_call_parsers/ # Phase 2 client-side parsers
|
||||
│ ├── __init__.py # Registry + base class
|
||||
│ ├── hermes_parser.py
|
||||
│ ├── mistral_parser.py
|
||||
│ ├── llama_parser.py
|
||||
│ ├── qwen_parser.py
|
||||
│ ├── qwen3_coder_parser.py
|
||||
│ ├── deepseek_v3_parser.py
|
||||
│ ├── deepseek_v3_1_parser.py
|
||||
│ ├── kimi_k2_parser.py
|
||||
│ ├── longcat_parser.py
|
||||
│ ├── glm45_parser.py
|
||||
│ └── glm47_parser.py
|
||||
│
|
||||
├── terminal_test_env/ # Stack validation environment
|
||||
│ └── terminal_test_env.py
|
||||
│
|
||||
├── hermes_swe_env/ # SWE-bench style training environment
|
||||
│ └── hermes_swe_env.py
|
||||
│
|
||||
└── benchmarks/ # Evaluation benchmarks
|
||||
├── terminalbench_2/ # 89 terminal tasks, Modal sandboxes
|
||||
│ └── terminalbench2_env.py
|
||||
├── tblite/ # 100 calibrated tasks (fast TB2 proxy)
|
||||
│ └── tblite_env.py
|
||||
└── yc_bench/ # Long-horizon strategic benchmark
|
||||
└── yc_bench_env.py
|
||||
```
|
||||
|
||||
## Concrete Environments
|
||||
|
||||
### TerminalTestEnv (`terminal_test_env/`)
|
||||
|
||||
A self-contained environment with inline tasks (no external dataset needed) for validating the full stack end-to-end. Each task asks the model to create a file at a known path, and the verifier checks the content matches.
|
||||
|
||||
```bash
|
||||
# Serve mode (needs run-api)
|
||||
run-api
|
||||
python environments/terminal_test_env/terminal_test_env.py serve
|
||||
|
||||
# Process mode (no run-api, saves to JSONL)
|
||||
python environments/terminal_test_env/terminal_test_env.py process \
|
||||
--env.data_path_to_save_groups terminal_test_output.jsonl
|
||||
```
|
||||
|
||||
### HermesSweEnv (`hermes_swe_env/`)
|
||||
|
||||
SWE-bench style training environment. The model gets a coding task, uses terminal + file + web tools to solve it, and the reward function runs tests in the same Modal sandbox.
|
||||
|
||||
```bash
|
||||
python environments/hermes_swe_env/hermes_swe_env.py serve \
|
||||
--openai.model_name YourModel \
|
||||
--env.dataset_name bigcode/humanevalpack \
|
||||
--env.terminal_backend modal
|
||||
```
|
||||
|
||||
### TerminalBench2EvalEnv (`benchmarks/terminalbench_2/`)
|
||||
|
||||
**Eval-only** environment for the Terminal-Bench 2.0 benchmark (89 tasks). Each task gets a pre-built Docker Hub image, a natural language instruction, and a test suite. The agent uses terminal + file tools to solve the task, then the test suite verifies correctness.
|
||||
|
||||
Follows the standard Atropos eval pattern (like GPQA, MMLU, etc.):
|
||||
- Run via `evaluate` subcommand (no `run-api` needed)
|
||||
- `setup()` loads the dataset, `evaluate()` runs all tasks
|
||||
- `rollout_and_score_eval()` handles per-task agent loop + test verification
|
||||
- Downloads verifier output locally for reliable reward checking (Harbor pattern)
|
||||
|
||||
```bash
|
||||
# Run full benchmark
|
||||
python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
||||
--openai.model_name anthropic/claude-opus-4.6
|
||||
|
||||
# Run subset of tasks
|
||||
python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
||||
--openai.model_name anthropic/claude-opus-4.6 \
|
||||
--env.task_filter fix-git,git-multibranch
|
||||
|
||||
# Skip specific tasks
|
||||
python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
||||
--openai.model_name anthropic/claude-opus-4.6 \
|
||||
--env.skip_tasks heavy-task,slow-task
|
||||
```
|
||||
|
||||
## Creating a New Environment
|
||||
|
||||
### Training Environment
|
||||
|
||||
1. Create a new directory under `environments/`
|
||||
2. Create your env file inheriting from `HermesAgentBaseEnv`
|
||||
3. Implement the four abstract methods + `evaluate()`
|
||||
|
||||
```python
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
|
||||
class MyEnvConfig(HermesAgentEnvConfig):
|
||||
pass # Add custom fields as needed
|
||||
|
||||
class MyEnv(HermesAgentBaseEnv):
|
||||
name = "my-env"
|
||||
env_config_cls = MyEnvConfig
|
||||
|
||||
@classmethod
|
||||
def config_init(cls):
|
||||
env_config = MyEnvConfig(
|
||||
enabled_toolsets=["terminal", "file"],
|
||||
terminal_backend="modal",
|
||||
# ... other config
|
||||
)
|
||||
server_configs = [APIServerConfig(...)]
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup(self):
|
||||
self.dataset = load_dataset(...)
|
||||
self.iter = 0
|
||||
|
||||
async def get_next_item(self):
|
||||
item = self.dataset[self.iter % len(self.dataset)]
|
||||
self.iter += 1
|
||||
return item
|
||||
|
||||
def format_prompt(self, item):
|
||||
return item["instruction"]
|
||||
|
||||
async def compute_reward(self, item, result, ctx):
|
||||
# ctx gives you full tool access to the rollout's sandbox
|
||||
test = ctx.terminal("pytest -v")
|
||||
return 1.0 if test["exit_code"] == 0 else 0.0
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
# Periodic evaluation logic
|
||||
...
|
||||
|
||||
if __name__ == "__main__":
|
||||
MyEnv.cli()
|
||||
```
|
||||
|
||||
### Eval-Only Environment (Benchmark)
|
||||
|
||||
For eval benchmarks, follow the pattern in `terminalbench2_env.py`:
|
||||
1. Create under `environments/benchmarks/your-benchmark/`
|
||||
2. Inherit from `HermesAgentBaseEnv`
|
||||
3. Set eval-only config: `eval_handling=STOP_TRAIN`, `steps_per_eval=1`, `total_steps=1`
|
||||
4. Stub the training methods (`collect_trajectories`, `score`)
|
||||
5. Implement `rollout_and_score_eval()` and `evaluate()`
|
||||
6. Run with `evaluate` subcommand
|
||||
|
||||
## Key Config Fields
|
||||
|
||||
| Field | Description | Default |
|
||||
|-------|-------------|---------|
|
||||
| `enabled_toolsets` | Which hermes toolsets to enable | `None` (all) |
|
||||
| `disabled_toolsets` | Toolsets to disable | `None` |
|
||||
| `distribution` | Probabilistic toolset distribution name | `None` |
|
||||
| `max_agent_turns` | Max LLM calls per rollout | `30` |
|
||||
| `agent_temperature` | Sampling temperature | `1.0` |
|
||||
| `terminal_backend` | `local`, `docker`, `modal`, `daytona`, `ssh`, `singularity` | `local` |
|
||||
| `system_prompt` | System message for the agent | `None` |
|
||||
| `tool_call_parser` | Parser name for Phase 2 | `hermes` |
|
||||
| `eval_handling` | `STOP_TRAIN`, `LIMIT_TRAIN`, `NONE` | `STOP_TRAIN` |
|
||||
@@ -1,36 +0,0 @@
|
||||
"""
|
||||
Hermes-Agent Atropos Environments
|
||||
|
||||
Provides a layered integration between hermes-agent's tool-calling capabilities
|
||||
and the Atropos RL training framework.
|
||||
|
||||
Core layers:
|
||||
- agent_loop: Reusable multi-turn agent loop with standard OpenAI-spec tool calling
|
||||
- tool_context: Per-rollout tool access handle for reward/verification functions
|
||||
- hermes_base_env: Abstract base environment (BaseEnv subclass) for Atropos
|
||||
- tool_call_parsers: Client-side tool call parser registry for Phase 2 (VLLM /generate)
|
||||
|
||||
Concrete environments:
|
||||
- terminal_test_env/: Simple file-creation tasks for testing the stack
|
||||
- hermes_swe_env/: SWE-bench style tasks with Modal sandboxes
|
||||
|
||||
Benchmarks (eval-only):
|
||||
- benchmarks/terminalbench_2/: Terminal-Bench 2.0 evaluation
|
||||
"""
|
||||
|
||||
try:
|
||||
from environments.agent_loop import AgentResult, HermesAgentLoop
|
||||
from environments.tool_context import ToolContext
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
except ImportError:
|
||||
# atroposlib not installed — environments are unavailable but
|
||||
# submodules like tool_call_parsers can still be imported directly.
|
||||
pass
|
||||
|
||||
__all__ = [
|
||||
"AgentResult",
|
||||
"HermesAgentLoop",
|
||||
"ToolContext",
|
||||
"HermesAgentBaseEnv",
|
||||
"HermesAgentEnvConfig",
|
||||
]
|
||||
@@ -1,534 +0,0 @@
|
||||
"""
|
||||
HermesAgentLoop -- Reusable Multi-Turn Agent Engine
|
||||
|
||||
Runs the hermes-agent tool-calling loop using standard OpenAI-spec tool calling.
|
||||
Works with any server that returns ChatCompletion objects with tool_calls:
|
||||
- Phase 1: OpenAI server type (VLLM, SGLang, OpenRouter, OpenAI API)
|
||||
- Phase 2: ManagedServer with client-side tool call parser
|
||||
|
||||
The loop passes tools= and checks response.choices[0].message.tool_calls,
|
||||
identical to hermes-agent's run_agent.py. Tool execution is dispatched via
|
||||
handle_function_call() from model_tools.py.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from model_tools import handle_function_call
|
||||
from tools.terminal_tool import get_active_env
|
||||
from tools.tool_result_storage import maybe_persist_tool_result, enforce_turn_budget
|
||||
|
||||
# Thread pool for running sync tool calls that internally use asyncio.run()
|
||||
# (e.g., the Modal/Docker/Daytona terminal backends). Running them in a separate
|
||||
# thread gives them a clean event loop so they don't deadlock inside Atropos's loop.
|
||||
# Size must be large enough for concurrent eval tasks (e.g., 89 TB2 tasks all
|
||||
# making tool calls). Too small = thread pool starvation, tasks queue for minutes.
|
||||
# Resized at runtime by HermesAgentBaseEnv.__init__ via resize_tool_pool().
|
||||
_tool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=128)
|
||||
|
||||
|
||||
def resize_tool_pool(max_workers: int):
|
||||
"""
|
||||
Replace the global tool executor with a new one of the given size.
|
||||
|
||||
Called by HermesAgentBaseEnv.__init__ based on config.tool_pool_size.
|
||||
Safe to call before any tasks are submitted.
|
||||
"""
|
||||
global _tool_executor
|
||||
old_executor = _tool_executor
|
||||
_tool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
|
||||
old_executor.shutdown(wait=False)
|
||||
logger.info("Tool thread pool resized to %d workers", max_workers)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolError:
|
||||
"""Record of a tool execution error during the agent loop."""
|
||||
|
||||
turn: int # Which turn the error occurred on
|
||||
tool_name: str # Which tool was called
|
||||
arguments: str # The arguments passed (truncated)
|
||||
error: str # The error message
|
||||
tool_result: str # The raw result returned to the model
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentResult:
|
||||
"""Result of running the agent loop."""
|
||||
|
||||
# Full conversation history in OpenAI message format
|
||||
messages: List[Dict[str, Any]]
|
||||
# ManagedServer.get_state() if available (Phase 2), None otherwise
|
||||
managed_state: Optional[Dict[str, Any]] = None
|
||||
# How many LLM calls were made
|
||||
turns_used: int = 0
|
||||
# True if model stopped calling tools naturally (vs hitting max_turns)
|
||||
finished_naturally: bool = False
|
||||
# Extracted reasoning content per turn (from PR #297 helpers)
|
||||
reasoning_per_turn: List[Optional[str]] = field(default_factory=list)
|
||||
# Tool errors encountered during the loop
|
||||
tool_errors: List[ToolError] = field(default_factory=list)
|
||||
|
||||
|
||||
def _extract_reasoning_from_message(message) -> Optional[str]:
|
||||
"""
|
||||
Extract reasoning content from a ChatCompletion message.
|
||||
|
||||
Handles multiple provider formats:
|
||||
1. message.reasoning_content field (some providers)
|
||||
2. message.reasoning field (some providers)
|
||||
3. message.reasoning_details[].text (OpenRouter style)
|
||||
|
||||
Note: <think> block extraction from content is NOT done here -- that's
|
||||
handled by the response already in Phase 1 (server does it) or by
|
||||
ManagedServer's patch in Phase 2.
|
||||
|
||||
Args:
|
||||
message: The assistant message from ChatCompletion response
|
||||
|
||||
Returns:
|
||||
Extracted reasoning text, or None if not found
|
||||
"""
|
||||
# Check reasoning_content field (common across providers)
|
||||
if hasattr(message, "reasoning_content") and message.reasoning_content:
|
||||
return message.reasoning_content
|
||||
|
||||
# Check reasoning field
|
||||
if hasattr(message, "reasoning") and message.reasoning:
|
||||
return message.reasoning
|
||||
|
||||
# Check reasoning_details (OpenRouter style)
|
||||
if hasattr(message, "reasoning_details") and message.reasoning_details:
|
||||
for detail in message.reasoning_details:
|
||||
if hasattr(detail, "text") and detail.text:
|
||||
return detail.text
|
||||
if isinstance(detail, dict) and detail.get("text"):
|
||||
return detail["text"]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class HermesAgentLoop:
|
||||
"""
|
||||
Runs hermes-agent's tool-calling loop using standard OpenAI-spec tool calling.
|
||||
|
||||
Same pattern as run_agent.py:
|
||||
- Pass tools= to the API
|
||||
- Check response.choices[0].message.tool_calls
|
||||
- Dispatch via handle_function_call()
|
||||
|
||||
Works identically with any server type -- OpenAI, VLLM, SGLang, OpenRouter,
|
||||
or ManagedServer with a parser. The server determines how tool_calls get
|
||||
populated on the response.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server,
|
||||
tool_schemas: List[Dict[str, Any]],
|
||||
valid_tool_names: Set[str],
|
||||
max_turns: int = 30,
|
||||
task_id: Optional[str] = None,
|
||||
temperature: float = 1.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
budget_config: Optional["BudgetConfig"] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the agent loop.
|
||||
|
||||
Args:
|
||||
server: Server object with chat_completion() method (OpenAIServer,
|
||||
ManagedServer, ServerManager, etc.)
|
||||
tool_schemas: OpenAI-format tool definitions from get_tool_definitions()
|
||||
valid_tool_names: Set of tool names the model is allowed to call
|
||||
max_turns: Maximum number of LLM calls before stopping
|
||||
task_id: Unique ID for terminal/browser session isolation
|
||||
temperature: Sampling temperature for generation
|
||||
max_tokens: Max tokens per generation (None for server default)
|
||||
extra_body: Extra parameters passed to the OpenAI client's create() call.
|
||||
Used for OpenRouter provider preferences, transforms, etc.
|
||||
e.g. {"provider": {"ignore": ["DeepInfra"]}}
|
||||
budget_config: Tool result persistence budget. Controls per-tool
|
||||
thresholds, per-turn aggregate budget, and preview size.
|
||||
If None, uses DEFAULT_BUDGET (current hardcoded values).
|
||||
"""
|
||||
from tools.budget_config import DEFAULT_BUDGET
|
||||
self.server = server
|
||||
self.tool_schemas = tool_schemas
|
||||
self.valid_tool_names = valid_tool_names
|
||||
self.max_turns = max_turns
|
||||
self.task_id = task_id or str(uuid.uuid4())
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.extra_body = extra_body
|
||||
self.budget_config = budget_config or DEFAULT_BUDGET
|
||||
|
||||
async def run(self, messages: List[Dict[str, Any]]) -> AgentResult:
|
||||
"""
|
||||
Execute the full agent loop using standard OpenAI tool calling.
|
||||
|
||||
Args:
|
||||
messages: Initial conversation messages (system + user).
|
||||
Modified in-place as the conversation progresses.
|
||||
|
||||
Returns:
|
||||
AgentResult with full conversation history, managed state, and metadata
|
||||
"""
|
||||
reasoning_per_turn = []
|
||||
tool_errors: List[ToolError] = []
|
||||
|
||||
# Per-loop TodoStore for the todo tool (ephemeral, dies with the loop)
|
||||
from tools.todo_tool import TodoStore, todo_tool as _todo_tool
|
||||
_todo_store = TodoStore()
|
||||
|
||||
# Extract user task from first user message for browser_snapshot context
|
||||
_user_task = None
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, str) and content.strip():
|
||||
_user_task = content.strip()[:500] # Cap to avoid huge strings
|
||||
break
|
||||
|
||||
import time as _time
|
||||
|
||||
for turn in range(self.max_turns):
|
||||
turn_start = _time.monotonic()
|
||||
|
||||
# Build the chat_completion kwargs
|
||||
chat_kwargs = {
|
||||
"messages": messages,
|
||||
"n": 1,
|
||||
"temperature": self.temperature,
|
||||
}
|
||||
|
||||
# Only pass tools if we have them
|
||||
if self.tool_schemas:
|
||||
chat_kwargs["tools"] = self.tool_schemas
|
||||
|
||||
# Only pass max_tokens if explicitly set
|
||||
if self.max_tokens is not None:
|
||||
chat_kwargs["max_tokens"] = self.max_tokens
|
||||
|
||||
# Inject extra_body for provider-specific params (e.g., OpenRouter
|
||||
# provider preferences like banned/preferred providers, transforms)
|
||||
if self.extra_body:
|
||||
chat_kwargs["extra_body"] = self.extra_body
|
||||
|
||||
# Make the API call -- standard OpenAI spec
|
||||
api_start = _time.monotonic()
|
||||
try:
|
||||
response = await self.server.chat_completion(**chat_kwargs)
|
||||
except Exception as e:
|
||||
api_elapsed = _time.monotonic() - api_start
|
||||
logger.error("API call failed on turn %d (%.1fs): %s", turn + 1, api_elapsed, e)
|
||||
return AgentResult(
|
||||
messages=messages,
|
||||
managed_state=self._get_managed_state(),
|
||||
turns_used=turn + 1,
|
||||
finished_naturally=False,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
)
|
||||
|
||||
api_elapsed = _time.monotonic() - api_start
|
||||
|
||||
if not response or not response.choices:
|
||||
logger.warning("Empty response on turn %d (api=%.1fs)", turn + 1, api_elapsed)
|
||||
return AgentResult(
|
||||
messages=messages,
|
||||
managed_state=self._get_managed_state(),
|
||||
turns_used=turn + 1,
|
||||
finished_naturally=False,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
)
|
||||
|
||||
assistant_msg = response.choices[0].message
|
||||
|
||||
# Extract reasoning content from the response (all provider formats)
|
||||
reasoning = _extract_reasoning_from_message(assistant_msg)
|
||||
reasoning_per_turn.append(reasoning)
|
||||
|
||||
# Check for tool calls -- standard OpenAI spec.
|
||||
# Fallback: if response has no structured tool_calls but content
|
||||
# contains raw tool call tags (e.g. <tool_call>), parse them using
|
||||
# hermes-agent's standalone parsers. This handles the case where
|
||||
# ManagedServer's ToolCallTranslator couldn't parse because vLLM
|
||||
# isn't installed.
|
||||
if (
|
||||
not assistant_msg.tool_calls
|
||||
and assistant_msg.content
|
||||
and self.tool_schemas
|
||||
and "<tool_call>" in (assistant_msg.content or "")
|
||||
):
|
||||
try:
|
||||
from environments.tool_call_parsers import get_parser
|
||||
fallback_parser = get_parser("hermes")
|
||||
parsed_content, parsed_calls = fallback_parser.parse(
|
||||
assistant_msg.content
|
||||
)
|
||||
if parsed_calls:
|
||||
assistant_msg.tool_calls = parsed_calls
|
||||
if parsed_content is not None:
|
||||
assistant_msg.content = parsed_content
|
||||
logger.debug(
|
||||
"Fallback parser extracted %d tool calls from raw content",
|
||||
len(parsed_calls),
|
||||
)
|
||||
except Exception:
|
||||
pass # Fall through to no tool calls
|
||||
|
||||
if assistant_msg.tool_calls:
|
||||
# Normalize tool calls to dicts — they may come as objects
|
||||
# (OpenAI API) or dicts (vLLM ToolCallTranslator).
|
||||
def _tc_to_dict(tc):
|
||||
if isinstance(tc, dict):
|
||||
return {
|
||||
"id": tc.get("id", f"call_{uuid.uuid4().hex[:8]}"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.get("function", {}).get("name", tc.get("name", "")),
|
||||
"arguments": tc.get("function", {}).get("arguments", tc.get("arguments", "{}")),
|
||||
},
|
||||
}
|
||||
return {
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments,
|
||||
},
|
||||
}
|
||||
|
||||
# Build the assistant message dict for conversation history
|
||||
msg_dict: Dict[str, Any] = {
|
||||
"role": "assistant",
|
||||
"content": assistant_msg.content or "",
|
||||
"tool_calls": [_tc_to_dict(tc) for tc in assistant_msg.tool_calls],
|
||||
}
|
||||
|
||||
# Preserve reasoning_content for multi-turn chat template handling
|
||||
# (e.g., Kimi-K2's template renders <think> blocks differently
|
||||
# for history vs. the latest turn based on this field)
|
||||
if reasoning:
|
||||
msg_dict["reasoning_content"] = reasoning
|
||||
|
||||
messages.append(msg_dict)
|
||||
|
||||
# Execute each tool call via hermes-agent's dispatch
|
||||
for tc in assistant_msg.tool_calls:
|
||||
# Handle both object (OpenAI) and dict (vLLM) formats
|
||||
if isinstance(tc, dict):
|
||||
tool_name = tc.get("function", {}).get("name", tc.get("name", ""))
|
||||
tool_args_raw = tc.get("function", {}).get("arguments", tc.get("arguments", "{}"))
|
||||
else:
|
||||
tool_name = tc.function.name
|
||||
tool_args_raw = tc.function.arguments
|
||||
|
||||
# Validate tool name
|
||||
if tool_name not in self.valid_tool_names:
|
||||
tool_result = json.dumps(
|
||||
{
|
||||
"error": f"Unknown tool '{tool_name}'. "
|
||||
f"Available tools: {sorted(self.valid_tool_names)}"
|
||||
}
|
||||
)
|
||||
tool_errors.append(ToolError(
|
||||
turn=turn + 1, tool_name=tool_name,
|
||||
arguments=tool_args_raw[:200],
|
||||
error=f"Unknown tool '{tool_name}'",
|
||||
tool_result=tool_result,
|
||||
))
|
||||
logger.warning(
|
||||
"Model called unknown tool '%s' on turn %d",
|
||||
tool_name, turn + 1,
|
||||
)
|
||||
else:
|
||||
# Parse arguments
|
||||
try:
|
||||
args = json.loads(tool_args_raw)
|
||||
except json.JSONDecodeError as e:
|
||||
args = None
|
||||
tool_result = json.dumps(
|
||||
{"error": f"Invalid JSON in tool arguments: {e}. Please retry with valid JSON."}
|
||||
)
|
||||
tool_errors.append(ToolError(
|
||||
turn=turn + 1, tool_name=tool_name,
|
||||
arguments=tool_args_raw[:200],
|
||||
error=f"Invalid JSON: {e}",
|
||||
tool_result=tool_result,
|
||||
))
|
||||
logger.warning(
|
||||
"Invalid JSON in tool call arguments for '%s': %s",
|
||||
tool_name, tool_args_raw[:200],
|
||||
)
|
||||
|
||||
# Dispatch tool only if arguments parsed successfully
|
||||
if args is not None:
|
||||
try:
|
||||
if tool_name == "terminal":
|
||||
backend = os.getenv("TERMINAL_ENV", "local")
|
||||
cmd_preview = args.get("command", "")[:80]
|
||||
logger.info(
|
||||
"[%s] $ %s", self.task_id[:8], cmd_preview,
|
||||
)
|
||||
|
||||
tool_submit_time = _time.monotonic()
|
||||
|
||||
# Todo tool -- handle locally (needs per-loop TodoStore)
|
||||
if tool_name == "todo":
|
||||
tool_result = _todo_tool(
|
||||
todos=args.get("todos"),
|
||||
merge=args.get("merge", False),
|
||||
store=_todo_store,
|
||||
)
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
elif tool_name == "memory":
|
||||
tool_result = json.dumps({"error": "Memory is not available in RL environments."})
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
elif tool_name == "session_search":
|
||||
tool_result = json.dumps({"error": "Session search is not available in RL environments."})
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
else:
|
||||
# Run tool calls in a thread pool so backends that
|
||||
# use asyncio.run() internally (modal, docker, daytona) get
|
||||
# a clean event loop instead of deadlocking.
|
||||
loop = asyncio.get_running_loop()
|
||||
# Capture current tool_name/args for the lambda
|
||||
_tn, _ta, _tid = tool_name, args, self.task_id
|
||||
tool_result = await loop.run_in_executor(
|
||||
_tool_executor,
|
||||
lambda: handle_function_call(
|
||||
_tn, _ta, task_id=_tid,
|
||||
user_task=_user_task,
|
||||
),
|
||||
)
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
|
||||
# Log slow tools and thread pool stats for debugging
|
||||
pool_active = _tool_executor._work_queue.qsize()
|
||||
if tool_elapsed > 30:
|
||||
logger.warning(
|
||||
"[%s] turn %d: %s took %.1fs (pool queue=%d)",
|
||||
self.task_id[:8], turn + 1, tool_name,
|
||||
tool_elapsed, pool_active,
|
||||
)
|
||||
except Exception as e:
|
||||
tool_result = json.dumps(
|
||||
{"error": f"Tool execution failed: {type(e).__name__}: {str(e)}"}
|
||||
)
|
||||
tool_errors.append(ToolError(
|
||||
turn=turn + 1, tool_name=tool_name,
|
||||
arguments=tool_args_raw[:200],
|
||||
error=f"{type(e).__name__}: {str(e)}",
|
||||
tool_result=tool_result,
|
||||
))
|
||||
logger.error(
|
||||
"Tool '%s' execution failed on turn %d: %s",
|
||||
tool_name, turn + 1, e,
|
||||
)
|
||||
|
||||
# Also check if the tool returned an error in its JSON result
|
||||
try:
|
||||
result_data = json.loads(tool_result)
|
||||
if isinstance(result_data, dict):
|
||||
err = result_data.get("error")
|
||||
exit_code = result_data.get("exit_code")
|
||||
if err and exit_code and exit_code < 0:
|
||||
tool_errors.append(ToolError(
|
||||
turn=turn + 1, tool_name=tool_name,
|
||||
arguments=tool_args_raw[:200],
|
||||
error=str(err),
|
||||
tool_result=tool_result[:500],
|
||||
))
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
tc_id = tc.get("id", "") if isinstance(tc, dict) else tc.id
|
||||
tool_result = maybe_persist_tool_result(
|
||||
content=tool_result,
|
||||
tool_name=tool_name,
|
||||
tool_use_id=tc_id,
|
||||
env=get_active_env(self.task_id),
|
||||
config=self.budget_config,
|
||||
)
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tc_id,
|
||||
"content": tool_result,
|
||||
}
|
||||
)
|
||||
|
||||
num_tcs = len(assistant_msg.tool_calls)
|
||||
if num_tcs > 0:
|
||||
enforce_turn_budget(
|
||||
messages[-num_tcs:],
|
||||
env=get_active_env(self.task_id),
|
||||
config=self.budget_config,
|
||||
)
|
||||
|
||||
turn_elapsed = _time.monotonic() - turn_start
|
||||
logger.info(
|
||||
"[%s] turn %d: api=%.1fs, %d tools, turn_total=%.1fs",
|
||||
self.task_id[:8], turn + 1, api_elapsed,
|
||||
len(assistant_msg.tool_calls), turn_elapsed,
|
||||
)
|
||||
|
||||
else:
|
||||
# No tool calls -- model is done
|
||||
msg_dict = {
|
||||
"role": "assistant",
|
||||
"content": assistant_msg.content or "",
|
||||
}
|
||||
if reasoning:
|
||||
msg_dict["reasoning_content"] = reasoning
|
||||
messages.append(msg_dict)
|
||||
|
||||
turn_elapsed = _time.monotonic() - turn_start
|
||||
logger.info(
|
||||
"[%s] turn %d: api=%.1fs, no tools (finished), turn_total=%.1fs",
|
||||
self.task_id[:8], turn + 1, api_elapsed, turn_elapsed,
|
||||
)
|
||||
|
||||
return AgentResult(
|
||||
messages=messages,
|
||||
managed_state=self._get_managed_state(),
|
||||
turns_used=turn + 1,
|
||||
finished_naturally=True,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
)
|
||||
|
||||
# Hit max turns without the model stopping
|
||||
logger.info("Agent hit max_turns (%d) without finishing", self.max_turns)
|
||||
return AgentResult(
|
||||
messages=messages,
|
||||
managed_state=self._get_managed_state(),
|
||||
turns_used=self.max_turns,
|
||||
finished_naturally=False,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
)
|
||||
|
||||
def _get_managed_state(self) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get ManagedServer state if the server supports it.
|
||||
|
||||
Returns state dict with SequenceNodes containing tokens/logprobs/masks,
|
||||
or None if the server doesn't support get_state() (e.g., regular OpenAI server).
|
||||
"""
|
||||
if hasattr(self.server, "get_state"):
|
||||
return self.server.get_state()
|
||||
return None
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,73 +0,0 @@
|
||||
# OpenThoughts-TBLite Evaluation Environment
|
||||
|
||||
This environment evaluates terminal agents on the [OpenThoughts-TBLite](https://huggingface.co/datasets/open-thoughts/OpenThoughts-TBLite) benchmark, a difficulty-calibrated subset of [Terminal-Bench 2.0](https://www.tbench.ai/leaderboard/terminal-bench/2.0).
|
||||
|
||||
## Source
|
||||
|
||||
OpenThoughts-TBLite was created by the [OpenThoughts](https://www.openthoughts.ai/) Agent team in collaboration with [Snorkel AI](https://snorkel.ai/) and [Bespoke Labs](https://bespokelabs.ai/). The original dataset and documentation live at:
|
||||
|
||||
- **Dataset (source):** [open-thoughts/OpenThoughts-TBLite](https://huggingface.co/datasets/open-thoughts/OpenThoughts-TBLite)
|
||||
- **GitHub:** [open-thoughts/OpenThoughts-TBLite](https://github.com/open-thoughts/OpenThoughts-TBLite)
|
||||
- **Blog post:** [openthoughts.ai/blog/openthoughts-tblite](https://www.openthoughts.ai/blog/openthoughts-tblite)
|
||||
|
||||
## Our Dataset
|
||||
|
||||
We converted the source into the same schema used by our Terminal-Bench 2.0 environment (pre-built Docker Hub images, base64-encoded test tarballs, etc.) and published it as:
|
||||
|
||||
- **Dataset (ours):** [NousResearch/openthoughts-tblite](https://huggingface.co/datasets/NousResearch/openthoughts-tblite)
|
||||
- **Docker images:** `nousresearch/tblite-<task-name>:latest` on Docker Hub (100 images)
|
||||
|
||||
The conversion script is at `scripts/prepare_tblite_dataset.py`.
|
||||
|
||||
## Why TBLite?
|
||||
|
||||
Terminal-Bench 2.0 is one of the strongest frontier evaluations for terminal agents, but when a model scores near the floor (e.g., Qwen 3 8B at <1%), many changes look identical in aggregate score. TBLite addresses this by calibrating task difficulty using Claude Haiku 4.5 as a reference:
|
||||
|
||||
| Difficulty | Pass Rate Range | Tasks |
|
||||
|------------|----------------|-------|
|
||||
| Easy | >= 70% | 40 |
|
||||
| Medium | 40-69% | 26 |
|
||||
| Hard | 10-39% | 26 |
|
||||
| Extreme | < 10% | 8 |
|
||||
|
||||
This gives enough solvable tasks to detect small improvements quickly, while preserving enough hard tasks to avoid saturation. The correlation between TBLite and TB2 scores is **r = 0.911**.
|
||||
|
||||
TBLite also runs 2.6-8x faster than the full TB2, making it practical for iteration loops.
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# Run the full benchmark
|
||||
python environments/benchmarks/tblite/tblite_env.py evaluate
|
||||
|
||||
# Filter to specific tasks
|
||||
python environments/benchmarks/tblite/tblite_env.py evaluate \
|
||||
--env.task_filter "broken-python,pandas-etl"
|
||||
|
||||
# Use a different model
|
||||
python environments/benchmarks/tblite/tblite_env.py evaluate \
|
||||
--server.model_name "qwen/qwen3-30b"
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
`TBLiteEvalEnv` is a thin subclass of `TerminalBench2EvalEnv`. All evaluation logic (agent loop, Docker sandbox management, test verification, metrics) is inherited. Only the defaults differ:
|
||||
|
||||
| Setting | TB2 | TBLite |
|
||||
|----------------|----------------------------------|-----------------------------------------|
|
||||
| Dataset | `NousResearch/terminal-bench-2` | `NousResearch/openthoughts-tblite` |
|
||||
| Tasks | 89 | 100 |
|
||||
| Task timeout | 1800s (30 min) | 1200s (20 min) |
|
||||
| Wandb name | `terminal-bench-2` | `openthoughts-tblite` |
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@software{OpenThoughts-TBLite,
|
||||
author = {OpenThoughts-Agent team, Snorkel AI, Bespoke Labs},
|
||||
month = Feb,
|
||||
title = {{OpenThoughts-TBLite: A High-Signal Benchmark for Iterating on Terminal Agents}},
|
||||
howpublished = {https://www.openthoughts.ai/blog/openthoughts-tblite},
|
||||
year = {2026}
|
||||
}
|
||||
```
|
||||
@@ -1,39 +0,0 @@
|
||||
# OpenThoughts-TBLite Evaluation -- Default Configuration
|
||||
#
|
||||
# Eval-only environment for the TBLite benchmark (100 difficulty-calibrated
|
||||
# terminal tasks, a faster proxy for Terminal-Bench 2.0).
|
||||
# Uses Modal terminal backend for per-task cloud-isolated sandboxes
|
||||
# and OpenRouter for inference.
|
||||
#
|
||||
# Usage:
|
||||
# python environments/benchmarks/tblite/tblite_env.py evaluate \
|
||||
# --config environments/benchmarks/tblite/default.yaml
|
||||
#
|
||||
# # Override model:
|
||||
# python environments/benchmarks/tblite/tblite_env.py evaluate \
|
||||
# --config environments/benchmarks/tblite/default.yaml \
|
||||
# --openai.model_name anthropic/claude-sonnet-4
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal", "file"]
|
||||
max_agent_turns: 60
|
||||
max_token_length: 32000
|
||||
agent_temperature: 0.8
|
||||
terminal_backend: "modal"
|
||||
terminal_timeout: 300 # 5 min per command (builds, pip install)
|
||||
tool_pool_size: 128 # thread pool for 100 parallel tasks
|
||||
dataset_name: "NousResearch/openthoughts-tblite"
|
||||
test_timeout: 600
|
||||
task_timeout: 1200 # 20 min wall-clock per task (TBLite tasks are faster)
|
||||
tokenizer_name: "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
use_wandb: true
|
||||
wandb_name: "openthoughts-tblite"
|
||||
ensure_scores_are_not_same: false
|
||||
data_dir_to_save_evals: "environments/benchmarks/evals/openthoughts-tblite"
|
||||
|
||||
openai:
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
model_name: "anthropic/claude-opus-4.6"
|
||||
server_type: "openai"
|
||||
health_check: false
|
||||
# api_key loaded from OPENROUTER_API_KEY in .env
|
||||
@@ -1,38 +0,0 @@
|
||||
# OpenThoughts-TBLite Evaluation -- Docker Backend (Local Compute)
|
||||
#
|
||||
# Runs tasks in Docker containers on the local machine.
|
||||
# Sandboxed like Modal but no cloud costs. Good for dev/testing.
|
||||
#
|
||||
# Usage:
|
||||
# python environments/benchmarks/tblite/tblite_env.py evaluate \
|
||||
# --config environments/benchmarks/tblite/local.yaml
|
||||
#
|
||||
# # Override concurrency:
|
||||
# python environments/benchmarks/tblite/tblite_env.py evaluate \
|
||||
# --config environments/benchmarks/tblite/local.yaml \
|
||||
# --env.eval_concurrency 4
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal", "file"]
|
||||
max_agent_turns: 60
|
||||
max_token_length: 32000
|
||||
agent_temperature: 0.8
|
||||
terminal_backend: "docker"
|
||||
terminal_timeout: 300
|
||||
tool_pool_size: 16
|
||||
dataset_name: "NousResearch/openthoughts-tblite"
|
||||
test_timeout: 600
|
||||
task_timeout: 1200
|
||||
eval_concurrency: 8 # max 8 tasks at once
|
||||
tokenizer_name: "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
use_wandb: false
|
||||
wandb_name: "openthoughts-tblite-local"
|
||||
ensure_scores_are_not_same: false
|
||||
data_dir_to_save_evals: "environments/benchmarks/evals/openthoughts-tblite-local"
|
||||
|
||||
openai:
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
model_name: "anthropic/claude-sonnet-4"
|
||||
server_type: "openai"
|
||||
health_check: false
|
||||
# api_key loaded from OPENROUTER_API_KEY in .env
|
||||
@@ -1,40 +0,0 @@
|
||||
# OpenThoughts-TBLite Evaluation -- Local vLLM Backend
|
||||
#
|
||||
# Runs against a local vLLM server with Docker sandboxes.
|
||||
#
|
||||
# Start the vLLM server from the atropos directory:
|
||||
# python -m example_trainer.vllm_api_server \
|
||||
# --model Qwen/Qwen3-4B-Instruct-2507 \
|
||||
# --port 9001 \
|
||||
# --gpu-memory-utilization 0.8 \
|
||||
# --max-model-len=32000
|
||||
#
|
||||
# Then run:
|
||||
# python environments/benchmarks/tblite/tblite_env.py evaluate \
|
||||
# --config environments/benchmarks/tblite/local_vllm.yaml
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal", "file"]
|
||||
max_agent_turns: 60
|
||||
max_token_length: 16000
|
||||
agent_temperature: 0.6
|
||||
terminal_backend: "docker"
|
||||
terminal_timeout: 300
|
||||
tool_pool_size: 16
|
||||
dataset_name: "NousResearch/openthoughts-tblite"
|
||||
test_timeout: 600
|
||||
task_timeout: 1200
|
||||
eval_concurrency: 8
|
||||
tool_call_parser: "hermes"
|
||||
system_prompt: "You are an expert terminal agent. You MUST use the provided tools to complete tasks. Use the terminal tool to run shell commands, read_file to read files, write_file to write files, search_files to search, and patch to edit files. Do NOT write out solutions as text - execute them using the tools. Always start by exploring the environment with terminal commands."
|
||||
tokenizer_name: "Qwen/Qwen3-4B-Instruct-2507"
|
||||
use_wandb: false
|
||||
wandb_name: "tblite-qwen3-4b-instruct"
|
||||
ensure_scores_are_not_same: false
|
||||
data_dir_to_save_evals: "environments/benchmarks/evals/tblite-qwen3-4b-local"
|
||||
|
||||
openai:
|
||||
base_url: "http://localhost:9001"
|
||||
model_name: "Qwen/Qwen3-4B-Instruct-2507"
|
||||
server_type: "vllm"
|
||||
health_check: false
|
||||
@@ -1,42 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# OpenThoughts-TBLite Evaluation
|
||||
#
|
||||
# Run from repo root:
|
||||
# bash environments/benchmarks/tblite/run_eval.sh
|
||||
#
|
||||
# Override model:
|
||||
# bash environments/benchmarks/tblite/run_eval.sh \
|
||||
# --openai.model_name anthropic/claude-sonnet-4
|
||||
#
|
||||
# Run a subset:
|
||||
# bash environments/benchmarks/tblite/run_eval.sh \
|
||||
# --env.task_filter broken-python,pandas-etl
|
||||
#
|
||||
# All terminal settings (backend, timeout, lifetime, pool size) are
|
||||
# configured via env config fields -- no env vars needed.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
mkdir -p logs evals/openthoughts-tblite
|
||||
LOG_FILE="logs/tblite_$(date +%Y%m%d_%H%M%S).log"
|
||||
|
||||
echo "OpenThoughts-TBLite Evaluation"
|
||||
echo "Log file: $LOG_FILE"
|
||||
echo ""
|
||||
|
||||
# Unbuffered python output so logs are written in real-time
|
||||
export PYTHONUNBUFFERED=1
|
||||
|
||||
# Show INFO-level agent loop timing (api/tool durations per turn)
|
||||
# These go to the log file; tqdm + [START]/[PASS]/[FAIL] go to terminal
|
||||
export LOGLEVEL=INFO
|
||||
|
||||
python tblite_env.py evaluate \
|
||||
--config default.yaml \
|
||||
"$@" \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
echo ""
|
||||
echo "Log saved to: $LOG_FILE"
|
||||
echo "Eval results: evals/openthoughts-tblite/"
|
||||
@@ -1,119 +0,0 @@
|
||||
"""
|
||||
OpenThoughts-TBLite Evaluation Environment
|
||||
|
||||
A lighter, faster alternative to Terminal-Bench 2.0 for iterating on terminal
|
||||
agents. Uses the same evaluation logic as TerminalBench2EvalEnv but defaults
|
||||
to the NousResearch/openthoughts-tblite dataset (100 difficulty-calibrated
|
||||
tasks vs TB2's 89 harder tasks).
|
||||
|
||||
TBLite tasks are a curated subset of TB2 with a difficulty distribution
|
||||
designed to give meaningful signal even for smaller models:
|
||||
- Easy (40 tasks): >= 70% pass rate with Claude Haiku 4.5
|
||||
- Medium (26 tasks): 40-69% pass rate
|
||||
- Hard (26 tasks): 10-39% pass rate
|
||||
- Extreme (8 tasks): < 10% pass rate
|
||||
|
||||
Usage:
|
||||
python environments/benchmarks/tblite/tblite_env.py evaluate
|
||||
|
||||
# Filter to specific tasks:
|
||||
python environments/benchmarks/tblite/tblite_env.py evaluate \\
|
||||
--env.task_filter "broken-python,pandas-etl"
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
_repo_root = Path(__file__).resolve().parent.parent.parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import EvalHandlingEnum
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
|
||||
from environments.benchmarks.terminalbench_2.terminalbench2_env import (
|
||||
TerminalBench2EvalConfig,
|
||||
TerminalBench2EvalEnv,
|
||||
)
|
||||
|
||||
|
||||
class TBLiteEvalConfig(TerminalBench2EvalConfig):
|
||||
"""Configuration for the OpenThoughts-TBLite evaluation environment.
|
||||
|
||||
Inherits all TB2 config fields. Only the dataset default and task timeout
|
||||
differ -- TBLite tasks are calibrated to be faster.
|
||||
"""
|
||||
|
||||
dataset_name: str = Field(
|
||||
default="NousResearch/openthoughts-tblite",
|
||||
description="HuggingFace dataset containing TBLite tasks.",
|
||||
)
|
||||
|
||||
task_timeout: int = Field(
|
||||
default=1200,
|
||||
description="Maximum wall-clock seconds per task. TBLite tasks are "
|
||||
"generally faster than TB2, so 20 minutes is usually sufficient.",
|
||||
)
|
||||
|
||||
|
||||
class TBLiteEvalEnv(TerminalBench2EvalEnv):
|
||||
"""OpenThoughts-TBLite evaluation environment.
|
||||
|
||||
Inherits all evaluation logic from TerminalBench2EvalEnv (agent loop,
|
||||
test verification, Docker image resolution, metrics, wandb logging).
|
||||
Only the default configuration differs.
|
||||
"""
|
||||
|
||||
name = "openthoughts-tblite"
|
||||
env_config_cls = TBLiteEvalConfig
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[TBLiteEvalConfig, List[APIServerConfig]]:
|
||||
env_config = TBLiteEvalConfig(
|
||||
enabled_toolsets=["terminal", "file"],
|
||||
disabled_toolsets=None,
|
||||
distribution=None,
|
||||
|
||||
max_agent_turns=60,
|
||||
max_token_length=16000,
|
||||
agent_temperature=0.6,
|
||||
system_prompt=None,
|
||||
|
||||
terminal_backend="modal",
|
||||
terminal_timeout=300,
|
||||
|
||||
test_timeout=180,
|
||||
|
||||
# 100 tasks in parallel
|
||||
tool_pool_size=128,
|
||||
|
||||
eval_handling=EvalHandlingEnum.STOP_TRAIN,
|
||||
group_size=1,
|
||||
steps_per_eval=1,
|
||||
total_steps=1,
|
||||
|
||||
tokenizer_name="NousResearch/Hermes-3-Llama-3.1-8B",
|
||||
use_wandb=True,
|
||||
wandb_name="openthoughts-tblite",
|
||||
ensure_scores_are_not_same=False,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model_name="anthropic/claude-sonnet-4",
|
||||
server_type="openai",
|
||||
api_key=os.getenv("OPENROUTER_API_KEY", ""),
|
||||
health_check=False,
|
||||
)
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
TBLiteEvalEnv.cli()
|
||||
@@ -1,42 +0,0 @@
|
||||
# Terminal-Bench 2.0 Evaluation -- Default Configuration
|
||||
#
|
||||
# Eval-only environment for the TB2 benchmark (89 terminal tasks).
|
||||
# Uses Modal terminal backend for per-task cloud-isolated sandboxes
|
||||
# and OpenRouter for inference.
|
||||
#
|
||||
# Usage:
|
||||
# python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
||||
# --config environments/benchmarks/terminalbench_2/default.yaml
|
||||
#
|
||||
# # Override model:
|
||||
# python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
||||
# --config environments/benchmarks/terminalbench_2/default.yaml \
|
||||
# --openai.model_name anthropic/claude-sonnet-4
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal", "file"]
|
||||
max_agent_turns: 60
|
||||
max_token_length: 32000
|
||||
agent_temperature: 0.8
|
||||
terminal_backend: "modal"
|
||||
terminal_timeout: 300 # 5 min per command (builds, pip install)
|
||||
tool_pool_size: 128 # thread pool for 89 parallel tasks
|
||||
dataset_name: "NousResearch/terminal-bench-2"
|
||||
test_timeout: 600
|
||||
task_timeout: 1800 # 30 min wall-clock per task, auto-FAIL if exceeded
|
||||
tokenizer_name: "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
use_wandb: true
|
||||
wandb_name: "terminal-bench-2"
|
||||
ensure_scores_are_not_same: false
|
||||
data_dir_to_save_evals: "environments/benchmarks/evals/terminal-bench-2"
|
||||
# CRITICAL: Limit concurrent Modal sandbox creations to avoid deadlocks.
|
||||
# Modal's blocking calls (App.lookup, etc.) deadlock when too many sandboxes
|
||||
# are created simultaneously inside thread pool workers via asyncio.run().
|
||||
max_concurrent_tasks: 8
|
||||
|
||||
openai:
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
model_name: "anthropic/claude-opus-4.6"
|
||||
server_type: "openai"
|
||||
health_check: false
|
||||
# api_key loaded from OPENROUTER_API_KEY in .env
|
||||
@@ -1,42 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Terminal-Bench 2.0 Evaluation
|
||||
#
|
||||
# Run from repo root:
|
||||
# bash environments/benchmarks/terminalbench_2/run_eval.sh
|
||||
#
|
||||
# Override model:
|
||||
# bash environments/benchmarks/terminalbench_2/run_eval.sh \
|
||||
# --openai.model_name anthropic/claude-sonnet-4
|
||||
#
|
||||
# Run a subset:
|
||||
# bash environments/benchmarks/terminalbench_2/run_eval.sh \
|
||||
# --env.task_filter fix-git,git-multibranch
|
||||
#
|
||||
# All terminal settings (backend, timeout, lifetime, pool size) are
|
||||
# configured via env config fields -- no env vars needed.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
mkdir -p logs evals/terminal-bench-2
|
||||
LOG_FILE="logs/terminalbench2_$(date +%Y%m%d_%H%M%S).log"
|
||||
|
||||
echo "Terminal-Bench 2.0 Evaluation"
|
||||
echo "Log file: $LOG_FILE"
|
||||
echo ""
|
||||
|
||||
# Unbuffered python output so logs are written in real-time
|
||||
export PYTHONUNBUFFERED=1
|
||||
|
||||
# Show INFO-level agent loop timing (api/tool durations per turn)
|
||||
# These go to the log file; tqdm + [START]/[PASS]/[FAIL] go to terminal
|
||||
export LOGLEVEL=INFO
|
||||
|
||||
python terminalbench2_env.py evaluate \
|
||||
--config default.yaml \
|
||||
"$@" \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
echo ""
|
||||
echo "Log saved to: $LOG_FILE"
|
||||
echo "Eval results: evals/terminal-bench-2/"
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,115 +0,0 @@
|
||||
# YC-Bench: Long-Horizon Agent Benchmark
|
||||
|
||||
[YC-Bench](https://github.com/collinear-ai/yc-bench) by [Collinear AI](https://collinear.ai/) is a deterministic, long-horizon benchmark that tests LLM agents' ability to act as a tech startup CEO. The agent manages a simulated company over 1-3 years, making compounding decisions about resource allocation, cash flow, task management, and prestige specialisation across 4 skill domains.
|
||||
|
||||
Unlike TerminalBench2 (which evaluates per-task coding ability with binary pass/fail), YC-Bench measures **long-term strategic coherence** — whether an agent can maintain consistent strategy, manage compounding consequences, and adapt plans over hundreds of turns.
|
||||
|
||||
## Setup
|
||||
|
||||
```bash
|
||||
# Install yc-bench (optional dependency)
|
||||
pip install "hermes-agent[yc-bench]"
|
||||
|
||||
# Or install from source
|
||||
git clone https://github.com/collinear-ai/yc-bench
|
||||
cd yc-bench && pip install -e .
|
||||
|
||||
# Verify
|
||||
yc-bench --help
|
||||
```
|
||||
|
||||
## Running
|
||||
|
||||
```bash
|
||||
# From the repo root:
|
||||
bash environments/benchmarks/yc_bench/run_eval.sh
|
||||
|
||||
# Or directly:
|
||||
python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
||||
--config environments/benchmarks/yc_bench/default.yaml
|
||||
|
||||
# Override model:
|
||||
bash environments/benchmarks/yc_bench/run_eval.sh \
|
||||
--openai.model_name anthropic/claude-opus-4-20250514
|
||||
|
||||
# Quick single-preset test:
|
||||
bash environments/benchmarks/yc_bench/run_eval.sh \
|
||||
--env.presets '["fast_test"]' --env.seeds '[1]'
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
### Architecture
|
||||
|
||||
```
|
||||
HermesAgentLoop (our agent)
|
||||
-> terminal tool -> subprocess("yc-bench company status") -> JSON output
|
||||
-> terminal tool -> subprocess("yc-bench task accept --task-id X") -> JSON
|
||||
-> terminal tool -> subprocess("yc-bench sim resume") -> JSON (advance time)
|
||||
-> ... (100-500 turns per run)
|
||||
```
|
||||
|
||||
The environment initialises the simulation via `yc-bench sim init` (NOT `yc-bench run`, which would start yc-bench's own built-in agent loop). Our `HermesAgentLoop` then drives all interaction through CLI commands.
|
||||
|
||||
### Simulation Mechanics
|
||||
|
||||
- **4 skill domains**: research, inference, data_environment, training
|
||||
- **Prestige system** (1.0-10.0): Gates access to higher-paying tasks
|
||||
- **Employee management**: Junior/Mid/Senior with domain-specific skill rates
|
||||
- **Throughput splitting**: `effective_rate = base_rate / N` active tasks per employee
|
||||
- **Financial pressure**: Monthly payroll, bankruptcy = game over
|
||||
- **Deterministic**: SHA256-based RNG — same seed + preset = same world
|
||||
|
||||
### Difficulty Presets
|
||||
|
||||
| Preset | Employees | Tasks | Focus |
|
||||
|-----------|-----------|-------|-------|
|
||||
| tutorial | 3 | 50 | Basic loop mechanics |
|
||||
| easy | 5 | 100 | Throughput awareness |
|
||||
| **medium**| 5 | 150 | Prestige climbing + domain specialisation |
|
||||
| **hard** | 7 | 200 | Precise ETA reasoning |
|
||||
| nightmare | 8 | 300 | Sustained perfection under payroll pressure |
|
||||
| fast_test | (varies) | (varies) | Quick validation (~50 turns) |
|
||||
|
||||
Default eval runs **fast_test + medium + hard** × 3 seeds = 9 runs.
|
||||
|
||||
### Scoring
|
||||
|
||||
```
|
||||
composite = 0.5 × survival + 0.5 × normalised_funds
|
||||
```
|
||||
|
||||
- **Survival** (binary): Did the company avoid bankruptcy?
|
||||
- **Normalised funds** (0.0-1.0): Log-scale relative to initial $250K capital
|
||||
|
||||
## Configuration
|
||||
|
||||
Key fields in `default.yaml`:
|
||||
|
||||
| Field | Default | Description |
|
||||
|-------|---------|-------------|
|
||||
| `presets` | `["fast_test", "medium", "hard"]` | Which presets to evaluate |
|
||||
| `seeds` | `[1, 2, 3]` | RNG seeds per preset |
|
||||
| `max_agent_turns` | 200 | Max LLM calls per run |
|
||||
| `run_timeout` | 3600 | Wall-clock timeout per run (seconds) |
|
||||
| `survival_weight` | 0.5 | Weight of survival in composite score |
|
||||
| `funds_weight` | 0.5 | Weight of normalised funds in composite |
|
||||
| `horizon_years` | null | Override horizon (null = auto from preset) |
|
||||
|
||||
## Cost & Time Estimates
|
||||
|
||||
Each run is 100-500 LLM turns. Approximate costs per run at typical API rates:
|
||||
|
||||
| Preset | Turns | Time | Est. Cost |
|
||||
|--------|-------|------|-----------|
|
||||
| fast_test | ~50 | 5-10 min | $1-5 |
|
||||
| medium | ~200 | 20-40 min | $5-15 |
|
||||
| hard | ~300 | 30-60 min | $10-25 |
|
||||
|
||||
Full default eval (9 runs): ~3-6 hours, $50-200 depending on model.
|
||||
|
||||
## References
|
||||
|
||||
- [collinear-ai/yc-bench](https://github.com/collinear-ai/yc-bench) — Official repository
|
||||
- [Collinear AI](https://collinear.ai/) — Company behind yc-bench
|
||||
- [TerminalBench2](../terminalbench_2/) — Per-task coding benchmark (complementary)
|
||||
@@ -1,43 +0,0 @@
|
||||
# YC-Bench Evaluation -- Default Configuration
|
||||
#
|
||||
# Long-horizon agent benchmark: agent plays CEO of an AI startup over
|
||||
# a simulated 1-3 year run, interacting via yc-bench CLI subcommands.
|
||||
#
|
||||
# Requires: pip install "hermes-agent[yc-bench]"
|
||||
#
|
||||
# Usage:
|
||||
# python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
||||
# --config environments/benchmarks/yc_bench/default.yaml
|
||||
#
|
||||
# # Override model:
|
||||
# python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
||||
# --config environments/benchmarks/yc_bench/default.yaml \
|
||||
# --openai.model_name anthropic/claude-opus-4-20250514
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal"]
|
||||
max_agent_turns: 200
|
||||
max_token_length: 32000
|
||||
agent_temperature: 0.0
|
||||
terminal_backend: "local"
|
||||
terminal_timeout: 60
|
||||
presets: ["fast_test", "medium", "hard"]
|
||||
seeds: [1, 2, 3]
|
||||
run_timeout: 3600 # 60 min wall-clock per run, auto-FAIL if exceeded
|
||||
survival_weight: 0.5 # weight of binary survival in composite score
|
||||
funds_weight: 0.5 # weight of normalised final funds in composite score
|
||||
db_dir: "/tmp/yc_bench_dbs"
|
||||
company_name: "BenchCo"
|
||||
start_date: "01/01/2025" # MM/DD/YYYY (yc-bench convention)
|
||||
tokenizer_name: "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
use_wandb: true
|
||||
wandb_name: "yc-bench"
|
||||
ensure_scores_are_not_same: false
|
||||
data_dir_to_save_evals: "environments/benchmarks/evals/yc-bench"
|
||||
|
||||
openai:
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
model_name: "anthropic/claude-sonnet-4.6"
|
||||
server_type: "openai"
|
||||
health_check: false
|
||||
# api_key loaded from OPENROUTER_API_KEY in .env
|
||||
@@ -1,34 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# YC-Bench Evaluation
|
||||
#
|
||||
# Requires: pip install "hermes-agent[yc-bench]"
|
||||
#
|
||||
# Run from repo root:
|
||||
# bash environments/benchmarks/yc_bench/run_eval.sh
|
||||
#
|
||||
# Override model:
|
||||
# bash environments/benchmarks/yc_bench/run_eval.sh \
|
||||
# --openai.model_name anthropic/claude-opus-4-20250514
|
||||
#
|
||||
# Run a single preset:
|
||||
# bash environments/benchmarks/yc_bench/run_eval.sh \
|
||||
# --env.presets '["fast_test"]' --env.seeds '[1]'
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
mkdir -p logs evals/yc-bench
|
||||
LOG_FILE="logs/yc_bench_$(date +%Y%m%d_%H%M%S).log"
|
||||
|
||||
echo "YC-Bench Evaluation"
|
||||
echo "Log: $LOG_FILE"
|
||||
echo ""
|
||||
|
||||
PYTHONUNBUFFERED=1 LOGLEVEL="${LOGLEVEL:-INFO}" \
|
||||
python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
||||
--config environments/benchmarks/yc_bench/default.yaml \
|
||||
"$@" \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
echo ""
|
||||
echo "Log saved to: $LOG_FILE"
|
||||
@@ -1,848 +0,0 @@
|
||||
"""
|
||||
YCBenchEvalEnv -- YC-Bench Long-Horizon Agent Benchmark Environment
|
||||
|
||||
Evaluates agentic LLMs on YC-Bench: a deterministic, long-horizon benchmark
|
||||
where the agent acts as CEO of an AI startup over a simulated 1-3 year run.
|
||||
The agent manages cash flow, employees, tasks, and prestige across 4 domains,
|
||||
interacting exclusively via CLI subprocess calls against a SQLite-backed
|
||||
discrete-event simulation.
|
||||
|
||||
Unlike TerminalBench2 (per-task binary pass/fail), YC-Bench measures sustained
|
||||
multi-turn strategic coherence -- whether an agent can manage compounding
|
||||
decisions over hundreds of turns without going bankrupt.
|
||||
|
||||
This is an eval-only environment. Run via:
|
||||
|
||||
python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
||||
--config environments/benchmarks/yc_bench/default.yaml
|
||||
|
||||
The evaluate flow:
|
||||
1. setup() -- Verifies yc-bench installed, builds eval matrix (preset x seed)
|
||||
2. evaluate() -- Iterates over all runs sequentially through:
|
||||
a. rollout_and_score_eval() -- Per-run agent loop
|
||||
- Initialises a fresh yc-bench simulation via `sim init` (NOT `run`)
|
||||
- Runs HermesAgentLoop with terminal tool only
|
||||
- Reads final SQLite DB to extract score
|
||||
- Returns survival (0/1) + normalised funds score
|
||||
b. Aggregates per-preset and overall metrics
|
||||
c. Logs results via evaluate_log() and wandb
|
||||
|
||||
Key features:
|
||||
- CLI-only interface: agent calls yc-bench subcommands via terminal tool
|
||||
- Deterministic: same seed + preset = same world (SHA256-based RNG)
|
||||
- Multi-dimensional scoring: survival + normalised final funds
|
||||
- Per-preset difficulty breakdown in results
|
||||
- Isolated SQLite DB per run (no cross-run state leakage)
|
||||
|
||||
Requires: pip install hermes-agent[yc-bench]
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sqlite3
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
_repo_root = Path(__file__).resolve().parent.parent.parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import EvalHandlingEnum
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
|
||||
from environments.agent_loop import HermesAgentLoop
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# =============================================================================
|
||||
# System prompt
|
||||
# =============================================================================
|
||||
|
||||
YC_BENCH_SYSTEM_PROMPT = """\
|
||||
You are the autonomous CEO of an early-stage AI startup in a deterministic
|
||||
business simulation. You manage the company exclusively through the `yc-bench`
|
||||
CLI tool. Your primary goal is to **survive** until the simulation horizon ends
|
||||
without going bankrupt, while **maximising final funds**.
|
||||
|
||||
## Simulation Mechanics
|
||||
|
||||
- **Funds**: You start with $250,000 seed capital. Revenue comes from completing
|
||||
tasks. Rewards scale with your prestige: `base × (1 + scale × (prestige − 1))`.
|
||||
- **Domains**: There are 4 skill domains: **research**, **inference**,
|
||||
**data_environment**, and **training**. Each has its own prestige level
|
||||
(1.0-10.0). Higher prestige unlocks better-paying tasks.
|
||||
- **Employees**: You have employees (Junior/Mid/Senior) with domain-specific
|
||||
skill rates. **Throughput splits**: `effective_rate = base_rate / N` where N
|
||||
is the number of active tasks assigned to that employee. Focus beats breadth.
|
||||
- **Payroll**: Deducted automatically on the first business day of each month.
|
||||
Running out of funds = bankruptcy = game over.
|
||||
- **Time**: The simulation runs on business days (Mon-Fri), 09:00-18:00.
|
||||
Time only advances when you call `yc-bench sim resume`.
|
||||
|
||||
## Task Lifecycle
|
||||
|
||||
1. Browse market tasks with `market browse`
|
||||
2. Accept a task with `task accept` (this sets its deadline)
|
||||
3. Assign employees with `task assign`
|
||||
4. Dispatch with `task dispatch` to start work
|
||||
5. Call `sim resume` to advance time and let employees make progress
|
||||
6. Tasks complete when all domain requirements are fulfilled
|
||||
|
||||
**Penalties for failure vary by difficulty preset.** Completing a task on time
|
||||
earns full reward + prestige gain. Missing a deadline or cancelling a task
|
||||
incurs prestige penalties -- cancelling is always more costly than letting a
|
||||
task fail, so cancel only as a last resort.
|
||||
|
||||
## CLI Commands
|
||||
|
||||
### Observe
|
||||
- `yc-bench company status` -- funds, prestige, runway
|
||||
- `yc-bench employee list` -- skills, salary, active tasks
|
||||
- `yc-bench market browse [--domain D] [--required-prestige-lte N]` -- available tasks
|
||||
- `yc-bench task list [--status active|planned]` -- your tasks
|
||||
- `yc-bench task inspect --task-id UUID` -- progress, deadline, assignments
|
||||
- `yc-bench finance ledger [--category monthly_payroll|task_reward]` -- transaction history
|
||||
- `yc-bench report monthly` -- monthly P&L
|
||||
|
||||
### Act
|
||||
- `yc-bench task accept --task-id UUID` -- accept from market
|
||||
- `yc-bench task assign --task-id UUID --employee-id UUID` -- assign employee
|
||||
- `yc-bench task dispatch --task-id UUID` -- start work (needs >=1 assignment)
|
||||
- `yc-bench task cancel --task-id UUID --reason "text"` -- cancel (prestige penalty)
|
||||
- `yc-bench sim resume` -- advance simulation clock
|
||||
|
||||
### Memory (persists across context truncation)
|
||||
- `yc-bench scratchpad read` -- read your persistent notes
|
||||
- `yc-bench scratchpad write --content "text"` -- overwrite notes
|
||||
- `yc-bench scratchpad append --content "text"` -- append to notes
|
||||
- `yc-bench scratchpad clear` -- clear notes
|
||||
|
||||
## Strategy Guidelines
|
||||
|
||||
1. **Specialise in 2-3 domains** to climb the prestige ladder faster and unlock
|
||||
high-reward tasks. Don't spread thin across all 4 domains early on.
|
||||
2. **Focus employees** -- assigning one employee to many tasks halves their
|
||||
throughput per additional task. Keep assignments concentrated.
|
||||
3. **Use the scratchpad** to track your strategy, upcoming deadlines, and
|
||||
employee assignments. This persists even if conversation context is truncated.
|
||||
4. **Monitor runway** -- always know how many months of payroll you can cover.
|
||||
Accept high-reward tasks before payroll dates.
|
||||
5. **Don't over-accept** -- taking too many tasks and missing deadlines cascades
|
||||
into prestige loss, locking you out of profitable contracts.
|
||||
6. Use `finance ledger` and `report monthly` to track revenue trends.
|
||||
|
||||
## Your Turn
|
||||
|
||||
Each turn:
|
||||
1. Call `yc-bench company status` and `yc-bench task list` to orient yourself.
|
||||
2. Check for completed tasks and pending deadlines.
|
||||
3. Browse market for profitable tasks within your prestige level.
|
||||
4. Accept, assign, and dispatch tasks strategically.
|
||||
5. Call `yc-bench sim resume` to advance time.
|
||||
6. Repeat until the simulation ends.
|
||||
|
||||
Think step by step before acting."""
|
||||
|
||||
# Starting funds in cents ($250,000)
|
||||
INITIAL_FUNDS_CENTS = 25_000_000
|
||||
|
||||
# Default horizon per preset (years)
|
||||
_PRESET_HORIZONS = {
|
||||
"tutorial": 1,
|
||||
"easy": 1,
|
||||
"medium": 1,
|
||||
"hard": 1,
|
||||
"nightmare": 1,
|
||||
"fast_test": 1,
|
||||
"default": 3,
|
||||
"high_reward": 1,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Configuration
|
||||
# =============================================================================
|
||||
|
||||
class YCBenchEvalConfig(HermesAgentEnvConfig):
|
||||
"""
|
||||
Configuration for the YC-Bench evaluation environment.
|
||||
|
||||
Extends HermesAgentEnvConfig with YC-Bench-specific settings for
|
||||
preset selection, seed control, scoring, and simulation parameters.
|
||||
"""
|
||||
|
||||
presets: List[str] = Field(
|
||||
default=["fast_test", "medium", "hard"],
|
||||
description="YC-Bench preset names to evaluate.",
|
||||
)
|
||||
seeds: List[int] = Field(
|
||||
default=[1, 2, 3],
|
||||
description="Random seeds -- each preset x seed = one run.",
|
||||
)
|
||||
run_timeout: int = Field(
|
||||
default=3600,
|
||||
description="Maximum wall-clock seconds per run. Default 60 minutes.",
|
||||
)
|
||||
survival_weight: float = Field(
|
||||
default=0.5,
|
||||
description="Weight of survival (0/1) in composite score.",
|
||||
)
|
||||
funds_weight: float = Field(
|
||||
default=0.5,
|
||||
description="Weight of normalised final funds in composite score.",
|
||||
)
|
||||
db_dir: str = Field(
|
||||
default="/tmp/yc_bench_dbs",
|
||||
description="Directory for per-run SQLite databases.",
|
||||
)
|
||||
horizon_years: Optional[int] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Simulation horizon in years. If None (default), inferred from "
|
||||
"preset name (1 year for most, 3 for 'default')."
|
||||
),
|
||||
)
|
||||
company_name: str = Field(
|
||||
default="BenchCo",
|
||||
description="Name of the simulated company.",
|
||||
)
|
||||
start_date: str = Field(
|
||||
default="01/01/2025",
|
||||
description="Simulation start date in MM/DD/YYYY format (yc-bench convention).",
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Scoring helpers
|
||||
# =============================================================================
|
||||
|
||||
def _read_final_score(db_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Read final game state from a YC-Bench SQLite database.
|
||||
|
||||
Returns dict with final_funds_cents (int), survived (bool),
|
||||
terminal_reason (str).
|
||||
|
||||
Note: yc-bench table names are plural -- 'companies' not 'company',
|
||||
'sim_events' not 'simulation_log'.
|
||||
"""
|
||||
if not os.path.exists(db_path):
|
||||
logger.warning("DB not found at %s", db_path)
|
||||
return {
|
||||
"final_funds_cents": 0,
|
||||
"survived": False,
|
||||
"terminal_reason": "db_missing",
|
||||
}
|
||||
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cur = conn.cursor()
|
||||
|
||||
# Read final funds from the 'companies' table
|
||||
cur.execute("SELECT funds_cents FROM companies LIMIT 1")
|
||||
row = cur.fetchone()
|
||||
funds = row[0] if row else 0
|
||||
|
||||
# Determine terminal reason from 'sim_events' table
|
||||
terminal_reason = "unknown"
|
||||
try:
|
||||
cur.execute(
|
||||
"SELECT event_type FROM sim_events "
|
||||
"WHERE event_type IN ('bankruptcy', 'horizon_end') "
|
||||
"ORDER BY scheduled_at DESC LIMIT 1"
|
||||
)
|
||||
event_row = cur.fetchone()
|
||||
if event_row:
|
||||
terminal_reason = event_row[0]
|
||||
except sqlite3.OperationalError:
|
||||
# Table may not exist if simulation didn't progress
|
||||
pass
|
||||
|
||||
survived = funds >= 0 and terminal_reason != "bankruptcy"
|
||||
return {
|
||||
"final_funds_cents": funds,
|
||||
"survived": survived,
|
||||
"terminal_reason": terminal_reason,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to read DB %s: %s", db_path, e)
|
||||
return {
|
||||
"final_funds_cents": 0,
|
||||
"survived": False,
|
||||
"terminal_reason": f"db_error: {e}",
|
||||
}
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
|
||||
def _compute_composite_score(
|
||||
final_funds_cents: int,
|
||||
survived: bool,
|
||||
survival_weight: float = 0.5,
|
||||
funds_weight: float = 0.5,
|
||||
initial_funds_cents: int = INITIAL_FUNDS_CENTS,
|
||||
) -> float:
|
||||
"""
|
||||
Compute composite score from survival and final funds.
|
||||
|
||||
Score = survival_weight * survival_score
|
||||
+ funds_weight * normalised_funds_score
|
||||
|
||||
Normalised funds uses log-scale relative to initial capital:
|
||||
- funds <= 0: 0.0
|
||||
- funds == initial: ~0.15
|
||||
- funds == 10x: ~0.52
|
||||
- funds == 100x: 1.0
|
||||
"""
|
||||
survival_score = 1.0 if survived else 0.0
|
||||
|
||||
if final_funds_cents <= 0:
|
||||
funds_score = 0.0
|
||||
else:
|
||||
max_ratio = 100.0
|
||||
ratio = final_funds_cents / max(initial_funds_cents, 1)
|
||||
funds_score = min(math.log1p(ratio) / math.log1p(max_ratio), 1.0)
|
||||
|
||||
return survival_weight * survival_score + funds_weight * funds_score
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Main Environment
|
||||
# =============================================================================
|
||||
|
||||
class YCBenchEvalEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
YC-Bench long-horizon agent benchmark environment (eval-only).
|
||||
|
||||
Each eval item is a (preset, seed) pair. The environment initialises the
|
||||
simulation via ``yc-bench sim init`` (NOT ``yc-bench run`` which would start
|
||||
a competing built-in agent loop). The HermesAgentLoop then drives the
|
||||
interaction by calling individual yc-bench CLI commands via the terminal tool.
|
||||
|
||||
After the agent loop ends, the SQLite DB is read to extract the final score.
|
||||
|
||||
Scoring:
|
||||
composite = 0.5 * survival + 0.5 * normalised_funds
|
||||
"""
|
||||
|
||||
name = "yc-bench"
|
||||
env_config_cls = YCBenchEvalConfig
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[YCBenchEvalConfig, List[APIServerConfig]]:
|
||||
env_config = YCBenchEvalConfig(
|
||||
enabled_toolsets=["terminal"],
|
||||
disabled_toolsets=None,
|
||||
distribution=None,
|
||||
max_agent_turns=200,
|
||||
max_token_length=32000,
|
||||
agent_temperature=0.0,
|
||||
system_prompt=YC_BENCH_SYSTEM_PROMPT,
|
||||
terminal_backend="local",
|
||||
terminal_timeout=60,
|
||||
presets=["fast_test", "medium", "hard"],
|
||||
seeds=[1, 2, 3],
|
||||
run_timeout=3600,
|
||||
survival_weight=0.5,
|
||||
funds_weight=0.5,
|
||||
db_dir="/tmp/yc_bench_dbs",
|
||||
eval_handling=EvalHandlingEnum.STOP_TRAIN,
|
||||
group_size=1,
|
||||
steps_per_eval=1,
|
||||
total_steps=1,
|
||||
tokenizer_name="NousResearch/Hermes-3-Llama-3.1-8B",
|
||||
use_wandb=True,
|
||||
wandb_name="yc-bench",
|
||||
ensure_scores_are_not_same=False,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model_name="anthropic/claude-sonnet-4.6",
|
||||
server_type="openai",
|
||||
api_key=os.getenv("OPENROUTER_API_KEY", ""),
|
||||
health_check=False,
|
||||
)
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
# =========================================================================
|
||||
# Setup
|
||||
# =========================================================================
|
||||
|
||||
async def setup(self):
|
||||
"""Verify yc-bench is installed and build the eval matrix."""
|
||||
# Verify yc-bench CLI is available
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["yc-bench", "--help"], capture_output=True, text=True, timeout=10
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise FileNotFoundError
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||
raise RuntimeError(
|
||||
"yc-bench CLI not found. Install with:\n"
|
||||
' pip install "hermes-agent[yc-bench]"\n'
|
||||
"Or: git clone https://github.com/collinear-ai/yc-bench "
|
||||
"&& cd yc-bench && pip install -e ."
|
||||
)
|
||||
print("yc-bench CLI verified.")
|
||||
|
||||
# Build eval matrix: preset x seed
|
||||
self.all_eval_items = [
|
||||
{"preset": preset, "seed": seed}
|
||||
for preset in self.config.presets
|
||||
for seed in self.config.seeds
|
||||
]
|
||||
self.iter = 0
|
||||
|
||||
os.makedirs(self.config.db_dir, exist_ok=True)
|
||||
self.eval_metrics: List[Tuple[str, float]] = []
|
||||
|
||||
# Streaming JSONL log for crash-safe result persistence
|
||||
log_dir = os.path.join(os.path.dirname(__file__), "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
run_ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
self._streaming_path = os.path.join(log_dir, f"samples_{run_ts}.jsonl")
|
||||
self._streaming_file = open(self._streaming_path, "w", encoding="utf-8")
|
||||
self._streaming_lock = threading.Lock()
|
||||
|
||||
print(f"\nYC-Bench eval matrix: {len(self.all_eval_items)} runs")
|
||||
for item in self.all_eval_items:
|
||||
print(f" preset={item['preset']!r} seed={item['seed']}")
|
||||
print(f"Streaming results to: {self._streaming_path}\n")
|
||||
|
||||
def _save_result(self, result: Dict[str, Any]):
|
||||
"""Write a single run result to the streaming JSONL file immediately."""
|
||||
if not hasattr(self, "_streaming_file") or self._streaming_file.closed:
|
||||
return
|
||||
with self._streaming_lock:
|
||||
self._streaming_file.write(
|
||||
json.dumps(result, ensure_ascii=False, default=str) + "\n"
|
||||
)
|
||||
self._streaming_file.flush()
|
||||
|
||||
# =========================================================================
|
||||
# Training pipeline stubs (eval-only -- not used)
|
||||
# =========================================================================
|
||||
|
||||
async def get_next_item(self):
|
||||
item = self.all_eval_items[self.iter % len(self.all_eval_items)]
|
||||
self.iter += 1
|
||||
return item
|
||||
|
||||
def format_prompt(self, item: Dict[str, Any]) -> str:
|
||||
preset = item["preset"]
|
||||
seed = item["seed"]
|
||||
return (
|
||||
f"A new YC-Bench simulation has been initialized "
|
||||
f"(preset='{preset}', seed={seed}).\n"
|
||||
f"Your company '{self.config.company_name}' is ready.\n\n"
|
||||
"Begin by calling:\n"
|
||||
"1. `yc-bench company status` -- see your starting funds and prestige\n"
|
||||
"2. `yc-bench employee list` -- see your team and their skills\n"
|
||||
"3. `yc-bench market browse --required-prestige-lte 1` -- find tasks "
|
||||
"you can take\n\n"
|
||||
"Then accept 2-3 tasks, assign employees, dispatch them, and call "
|
||||
"`yc-bench sim resume` to advance time. Repeat this loop until the "
|
||||
"simulation ends (horizon reached or bankruptcy)."
|
||||
)
|
||||
|
||||
async def compute_reward(self, item, result, ctx) -> float:
|
||||
return 0.0
|
||||
|
||||
async def collect_trajectories(self, item):
|
||||
return None, []
|
||||
|
||||
async def score(self, rollout_group_data):
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# Per-run evaluation
|
||||
# =========================================================================
|
||||
|
||||
async def rollout_and_score_eval(self, eval_item: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Evaluate a single (preset, seed) run.
|
||||
|
||||
1. Sets DATABASE_URL and YC_BENCH_EXPERIMENT env vars
|
||||
2. Initialises the simulation via ``yc-bench sim init`` (NOT ``run``)
|
||||
3. Runs HermesAgentLoop with terminal tool
|
||||
4. Reads SQLite DB to compute final score
|
||||
5. Returns result dict with survival, funds, and composite score
|
||||
"""
|
||||
preset = eval_item["preset"]
|
||||
seed = eval_item["seed"]
|
||||
run_id = str(uuid.uuid4())[:8]
|
||||
run_key = f"{preset}_seed{seed}_{run_id}"
|
||||
|
||||
from tqdm import tqdm
|
||||
tqdm.write(f" [START] preset={preset!r} seed={seed} (run_id={run_id})")
|
||||
run_start = time.time()
|
||||
|
||||
# Isolated DB per run -- prevents cross-run state leakage
|
||||
db_path = os.path.join(self.config.db_dir, f"yc_bench_{run_key}.db")
|
||||
os.environ["DATABASE_URL"] = f"sqlite:///{db_path}"
|
||||
os.environ["YC_BENCH_EXPERIMENT"] = preset
|
||||
|
||||
# Determine horizon: explicit config override > preset lookup > default 1
|
||||
horizon = self.config.horizon_years or _PRESET_HORIZONS.get(preset, 1)
|
||||
|
||||
try:
|
||||
# ----------------------------------------------------------
|
||||
# Step 1: Initialise the simulation via CLI
|
||||
# IMPORTANT: We use `sim init`, NOT `yc-bench run`.
|
||||
# `yc-bench run` starts yc-bench's own LLM agent loop (via
|
||||
# LiteLLM), which would compete with our HermesAgentLoop.
|
||||
# `sim init` just sets up the world and returns.
|
||||
# ----------------------------------------------------------
|
||||
init_cmd = [
|
||||
"yc-bench", "sim", "init",
|
||||
"--seed", str(seed),
|
||||
"--start-date", self.config.start_date,
|
||||
"--company-name", self.config.company_name,
|
||||
"--horizon-years", str(horizon),
|
||||
]
|
||||
init_result = subprocess.run(
|
||||
init_cmd, capture_output=True, text=True, timeout=30,
|
||||
)
|
||||
if init_result.returncode != 0:
|
||||
error_msg = (init_result.stderr or init_result.stdout).strip()
|
||||
raise RuntimeError(f"yc-bench sim init failed: {error_msg}")
|
||||
|
||||
tqdm.write(f" Simulation initialized (horizon={horizon}yr)")
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# Step 2: Run the HermesAgentLoop
|
||||
# ----------------------------------------------------------
|
||||
tools, valid_names = self._resolve_tools_for_group()
|
||||
|
||||
messages: List[Dict[str, Any]] = [
|
||||
{"role": "system", "content": YC_BENCH_SYSTEM_PROMPT},
|
||||
{"role": "user", "content": self.format_prompt(eval_item)},
|
||||
]
|
||||
|
||||
agent = HermesAgentLoop(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=run_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
budget_config=self.config.build_budget_config(),
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# Step 3: Read final score from the simulation DB
|
||||
# ----------------------------------------------------------
|
||||
score_data = _read_final_score(db_path)
|
||||
final_funds = score_data["final_funds_cents"]
|
||||
survived = score_data["survived"]
|
||||
terminal_reason = score_data["terminal_reason"]
|
||||
|
||||
composite = _compute_composite_score(
|
||||
final_funds_cents=final_funds,
|
||||
survived=survived,
|
||||
survival_weight=self.config.survival_weight,
|
||||
funds_weight=self.config.funds_weight,
|
||||
)
|
||||
|
||||
elapsed = time.time() - run_start
|
||||
status = "SURVIVED" if survived else "BANKRUPT"
|
||||
if final_funds >= 0:
|
||||
funds_str = f"${final_funds / 100:,.0f}"
|
||||
else:
|
||||
funds_str = f"-${abs(final_funds) / 100:,.0f}"
|
||||
|
||||
tqdm.write(
|
||||
f" [{status}] preset={preset!r} seed={seed} "
|
||||
f"funds={funds_str} score={composite:.3f} "
|
||||
f"turns={result.turns_used} ({elapsed:.0f}s)"
|
||||
)
|
||||
|
||||
out = {
|
||||
"preset": preset,
|
||||
"seed": seed,
|
||||
"survived": survived,
|
||||
"final_funds_cents": final_funds,
|
||||
"final_funds_usd": final_funds / 100,
|
||||
"terminal_reason": terminal_reason,
|
||||
"composite_score": composite,
|
||||
"turns_used": result.turns_used,
|
||||
"finished_naturally": result.finished_naturally,
|
||||
"elapsed_seconds": elapsed,
|
||||
"db_path": db_path,
|
||||
"messages": result.messages,
|
||||
}
|
||||
self._save_result(out)
|
||||
return out
|
||||
|
||||
except Exception as e:
|
||||
elapsed = time.time() - run_start
|
||||
logger.error("Run %s failed: %s", run_key, e, exc_info=True)
|
||||
tqdm.write(
|
||||
f" [ERROR] preset={preset!r} seed={seed}: {e} ({elapsed:.0f}s)"
|
||||
)
|
||||
out = {
|
||||
"preset": preset,
|
||||
"seed": seed,
|
||||
"survived": False,
|
||||
"final_funds_cents": 0,
|
||||
"final_funds_usd": 0.0,
|
||||
"terminal_reason": f"error: {e}",
|
||||
"composite_score": 0.0,
|
||||
"turns_used": 0,
|
||||
"error": str(e),
|
||||
"elapsed_seconds": elapsed,
|
||||
}
|
||||
self._save_result(out)
|
||||
return out
|
||||
|
||||
# =========================================================================
|
||||
# Evaluate
|
||||
# =========================================================================
|
||||
|
||||
async def _run_with_timeout(self, item: Dict[str, Any]) -> Dict:
|
||||
"""Wrap a single rollout with a wall-clock timeout."""
|
||||
preset = item["preset"]
|
||||
seed = item["seed"]
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
self.rollout_and_score_eval(item),
|
||||
timeout=self.config.run_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
from tqdm import tqdm
|
||||
tqdm.write(
|
||||
f" [TIMEOUT] preset={preset!r} seed={seed} "
|
||||
f"(exceeded {self.config.run_timeout}s)"
|
||||
)
|
||||
out = {
|
||||
"preset": preset,
|
||||
"seed": seed,
|
||||
"survived": False,
|
||||
"final_funds_cents": 0,
|
||||
"final_funds_usd": 0.0,
|
||||
"terminal_reason": f"timeout ({self.config.run_timeout}s)",
|
||||
"composite_score": 0.0,
|
||||
"turns_used": 0,
|
||||
"error": "timeout",
|
||||
}
|
||||
self._save_result(out)
|
||||
return out
|
||||
|
||||
async def evaluate(self, *args, **kwargs) -> None:
|
||||
"""
|
||||
Run YC-Bench evaluation over all (preset, seed) combinations.
|
||||
|
||||
Runs sequentially -- each run is 100-500 turns, parallelising would
|
||||
be prohibitively expensive and cause env var conflicts.
|
||||
"""
|
||||
start_time = time.time()
|
||||
from tqdm import tqdm
|
||||
|
||||
# --- tqdm-compatible logging handler (TB2 pattern) ---
|
||||
class _TqdmHandler(logging.Handler):
|
||||
def emit(self, record):
|
||||
try:
|
||||
tqdm.write(self.format(record))
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
root = logging.getLogger()
|
||||
handler = _TqdmHandler()
|
||||
handler.setFormatter(
|
||||
logging.Formatter("%(levelname)s %(name)s: %(message)s")
|
||||
)
|
||||
root.handlers = [handler]
|
||||
for noisy in ("httpx", "openai"):
|
||||
logging.getLogger(noisy).setLevel(logging.WARNING)
|
||||
|
||||
# --- Print config summary ---
|
||||
print(f"\n{'='*60}")
|
||||
print("Starting YC-Bench Evaluation")
|
||||
print(f"{'='*60}")
|
||||
print(f" Presets: {self.config.presets}")
|
||||
print(f" Seeds: {self.config.seeds}")
|
||||
print(f" Total runs: {len(self.all_eval_items)}")
|
||||
print(f" Max turns/run: {self.config.max_agent_turns}")
|
||||
print(f" Run timeout: {self.config.run_timeout}s")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
results = []
|
||||
pbar = tqdm(
|
||||
total=len(self.all_eval_items), desc="YC-Bench", dynamic_ncols=True
|
||||
)
|
||||
|
||||
try:
|
||||
for item in self.all_eval_items:
|
||||
result = await self._run_with_timeout(item)
|
||||
results.append(result)
|
||||
survived_count = sum(1 for r in results if r.get("survived"))
|
||||
pbar.set_postfix_str(
|
||||
f"survived={survived_count}/{len(results)}"
|
||||
)
|
||||
pbar.update(1)
|
||||
|
||||
except (KeyboardInterrupt, asyncio.CancelledError):
|
||||
tqdm.write("\n[INTERRUPTED] Stopping evaluation...")
|
||||
pbar.close()
|
||||
try:
|
||||
from tools.terminal_tool import cleanup_all_environments
|
||||
cleanup_all_environments()
|
||||
except Exception:
|
||||
pass
|
||||
if hasattr(self, "_streaming_file") and not self._streaming_file.closed:
|
||||
self._streaming_file.close()
|
||||
return
|
||||
|
||||
pbar.close()
|
||||
end_time = time.time()
|
||||
|
||||
# --- Compute metrics ---
|
||||
valid = [r for r in results if r is not None]
|
||||
if not valid:
|
||||
print("Warning: No valid results.")
|
||||
return
|
||||
|
||||
total = len(valid)
|
||||
survived_total = sum(1 for r in valid if r.get("survived"))
|
||||
survival_rate = survived_total / total if total else 0.0
|
||||
avg_score = (
|
||||
sum(r.get("composite_score", 0) for r in valid) / total
|
||||
if total
|
||||
else 0.0
|
||||
)
|
||||
|
||||
preset_results: Dict[str, List[Dict]] = defaultdict(list)
|
||||
for r in valid:
|
||||
preset_results[r["preset"]].append(r)
|
||||
|
||||
eval_metrics = {
|
||||
"eval/survival_rate": survival_rate,
|
||||
"eval/avg_composite_score": avg_score,
|
||||
"eval/total_runs": total,
|
||||
"eval/survived_runs": survived_total,
|
||||
"eval/evaluation_time_seconds": end_time - start_time,
|
||||
}
|
||||
|
||||
for preset, items in sorted(preset_results.items()):
|
||||
ps = sum(1 for r in items if r.get("survived"))
|
||||
pt = len(items)
|
||||
pa = (
|
||||
sum(r.get("composite_score", 0) for r in items) / pt
|
||||
if pt
|
||||
else 0
|
||||
)
|
||||
key = preset.replace("-", "_")
|
||||
eval_metrics[f"eval/survival_rate_{key}"] = ps / pt if pt else 0
|
||||
eval_metrics[f"eval/avg_score_{key}"] = pa
|
||||
|
||||
self.eval_metrics = list(eval_metrics.items())
|
||||
|
||||
# --- Print summary ---
|
||||
print(f"\n{'='*60}")
|
||||
print("YC-Bench Evaluation Results")
|
||||
print(f"{'='*60}")
|
||||
print(
|
||||
f"Overall survival rate: {survival_rate:.1%} "
|
||||
f"({survived_total}/{total})"
|
||||
)
|
||||
print(f"Average composite score: {avg_score:.4f}")
|
||||
print(f"Evaluation time: {end_time - start_time:.1f}s")
|
||||
|
||||
print("\nPer-preset breakdown:")
|
||||
for preset, items in sorted(preset_results.items()):
|
||||
ps = sum(1 for r in items if r.get("survived"))
|
||||
pt = len(items)
|
||||
pa = (
|
||||
sum(r.get("composite_score", 0) for r in items) / pt
|
||||
if pt
|
||||
else 0
|
||||
)
|
||||
print(f" {preset}: {ps}/{pt} survived avg_score={pa:.4f}")
|
||||
for r in items:
|
||||
status = "SURVIVED" if r.get("survived") else "BANKRUPT"
|
||||
funds = r.get("final_funds_usd", 0)
|
||||
print(
|
||||
f" seed={r['seed']} [{status}] "
|
||||
f"${funds:,.0f} "
|
||||
f"score={r.get('composite_score', 0):.3f}"
|
||||
)
|
||||
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# --- Log results ---
|
||||
samples = [
|
||||
{k: v for k, v in r.items() if k != "messages"} for r in valid
|
||||
]
|
||||
|
||||
try:
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
samples=samples,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
generation_parameters={
|
||||
"temperature": self.config.agent_temperature,
|
||||
"max_tokens": self.config.max_token_length,
|
||||
"max_agent_turns": self.config.max_agent_turns,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error logging results: {e}")
|
||||
|
||||
# --- Cleanup (TB2 pattern) ---
|
||||
if hasattr(self, "_streaming_file") and not self._streaming_file.closed:
|
||||
self._streaming_file.close()
|
||||
print(f"Results saved to: {self._streaming_path}")
|
||||
|
||||
try:
|
||||
from tools.terminal_tool import cleanup_all_environments
|
||||
cleanup_all_environments()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
from environments.agent_loop import _tool_executor
|
||||
_tool_executor.shutdown(wait=False, cancel_futures=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# =========================================================================
|
||||
# Wandb logging
|
||||
# =========================================================================
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log YC-Bench-specific metrics to wandb."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
for k, v in self.eval_metrics:
|
||||
wandb_metrics[k] = v
|
||||
self.eval_metrics = []
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
YCBenchEvalEnv.cli()
|
||||
@@ -1,714 +0,0 @@
|
||||
"""
|
||||
HermesAgentBaseEnv -- Abstract Base Environment for Hermes-Agent + Atropos
|
||||
|
||||
Provides the Atropos integration plumbing that all hermes-agent environments share:
|
||||
- Two-mode operation (OpenAI server for Phase 1, VLLM ManagedServer for Phase 2)
|
||||
- Per-group toolset/distribution resolution
|
||||
- Agent loop orchestration via HermesAgentLoop
|
||||
- ToolContext creation for reward functions
|
||||
- ScoredDataGroup construction from ManagedServer state
|
||||
|
||||
Subclasses only need to implement:
|
||||
setup() -- Load dataset, initialize state
|
||||
get_next_item() -- Return the next item from the dataset
|
||||
format_prompt() -- Convert a dataset item into the user message
|
||||
compute_reward() -- Score the rollout (has full ToolContext access)
|
||||
evaluate() -- Periodic evaluation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
# Ensure the hermes-agent repo root is on sys.path so that imports like
|
||||
# `from model_tools import ...` and `from environments.X import ...` work
|
||||
# regardless of where the script is invoked from.
|
||||
_repo_root = Path(__file__).resolve().parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field
|
||||
|
||||
# Load API keys from hermes-agent/.env so all environments can access them
|
||||
_env_path = _repo_root / ".env"
|
||||
if _env_path.exists():
|
||||
load_dotenv(dotenv_path=_env_path)
|
||||
|
||||
# Apply monkey patches for async-safe tool operation inside Atropos's event loop.
|
||||
# This patches SwerexModalEnvironment to use a background thread instead of
|
||||
# asyncio.run(), which would deadlock inside Atropos. Safe for normal CLI too.
|
||||
from environments.patches import apply_patches
|
||||
apply_patches()
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
ScoredDataItem,
|
||||
)
|
||||
from atroposlib.envs.server_handling.server_manager import (
|
||||
APIServerConfig,
|
||||
ServerBaseline,
|
||||
ServerManager,
|
||||
)
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
from environments.agent_loop import AgentResult, HermesAgentLoop
|
||||
from environments.tool_context import ToolContext
|
||||
from tools.budget_config import (
|
||||
DEFAULT_RESULT_SIZE_CHARS,
|
||||
DEFAULT_TURN_BUDGET_CHARS,
|
||||
DEFAULT_PREVIEW_SIZE_CHARS,
|
||||
)
|
||||
|
||||
# Import hermes-agent toolset infrastructure
|
||||
from model_tools import get_tool_definitions
|
||||
from toolset_distributions import sample_toolsets_from_distribution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HermesAgentEnvConfig(BaseEnvConfig):
|
||||
"""
|
||||
Configuration for hermes-agent Atropos environments.
|
||||
|
||||
Extends BaseEnvConfig with agent-specific settings for toolsets,
|
||||
terminal backend, dataset loading, and tool call parsing.
|
||||
"""
|
||||
|
||||
# --- Toolset configuration ---
|
||||
# Mutually exclusive: use either enabled_toolsets OR distribution
|
||||
enabled_toolsets: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="Explicit list of hermes toolsets to enable (e.g., ['terminal', 'file', 'web']). "
|
||||
"If None and distribution is also None, all available toolsets are enabled.",
|
||||
)
|
||||
disabled_toolsets: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="Toolsets to disable. Applied as a filter on top of enabled_toolsets or distribution.",
|
||||
)
|
||||
distribution: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Name of a toolset distribution from toolset_distributions.py "
|
||||
"(e.g., 'development', 'terminal_tasks'). Sampled once per group. "
|
||||
"Mutually exclusive with enabled_toolsets.",
|
||||
)
|
||||
|
||||
# --- Agent loop configuration ---
|
||||
max_agent_turns: int = Field(
|
||||
default=30,
|
||||
description="Maximum number of LLM calls (tool-calling iterations) per rollout.",
|
||||
)
|
||||
system_prompt: Optional[str] = Field(
|
||||
default=None,
|
||||
description="System prompt for the agent. Tools are handled via the tools= parameter, "
|
||||
"not embedded in the prompt text.",
|
||||
)
|
||||
agent_temperature: float = Field(
|
||||
default=1.0,
|
||||
description="Sampling temperature for agent generation during rollouts.",
|
||||
)
|
||||
|
||||
# --- Terminal backend ---
|
||||
terminal_backend: str = Field(
|
||||
default="local",
|
||||
description="Terminal backend: 'local', 'docker', 'modal', 'daytona', 'ssh', 'singularity'. "
|
||||
"Modal or Daytona recommended for production RL (cloud isolation per rollout).",
|
||||
)
|
||||
terminal_timeout: int = Field(
|
||||
default=120,
|
||||
description="Per-command timeout in seconds for terminal tool calls. "
|
||||
"Commands exceeding this are killed. Increase for tasks with long-running "
|
||||
"commands (compilation, pip install, etc.).",
|
||||
)
|
||||
terminal_lifetime: int = Field(
|
||||
default=3600,
|
||||
description="Sandbox inactivity lifetime in seconds. The cleanup thread kills "
|
||||
"sandboxes that have been idle longer than this. Must be longer than "
|
||||
"the longest gap between tool calls (e.g., waiting for LLM response).",
|
||||
)
|
||||
|
||||
# --- Dataset ---
|
||||
dataset_name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="HuggingFace dataset name. Optional if tasks are defined inline.",
|
||||
)
|
||||
dataset_split: str = Field(
|
||||
default="train",
|
||||
description="Dataset split to use.",
|
||||
)
|
||||
prompt_field: str = Field(
|
||||
default="prompt",
|
||||
description="Which field in the dataset contains the prompt.",
|
||||
)
|
||||
|
||||
# --- Thread pool ---
|
||||
tool_pool_size: int = Field(
|
||||
default=128,
|
||||
description="Thread pool size for tool execution. Each concurrent task needs a "
|
||||
"thread for tool calls. Must be large enough for parallel evaluation. "
|
||||
"Too small = thread pool starvation.",
|
||||
)
|
||||
|
||||
# --- Phase 2: Tool call parsing ---
|
||||
tool_call_parser: str = Field(
|
||||
default="hermes",
|
||||
description="Tool call parser name for Phase 2 (VLLM server type). "
|
||||
"Ignored in Phase 1 (OpenAI server type where VLLM parses natively). "
|
||||
"Options: hermes, mistral, llama3_json, qwen, deepseek_v3, etc.",
|
||||
)
|
||||
|
||||
# --- Tool result budget ---
|
||||
# Defaults imported from tools.budget_config (single source of truth).
|
||||
default_result_size_chars: int = Field(
|
||||
default=DEFAULT_RESULT_SIZE_CHARS,
|
||||
description="Default per-tool threshold (chars) for persisting large results "
|
||||
"to sandbox. Results exceeding this are written to /tmp/hermes-results/ "
|
||||
"and replaced with a preview. Per-tool registry values take precedence "
|
||||
"unless overridden via tool_result_overrides.",
|
||||
)
|
||||
turn_budget_chars: int = Field(
|
||||
default=DEFAULT_TURN_BUDGET_CHARS,
|
||||
description="Aggregate char budget per assistant turn. If all tool results "
|
||||
"in a single turn exceed this, the largest are persisted to disk first.",
|
||||
)
|
||||
preview_size_chars: int = Field(
|
||||
default=DEFAULT_PREVIEW_SIZE_CHARS,
|
||||
description="Size of the inline preview shown after a tool result is persisted.",
|
||||
)
|
||||
tool_result_overrides: Optional[Dict[str, int]] = Field(
|
||||
default=None,
|
||||
description="Per-tool threshold overrides (chars). Keys are tool names, "
|
||||
"values are char thresholds. Overrides both the default and registry "
|
||||
"per-tool values. Example: {'terminal': 10000, 'search_files': 5000}. "
|
||||
"Note: read_file is pinned to infinity and cannot be overridden.",
|
||||
)
|
||||
|
||||
# --- Provider-specific parameters ---
|
||||
# Passed as extra_body to the OpenAI client's chat.completions.create() call.
|
||||
# Useful for OpenRouter provider preferences, transforms, route settings, etc.
|
||||
# Example YAML:
|
||||
# extra_body:
|
||||
# provider:
|
||||
# ignore: ["DeepInfra", "Fireworks"]
|
||||
# order: ["Together"]
|
||||
# transforms: ["middle-out"]
|
||||
extra_body: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="Extra body parameters passed to the OpenAI client's "
|
||||
"chat.completions.create(). Used for OpenRouter provider preferences, "
|
||||
"transforms, and other provider-specific settings.",
|
||||
)
|
||||
|
||||
def build_budget_config(self):
|
||||
"""Build a BudgetConfig from env config fields."""
|
||||
from tools.budget_config import BudgetConfig
|
||||
return BudgetConfig(
|
||||
default_result_size=self.default_result_size_chars,
|
||||
turn_budget=self.turn_budget_chars,
|
||||
preview_size=self.preview_size_chars,
|
||||
tool_overrides=dict(self.tool_result_overrides) if self.tool_result_overrides else {},
|
||||
)
|
||||
|
||||
|
||||
class HermesAgentBaseEnv(BaseEnv):
|
||||
"""
|
||||
Abstract base environment for hermes-agent Atropos integration.
|
||||
|
||||
Handles two modes of operation:
|
||||
- Phase 1 (OpenAI server type): Uses server.chat_completion() directly.
|
||||
The server (VLLM, SGLang, OpenRouter, OpenAI) handles tool call parsing
|
||||
and reasoning extraction natively. DummyManagedServer provides placeholder
|
||||
tokens. Good for SFT data gen, verifier testing, evaluation.
|
||||
|
||||
- Phase 2 (VLLM server type): Uses ManagedServer for exact token IDs + logprobs
|
||||
via /generate. Client-side tool call parser reconstructs structured tool_calls
|
||||
from raw output. Full RL training capability.
|
||||
|
||||
Subclasses must implement:
|
||||
setup() -- Load dataset, initialize state
|
||||
get_next_item() -- Return the next item to roll out
|
||||
format_prompt() -- Convert a dataset item into the user message string
|
||||
compute_reward() -- Score the rollout using ToolContext
|
||||
evaluate() -- Periodic evaluation
|
||||
"""
|
||||
|
||||
name: Optional[str] = "hermes-agent"
|
||||
env_config_cls = HermesAgentEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: HermesAgentEnvConfig,
|
||||
server_configs: Union[ServerBaseline, List[APIServerConfig]],
|
||||
slurm=False,
|
||||
testing=False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
|
||||
# Set terminal environment variables so hermes tools pick them up.
|
||||
# These can all be overridden per-environment via config fields instead
|
||||
# of requiring users to set shell env vars.
|
||||
if config.terminal_backend:
|
||||
os.environ["TERMINAL_ENV"] = config.terminal_backend
|
||||
os.environ["TERMINAL_TIMEOUT"] = str(config.terminal_timeout)
|
||||
os.environ["TERMINAL_LIFETIME_SECONDS"] = str(config.terminal_lifetime)
|
||||
print(
|
||||
f"🖥️ Terminal: backend={config.terminal_backend}, "
|
||||
f"timeout={config.terminal_timeout}s, lifetime={config.terminal_lifetime}s"
|
||||
)
|
||||
|
||||
# Resize the agent loop's thread pool for tool execution.
|
||||
# This must be large enough for the number of concurrent tasks
|
||||
# (e.g., 89 parallel TB2 eval tasks each need a thread for tool calls).
|
||||
from environments.agent_loop import resize_tool_pool
|
||||
resize_tool_pool(config.tool_pool_size)
|
||||
|
||||
# Set tool_parser on the ServerManager so ManagedServer uses it
|
||||
# for bidirectional tool call translation (raw text ↔ OpenAI tool_calls).
|
||||
if hasattr(self.server, 'tool_parser'):
|
||||
self.server.tool_parser = config.tool_call_parser
|
||||
print(f"🔧 Tool parser: {config.tool_call_parser}")
|
||||
|
||||
# Current group's resolved tools (set in collect_trajectories)
|
||||
self._current_group_tools: Optional[Tuple[List[Dict], Set[str]]] = None
|
||||
|
||||
# Tool error tracking for wandb logging
|
||||
self._tool_error_buffer: List[Dict[str, Any]] = []
|
||||
|
||||
# =========================================================================
|
||||
# Toolset resolution (per-group)
|
||||
# =========================================================================
|
||||
|
||||
def _resolve_tools_for_group(self) -> Tuple[List[Dict[str, Any]], Set[str]]:
|
||||
"""
|
||||
Resolve toolsets for a group. Called once in collect_trajectories(),
|
||||
then shared by all collect_trajectory() calls in the group.
|
||||
|
||||
If distribution is set, samples probabilistically.
|
||||
If enabled_toolsets is set, uses that explicit list.
|
||||
disabled_toolsets is applied as a filter on top.
|
||||
|
||||
Returns:
|
||||
(tool_schemas, valid_tool_names) tuple
|
||||
"""
|
||||
config = self.config
|
||||
|
||||
if config.distribution:
|
||||
group_toolsets = sample_toolsets_from_distribution(config.distribution)
|
||||
logger.info("Sampled toolsets from '%s': %s", config.distribution, group_toolsets)
|
||||
else:
|
||||
group_toolsets = config.enabled_toolsets # None means "all available"
|
||||
if group_toolsets is None:
|
||||
logger.warning(
|
||||
"enabled_toolsets is None -- loading ALL tools including messaging. "
|
||||
"Set explicit enabled_toolsets for RL training."
|
||||
)
|
||||
|
||||
tools = get_tool_definitions(
|
||||
enabled_toolsets=group_toolsets,
|
||||
disabled_toolsets=config.disabled_toolsets,
|
||||
quiet_mode=True,
|
||||
)
|
||||
|
||||
valid_names = {t["function"]["name"] for t in tools} if tools else set()
|
||||
logger.info("Resolved %d tools for group: %s", len(valid_names), sorted(valid_names))
|
||||
return tools, valid_names
|
||||
|
||||
# =========================================================================
|
||||
# Server mode detection
|
||||
# =========================================================================
|
||||
|
||||
def _use_managed_server(self) -> bool:
|
||||
"""
|
||||
Determine if we should use ManagedServer (Phase 2) or direct server (Phase 1).
|
||||
|
||||
Phase 2 (ManagedServer) is used when the server type is 'vllm' or 'sglang',
|
||||
which go through the /generate endpoint for exact token tracking.
|
||||
|
||||
Phase 1 (direct server) is used for 'openai' server type, which uses
|
||||
/v1/chat/completions with native tool call parsing.
|
||||
"""
|
||||
if not self.server.servers:
|
||||
return False
|
||||
|
||||
server = self.server.servers[0]
|
||||
# If the server is an OpenAI server (not VLLM/SGLang), use direct mode
|
||||
from atroposlib.envs.server_handling.openai_server import OpenAIServer
|
||||
return not isinstance(server, OpenAIServer)
|
||||
|
||||
# =========================================================================
|
||||
# Core Atropos integration
|
||||
# =========================================================================
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item: Item
|
||||
) -> Tuple[
|
||||
Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]],
|
||||
List[Item],
|
||||
]:
|
||||
"""
|
||||
Override collect_trajectories to resolve toolsets once per group,
|
||||
then delegate to the standard group-level collection.
|
||||
|
||||
The default BaseEnv.collect_trajectories() calls collect_trajectory()
|
||||
group_size times in parallel. We resolve tools once here and store
|
||||
them for all those calls to use.
|
||||
"""
|
||||
# Resolve toolsets for this group (shared by all rollouts in the group)
|
||||
self._current_group_tools = self._resolve_tools_for_group()
|
||||
|
||||
# Delegate to the default implementation which calls collect_trajectory()
|
||||
# group_size times via asyncio.gather
|
||||
return await super().collect_trajectories(item)
|
||||
|
||||
# =========================================================================
|
||||
# Wandb rollout display -- format trajectories nicely
|
||||
# =========================================================================
|
||||
|
||||
@staticmethod
|
||||
def _format_trajectory_for_display(messages: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Format a conversation's messages into a readable trajectory string
|
||||
for wandb rollout tables. Shows tool calls, tool results, and reasoning
|
||||
in a structured way instead of raw token decoding.
|
||||
"""
|
||||
parts = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content", "")
|
||||
|
||||
if role == "system":
|
||||
parts.append(f"[SYSTEM]\n{content}")
|
||||
|
||||
elif role == "user":
|
||||
parts.append(f"[USER]\n{content}")
|
||||
|
||||
elif role == "assistant":
|
||||
# Show reasoning if present
|
||||
reasoning = msg.get("reasoning_content", "")
|
||||
if reasoning:
|
||||
# Truncate long reasoning for display
|
||||
if len(reasoning) > 300:
|
||||
reasoning = reasoning[:300] + "..."
|
||||
parts.append(f"[ASSISTANT thinking]\n{reasoning}")
|
||||
|
||||
# Show content
|
||||
if content:
|
||||
parts.append(f"[ASSISTANT]\n{content}")
|
||||
|
||||
# Show tool calls
|
||||
tool_calls = msg.get("tool_calls", [])
|
||||
for tc in tool_calls:
|
||||
func = tc.get("function", {})
|
||||
name = func.get("name", "?")
|
||||
args = func.get("arguments", "{}")
|
||||
# Truncate long arguments for display
|
||||
if len(args) > 200:
|
||||
args = args[:200] + "..."
|
||||
parts.append(f"[TOOL CALL] {name}({args})")
|
||||
|
||||
elif role == "tool":
|
||||
tool_id = msg.get("tool_call_id", "")
|
||||
result = content
|
||||
# Truncate long tool results for display
|
||||
if len(result) > 500:
|
||||
result = result[:500] + "..."
|
||||
parts.append(f"[TOOL RESULT] {result}")
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
async def add_rollouts_for_wandb(
|
||||
self,
|
||||
scored_data,
|
||||
item=None,
|
||||
):
|
||||
"""
|
||||
Override to show formatted trajectories with tool calls visible,
|
||||
instead of raw token decoding which loses all structure.
|
||||
"""
|
||||
num_keep = self.config.num_rollouts_per_group_for_logging
|
||||
if num_keep == -1:
|
||||
num_keep = self.config.group_size
|
||||
|
||||
group = []
|
||||
for i in range(min(num_keep, len(scored_data.get("scores", [])))):
|
||||
score = scored_data["scores"][i]
|
||||
|
||||
# Use messages if available for rich display
|
||||
messages = None
|
||||
if scored_data.get("messages") and i < len(scored_data["messages"]):
|
||||
messages = scored_data["messages"][i]
|
||||
|
||||
if messages:
|
||||
text = self._format_trajectory_for_display(messages)
|
||||
elif scored_data.get("tokens") and i < len(scored_data["tokens"]):
|
||||
text = self.tokenizer.decode(scored_data["tokens"][i])
|
||||
else:
|
||||
text = "(no data)"
|
||||
|
||||
group.append((text, score))
|
||||
|
||||
self.rollouts_for_wandb.append(group)
|
||||
if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep:
|
||||
self.rollouts_for_wandb.pop(0)
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log base metrics including tool errors to wandb."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
# Log tool error stats
|
||||
if self._tool_error_buffer:
|
||||
wandb_metrics["train/tool_errors_count"] = len(self._tool_error_buffer)
|
||||
|
||||
# Log error details as a summary string (tables can crash wandb on tmp cleanup)
|
||||
error_summaries = []
|
||||
for err in self._tool_error_buffer:
|
||||
error_summaries.append(
|
||||
f"[turn {err['turn']}] {err['tool']}({err['args'][:80]}) -> {err['error'][:150]}"
|
||||
)
|
||||
wandb_metrics["train/tool_error_details"] = "\n".join(error_summaries)
|
||||
|
||||
# Also print to stdout for immediate visibility
|
||||
for summary in error_summaries:
|
||||
print(f" Tool Error: {summary}")
|
||||
|
||||
self._tool_error_buffer = []
|
||||
else:
|
||||
wandb_metrics["train/tool_errors_count"] = 0
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
async def collect_trajectory(
|
||||
self, item: Item
|
||||
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
|
||||
"""
|
||||
Run a single rollout: agent loop + reward computation.
|
||||
|
||||
This is called group_size times in parallel by collect_trajectories().
|
||||
Each call gets its own task_id for terminal/browser session isolation.
|
||||
"""
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# Get group-level tools (resolved once in collect_trajectories)
|
||||
if self._current_group_tools is None:
|
||||
# Fallback: resolve per-trajectory if called outside collect_trajectories
|
||||
tools, valid_names = self._resolve_tools_for_group()
|
||||
else:
|
||||
tools, valid_names = self._current_group_tools
|
||||
|
||||
# Build initial messages
|
||||
messages: List[Dict[str, Any]] = []
|
||||
if self.config.system_prompt:
|
||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
messages.append({"role": "user", "content": self.format_prompt(item)})
|
||||
|
||||
# Run the agent loop
|
||||
result: AgentResult
|
||||
if self._use_managed_server():
|
||||
# Phase 2: ManagedServer with ToolCallTranslator -- exact tokens + logprobs
|
||||
# tool_parser is set on ServerManager in __init__ and passed through
|
||||
# to ManagedServer, which uses ToolCallTranslator for bidirectional
|
||||
# translation between raw text and OpenAI tool_calls.
|
||||
try:
|
||||
async with self.server.managed_server(
|
||||
tokenizer=self.tokenizer,
|
||||
preserve_think_blocks=bool(self.config.thinking_mode),
|
||||
) as managed:
|
||||
agent = HermesAgentLoop(
|
||||
server=managed,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
budget_config=self.config.build_budget_config(),
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
except NotImplementedError:
|
||||
# DummyManagedServer not allowed -- fall back to Phase 1
|
||||
logger.warning(
|
||||
"ManagedServer not available (OpenAI server?). "
|
||||
"Falling back to direct server mode."
|
||||
)
|
||||
agent = HermesAgentLoop(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
budget_config=self.config.build_budget_config(),
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
else:
|
||||
# Phase 1: OpenAI server -- native tool_calls, placeholder tokens
|
||||
agent = HermesAgentLoop(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
budget_config=self.config.build_budget_config(),
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
# Skip reward computation if the agent loop produced no meaningful work
|
||||
# (e.g., API call failed on turn 1). No point spinning up a Modal sandbox
|
||||
# just to verify files that were never created.
|
||||
only_system_and_user = all(
|
||||
msg.get("role") in {"system", "user"} for msg in result.messages
|
||||
)
|
||||
if result.turns_used == 0 or only_system_and_user:
|
||||
logger.warning(
|
||||
"Agent loop produced no output (turns=%d, msgs=%d). Skipping reward.",
|
||||
result.turns_used, len(result.messages),
|
||||
)
|
||||
reward = 0.0
|
||||
else:
|
||||
# Compute reward using ToolContext (gives verifier full tool access)
|
||||
ctx = ToolContext(task_id)
|
||||
try:
|
||||
reward = await self.compute_reward(item, result, ctx)
|
||||
except Exception as e:
|
||||
logger.error("compute_reward failed: %s", e)
|
||||
reward = 0.0
|
||||
finally:
|
||||
ctx.cleanup()
|
||||
|
||||
# Track tool errors for wandb logging
|
||||
if result.tool_errors:
|
||||
for err in result.tool_errors:
|
||||
self._tool_error_buffer.append({
|
||||
"turn": err.turn,
|
||||
"tool": err.tool_name,
|
||||
"args": err.arguments[:150],
|
||||
"error": err.error[:300],
|
||||
"result": err.tool_result[:300],
|
||||
})
|
||||
|
||||
# Build ScoredDataItem from ManagedServer state
|
||||
# Phase 2: real tokens/masks/logprobs from SequenceNodes
|
||||
# Phase 1: placeholder tokens (still need a valid ScoredDataItem for the pipeline)
|
||||
nodes = (result.managed_state or {}).get("nodes", [])
|
||||
|
||||
if nodes:
|
||||
# Phase 2 (or DummyManagedServer): use actual node data
|
||||
node = nodes[-1] # Final sequence node = full trajectory
|
||||
scored_item: Dict[str, Any] = {
|
||||
"tokens": node.tokens,
|
||||
"masks": node.masked_tokens,
|
||||
"scores": reward,
|
||||
}
|
||||
|
||||
# Include logprobs if available (Phase 2)
|
||||
if hasattr(node, "logprobs") and node.logprobs:
|
||||
scored_item["advantages"] = None # Computed by trainer
|
||||
scored_item["ref_logprobs"] = None
|
||||
else:
|
||||
# Phase 1 with no managed state: create placeholder tokens
|
||||
# so the data pipeline doesn't break. These are NOT suitable
|
||||
# for training but allow process mode (SFT data gen) to work.
|
||||
# Tokenize the full conversation to get approximate tokens.
|
||||
full_text = "\n".join(
|
||||
msg.get("content", "") for msg in result.messages if msg.get("content")
|
||||
)
|
||||
if self.tokenizer:
|
||||
tokens = self.tokenizer.encode(full_text, add_special_tokens=True)
|
||||
else:
|
||||
tokens = list(range(min(len(full_text) // 4, 128)))
|
||||
|
||||
scored_item = {
|
||||
"tokens": tokens,
|
||||
"masks": [-100] + tokens[1:], # Mask first token as prompt
|
||||
"scores": reward,
|
||||
}
|
||||
|
||||
# Always include messages for wandb rollout display and data logging
|
||||
scored_item["messages"] = result.messages
|
||||
|
||||
return scored_item, []
|
||||
|
||||
# =========================================================================
|
||||
# Abstract methods -- subclasses must implement
|
||||
# =========================================================================
|
||||
|
||||
@abstractmethod
|
||||
async def setup(self):
|
||||
"""
|
||||
Load dataset, initialize state.
|
||||
|
||||
Called once when the environment starts. Typical implementation:
|
||||
self.dataset = load_dataset(self.config.dataset_name, split=self.config.dataset_split)
|
||||
self.iter = 0
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_next_item(self) -> Item:
|
||||
"""
|
||||
Return the next item from the dataset for rollout.
|
||||
|
||||
Called by the base env's main loop to get items for workers.
|
||||
Should cycle through the dataset.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def format_prompt(self, item: Item) -> str:
|
||||
"""
|
||||
Convert a dataset item into the user message for the agent.
|
||||
|
||||
Args:
|
||||
item: Dataset item (dict, tuple, etc.)
|
||||
|
||||
Returns:
|
||||
The prompt string to send to the agent
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def compute_reward(
|
||||
self, item: Item, result: AgentResult, ctx: ToolContext
|
||||
) -> float:
|
||||
"""
|
||||
Score the rollout. Has full access to:
|
||||
- item: the original dataset item (ground truth, test commands, etc.)
|
||||
- result: AgentResult with full messages, turn count, reasoning, etc.
|
||||
- ctx: ToolContext -- call ANY hermes-agent tool (terminal, file, web,
|
||||
browser, vision...) scoped to this rollout's sandbox. Nothing
|
||||
is off-limits.
|
||||
|
||||
Args:
|
||||
item: The dataset item that was rolled out
|
||||
result: The agent's rollout result
|
||||
ctx: ToolContext with full tool access for verification
|
||||
|
||||
Returns:
|
||||
Reward float (typically 0.0 to 1.0, but any float is valid)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""
|
||||
Periodic evaluation. Called every steps_per_eval steps.
|
||||
|
||||
Typical implementation runs the agent on a held-out eval set
|
||||
and logs metrics via wandb/evaluate_log.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -1,34 +0,0 @@
|
||||
# SWE Environment -- Default Configuration
|
||||
#
|
||||
# SWE-bench style tasks with Modal sandboxes for cloud isolation.
|
||||
# Uses terminal + file + web toolsets.
|
||||
#
|
||||
# Usage:
|
||||
# python environments/hermes_swe_env/hermes_swe_env.py serve \
|
||||
# --config environments/hermes_swe_env/default.yaml
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal", "file", "web"]
|
||||
max_agent_turns: 30
|
||||
max_token_length: 4096
|
||||
group_size: 4
|
||||
terminal_backend: "modal"
|
||||
tool_call_parser: "hermes"
|
||||
tokenizer_name: "NousResearch/DeepHermes-3-Llama-3-3B-Preview"
|
||||
dataset_name: "bigcode/humanevalpack"
|
||||
dataset_split: "test"
|
||||
prompt_field: "prompt"
|
||||
steps_per_eval: 50
|
||||
total_steps: 500
|
||||
use_wandb: true
|
||||
wandb_name: "hermes-swe"
|
||||
system_prompt: >
|
||||
You are a skilled software engineer. You have access to a terminal,
|
||||
file tools, and web search. Use these tools to complete the coding task.
|
||||
Write clean, working code and verify it runs correctly before finishing.
|
||||
|
||||
openai:
|
||||
base_url: "http://localhost:8000/v1"
|
||||
model_name: "NousResearch/DeepHermes-3-Llama-3-3B-Preview"
|
||||
server_type: "openai"
|
||||
api_key: ""
|
||||
@@ -1,229 +0,0 @@
|
||||
"""
|
||||
HermesSweEnv -- SWE-Bench Style Environment with Modal Sandboxes
|
||||
|
||||
A concrete environment for software engineering tasks where the model writes code
|
||||
and the reward function runs tests to verify correctness. Uses Modal terminal
|
||||
backend for cloud-isolated sandboxes per rollout.
|
||||
|
||||
The reward function uses ToolContext.terminal() to run test commands in the same
|
||||
Modal sandbox the model used during its agentic loop. All filesystem state from
|
||||
the model's tool calls is preserved for verification.
|
||||
|
||||
Usage:
|
||||
# Phase 1: OpenAI server type
|
||||
vllm serve YourModel --tool-parser hermes
|
||||
run-api
|
||||
python environments/hermes_swe_env.py serve \\
|
||||
--openai.base_url http://localhost:8000/v1 \\
|
||||
--openai.model_name YourModel \\
|
||||
--openai.server_type openai \\
|
||||
--env.dataset_name bigcode/humanevalpack \\
|
||||
--env.terminal_backend modal
|
||||
|
||||
# Phase 2: VLLM server type (full RL training)
|
||||
python environments/hermes_swe_env.py serve \\
|
||||
--openai.base_url http://localhost:8000/v1 \\
|
||||
--openai.model_name YourModel \\
|
||||
--openai.server_type vllm \\
|
||||
--env.tool_call_parser hermes \\
|
||||
--env.terminal_backend modal
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
# Ensure repo root is on sys.path for imports
|
||||
_repo_root = Path(__file__).resolve().parent.parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from atroposlib.envs.base import ScoredDataGroup
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
from environments.agent_loop import AgentResult
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
from environments.tool_context import ToolContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HermesSweEnvConfig(HermesAgentEnvConfig):
|
||||
"""Config with defaults for SWE-bench style tasks."""
|
||||
|
||||
pass # Inherits all fields, overrides defaults in config_init
|
||||
|
||||
|
||||
class HermesSweEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
SWE-bench style environment using Modal terminal backend.
|
||||
|
||||
The model gets a coding task, uses terminal + file + web tools to solve it,
|
||||
and the reward function runs tests in the same Modal sandbox to verify.
|
||||
|
||||
Subclass this for specific SWE datasets (HumanEval, SWE-bench, etc.)
|
||||
and customize format_prompt() and compute_reward() as needed.
|
||||
"""
|
||||
|
||||
name = "hermes-swe"
|
||||
env_config_cls = HermesSweEnvConfig
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[HermesSweEnvConfig, List[APIServerConfig]]:
|
||||
"""
|
||||
Default configuration for the SWE environment.
|
||||
|
||||
Uses Modal terminal backend for cloud isolation and terminal + file + web toolsets.
|
||||
"""
|
||||
env_config = HermesSweEnvConfig(
|
||||
# Toolsets: terminal for running code, file for reading/writing, web for docs
|
||||
enabled_toolsets=["terminal", "file", "web"],
|
||||
disabled_toolsets=None,
|
||||
distribution=None,
|
||||
# Agent settings -- SWE tasks need more turns
|
||||
max_agent_turns=30,
|
||||
max_token_length=4096,
|
||||
agent_temperature=1.0,
|
||||
system_prompt=(
|
||||
"You are a skilled software engineer. You have access to a terminal, "
|
||||
"file tools, and web search. Use these tools to complete the coding task. "
|
||||
"Write clean, working code and verify it runs correctly before finishing."
|
||||
),
|
||||
# Modal backend for cloud-isolated sandboxes
|
||||
terminal_backend="modal",
|
||||
# Dataset -- override via CLI for your specific SWE dataset
|
||||
dataset_name="bigcode/humanevalpack",
|
||||
dataset_split="test",
|
||||
prompt_field="prompt",
|
||||
# Atropos settings
|
||||
group_size=4,
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
tool_call_parser="hermes",
|
||||
steps_per_eval=50,
|
||||
total_steps=500,
|
||||
use_wandb=True,
|
||||
wandb_name="hermes-swe",
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
base_url="http://localhost:8000/v1",
|
||||
model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
server_type="openai", # Phase 1; switch to "vllm" for Phase 2
|
||||
api_key="",
|
||||
)
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup(self):
|
||||
"""Load the SWE dataset."""
|
||||
if self.config.dataset_name:
|
||||
self.dataset = load_dataset(
|
||||
self.config.dataset_name, split=self.config.dataset_split
|
||||
)
|
||||
else:
|
||||
# Placeholder if no dataset specified
|
||||
self.dataset = []
|
||||
self.iter = 0
|
||||
self.reward_buffer: List[float] = []
|
||||
|
||||
async def get_next_item(self) -> Dict[str, Any]:
|
||||
"""Cycle through the SWE dataset."""
|
||||
if not self.dataset:
|
||||
raise ValueError("No dataset loaded. Set dataset_name in config.")
|
||||
item = self.dataset[self.iter % len(self.dataset)]
|
||||
self.iter += 1
|
||||
return item
|
||||
|
||||
def format_prompt(self, item: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Format the SWE task prompt.
|
||||
|
||||
Override this in subclasses for different dataset formats.
|
||||
Default assumes the dataset has a 'prompt' field and optionally a 'test' field.
|
||||
"""
|
||||
prompt = item.get(self.config.prompt_field, "")
|
||||
|
||||
# If the dataset has test information, include it in the prompt
|
||||
test_info = item.get("test", item.get("test_code", item.get("tests", "")))
|
||||
if test_info:
|
||||
prompt += f"\n\nTests to pass:\n{test_info}"
|
||||
|
||||
return prompt
|
||||
|
||||
async def compute_reward(
|
||||
self, item: Dict[str, Any], result: AgentResult, ctx: ToolContext
|
||||
) -> float:
|
||||
"""
|
||||
Score by running tests in the model's Modal sandbox.
|
||||
|
||||
Default implementation:
|
||||
- If the dataset item has a 'test' or 'test_code' field, run it
|
||||
- Check exit code: 0 = pass, non-zero = fail
|
||||
- Partial credit for file creation
|
||||
|
||||
Override this in subclasses for more sophisticated reward logic.
|
||||
"""
|
||||
# Find the test command from the dataset item
|
||||
test_code = item.get("test", item.get("test_code", item.get("tests", "")))
|
||||
|
||||
if test_code:
|
||||
# Run the test in the model's sandbox
|
||||
test_result = ctx.terminal(
|
||||
f'cd /workspace && python3 -c "{test_code}"', timeout=60
|
||||
)
|
||||
|
||||
if test_result["exit_code"] == 0:
|
||||
self.reward_buffer.append(1.0)
|
||||
return 1.0
|
||||
|
||||
# Partial credit: check if the model created any Python files
|
||||
file_check = ctx.terminal("find /workspace -name '*.py' -newer /tmp/.start_marker 2>/dev/null | head -5")
|
||||
if file_check["exit_code"] == 0 and file_check.get("output", "").strip():
|
||||
self.reward_buffer.append(0.1)
|
||||
return 0.1
|
||||
|
||||
self.reward_buffer.append(0.0)
|
||||
return 0.0
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""
|
||||
Run evaluation on a held-out set.
|
||||
|
||||
Override for dataset-specific evaluation logic.
|
||||
"""
|
||||
start_time = time.time()
|
||||
end_time = time.time()
|
||||
|
||||
eval_metrics = {"eval/placeholder": 0.0}
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log SWE-specific metrics."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
if self.reward_buffer:
|
||||
wandb_metrics["train/avg_reward"] = sum(self.reward_buffer) / len(
|
||||
self.reward_buffer
|
||||
)
|
||||
wandb_metrics["train/pass_rate"] = sum(
|
||||
1 for r in self.reward_buffer if r == 1.0
|
||||
) / len(self.reward_buffer)
|
||||
self.reward_buffer = []
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
HermesSweEnv.cli()
|
||||
@@ -1,35 +0,0 @@
|
||||
"""
|
||||
Monkey patches for making hermes-agent tools work inside async frameworks (Atropos).
|
||||
|
||||
Problem:
|
||||
Some tools use asyncio.run() internally (e.g., Modal backend via SWE-ReX,
|
||||
web_extract). This crashes when called from inside Atropos's event loop because
|
||||
asyncio.run() can't be nested.
|
||||
|
||||
Solution:
|
||||
The Modal environment (tools/environments/modal.py) now uses a dedicated
|
||||
_AsyncWorker thread internally, making it safe for both CLI and Atropos use.
|
||||
No monkey-patching is required.
|
||||
|
||||
This module is kept for backward compatibility. apply_patches() is a no-op.
|
||||
|
||||
Usage:
|
||||
Call apply_patches() once at import time (done automatically by hermes_base_env.py).
|
||||
This is idempotent and safe to call multiple times.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_patches_applied = False
|
||||
|
||||
|
||||
def apply_patches():
|
||||
"""Apply all monkey patches needed for Atropos compatibility."""
|
||||
global _patches_applied
|
||||
if _patches_applied:
|
||||
return
|
||||
|
||||
logger.debug("apply_patches() called; no patches needed (async safety is built-in)")
|
||||
_patches_applied = True
|
||||
@@ -1,34 +0,0 @@
|
||||
# Terminal Test Environment -- Default Configuration
|
||||
#
|
||||
# Simple file-creation tasks for validating the full Atropos + hermes-agent stack.
|
||||
# Uses Modal terminal backend and OpenRouter (Claude) for inference.
|
||||
# API keys loaded from ~/hermes-agent/.env
|
||||
#
|
||||
# Usage:
|
||||
# run-api
|
||||
# python environments/terminal_test_env/terminal_test_env.py serve \
|
||||
# --config environments/terminal_test_env/default.yaml
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal", "file"]
|
||||
max_agent_turns: 10
|
||||
max_token_length: 2048
|
||||
group_size: 3
|
||||
total_steps: 3
|
||||
steps_per_eval: 3
|
||||
terminal_backend: "modal"
|
||||
tool_call_parser: "hermes"
|
||||
tokenizer_name: "NousResearch/DeepHermes-3-Llama-3-3B-Preview"
|
||||
ensure_scores_are_not_same: false
|
||||
use_wandb: false
|
||||
system_prompt: >
|
||||
You are a helpful assistant with access to a terminal and file tools.
|
||||
Complete the user's request by using the available tools.
|
||||
Be precise and follow instructions exactly.
|
||||
|
||||
openai:
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
model_name: "anthropic/claude-opus-4.6"
|
||||
server_type: "openai"
|
||||
health_check: false
|
||||
# api_key loaded from OPENROUTER_API_KEY in .env
|
||||
@@ -1,292 +0,0 @@
|
||||
"""
|
||||
TerminalTestEnv -- Simple Test Environment for Validating the Stack
|
||||
|
||||
A self-contained environment with inline tasks (no external dataset needed).
|
||||
Each task asks the model to create a file at a known path with specific content.
|
||||
The reward verifier cats the file and checks if the content matches.
|
||||
|
||||
Enables only terminal + file toolsets. Uses Modal terminal backend with
|
||||
OpenRouter (Claude) by default.
|
||||
|
||||
Training tasks (3):
|
||||
1. Create ~/greeting.txt with "Hello from Hermes Agent"
|
||||
2. Create ~/count.txt with numbers 1-5, one per line
|
||||
3. Create ~/answer.txt with the result of 123 + 456
|
||||
|
||||
Eval task (1):
|
||||
1. Create ~/result.txt with the result of 6 * 7
|
||||
|
||||
Usage:
|
||||
# Start Atropos API server
|
||||
run-api
|
||||
|
||||
# Run environment (uses OpenRouter + Modal by default)
|
||||
python environments/terminal_test_env.py serve
|
||||
|
||||
# Process mode (no run-api needed, saves to JSONL)
|
||||
python environments/terminal_test_env.py process \\
|
||||
--env.data_path_to_save_groups terminal_test_output.jsonl
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
# Ensure repo root is on sys.path for imports
|
||||
_repo_root = Path(__file__).resolve().parent.parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from atroposlib.envs.base import ScoredDataGroup
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
from environments.agent_loop import AgentResult
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
from environments.tool_context import ToolContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Inline task definitions -- no external dataset needed
|
||||
# =============================================================================
|
||||
|
||||
TRAIN_TASKS = [
|
||||
{
|
||||
"prompt": "Create a file at ~/greeting.txt containing exactly the text: Hello from Hermes Agent",
|
||||
"verify_path": "~/greeting.txt",
|
||||
"expected_content": "Hello from Hermes Agent",
|
||||
},
|
||||
{
|
||||
"prompt": "Create a file at ~/count.txt containing the numbers 1 through 5, one per line",
|
||||
"verify_path": "~/count.txt",
|
||||
"expected_content": "1\n2\n3\n4\n5",
|
||||
},
|
||||
{
|
||||
"prompt": "Create a file at ~/answer.txt containing the result of 123 + 456",
|
||||
"verify_path": "~/answer.txt",
|
||||
"expected_content": "579",
|
||||
},
|
||||
]
|
||||
|
||||
EVAL_TASKS = [
|
||||
{
|
||||
"prompt": "Create a file at ~/result.txt containing the result of 6 * 7",
|
||||
"verify_path": "~/result.txt",
|
||||
"expected_content": "42",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class TerminalTestEnvConfig(HermesAgentEnvConfig):
|
||||
"""Config with defaults suitable for terminal testing."""
|
||||
|
||||
pass # Inherits all fields, overrides defaults in config_init
|
||||
|
||||
|
||||
class TerminalTestEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
Simple test environment with inline file-creation tasks.
|
||||
|
||||
All tasks follow the same pattern: "create a file at ~/X.txt with content Y".
|
||||
The verifier runs `cat ~/X.txt` in the rollout's terminal and checks the output
|
||||
against the expected string. Same verifier logic for all tasks.
|
||||
|
||||
This environment is designed to validate the full stack end-to-end:
|
||||
- Agent loop executes tool calls (terminal/file)
|
||||
- ToolContext provides terminal access to the reward function
|
||||
- Reward function verifies file content via cat
|
||||
- Scored data flows through the Atropos pipeline
|
||||
"""
|
||||
|
||||
name = "terminal-test"
|
||||
env_config_cls = TerminalTestEnvConfig
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[TerminalTestEnvConfig, List[APIServerConfig]]:
|
||||
"""
|
||||
Default configuration for the terminal test environment.
|
||||
|
||||
Uses Modal terminal backend for cloud isolation and OpenRouter with
|
||||
Claude for inference. API keys loaded from ~/hermes-agent/.env.
|
||||
"""
|
||||
env_config = TerminalTestEnvConfig(
|
||||
# Terminal + file tools only
|
||||
enabled_toolsets=["terminal", "file"],
|
||||
disabled_toolsets=None,
|
||||
distribution=None,
|
||||
# Agent settings
|
||||
max_agent_turns=10, # Simple tasks, don't need many turns
|
||||
max_token_length=16000,
|
||||
agent_temperature=1.0,
|
||||
system_prompt=(
|
||||
"You are a helpful assistant with access to a terminal and file tools. "
|
||||
"Complete the user's request by using the available tools. "
|
||||
"Be precise and follow instructions exactly."
|
||||
),
|
||||
# Modal terminal backend for cloud-isolated sandboxes per rollout
|
||||
terminal_backend="modal",
|
||||
# Atropos settings
|
||||
group_size=3, # 3 rollouts per group
|
||||
tokenizer_name="NousResearch/q-30b-t-h45-e1",
|
||||
tool_call_parser="hermes",
|
||||
steps_per_eval=3, # Eval after all 3 steps
|
||||
total_steps=3, # 3 groups total (1 group per step)
|
||||
use_wandb=True,
|
||||
wandb_name="terminal-test",
|
||||
ensure_scores_are_not_same=False, # Allow all-same scores for simple tasks
|
||||
# No external dataset
|
||||
dataset_name=None,
|
||||
)
|
||||
|
||||
# OpenRouter with Claude -- API key loaded from .env (OPENROUTER_API_KEY)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model_name="anthropic/claude-opus-4.6",
|
||||
server_type="openai",
|
||||
api_key=os.getenv("OPENROUTER_API_KEY", ""),
|
||||
health_check=False, # OpenRouter doesn't have a /health endpoint
|
||||
)
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup(self):
|
||||
"""Initialize inline task lists."""
|
||||
self.train_tasks = list(TRAIN_TASKS)
|
||||
self.eval_tasks = list(EVAL_TASKS)
|
||||
self.iter = 0
|
||||
# Track reward stats for wandb logging
|
||||
self.reward_buffer: List[float] = []
|
||||
|
||||
async def get_next_item(self) -> Dict[str, str]:
|
||||
"""Cycle through training tasks."""
|
||||
item = self.train_tasks[self.iter % len(self.train_tasks)]
|
||||
self.iter += 1
|
||||
return item
|
||||
|
||||
def format_prompt(self, item: Dict[str, str]) -> str:
|
||||
"""The prompt is directly in the task item."""
|
||||
return item["prompt"]
|
||||
|
||||
async def compute_reward(
|
||||
self, item: Dict[str, str], result: AgentResult, ctx: ToolContext
|
||||
) -> float:
|
||||
"""
|
||||
Verify by cat-ing the expected file path and checking content matches.
|
||||
Same verifier for all tasks -- they all write a file at a known path.
|
||||
|
||||
Scoring:
|
||||
1.0 = exact match
|
||||
0.5 = expected content is present but has extra stuff
|
||||
0.0 = file doesn't exist or content doesn't match
|
||||
"""
|
||||
verify_result = ctx.terminal(f"cat {item['verify_path']}")
|
||||
|
||||
# File doesn't exist or can't be read
|
||||
if verify_result["exit_code"] != 0:
|
||||
self.reward_buffer.append(0.0)
|
||||
return 0.0
|
||||
|
||||
actual = verify_result.get("output", "").strip()
|
||||
expected = item["expected_content"].strip()
|
||||
|
||||
# Exact match
|
||||
if actual == expected:
|
||||
self.reward_buffer.append(1.0)
|
||||
return 1.0
|
||||
|
||||
# Partial credit: expected content is present but has extra stuff
|
||||
if expected in actual:
|
||||
self.reward_buffer.append(0.5)
|
||||
return 0.5
|
||||
|
||||
self.reward_buffer.append(0.0)
|
||||
return 0.0
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""
|
||||
Run eval tasks using the agent loop and verify results.
|
||||
Logs accuracy metrics.
|
||||
"""
|
||||
start_time = time.time()
|
||||
correct = 0
|
||||
total = len(self.eval_tasks)
|
||||
samples = []
|
||||
|
||||
for eval_item in self.eval_tasks:
|
||||
try:
|
||||
# For eval, we do a simple single-turn completion (not full agent loop)
|
||||
# to keep eval fast. The agent loop is tested via training.
|
||||
completion = await self.server.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": self.config.system_prompt or ""},
|
||||
{"role": "user", "content": eval_item["prompt"]},
|
||||
],
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.0,
|
||||
split="eval",
|
||||
)
|
||||
|
||||
response_content = (
|
||||
completion.choices[0].message.content if completion.choices else ""
|
||||
)
|
||||
|
||||
samples.append(
|
||||
{
|
||||
"prompt": eval_item["prompt"],
|
||||
"response": response_content,
|
||||
"expected": eval_item["expected_content"],
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Eval failed for item: %s", e)
|
||||
samples.append(
|
||||
{
|
||||
"prompt": eval_item["prompt"],
|
||||
"response": f"ERROR: {e}",
|
||||
"expected": eval_item["expected_content"],
|
||||
}
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
eval_metrics = {
|
||||
"eval/num_samples": total,
|
||||
}
|
||||
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
samples=samples,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log training metrics including reward stats and accuracy."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
if self.reward_buffer:
|
||||
total = len(self.reward_buffer)
|
||||
correct = sum(1 for r in self.reward_buffer if r == 1.0)
|
||||
partial = sum(1 for r in self.reward_buffer if r == 0.5)
|
||||
|
||||
wandb_metrics["train/avg_reward"] = sum(self.reward_buffer) / total
|
||||
wandb_metrics["train/accuracy"] = correct / total
|
||||
wandb_metrics["train/partial_match_rate"] = partial / total
|
||||
wandb_metrics["train/total_rollouts"] = total
|
||||
self.reward_buffer = []
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
TerminalTestEnv.cli()
|
||||
@@ -1,120 +0,0 @@
|
||||
"""
|
||||
Tool Call Parser Registry
|
||||
|
||||
Client-side parsers that extract structured tool_calls from raw model output text.
|
||||
Used in Phase 2 (VLLM server type) where ManagedServer's /generate endpoint returns
|
||||
raw text without tool call parsing.
|
||||
|
||||
Each parser is a standalone reimplementation of the corresponding VLLM parser's
|
||||
non-streaming extract_tool_calls() logic. No VLLM dependency -- only standard library
|
||||
(re, json, uuid) and openai types.
|
||||
|
||||
Usage:
|
||||
from environments.tool_call_parsers import get_parser
|
||||
|
||||
parser = get_parser("hermes")
|
||||
content, tool_calls = parser.parse(raw_model_output)
|
||||
# content = text with tool call markup stripped
|
||||
# tool_calls = list of ChatCompletionMessageToolCall objects, or None
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Type alias for parser return value
|
||||
ParseResult = Tuple[Optional[str], Optional[List[ChatCompletionMessageToolCall]]]
|
||||
|
||||
|
||||
class ToolCallParser(ABC):
|
||||
"""
|
||||
Base class for tool call parsers.
|
||||
|
||||
Each parser knows how to extract structured tool_calls from a specific
|
||||
model family's raw output text format.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
"""
|
||||
Parse raw model output text for tool calls.
|
||||
|
||||
Args:
|
||||
text: Raw decoded text from the model's completion
|
||||
|
||||
Returns:
|
||||
Tuple of (content, tool_calls) where:
|
||||
- content: text with tool call markup stripped (the message 'content' field),
|
||||
or None if the entire output was tool calls
|
||||
- tool_calls: list of ChatCompletionMessageToolCall objects,
|
||||
or None if no tool calls were found
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Global parser registry: name -> parser class
|
||||
PARSER_REGISTRY: Dict[str, Type[ToolCallParser]] = {}
|
||||
|
||||
|
||||
def register_parser(name: str):
|
||||
"""
|
||||
Decorator to register a parser class under a given name.
|
||||
|
||||
Usage:
|
||||
@register_parser("hermes")
|
||||
class HermesToolCallParser(ToolCallParser):
|
||||
...
|
||||
"""
|
||||
|
||||
def decorator(cls: Type[ToolCallParser]) -> Type[ToolCallParser]:
|
||||
PARSER_REGISTRY[name] = cls
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def get_parser(name: str) -> ToolCallParser:
|
||||
"""
|
||||
Get a parser instance by name.
|
||||
|
||||
Args:
|
||||
name: Parser name (e.g., "hermes", "mistral", "llama3_json")
|
||||
|
||||
Returns:
|
||||
Instantiated parser
|
||||
|
||||
Raises:
|
||||
KeyError: If parser name is not found in registry
|
||||
"""
|
||||
if name not in PARSER_REGISTRY:
|
||||
available = sorted(PARSER_REGISTRY.keys())
|
||||
raise KeyError(
|
||||
f"Tool call parser '{name}' not found. Available parsers: {available}"
|
||||
)
|
||||
return PARSER_REGISTRY[name]()
|
||||
|
||||
|
||||
def list_parsers() -> List[str]:
|
||||
"""Return sorted list of registered parser names."""
|
||||
return sorted(PARSER_REGISTRY.keys())
|
||||
|
||||
|
||||
# Import all parser modules to trigger registration via @register_parser decorators
|
||||
# Each module registers itself when imported
|
||||
from environments.tool_call_parsers.hermes_parser import HermesToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.longcat_parser import LongcatToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.mistral_parser import MistralToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.llama_parser import LlamaToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.qwen_parser import QwenToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.deepseek_v3_parser import DeepSeekV3ToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.deepseek_v3_1_parser import DeepSeekV31ToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.kimi_k2_parser import KimiK2ToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.glm45_parser import Glm45ToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.glm47_parser import Glm47ToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.qwen3_coder_parser import Qwen3CoderToolCallParser # noqa: E402, F401
|
||||
@@ -1,72 +0,0 @@
|
||||
"""
|
||||
DeepSeek V3.1 tool call parser.
|
||||
|
||||
Similar to V3 but with a slightly different format:
|
||||
<|tool▁call▁begin|>function_name<|tool▁sep|>arguments<|tool▁call▁end|>
|
||||
|
||||
Note: V3 has type+name before the separator, V3.1 has name before and args after.
|
||||
|
||||
Based on VLLM's DeepSeekV31ToolParser.extract_tool_calls()
|
||||
"""
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
|
||||
@register_parser("deepseek_v3_1")
|
||||
@register_parser("deepseek_v31")
|
||||
class DeepSeekV31ToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for DeepSeek V3.1 tool calls.
|
||||
|
||||
Slightly different regex than V3: function_name comes before the separator,
|
||||
arguments come after (no type field, no json code block wrapper).
|
||||
"""
|
||||
|
||||
START_TOKEN = "<|tool▁calls▁begin|>"
|
||||
|
||||
# Regex captures: function_name, function_arguments
|
||||
PATTERN = re.compile(
|
||||
r"<|tool▁call▁begin|>(?P<function_name>.*?)<|tool▁sep|>(?P<function_arguments>.*?)<|tool▁call▁end|>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
if self.START_TOKEN not in text:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
matches = self.PATTERN.findall(text)
|
||||
if not matches:
|
||||
return text, None
|
||||
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
for match in matches:
|
||||
func_name, func_args = match
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
||||
type="function",
|
||||
function=Function(
|
||||
name=func_name.strip(),
|
||||
arguments=func_args.strip(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
content = text[: text.find(self.START_TOKEN)].strip()
|
||||
return content if content else None, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
@@ -1,89 +0,0 @@
|
||||
"""
|
||||
DeepSeek V3 tool call parser.
|
||||
|
||||
Format uses special unicode tokens:
|
||||
<|tool▁calls▁begin|>
|
||||
<|tool▁call▁begin|>type<|tool▁sep|>function_name
|
||||
```json
|
||||
{"arg": "value"}
|
||||
```
|
||||
<|tool▁call▁end|>
|
||||
<|tool▁calls▁end|>
|
||||
|
||||
Fixes Issue #989: Support for multiple simultaneous tool calls.
|
||||
"""
|
||||
|
||||
import re
|
||||
import uuid
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@register_parser("deepseek_v3")
|
||||
class DeepSeekV3ToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for DeepSeek V3 tool calls.
|
||||
|
||||
Uses special unicode tokens with fullwidth angle brackets and block elements.
|
||||
Extracts type, function name, and JSON arguments from the structured format.
|
||||
Ensures all tool calls are captured when the model executes multiple actions.
|
||||
"""
|
||||
|
||||
START_TOKEN = "<|tool▁calls▁begin|>"
|
||||
|
||||
# Updated PATTERN: Using \s* instead of literal \n for increased robustness
|
||||
# against variations in model formatting (Issue #989).
|
||||
PATTERN = re.compile(
|
||||
r"<|tool▁call▁begin|>(?P<type>.*?)<|tool▁sep|>(?P<function_name>.*?)\s*```json\s*(?P<function_arguments>.*?)\s*```\s*<|tool▁call▁end|>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
"""
|
||||
Parses the input text and extracts all available tool calls.
|
||||
"""
|
||||
if self.START_TOKEN not in text:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
# Using finditer to capture ALL tool calls in the sequence
|
||||
matches = list(self.PATTERN.finditer(text))
|
||||
if not matches:
|
||||
return text, None
|
||||
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
|
||||
for match in matches:
|
||||
func_name = match.group("function_name").strip()
|
||||
func_args = match.group("function_arguments").strip()
|
||||
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
||||
type="function",
|
||||
function=Function(
|
||||
name=func_name,
|
||||
arguments=func_args,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if tool_calls:
|
||||
# Content is text before the first tool call block
|
||||
content_index = text.find(self.START_TOKEN)
|
||||
content = text[:content_index].strip()
|
||||
return content if content else None, tool_calls
|
||||
|
||||
return text, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing DeepSeek V3 tool calls: {e}")
|
||||
return text, None
|
||||
@@ -1,109 +0,0 @@
|
||||
"""
|
||||
GLM 4.5 (GLM-4-MoE) tool call parser.
|
||||
|
||||
Format uses custom arg_key/arg_value tags rather than standard JSON:
|
||||
<tool_call>function_name
|
||||
<arg_key>param1</arg_key><arg_value>value1</arg_value>
|
||||
<arg_key>param2</arg_key><arg_value>value2</arg_value>
|
||||
</tool_call>
|
||||
|
||||
Values are deserialized using json.loads -> ast.literal_eval -> raw string fallback.
|
||||
|
||||
Based on VLLM's Glm4MoeModelToolParser.extract_tool_calls()
|
||||
"""
|
||||
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
|
||||
def _deserialize_value(value: str) -> Any:
|
||||
"""
|
||||
Try to deserialize a string value to its native Python type.
|
||||
Attempts json.loads, then ast.literal_eval, then returns raw string.
|
||||
"""
|
||||
try:
|
||||
return json.loads(value)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
try:
|
||||
return ast.literal_eval(value)
|
||||
except (ValueError, SyntaxError, TypeError):
|
||||
pass
|
||||
|
||||
return value
|
||||
|
||||
|
||||
@register_parser("glm45")
|
||||
class Glm45ToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for GLM 4.5 (GLM-4-MoE) tool calls.
|
||||
|
||||
Uses <tool_call>...</tool_call> tags with <arg_key>/<arg_value> pairs
|
||||
instead of standard JSON arguments.
|
||||
"""
|
||||
|
||||
FUNC_CALL_REGEX = re.compile(r"<tool_call>.*?</tool_call>", re.DOTALL)
|
||||
FUNC_DETAIL_REGEX = re.compile(r"<tool_call>([^\n]*)\n(.*)</tool_call>", re.DOTALL)
|
||||
FUNC_ARG_REGEX = re.compile(
|
||||
r"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>", re.DOTALL
|
||||
)
|
||||
|
||||
START_TOKEN = "<tool_call>"
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
if self.START_TOKEN not in text:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
matched_calls = self.FUNC_CALL_REGEX.findall(text)
|
||||
if not matched_calls:
|
||||
return text, None
|
||||
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
|
||||
for match in matched_calls:
|
||||
detail = self.FUNC_DETAIL_REGEX.search(match)
|
||||
if not detail:
|
||||
continue
|
||||
|
||||
func_name = detail.group(1).strip()
|
||||
func_args_raw = detail.group(2)
|
||||
|
||||
# Parse arg_key/arg_value pairs
|
||||
pairs = self.FUNC_ARG_REGEX.findall(func_args_raw) if func_args_raw else []
|
||||
arg_dict: Dict[str, Any] = {}
|
||||
for key, value in pairs:
|
||||
arg_key = key.strip()
|
||||
arg_val = _deserialize_value(value.strip())
|
||||
arg_dict[arg_key] = arg_val
|
||||
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
||||
type="function",
|
||||
function=Function(
|
||||
name=func_name,
|
||||
arguments=json.dumps(arg_dict, ensure_ascii=False),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
content = text[: text.find(self.START_TOKEN)].strip()
|
||||
return content if content else None, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
@@ -1,35 +0,0 @@
|
||||
"""
|
||||
GLM 4.7 tool call parser.
|
||||
|
||||
Same as GLM 4.5 but with slightly different regex patterns.
|
||||
The tool_call tags may wrap differently and arg parsing handles
|
||||
newlines between key/value pairs.
|
||||
|
||||
Based on VLLM's Glm47MoeModelToolParser (extends Glm4MoeModelToolParser).
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, register_parser
|
||||
from environments.tool_call_parsers.glm45_parser import Glm45ToolCallParser
|
||||
|
||||
|
||||
@register_parser("glm47")
|
||||
class Glm47ToolCallParser(Glm45ToolCallParser):
|
||||
"""
|
||||
Parser for GLM 4.7 tool calls.
|
||||
Extends GLM 4.5 with updated regex patterns.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# GLM 4.7 uses a slightly different detail regex that includes
|
||||
# the <tool_call> wrapper and optional arg_key content
|
||||
self.FUNC_DETAIL_REGEX = re.compile(
|
||||
r"<tool_call>(.*?)(<arg_key>.*?)?</tool_call>", re.DOTALL
|
||||
)
|
||||
# GLM 4.7 handles newlines between arg_key and arg_value tags
|
||||
self.FUNC_ARG_REGEX = re.compile(
|
||||
r"<arg_key>(.*?)</arg_key>(?:\\n|\s)*<arg_value>(.*?)</arg_value>",
|
||||
re.DOTALL,
|
||||
)
|
||||
@@ -1,75 +0,0 @@
|
||||
"""
|
||||
Hermes tool call parser.
|
||||
|
||||
Format: <tool_call>{"name": "func", "arguments": {...}}</tool_call>
|
||||
Based on VLLM's Hermes2ProToolParser.extract_tool_calls()
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
|
||||
@register_parser("hermes")
|
||||
class HermesToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for Hermes-format tool calls.
|
||||
|
||||
Matches <tool_call>...</tool_call> tags containing JSON with "name" and "arguments".
|
||||
Also handles unclosed <tool_call> at end-of-string (truncated generation).
|
||||
"""
|
||||
|
||||
# Matches both closed and unclosed tool_call tags
|
||||
PATTERN = re.compile(
|
||||
r"<tool_call>\s*(.*?)\s*</tool_call>|<tool_call>\s*(.*)", re.DOTALL
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
if "<tool_call>" not in text:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
matches = self.PATTERN.findall(text)
|
||||
if not matches:
|
||||
return text, None
|
||||
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
for match in matches:
|
||||
# match is a tuple: (closed_content, unclosed_content)
|
||||
raw_json = match[0] if match[0] else match[1]
|
||||
if not raw_json.strip():
|
||||
continue
|
||||
|
||||
tc_data = json.loads(raw_json)
|
||||
if "name" not in tc_data:
|
||||
continue
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
||||
type="function",
|
||||
function=Function(
|
||||
name=tc_data["name"],
|
||||
arguments=json.dumps(
|
||||
tc_data.get("arguments", {}), ensure_ascii=False
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
# Content is everything before the first <tool_call> tag
|
||||
content = text[: text.find("<tool_call>")].strip()
|
||||
return content if content else None, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
@@ -1,93 +0,0 @@
|
||||
"""
|
||||
Kimi K2 tool call parser.
|
||||
|
||||
Format:
|
||||
<|tool_calls_section_begin|>
|
||||
<|tool_call_begin|>function_id:0<|tool_call_argument_begin|>{"arg": "val"}<|tool_call_end|>
|
||||
<|tool_calls_section_end|>
|
||||
|
||||
The function_id format is typically "functions.func_name:index" or "func_name:index".
|
||||
|
||||
Based on VLLM's KimiK2ToolParser.extract_tool_calls()
|
||||
"""
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
|
||||
@register_parser("kimi_k2")
|
||||
class KimiK2ToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for Kimi K2 tool calls.
|
||||
|
||||
Uses section begin/end tokens wrapping individual tool call begin/end tokens.
|
||||
The tool_call_id contains the function name (after last dot, before colon).
|
||||
"""
|
||||
|
||||
# Support both singular and plural variants
|
||||
START_TOKENS = [
|
||||
"<|tool_calls_section_begin|>",
|
||||
"<|tool_call_section_begin|>",
|
||||
]
|
||||
|
||||
# Regex captures: tool_call_id (e.g., "functions.get_weather:0"), function_arguments
|
||||
PATTERN = re.compile(
|
||||
r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[^<]+:\d+)\s*"
|
||||
r"<\|tool_call_argument_begin\|>\s*"
|
||||
r"(?P<function_arguments>(?:(?!<\|tool_call_begin\|>).)*?)\s*"
|
||||
r"<\|tool_call_end\|>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
# Check for any variant of the start token
|
||||
has_start = any(token in text for token in self.START_TOKENS)
|
||||
if not has_start:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
matches = self.PATTERN.findall(text)
|
||||
if not matches:
|
||||
return text, None
|
||||
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
for match in matches:
|
||||
function_id, function_args = match
|
||||
|
||||
# Extract function name from ID format: "functions.get_weather:0" -> "get_weather"
|
||||
function_name = function_id.split(":")[0].split(".")[-1]
|
||||
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=function_id, # Preserve the original ID format
|
||||
type="function",
|
||||
function=Function(
|
||||
name=function_name,
|
||||
arguments=function_args.strip(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
# Content is everything before the tool calls section
|
||||
earliest_start = len(text)
|
||||
for token in self.START_TOKENS:
|
||||
idx = text.find(token)
|
||||
if idx >= 0 and idx < earliest_start:
|
||||
earliest_start = idx
|
||||
|
||||
content = text[:earliest_start].strip()
|
||||
return content if content else None, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
@@ -1,96 +0,0 @@
|
||||
"""
|
||||
Llama 3.x / 4 tool call parser.
|
||||
|
||||
Format: The model outputs JSON objects with "name" and "arguments" (or "parameters") keys.
|
||||
May be preceded by <|python_tag|> token. Supports multiple JSON objects separated
|
||||
by content or semicolons.
|
||||
|
||||
Based on VLLM's Llama3JsonToolParser.extract_tool_calls()
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
|
||||
@register_parser("llama3_json")
|
||||
@register_parser("llama4_json")
|
||||
class LlamaToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for Llama 3.x and 4 JSON-format tool calls.
|
||||
|
||||
Finds JSON objects containing "name" + ("arguments" or "parameters") keys.
|
||||
Uses Python's json.JSONDecoder.raw_decode for robust extraction of
|
||||
JSON objects from mixed text.
|
||||
"""
|
||||
|
||||
BOT_TOKEN = "<|python_tag|>"
|
||||
|
||||
# Regex to find the start of potential JSON objects
|
||||
JSON_START = re.compile(r"\{")
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
# Quick check: need either the bot token or a JSON brace
|
||||
if self.BOT_TOKEN not in text and "{" not in text:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
decoder = json.JSONDecoder()
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
end_index = -1 # Track where the last parsed JSON ended
|
||||
|
||||
for match in self.JSON_START.finditer(text):
|
||||
start = match.start()
|
||||
# Skip if this brace is inside a previously parsed JSON object
|
||||
if start <= end_index:
|
||||
continue
|
||||
|
||||
try:
|
||||
obj, json_end = decoder.raw_decode(text[start:])
|
||||
end_index = start + json_end
|
||||
|
||||
# Must have "name" and either "arguments" or "parameters"
|
||||
name = obj.get("name")
|
||||
args = obj.get("arguments", obj.get("parameters"))
|
||||
|
||||
if not name or args is None:
|
||||
continue
|
||||
|
||||
# Normalize arguments to JSON string
|
||||
if isinstance(args, dict):
|
||||
args = json.dumps(args, ensure_ascii=False)
|
||||
elif not isinstance(args, str):
|
||||
args = json.dumps(args, ensure_ascii=False)
|
||||
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
||||
type="function",
|
||||
function=Function(name=name, arguments=args),
|
||||
)
|
||||
)
|
||||
except (json.JSONDecodeError, KeyError, ValueError):
|
||||
continue
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
# Content is everything before the first tool call JSON
|
||||
# Find where the first tool call starts in the text
|
||||
first_tc_start = text.find("{")
|
||||
if self.BOT_TOKEN in text:
|
||||
first_tc_start = text.find(self.BOT_TOKEN)
|
||||
content = text[:first_tc_start].strip() if first_tc_start > 0 else None
|
||||
|
||||
return content, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
@@ -1,69 +0,0 @@
|
||||
"""
|
||||
Longcat Flash Chat tool call parser.
|
||||
|
||||
Same as Hermes but uses <longcat_tool_call> tags instead of <tool_call>.
|
||||
Based on VLLM's LongcatFlashToolParser (extends Hermes2ProToolParser).
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
|
||||
@register_parser("longcat")
|
||||
class LongcatToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for Longcat Flash Chat tool calls.
|
||||
Identical logic to Hermes, just different tag names.
|
||||
"""
|
||||
|
||||
PATTERN = re.compile(
|
||||
r"<longcat_tool_call>\s*(.*?)\s*</longcat_tool_call>|<longcat_tool_call>\s*(.*)",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
if "<longcat_tool_call>" not in text:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
matches = self.PATTERN.findall(text)
|
||||
if not matches:
|
||||
return text, None
|
||||
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
for match in matches:
|
||||
raw_json = match[0] if match[0] else match[1]
|
||||
if not raw_json.strip():
|
||||
continue
|
||||
|
||||
tc_data = json.loads(raw_json)
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
||||
type="function",
|
||||
function=Function(
|
||||
name=tc_data["name"],
|
||||
arguments=json.dumps(
|
||||
tc_data.get("arguments", {}), ensure_ascii=False
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
content = text[: text.find("<longcat_tool_call>")].strip()
|
||||
return content if content else None, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
@@ -1,137 +0,0 @@
|
||||
"""
|
||||
Mistral tool call parser.
|
||||
|
||||
Supports two formats depending on tokenizer version:
|
||||
- Pre-v11: content[TOOL_CALLS] [{"name": ..., "arguments": {...}}, ...]
|
||||
- v11+: content[TOOL_CALLS]tool_name1{"arg": "val"}[TOOL_CALLS]tool_name2{"arg": "val"}
|
||||
|
||||
Based on VLLM's MistralToolParser.extract_tool_calls()
|
||||
The [TOOL_CALLS] token is the bot_token used by Mistral models.
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
|
||||
def _generate_mistral_id() -> str:
|
||||
"""Mistral tool call IDs are 9-char alphanumeric strings."""
|
||||
import random
|
||||
import string
|
||||
|
||||
return "".join(random.choices(string.ascii_letters + string.digits, k=9))
|
||||
|
||||
|
||||
@register_parser("mistral")
|
||||
class MistralToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for Mistral-format tool calls.
|
||||
|
||||
Detects format by checking if the content after [TOOL_CALLS] starts with '['
|
||||
(pre-v11 JSON array) or with a tool name (v11+ format).
|
||||
"""
|
||||
|
||||
# The [TOOL_CALLS] token -- may appear as different strings depending on tokenizer
|
||||
BOT_TOKEN = "[TOOL_CALLS]"
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
if self.BOT_TOKEN not in text:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
parts = text.split(self.BOT_TOKEN)
|
||||
content = parts[0].strip()
|
||||
raw_tool_calls = parts[1:]
|
||||
|
||||
# Detect format: if the first raw part starts with '[', it's pre-v11
|
||||
first_raw = raw_tool_calls[0].strip() if raw_tool_calls else ""
|
||||
is_pre_v11 = first_raw.startswith("[") or first_raw.startswith("{")
|
||||
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
|
||||
if not is_pre_v11:
|
||||
# v11+ format: [TOOL_CALLS]tool_name{args}[TOOL_CALLS]tool_name2{args2}
|
||||
for raw in raw_tool_calls:
|
||||
raw = raw.strip()
|
||||
if not raw or "{" not in raw:
|
||||
continue
|
||||
|
||||
brace_idx = raw.find("{")
|
||||
tool_name = raw[:brace_idx].strip()
|
||||
args_str = raw[brace_idx:]
|
||||
|
||||
# Validate and clean the JSON arguments
|
||||
try:
|
||||
parsed_args = json.loads(args_str)
|
||||
args_str = json.dumps(parsed_args, ensure_ascii=False)
|
||||
except json.JSONDecodeError:
|
||||
pass # Keep raw if parsing fails
|
||||
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=_generate_mistral_id(),
|
||||
type="function",
|
||||
function=Function(name=tool_name, arguments=args_str),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Pre-v11 format: [TOOL_CALLS] [{"name": ..., "arguments": {...}}]
|
||||
try:
|
||||
parsed = json.loads(first_raw)
|
||||
if isinstance(parsed, dict):
|
||||
parsed = [parsed]
|
||||
|
||||
for tc in parsed:
|
||||
if "name" not in tc:
|
||||
continue
|
||||
args = tc.get("arguments", {})
|
||||
if isinstance(args, dict):
|
||||
args = json.dumps(args, ensure_ascii=False)
|
||||
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=_generate_mistral_id(),
|
||||
type="function",
|
||||
function=Function(
|
||||
name=tc["name"], arguments=args
|
||||
),
|
||||
)
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# Fallback: extract JSON objects using raw_decode
|
||||
decoder = json.JSONDecoder()
|
||||
idx = 0
|
||||
while idx < len(first_raw):
|
||||
try:
|
||||
obj, end_idx = decoder.raw_decode(first_raw, idx)
|
||||
if isinstance(obj, dict) and "name" in obj:
|
||||
args = obj.get("arguments", {})
|
||||
if isinstance(args, dict):
|
||||
args = json.dumps(args, ensure_ascii=False)
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=_generate_mistral_id(),
|
||||
type="function",
|
||||
function=Function(
|
||||
name=obj["name"], arguments=args
|
||||
),
|
||||
)
|
||||
)
|
||||
idx = end_idx
|
||||
except json.JSONDecodeError:
|
||||
idx += 1
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
return content if content else None, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
@@ -1,163 +0,0 @@
|
||||
"""
|
||||
Qwen3-Coder tool call parser.
|
||||
|
||||
Format uses XML-style nested tags:
|
||||
<tool_call>
|
||||
<function=function_name>
|
||||
<parameter=param_name>value</parameter>
|
||||
<parameter=param_name2>value2</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
|
||||
Parameters are extracted from <parameter=name>value</parameter> tags and
|
||||
type-converted using the schema if available, otherwise treated as strings.
|
||||
|
||||
Based on VLLM's Qwen3CoderToolParser.extract_tool_calls()
|
||||
"""
|
||||
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
|
||||
def _try_convert_value(value: str) -> Any:
|
||||
"""
|
||||
Try to convert a parameter value string to a native Python type.
|
||||
Handles null, numbers, booleans, JSON objects/arrays, and falls back to string.
|
||||
"""
|
||||
stripped = value.strip()
|
||||
|
||||
# Handle null
|
||||
if stripped.lower() == "null":
|
||||
return None
|
||||
|
||||
# Try JSON first (handles objects, arrays, strings, numbers, booleans)
|
||||
try:
|
||||
return json.loads(stripped)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# Try Python literal eval (handles tuples, etc.)
|
||||
try:
|
||||
return ast.literal_eval(stripped)
|
||||
except (ValueError, SyntaxError, TypeError):
|
||||
pass
|
||||
|
||||
# Return as string
|
||||
return stripped
|
||||
|
||||
|
||||
@register_parser("qwen3_coder")
|
||||
class Qwen3CoderToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for Qwen3-Coder XML-format tool calls.
|
||||
|
||||
Uses nested XML tags: <tool_call><function=name><parameter=key>val</parameter></function></tool_call>
|
||||
"""
|
||||
|
||||
START_TOKEN = "<tool_call>"
|
||||
FUNCTION_PREFIX = "<function="
|
||||
|
||||
# Find complete tool_call blocks (or unclosed at end)
|
||||
TOOL_CALL_REGEX = re.compile(
|
||||
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*?)$", re.DOTALL
|
||||
)
|
||||
|
||||
# Find function blocks within a tool_call
|
||||
FUNCTION_REGEX = re.compile(
|
||||
r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL
|
||||
)
|
||||
|
||||
# Find parameter blocks within a function
|
||||
PARAMETER_REGEX = re.compile(
|
||||
r"<parameter=(.*?)(?:</parameter>|(?=<parameter=)|(?=</function>)|$)",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
def _parse_function_call(self, function_str: str) -> Optional[ChatCompletionMessageToolCall]:
|
||||
"""Parse a single <function=name>...</function> block into a ToolCall."""
|
||||
try:
|
||||
# Extract function name: everything before the first '>'
|
||||
gt_idx = function_str.index(">")
|
||||
func_name = function_str[:gt_idx].strip()
|
||||
params_str = function_str[gt_idx + 1:]
|
||||
|
||||
# Extract parameters
|
||||
param_dict: Dict[str, Any] = {}
|
||||
for match_text in self.PARAMETER_REGEX.findall(params_str):
|
||||
if ">" not in match_text:
|
||||
continue
|
||||
eq_idx = match_text.index(">")
|
||||
param_name = match_text[:eq_idx].strip()
|
||||
param_value = match_text[eq_idx + 1:]
|
||||
|
||||
# Clean up whitespace
|
||||
if param_value.startswith("\n"):
|
||||
param_value = param_value[1:]
|
||||
if param_value.endswith("\n"):
|
||||
param_value = param_value[:-1]
|
||||
|
||||
param_dict[param_name] = _try_convert_value(param_value)
|
||||
|
||||
return ChatCompletionMessageToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:24]}",
|
||||
type="function",
|
||||
function=Function(
|
||||
name=func_name,
|
||||
arguments=json.dumps(param_dict, ensure_ascii=False),
|
||||
),
|
||||
)
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
if self.FUNCTION_PREFIX not in text:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
# Find all tool_call blocks
|
||||
tc_matches = self.TOOL_CALL_REGEX.findall(text)
|
||||
raw_blocks = [m[0] if m[0] else m[1] for m in tc_matches]
|
||||
|
||||
# Fallback: if no tool_call tags, try the whole text
|
||||
if not raw_blocks:
|
||||
raw_blocks = [text]
|
||||
|
||||
# Find function blocks within each tool_call
|
||||
function_strs: List[str] = []
|
||||
for block in raw_blocks:
|
||||
func_matches = self.FUNCTION_REGEX.findall(block)
|
||||
function_strs.extend(m[0] if m[0] else m[1] for m in func_matches)
|
||||
|
||||
if not function_strs:
|
||||
return text, None
|
||||
|
||||
# Parse each function call
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
for func_str in function_strs:
|
||||
tc = self._parse_function_call(func_str)
|
||||
if tc is not None:
|
||||
tool_calls.append(tc)
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
# Content before tool calls
|
||||
first_tc = text.find(self.START_TOKEN)
|
||||
if first_tc < 0:
|
||||
first_tc = text.find(self.FUNCTION_PREFIX)
|
||||
content = text[:first_tc].strip() if first_tc > 0 else None
|
||||
|
||||
return content, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
@@ -1,19 +0,0 @@
|
||||
"""
|
||||
Qwen 2.5 tool call parser.
|
||||
|
||||
Uses the same <tool_call> format as Hermes.
|
||||
Registered as a separate parser name for clarity when using --tool-parser=qwen.
|
||||
"""
|
||||
|
||||
from environments.tool_call_parsers import register_parser
|
||||
from environments.tool_call_parsers.hermes_parser import HermesToolCallParser
|
||||
|
||||
|
||||
@register_parser("qwen")
|
||||
class QwenToolCallParser(HermesToolCallParser):
|
||||
"""
|
||||
Parser for Qwen 2.5 tool calls.
|
||||
Same <tool_call>{"name": ..., "arguments": ...}</tool_call> format as Hermes.
|
||||
"""
|
||||
|
||||
pass # Identical format -- inherits everything from Hermes
|
||||
@@ -1,473 +0,0 @@
|
||||
"""
|
||||
ToolContext -- Unrestricted Tool Access for Reward Functions
|
||||
|
||||
A per-rollout handle that gives reward/verification functions direct access to
|
||||
ALL hermes-agent tools, scoped to the rollout's task_id. The same task_id means
|
||||
the terminal/browser session is the SAME one the model used during its rollout --
|
||||
all state (files, processes, browser tabs) is preserved.
|
||||
|
||||
The verifier author decides which tools to use. Nothing is hardcoded or gated.
|
||||
|
||||
Example usage in a compute_reward():
|
||||
async def compute_reward(self, item, result, ctx):
|
||||
# Run tests in the model's terminal sandbox
|
||||
test = ctx.terminal("pytest -v")
|
||||
if test["exit_code"] == 0:
|
||||
return 1.0
|
||||
|
||||
# Check if a file was created
|
||||
content = ctx.read_file("/workspace/solution.py")
|
||||
if content.get("content"):
|
||||
return 0.5
|
||||
|
||||
return 0.0
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
|
||||
from model_tools import handle_function_call
|
||||
from tools.terminal_tool import cleanup_vm
|
||||
from tools.browser_tool import cleanup_browser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Thread pool for running sync tool calls that internally use asyncio.run()
|
||||
_tool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
def _run_tool_in_thread(tool_name: str, arguments: Dict[str, Any], task_id: str) -> str:
|
||||
"""
|
||||
Run a tool call in a thread pool executor so backends that use asyncio.run()
|
||||
internally (modal, docker, daytona) get a clean event loop.
|
||||
|
||||
If we're already in an async context, executes handle_function_call() in a
|
||||
disposable worker thread and blocks for the result.
|
||||
If not (e.g., called from sync code), runs directly.
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
# We're in an async context -- need to run in thread
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
future = pool.submit(
|
||||
handle_function_call, tool_name, arguments, task_id
|
||||
)
|
||||
return future.result(timeout=300)
|
||||
except RuntimeError:
|
||||
# No running event loop -- safe to call directly
|
||||
return handle_function_call(tool_name, arguments, task_id)
|
||||
|
||||
|
||||
class ToolContext:
|
||||
"""
|
||||
Open-ended access to all hermes-agent tools for a specific rollout.
|
||||
|
||||
Passed to compute_reward() so verifiers can use any tool they need:
|
||||
terminal commands, file reads/writes, web searches, browser automation, etc.
|
||||
All calls share the rollout's task_id for session isolation.
|
||||
"""
|
||||
|
||||
def __init__(self, task_id: str):
|
||||
self.task_id = task_id
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Terminal tools
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def terminal(self, command: str, timeout: int = 180) -> Dict[str, Any]:
|
||||
"""
|
||||
Run a command in the rollout's terminal session.
|
||||
|
||||
Args:
|
||||
command: Shell command to execute
|
||||
timeout: Command timeout in seconds
|
||||
|
||||
Returns:
|
||||
Dict with 'exit_code' (int) and 'output' (str)
|
||||
"""
|
||||
import os
|
||||
backend = os.getenv("TERMINAL_ENV", "local")
|
||||
logger.debug("ToolContext.terminal [%s backend] task=%s: %s", backend, self.task_id[:8], command[:100])
|
||||
|
||||
# Run via thread helper so modal/docker/daytona backends' asyncio.run() doesn't deadlock
|
||||
result = _run_tool_in_thread(
|
||||
"terminal",
|
||||
{"command": command, "timeout": timeout},
|
||||
self.task_id,
|
||||
)
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
return {"exit_code": -1, "output": result}
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# File tools
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def read_file(self, path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Read a file from the rollout's filesystem.
|
||||
|
||||
Args:
|
||||
path: File path to read
|
||||
|
||||
Returns:
|
||||
Dict with file content or error
|
||||
"""
|
||||
result = handle_function_call(
|
||||
"read_file", {"path": path}, task_id=self.task_id
|
||||
)
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
return {"error": result}
|
||||
|
||||
def write_file(self, path: str, content: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Write a TEXT file in the rollout's filesystem.
|
||||
|
||||
Uses a shell heredoc under the hood, so this is only safe for text content.
|
||||
For binary files (images, compiled artifacts, etc.), use upload_file() instead.
|
||||
|
||||
Args:
|
||||
path: File path to write
|
||||
content: Text content to write
|
||||
|
||||
Returns:
|
||||
Dict with success status or error
|
||||
"""
|
||||
result = handle_function_call(
|
||||
"write_file", {"path": path, "content": content}, task_id=self.task_id
|
||||
)
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
return {"error": result}
|
||||
|
||||
def upload_file(self, local_path: str, remote_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Upload a local file to the rollout's sandbox (binary-safe).
|
||||
|
||||
Unlike write_file() which passes content through a shell heredoc (text-only),
|
||||
this method base64-encodes the file and decodes it inside the sandbox.
|
||||
Safe for any file type: binaries, images, archives, etc.
|
||||
|
||||
For large files (>1MB), the content is split into chunks to avoid
|
||||
hitting shell command-length limits.
|
||||
|
||||
Args:
|
||||
local_path: Path to a local file on the host
|
||||
remote_path: Destination path inside the sandbox
|
||||
|
||||
Returns:
|
||||
Dict with 'exit_code' and 'output'
|
||||
"""
|
||||
import base64
|
||||
from pathlib import Path as _Path
|
||||
|
||||
local = _Path(local_path)
|
||||
if not local.exists():
|
||||
return {"exit_code": -1, "output": f"Local file not found: {local_path}"}
|
||||
|
||||
raw = local.read_bytes()
|
||||
b64 = base64.b64encode(raw).decode("ascii")
|
||||
|
||||
# Ensure parent directory exists in the sandbox
|
||||
parent = str(_Path(remote_path).parent)
|
||||
if parent not in {".", "/"}:
|
||||
self.terminal(f"mkdir -p {parent}", timeout=10)
|
||||
|
||||
# For small files, single command is fine
|
||||
chunk_size = 60_000 # ~60KB per chunk (well within shell limits)
|
||||
if len(b64) <= chunk_size:
|
||||
result = self.terminal(
|
||||
f"printf '%s' '{b64}' | base64 -d > {remote_path}",
|
||||
timeout=30,
|
||||
)
|
||||
else:
|
||||
# For larger files, write base64 in chunks then decode
|
||||
tmp_b64 = "/tmp/_hermes_upload.b64"
|
||||
self.terminal(f": > {tmp_b64}", timeout=5) # truncate
|
||||
for i in range(0, len(b64), chunk_size):
|
||||
chunk = b64[i : i + chunk_size]
|
||||
self.terminal(f"printf '%s' '{chunk}' >> {tmp_b64}", timeout=15)
|
||||
result = self.terminal(
|
||||
f"base64 -d {tmp_b64} > {remote_path} && rm -f {tmp_b64}",
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def upload_dir(self, local_dir: str, remote_dir: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Upload an entire local directory to the rollout's sandbox (binary-safe).
|
||||
|
||||
Recursively uploads all files, preserving directory structure.
|
||||
|
||||
Args:
|
||||
local_dir: Path to a local directory on the host
|
||||
remote_dir: Destination directory inside the sandbox
|
||||
|
||||
Returns:
|
||||
List of results, one per file uploaded
|
||||
"""
|
||||
from pathlib import Path as _Path
|
||||
|
||||
local = _Path(local_dir)
|
||||
if not local.exists() or not local.is_dir():
|
||||
return [{"exit_code": -1, "output": f"Local directory not found: {local_dir}"}]
|
||||
|
||||
results = []
|
||||
for file_path in sorted(local.rglob("*")):
|
||||
if file_path.is_file():
|
||||
relative = file_path.relative_to(local)
|
||||
target = f"{remote_dir}/{relative}"
|
||||
results.append(self.upload_file(str(file_path), target))
|
||||
return results
|
||||
|
||||
def download_file(self, remote_path: str, local_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Download a file from the rollout's sandbox to the host (binary-safe).
|
||||
|
||||
The inverse of upload_file(). Base64-encodes the file inside the sandbox,
|
||||
reads the encoded data through the terminal, and decodes it locally.
|
||||
Safe for any file type.
|
||||
|
||||
Args:
|
||||
remote_path: Path to the file inside the sandbox
|
||||
local_path: Destination path on the host
|
||||
|
||||
Returns:
|
||||
Dict with 'success' (bool) and 'bytes' (int) or 'error' (str)
|
||||
"""
|
||||
import base64
|
||||
from pathlib import Path as _Path
|
||||
|
||||
# Base64-encode the file inside the sandbox and capture output
|
||||
result = self.terminal(
|
||||
f"base64 {remote_path} 2>/dev/null",
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
if result.get("exit_code", -1) != 0:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Failed to read remote file: {result.get('output', '')}",
|
||||
}
|
||||
|
||||
b64_data = result.get("output", "").strip()
|
||||
if not b64_data:
|
||||
return {"success": False, "error": f"Remote file is empty or missing: {remote_path}"}
|
||||
|
||||
try:
|
||||
raw = base64.b64decode(b64_data)
|
||||
except Exception as e:
|
||||
return {"success": False, "error": f"Base64 decode failed: {e}"}
|
||||
|
||||
# Write to local host filesystem
|
||||
local = _Path(local_path)
|
||||
local.parent.mkdir(parents=True, exist_ok=True)
|
||||
local.write_bytes(raw)
|
||||
|
||||
return {"success": True, "bytes": len(raw)}
|
||||
|
||||
def download_dir(self, remote_dir: str, local_dir: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Download a directory from the rollout's sandbox to the host (binary-safe).
|
||||
|
||||
Lists all files in the remote directory, then downloads each one.
|
||||
Preserves directory structure.
|
||||
|
||||
Args:
|
||||
remote_dir: Path to the directory inside the sandbox
|
||||
local_dir: Destination directory on the host
|
||||
|
||||
Returns:
|
||||
List of results, one per file downloaded
|
||||
"""
|
||||
from pathlib import Path as _Path
|
||||
|
||||
# List files in the remote directory
|
||||
ls_result = self.terminal(
|
||||
f"find {remote_dir} -type f 2>/dev/null",
|
||||
timeout=15,
|
||||
)
|
||||
|
||||
if ls_result.get("exit_code", -1) != 0:
|
||||
return [{"success": False, "error": f"Failed to list remote dir: {remote_dir}"}]
|
||||
|
||||
file_list = ls_result.get("output", "").strip()
|
||||
if not file_list:
|
||||
return [{"success": False, "error": f"Remote directory is empty or missing: {remote_dir}"}]
|
||||
|
||||
results = []
|
||||
for remote_file in file_list.splitlines():
|
||||
remote_file = remote_file.strip()
|
||||
if not remote_file:
|
||||
continue
|
||||
# Compute the relative path to preserve directory structure
|
||||
if remote_file.startswith(remote_dir):
|
||||
relative = remote_file[len(remote_dir):].lstrip("/")
|
||||
else:
|
||||
relative = _Path(remote_file).name
|
||||
local_file = str(_Path(local_dir) / relative)
|
||||
results.append(self.download_file(remote_file, local_file))
|
||||
|
||||
return results
|
||||
|
||||
def search(self, query: str, path: str = ".") -> Dict[str, Any]:
|
||||
"""
|
||||
Search for text in the rollout's filesystem.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
path: Directory to search in
|
||||
|
||||
Returns:
|
||||
Dict with search results
|
||||
"""
|
||||
result = handle_function_call(
|
||||
"search_files", {"pattern": query, "path": path}, task_id=self.task_id
|
||||
)
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
return {"error": result}
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Web tools
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def web_search(self, query: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Search the web.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
Dict with search results
|
||||
"""
|
||||
result = handle_function_call("web_search", {"query": query})
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
return {"error": result}
|
||||
|
||||
def web_extract(self, urls: List[str]) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract content from URLs.
|
||||
|
||||
Args:
|
||||
urls: List of URLs to extract content from
|
||||
|
||||
Returns:
|
||||
Dict with extracted content
|
||||
"""
|
||||
result = handle_function_call("web_extract", {"urls": urls})
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
return {"error": result}
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Browser tools
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def browser_navigate(self, url: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Navigate the rollout's browser session to a URL.
|
||||
|
||||
Args:
|
||||
url: URL to navigate to
|
||||
|
||||
Returns:
|
||||
Dict with page snapshot or error
|
||||
"""
|
||||
result = handle_function_call(
|
||||
"browser_navigate", {"url": url}, task_id=self.task_id
|
||||
)
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
return {"error": result}
|
||||
|
||||
def browser_snapshot(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Take a snapshot of the current browser page.
|
||||
|
||||
Returns:
|
||||
Dict with page content/accessibility snapshot
|
||||
"""
|
||||
result = handle_function_call(
|
||||
"browser_snapshot", {}, task_id=self.task_id
|
||||
)
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
return {"error": result}
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Generic tool access
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Call any hermes-agent tool by name.
|
||||
|
||||
This is the generic escape hatch -- if a tool doesn't have a convenience
|
||||
wrapper above, you can call it directly here.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool (e.g., "vision_analyze", "skills_list")
|
||||
arguments: Dict of arguments for the tool
|
||||
|
||||
Returns:
|
||||
Raw JSON string result from the tool
|
||||
"""
|
||||
return _run_tool_in_thread(tool_name, arguments, self.task_id)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Cleanup
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def cleanup(self):
|
||||
"""
|
||||
Release all resources (terminal VMs, browser sessions, background processes)
|
||||
for this rollout.
|
||||
|
||||
Called automatically by the base environment via try/finally after
|
||||
compute_reward() completes. You generally don't need to call this yourself.
|
||||
"""
|
||||
# Kill any background processes from this rollout (safety net)
|
||||
try:
|
||||
from tools.process_registry import process_registry
|
||||
killed = process_registry.kill_all(task_id=self.task_id)
|
||||
if killed:
|
||||
logger.debug("Process cleanup for task %s: killed %d process(es)", self.task_id, killed)
|
||||
except Exception as e:
|
||||
logger.debug("Process cleanup for task %s: %s", self.task_id, e)
|
||||
|
||||
try:
|
||||
cleanup_vm(self.task_id)
|
||||
except Exception as e:
|
||||
logger.debug("VM cleanup for task %s: %s", self.task_id, e)
|
||||
|
||||
# Suppress browser_tool's noisy debug prints during cleanup.
|
||||
# The cleanup still runs (safe), it just doesn't spam the console.
|
||||
_prev_quiet = os.environ.get("HERMES_QUIET")
|
||||
os.environ["HERMES_QUIET"] = "1"
|
||||
try:
|
||||
cleanup_browser(self.task_id)
|
||||
except Exception as e:
|
||||
logger.debug("Browser cleanup for task %s: %s", self.task_id, e)
|
||||
finally:
|
||||
if _prev_quiet is None:
|
||||
os.environ.pop("HERMES_QUIET", None)
|
||||
else:
|
||||
os.environ["HERMES_QUIET"] = _prev_quiet
|
||||
@@ -1,719 +0,0 @@
|
||||
"""
|
||||
WebResearchEnv — RL Environment for Multi-Step Web Research
|
||||
============================================================
|
||||
|
||||
Trains models to do accurate, efficient, multi-source web research.
|
||||
|
||||
Reward signals:
|
||||
- Answer correctness (LLM judge, 0.0–1.0)
|
||||
- Source diversity (used ≥2 distinct domains)
|
||||
- Efficiency (penalizes excessive tool calls)
|
||||
- Tool usage (bonus for actually using web tools)
|
||||
|
||||
Dataset: FRAMES benchmark (Google, 2024) — multi-hop factual questions
|
||||
HuggingFace: google/frames-benchmark
|
||||
Fallback: built-in sample questions (no HF token needed)
|
||||
|
||||
Usage:
|
||||
# Phase 1 (OpenAI-compatible server)
|
||||
python environments/web_research_env.py serve \\
|
||||
--openai.base_url http://localhost:8000/v1 \\
|
||||
--openai.model_name YourModel \\
|
||||
--openai.server_type openai
|
||||
|
||||
# Process mode (offline data generation)
|
||||
python environments/web_research_env.py process \\
|
||||
--env.data_path_to_save_groups data/web_research.jsonl
|
||||
|
||||
# Standalone eval
|
||||
python environments/web_research_env.py evaluate \\
|
||||
--openai.base_url http://localhost:8000/v1 \\
|
||||
--openai.model_name YourModel
|
||||
|
||||
Built by: github.com/jackx707
|
||||
Inspired by: GroceryMind — production Hermes agent doing live web research
|
||||
across German grocery stores (firecrawl + hermes-agent)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
# Ensure hermes-agent root is on path
|
||||
_repo_root = Path(__file__).resolve().parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Optional HuggingFace datasets import
|
||||
# ---------------------------------------------------------------------------
|
||||
try:
|
||||
from datasets import load_dataset
|
||||
HF_AVAILABLE = True
|
||||
except ImportError:
|
||||
HF_AVAILABLE = False
|
||||
|
||||
from atroposlib.envs.base import ScoredDataGroup
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
from environments.agent_loop import AgentResult
|
||||
from environments.tool_context import ToolContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fallback sample dataset (used when HuggingFace is unavailable)
|
||||
# Multi-hop questions requiring real web search to answer.
|
||||
# ---------------------------------------------------------------------------
|
||||
SAMPLE_QUESTIONS = [
|
||||
{
|
||||
"question": "What is the current population of the capital city of the country that won the 2022 FIFA World Cup?",
|
||||
"answer": "Buenos Aires has approximately 3 million people in the city proper, or around 15 million in the greater metro area.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "Who is the CEO of the company that makes the most widely used open-source container orchestration platform?",
|
||||
"answer": "The Linux Foundation oversees Kubernetes. CNCF (Cloud Native Computing Foundation) is the specific body — it does not have a traditional CEO but has an executive director.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "What programming language was used to write the original version of the web framework used by Instagram?",
|
||||
"answer": "Django, which Instagram was built on, is written in Python.",
|
||||
"difficulty": "easy",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "In what year was the university founded where the inventor of the World Wide Web currently holds a professorship?",
|
||||
"answer": "Tim Berners-Lee holds a professorship at MIT (founded 1861) and the University of Southampton (founded 1952).",
|
||||
"difficulty": "hard",
|
||||
"hops": 3,
|
||||
},
|
||||
{
|
||||
"question": "What is the latest stable version of the programming language that ranks #1 on the TIOBE index as of this year?",
|
||||
"answer": "Python is currently #1 on TIOBE. The latest stable version should be verified via the official python.org site.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "How many employees does the parent company of Instagram have?",
|
||||
"answer": "Meta Platforms (parent of Instagram) employs approximately 70,000+ people as of recent reports.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "What is the current interest rate set by the central bank of the country where the Eiffel Tower is located?",
|
||||
"answer": "The European Central Bank sets rates for France/eurozone. The current rate should be verified — it has changed frequently in 2023-2025.",
|
||||
"difficulty": "hard",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "Which company acquired the startup founded by the creator of Oculus VR?",
|
||||
"answer": "Palmer Luckey founded Oculus VR, which was acquired by Facebook (now Meta). He later founded Anduril Industries.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "What is the market cap of the company that owns the most popular search engine in Russia?",
|
||||
"answer": "Yandex (now split into separate entities after 2024 restructuring). Current market cap should be verified via financial sources.",
|
||||
"difficulty": "hard",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "What was the GDP growth rate of the country that hosted the most recent Summer Olympics?",
|
||||
"answer": "Paris, France hosted the 2024 Summer Olympics. France's recent GDP growth should be verified via World Bank or IMF data.",
|
||||
"difficulty": "hard",
|
||||
"hops": 2,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class WebResearchEnvConfig(HermesAgentEnvConfig):
|
||||
"""Configuration for the web research RL environment."""
|
||||
|
||||
# Reward weights
|
||||
correctness_weight: float = Field(
|
||||
default=0.6,
|
||||
description="Weight for answer correctness in reward (LLM judge score).",
|
||||
)
|
||||
tool_usage_weight: float = Field(
|
||||
default=0.2,
|
||||
description="Weight for tool usage signal (did the model actually use web tools?).",
|
||||
)
|
||||
efficiency_weight: float = Field(
|
||||
default=0.2,
|
||||
description="Weight for efficiency signal (penalizes excessive tool calls).",
|
||||
)
|
||||
diversity_bonus: float = Field(
|
||||
default=0.1,
|
||||
description="Bonus reward for citing ≥2 distinct domains.",
|
||||
)
|
||||
|
||||
# Efficiency thresholds
|
||||
efficient_max_calls: int = Field(
|
||||
default=5,
|
||||
description="Maximum tool calls before efficiency penalty begins.",
|
||||
)
|
||||
heavy_penalty_calls: int = Field(
|
||||
default=10,
|
||||
description="Tool call count where efficiency penalty steepens.",
|
||||
)
|
||||
|
||||
# Eval
|
||||
eval_size: int = Field(
|
||||
default=20,
|
||||
description="Number of held-out items for evaluation.",
|
||||
)
|
||||
eval_split_ratio: float = Field(
|
||||
default=0.1,
|
||||
description="Fraction of dataset to hold out for evaluation (0.0–1.0).",
|
||||
)
|
||||
|
||||
# Dataset
|
||||
dataset_name: str = Field(
|
||||
default="google/frames-benchmark",
|
||||
description="HuggingFace dataset name for research questions.",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Environment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class WebResearchEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
RL environment for training multi-step web research skills.
|
||||
|
||||
The model is given a factual question requiring 2-3 hops of web research
|
||||
and must use web_search / web_extract tools to find and synthesize the answer.
|
||||
|
||||
Reward is multi-signal:
|
||||
60% — answer correctness (LLM judge)
|
||||
20% — tool usage (did the model actually search the web?)
|
||||
20% — efficiency (penalizes >5 tool calls)
|
||||
|
||||
Bonus +0.1 for source diversity (≥2 distinct domains cited).
|
||||
"""
|
||||
|
||||
name = "web-research"
|
||||
env_config_cls = WebResearchEnvConfig
|
||||
|
||||
# Default toolsets for this environment — web + file for saving notes
|
||||
default_toolsets = ["web", "file"]
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[WebResearchEnvConfig, List[APIServerConfig]]:
|
||||
"""Default configuration for the web research environment."""
|
||||
env_config = WebResearchEnvConfig(
|
||||
enabled_toolsets=["web", "file"],
|
||||
max_agent_turns=15,
|
||||
agent_temperature=1.0,
|
||||
system_prompt=(
|
||||
"You are a highly capable research agent. When asked a factual question, "
|
||||
"always use web_search to find current, accurate information before answering. "
|
||||
"Cite at least 2 sources. Be concise and accurate."
|
||||
),
|
||||
group_size=4,
|
||||
total_steps=1000,
|
||||
steps_per_eval=100,
|
||||
use_wandb=True,
|
||||
wandb_name="web-research",
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model_name="anthropic/claude-sonnet-4.5",
|
||||
server_type="openai",
|
||||
api_key=os.getenv("OPENROUTER_API_KEY", ""),
|
||||
health_check=False,
|
||||
)
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._items: list[dict] = []
|
||||
self._eval_items: list[dict] = []
|
||||
self._index: int = 0
|
||||
|
||||
# Metrics tracking for wandb
|
||||
self._reward_buffer: list[float] = []
|
||||
self._correctness_buffer: list[float] = []
|
||||
self._tool_usage_buffer: list[float] = []
|
||||
self._efficiency_buffer: list[float] = []
|
||||
self._diversity_buffer: list[float] = []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 1. Setup — load dataset
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def setup(self) -> None:
|
||||
"""Load the FRAMES benchmark or fall back to built-in samples."""
|
||||
if HF_AVAILABLE:
|
||||
try:
|
||||
logger.info("Loading FRAMES benchmark from HuggingFace...")
|
||||
ds = load_dataset(self.config.dataset_name, split="test")
|
||||
self._items = [
|
||||
{
|
||||
"question": row["Prompt"],
|
||||
"answer": row["Answer"],
|
||||
"difficulty": row.get("reasoning_types", "unknown"),
|
||||
"hops": 2,
|
||||
}
|
||||
for row in ds
|
||||
]
|
||||
# Hold out for eval
|
||||
eval_size = max(
|
||||
self.config.eval_size,
|
||||
int(len(self._items) * self.config.eval_split_ratio),
|
||||
)
|
||||
random.shuffle(self._items)
|
||||
self._eval_items = self._items[:eval_size]
|
||||
self._items = self._items[eval_size:]
|
||||
logger.info(
|
||||
f"Loaded {len(self._items)} train / {len(self._eval_items)} eval items "
|
||||
f"from FRAMES benchmark."
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load FRAMES from HuggingFace: {e}. Using built-in samples.")
|
||||
|
||||
# Fallback
|
||||
random.shuffle(SAMPLE_QUESTIONS)
|
||||
split = max(1, len(SAMPLE_QUESTIONS) * 8 // 10)
|
||||
self._items = SAMPLE_QUESTIONS[:split]
|
||||
self._eval_items = SAMPLE_QUESTIONS[split:]
|
||||
logger.info(
|
||||
f"Using built-in sample dataset: {len(self._items)} train / "
|
||||
f"{len(self._eval_items)} eval items."
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 2. get_next_item — return the next question
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def get_next_item(self) -> dict:
|
||||
"""Return the next item, cycling through the dataset."""
|
||||
if not self._items:
|
||||
raise RuntimeError("Dataset is empty. Did you call setup()?")
|
||||
item = self._items[self._index % len(self._items)]
|
||||
self._index += 1
|
||||
return item
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 3. format_prompt — build the user-facing prompt
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def format_prompt(self, item: dict) -> str:
|
||||
"""Format the research question as a task prompt."""
|
||||
return (
|
||||
f"Research the following question thoroughly using web search. "
|
||||
f"You MUST search the web to find current, accurate information — "
|
||||
f"do not rely solely on your training data.\n\n"
|
||||
f"Question: {item['question']}\n\n"
|
||||
f"Requirements:\n"
|
||||
f"- Use web_search and/or web_extract tools to find information\n"
|
||||
f"- Search at least 2 different sources\n"
|
||||
f"- Provide a concise, accurate answer (2-4 sentences)\n"
|
||||
f"- Cite the sources you used"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 4. compute_reward — multi-signal scoring
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def compute_reward(
|
||||
self,
|
||||
item: dict,
|
||||
result: AgentResult,
|
||||
ctx: ToolContext,
|
||||
) -> float:
|
||||
"""
|
||||
Multi-signal reward function:
|
||||
|
||||
correctness_weight * correctness — LLM judge comparing answer to ground truth
|
||||
tool_usage_weight * tool_used — binary: did the model use web tools?
|
||||
efficiency_weight * efficiency — penalizes wasteful tool usage
|
||||
+ diversity_bonus — source diversity (≥2 distinct domains)
|
||||
"""
|
||||
# Extract final response from messages (last assistant message with content)
|
||||
final_response = ""
|
||||
tools_used: list[str] = []
|
||||
for msg in reversed(result.messages):
|
||||
if msg.get("role") == "assistant" and msg.get("content") and not final_response:
|
||||
final_response = msg["content"]
|
||||
# Collect tool names from tool call messages
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
for tc in msg["tool_calls"]:
|
||||
fn = tc.get("function", {}) if isinstance(tc, dict) else {}
|
||||
name = fn.get("name", "")
|
||||
if name:
|
||||
tools_used.append(name)
|
||||
tool_call_count: int = result.turns_used or len(tools_used)
|
||||
|
||||
cfg = self.config
|
||||
|
||||
# ---- Signal 1: Answer correctness (LLM judge) ----------------
|
||||
correctness = await self._llm_judge(
|
||||
question=item["question"],
|
||||
expected=item["answer"],
|
||||
model_answer=final_response,
|
||||
)
|
||||
|
||||
# ---- Signal 2: Web tool usage --------------------------------
|
||||
web_tools = {"web_search", "web_extract", "search", "firecrawl"}
|
||||
tool_used = 1.0 if any(t in web_tools for t in tools_used) else 0.0
|
||||
|
||||
# ---- Signal 3: Efficiency ------------------------------------
|
||||
if tool_call_count <= cfg.efficient_max_calls:
|
||||
efficiency = 1.0
|
||||
elif tool_call_count <= cfg.heavy_penalty_calls:
|
||||
efficiency = 1.0 - (tool_call_count - cfg.efficient_max_calls) * 0.08
|
||||
else:
|
||||
efficiency = max(0.0, 1.0 - (tool_call_count - cfg.efficient_max_calls) * 0.12)
|
||||
|
||||
# ---- Bonus: Source diversity ---------------------------------
|
||||
domains = self._extract_domains(final_response)
|
||||
diversity = cfg.diversity_bonus if len(domains) >= 2 else 0.0
|
||||
|
||||
# ---- Combine ------------------------------------------------
|
||||
reward = (
|
||||
cfg.correctness_weight * correctness
|
||||
+ cfg.tool_usage_weight * tool_used
|
||||
+ cfg.efficiency_weight * efficiency
|
||||
+ diversity
|
||||
)
|
||||
reward = min(1.0, max(0.0, reward)) # clamp to [0, 1]
|
||||
|
||||
# Track for wandb
|
||||
self._reward_buffer.append(reward)
|
||||
self._correctness_buffer.append(correctness)
|
||||
self._tool_usage_buffer.append(tool_used)
|
||||
self._efficiency_buffer.append(efficiency)
|
||||
self._diversity_buffer.append(diversity)
|
||||
|
||||
logger.debug(
|
||||
f"Reward breakdown — correctness={correctness:.2f}, "
|
||||
f"tool_used={tool_used:.1f}, efficiency={efficiency:.2f}, "
|
||||
f"diversity={diversity:.1f} → total={reward:.3f}"
|
||||
)
|
||||
|
||||
return reward
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 5. evaluate — run on held-out eval split
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def evaluate(self, *args, **kwargs) -> None:
|
||||
"""Run evaluation on the held-out split using the full agent loop with tools.
|
||||
|
||||
Each eval item runs through the same agent loop as training —
|
||||
the model can use web_search, web_extract, etc. to research answers.
|
||||
This measures actual agentic research capability, not just knowledge.
|
||||
"""
|
||||
import time
|
||||
import uuid
|
||||
from environments.agent_loop import HermesAgentLoop
|
||||
from environments.tool_context import ToolContext
|
||||
|
||||
items = self._eval_items
|
||||
if not items:
|
||||
logger.warning("No eval items available.")
|
||||
return
|
||||
|
||||
eval_size = min(self.config.eval_size, len(items))
|
||||
eval_items = items[:eval_size]
|
||||
|
||||
logger.info(f"Running eval on {len(eval_items)} questions (with agent loop + tools)...")
|
||||
start_time = time.time()
|
||||
samples = []
|
||||
|
||||
# Resolve tools once for all eval items
|
||||
tools, valid_names = self._resolve_tools_for_group()
|
||||
|
||||
for i, item in enumerate(eval_items):
|
||||
task_id = str(uuid.uuid4())
|
||||
logger.info(f"Eval [{i+1}/{len(eval_items)}]: {item['question'][:80]}...")
|
||||
|
||||
try:
|
||||
# Build messages
|
||||
messages: List[Dict[str, Any]] = []
|
||||
if self.config.system_prompt:
|
||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
messages.append({"role": "user", "content": self.format_prompt(item)})
|
||||
|
||||
# Run the full agent loop with tools
|
||||
agent = HermesAgentLoop(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=task_id,
|
||||
temperature=0.0, # Deterministic for eval
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
budget_config=self.config.build_budget_config(),
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
# Extract final response and tool usage from messages
|
||||
final_response = ""
|
||||
tool_call_count = 0
|
||||
for msg in reversed(result.messages):
|
||||
if msg.get("role") == "assistant" and msg.get("content") and not final_response:
|
||||
final_response = msg["content"]
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
tool_call_count += len(msg["tool_calls"])
|
||||
|
||||
# Compute reward (includes LLM judge for correctness)
|
||||
# Temporarily save buffer lengths so we can extract the
|
||||
# correctness score without calling judge twice, and avoid
|
||||
# polluting training metric buffers with eval data.
|
||||
buf_len = len(self._correctness_buffer)
|
||||
ctx = ToolContext(task_id)
|
||||
try:
|
||||
reward = await self.compute_reward(item, result, ctx)
|
||||
finally:
|
||||
ctx.cleanup()
|
||||
|
||||
# Extract correctness from the buffer (compute_reward appended it)
|
||||
# then remove eval entries from training buffers
|
||||
correctness = (
|
||||
self._correctness_buffer[buf_len]
|
||||
if len(self._correctness_buffer) > buf_len
|
||||
else 0.0
|
||||
)
|
||||
# Roll back buffers to avoid polluting training metrics
|
||||
for buf in (
|
||||
self._reward_buffer, self._correctness_buffer,
|
||||
self._tool_usage_buffer, self._efficiency_buffer,
|
||||
self._diversity_buffer,
|
||||
):
|
||||
if len(buf) > buf_len:
|
||||
buf.pop()
|
||||
|
||||
samples.append({
|
||||
"prompt": item["question"],
|
||||
"response": final_response[:500],
|
||||
"expected": item["answer"],
|
||||
"correctness": correctness,
|
||||
"reward": reward,
|
||||
"tool_calls": tool_call_count,
|
||||
"turns": result.turns_used,
|
||||
})
|
||||
|
||||
logger.info(
|
||||
f" → correctness={correctness:.2f}, reward={reward:.3f}, "
|
||||
f"tools={tool_call_count}, turns={result.turns_used}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Eval error on item: {e}")
|
||||
samples.append({
|
||||
"prompt": item["question"],
|
||||
"response": f"ERROR: {e}",
|
||||
"expected": item["answer"],
|
||||
"correctness": 0.0,
|
||||
"reward": 0.0,
|
||||
"tool_calls": 0,
|
||||
"turns": 0,
|
||||
})
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
# Compute aggregate metrics
|
||||
correctness_scores = [s["correctness"] for s in samples]
|
||||
rewards = [s["reward"] for s in samples]
|
||||
tool_counts = [s["tool_calls"] for s in samples]
|
||||
n = len(samples)
|
||||
|
||||
eval_metrics = {
|
||||
"eval/mean_correctness": sum(correctness_scores) / n if n else 0.0,
|
||||
"eval/mean_reward": sum(rewards) / n if n else 0.0,
|
||||
"eval/mean_tool_calls": sum(tool_counts) / n if n else 0.0,
|
||||
"eval/tool_usage_rate": sum(1 for t in tool_counts if t > 0) / n if n else 0.0,
|
||||
"eval/n_items": n,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Eval complete — correctness={eval_metrics['eval/mean_correctness']:.3f}, "
|
||||
f"reward={eval_metrics['eval/mean_reward']:.3f}, "
|
||||
f"tool_usage={eval_metrics['eval/tool_usage_rate']:.0%}"
|
||||
)
|
||||
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
samples=samples,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 6. wandb_log — custom metrics
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None) -> None:
|
||||
"""Log reward breakdown metrics to wandb."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
if self._reward_buffer:
|
||||
n = len(self._reward_buffer)
|
||||
wandb_metrics["train/mean_reward"] = sum(self._reward_buffer) / n
|
||||
wandb_metrics["train/mean_correctness"] = sum(self._correctness_buffer) / n
|
||||
wandb_metrics["train/mean_tool_usage"] = sum(self._tool_usage_buffer) / n
|
||||
wandb_metrics["train/mean_efficiency"] = sum(self._efficiency_buffer) / n
|
||||
wandb_metrics["train/mean_diversity"] = sum(self._diversity_buffer) / n
|
||||
wandb_metrics["train/total_rollouts"] = n
|
||||
|
||||
# Accuracy buckets
|
||||
wandb_metrics["train/correct_rate"] = (
|
||||
sum(1 for c in self._correctness_buffer if c >= 0.7) / n
|
||||
)
|
||||
wandb_metrics["train/tool_usage_rate"] = (
|
||||
sum(1 for t in self._tool_usage_buffer if t > 0) / n
|
||||
)
|
||||
|
||||
# Clear buffers
|
||||
self._reward_buffer.clear()
|
||||
self._correctness_buffer.clear()
|
||||
self._tool_usage_buffer.clear()
|
||||
self._efficiency_buffer.clear()
|
||||
self._diversity_buffer.clear()
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _llm_judge(
|
||||
self,
|
||||
question: str,
|
||||
expected: str,
|
||||
model_answer: str,
|
||||
) -> float:
|
||||
"""
|
||||
Use the server's LLM to judge answer correctness.
|
||||
Falls back to keyword heuristic if LLM call fails.
|
||||
"""
|
||||
if not model_answer or not model_answer.strip():
|
||||
return 0.0
|
||||
|
||||
judge_prompt = (
|
||||
"You are an impartial judge evaluating the quality of an AI research answer.\n\n"
|
||||
f"Question: {question}\n\n"
|
||||
f"Reference answer: {expected}\n\n"
|
||||
f"Model answer: {model_answer}\n\n"
|
||||
"Score the model answer on a scale from 0.0 to 1.0 where:\n"
|
||||
" 1.0 = fully correct and complete\n"
|
||||
" 0.7 = mostly correct with minor gaps\n"
|
||||
" 0.4 = partially correct\n"
|
||||
" 0.1 = mentions relevant topic but wrong or very incomplete\n"
|
||||
" 0.0 = completely wrong or no answer\n\n"
|
||||
"Consider: factual accuracy, completeness, and relevance.\n"
|
||||
'Respond with ONLY a JSON object: {"score": <float>, "reason": "<one sentence>"}'
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.server.chat_completion(
|
||||
messages=[{"role": "user", "content": judge_prompt}],
|
||||
n=1,
|
||||
max_tokens=150,
|
||||
temperature=0.0,
|
||||
split="eval",
|
||||
)
|
||||
text = response.choices[0].message.content if response.choices else ""
|
||||
parsed = self._parse_judge_json(text)
|
||||
if parsed is not None:
|
||||
return float(parsed)
|
||||
except Exception as e:
|
||||
logger.debug(f"LLM judge failed: {e}. Using heuristic.")
|
||||
|
||||
return self._heuristic_score(expected, model_answer)
|
||||
|
||||
@staticmethod
|
||||
def _parse_judge_json(text: str) -> Optional[float]:
|
||||
"""Extract the score float from LLM judge JSON response."""
|
||||
try:
|
||||
clean = re.sub(r"```(?:json)?|```", "", text).strip()
|
||||
data = json.loads(clean)
|
||||
score = float(data.get("score", -1))
|
||||
if 0.0 <= score <= 1.0:
|
||||
return score
|
||||
except Exception:
|
||||
match = re.search(r'"score"\s*:\s*([0-9.]+)', text)
|
||||
if match:
|
||||
score = float(match.group(1))
|
||||
if 0.0 <= score <= 1.0:
|
||||
return score
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _heuristic_score(expected: str, model_answer: str) -> float:
|
||||
"""Lightweight keyword overlap score as fallback."""
|
||||
stopwords = {
|
||||
"the", "a", "an", "is", "are", "was", "were", "of", "in", "on",
|
||||
"at", "to", "for", "with", "and", "or", "but", "it", "its",
|
||||
"this", "that", "as", "by", "from", "be", "has", "have", "had",
|
||||
}
|
||||
|
||||
def tokenize(text: str) -> set:
|
||||
tokens = re.findall(r'\b\w+\b', text.lower())
|
||||
return {t for t in tokens if t not in stopwords and len(t) > 2}
|
||||
|
||||
expected_tokens = tokenize(expected)
|
||||
answer_tokens = tokenize(model_answer)
|
||||
|
||||
if not expected_tokens:
|
||||
return 0.5
|
||||
|
||||
overlap = len(expected_tokens & answer_tokens)
|
||||
union = len(expected_tokens | answer_tokens)
|
||||
|
||||
jaccard = overlap / union if union > 0 else 0.0
|
||||
recall = overlap / len(expected_tokens)
|
||||
return min(1.0, 0.4 * jaccard + 0.6 * recall)
|
||||
|
||||
@staticmethod
|
||||
def _extract_domains(text: str) -> set:
|
||||
"""Extract unique domains from URLs cited in the response."""
|
||||
urls = re.findall(r'https?://[^\s\)>\]"\']+', text)
|
||||
domains = set()
|
||||
for url in urls:
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
domain = parsed.netloc.lower().lstrip("www.")
|
||||
if domain:
|
||||
domains.add(domain)
|
||||
except Exception:
|
||||
pass
|
||||
return domains
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
WebResearchEnv.cli()
|
||||
@@ -941,6 +941,14 @@ def load_gateway_config() -> GatewayConfig:
|
||||
if isinstance(ntc, list):
|
||||
ntc = ",".join(str(v) for v in ntc)
|
||||
os.environ["DISCORD_NO_THREAD_CHANNELS"] = str(ntc)
|
||||
# history_backfill: recover missed channel messages for shared sessions
|
||||
# when require_mention is active. Fetches messages between bot turns
|
||||
# and prepends them to the user message for context.
|
||||
if "history_backfill" in discord_cfg and not os.getenv("DISCORD_HISTORY_BACKFILL"):
|
||||
os.environ["DISCORD_HISTORY_BACKFILL"] = str(discord_cfg["history_backfill"]).lower()
|
||||
hbl = discord_cfg.get("history_backfill_limit")
|
||||
if hbl is not None and not os.getenv("DISCORD_HISTORY_BACKFILL_LIMIT"):
|
||||
os.environ["DISCORD_HISTORY_BACKFILL_LIMIT"] = str(hbl)
|
||||
# allow_mentions: granular control over what the bot can ping.
|
||||
# Safe defaults (no @everyone/roles) are applied in the adapter;
|
||||
# these YAML keys only override when set and let users opt back
|
||||
|
||||
@@ -356,15 +356,34 @@ class ResponseStore:
|
||||
# Evict oldest entries beyond max_size
|
||||
count = self._conn.execute("SELECT COUNT(*) FROM responses").fetchone()[0]
|
||||
if count > self._max_size:
|
||||
self._conn.execute(
|
||||
"DELETE FROM responses WHERE response_id IN "
|
||||
"(SELECT response_id FROM responses ORDER BY accessed_at ASC LIMIT ?)",
|
||||
(count - self._max_size,),
|
||||
)
|
||||
# Collect IDs that will be evicted
|
||||
evict_ids = [
|
||||
row[0]
|
||||
for row in self._conn.execute(
|
||||
"SELECT response_id FROM responses ORDER BY accessed_at ASC LIMIT ?",
|
||||
(count - self._max_size,),
|
||||
).fetchall()
|
||||
]
|
||||
if evict_ids:
|
||||
placeholders = ",".join("?" for _ in evict_ids)
|
||||
# Clear conversation mappings pointing to evicted responses
|
||||
self._conn.execute(
|
||||
f"DELETE FROM conversations WHERE response_id IN ({placeholders})",
|
||||
evict_ids,
|
||||
)
|
||||
# Delete evicted responses
|
||||
self._conn.execute(
|
||||
f"DELETE FROM responses WHERE response_id IN ({placeholders})",
|
||||
evict_ids,
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def delete(self, response_id: str) -> bool:
|
||||
"""Remove a response from the store. Returns True if found and deleted."""
|
||||
# Clear conversation mappings pointing to this response
|
||||
self._conn.execute(
|
||||
"DELETE FROM conversations WHERE response_id = ?", (response_id,)
|
||||
)
|
||||
cursor = self._conn.execute(
|
||||
"DELETE FROM responses WHERE response_id = ?", (response_id,)
|
||||
)
|
||||
|
||||
@@ -955,6 +955,12 @@ class MessageEvent:
|
||||
# Per-channel ephemeral system prompt (e.g. Discord channel_prompts).
|
||||
# Applied at API call time and never persisted to transcript history.
|
||||
channel_prompt: Optional[str] = None
|
||||
|
||||
# Channel context recovered by history backfill (e.g. messages between
|
||||
# bot turns that were missed due to require_mention). Kept separate
|
||||
# from ``text`` so the sender-prefix logic in run.py can operate on the
|
||||
# trigger message alone, then prepend this context afterward.
|
||||
channel_context: Optional[str] = None
|
||||
|
||||
# Internal flag — set for synthetic events (e.g. background process
|
||||
# completion notifications) that must bypass user authorization checks.
|
||||
|
||||
@@ -589,6 +589,10 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
# chunk only, default), "all" (reply-reference on every chunk).
|
||||
self._reply_to_mode: str = getattr(config, 'reply_to_mode', 'first') or 'first'
|
||||
self._slash_commands: bool = self.config.extra.get("slash_commands", True)
|
||||
# In-memory cache of the bot's last message ID per channel, used by
|
||||
# history backfill to skip the full scan on hot paths. Falls back to
|
||||
# scanning channel.history() on cache miss (cold start / restart).
|
||||
self._last_self_message_id: Dict[str, str] = {}
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to Discord and start receiving events."""
|
||||
@@ -1459,6 +1463,12 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
raise
|
||||
message_ids.append(str(msg.id))
|
||||
|
||||
# Track the last message we sent in this channel for history
|
||||
# backfill — avoids a full channel.history() scan on hot paths.
|
||||
if message_ids:
|
||||
_target_id = thread_id or chat_id
|
||||
self._last_self_message_id[_target_id] = message_ids[-1]
|
||||
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=message_ids[0] if message_ids else None,
|
||||
@@ -3596,6 +3606,134 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
return bool(configured)
|
||||
return os.getenv("DISCORD_THREAD_REQUIRE_MENTION", "false").lower() in ("true", "1", "yes", "on")
|
||||
|
||||
def _discord_history_backfill(self) -> bool:
|
||||
"""Return whether history backfill is enabled for shared sessions."""
|
||||
configured = self.config.extra.get("history_backfill")
|
||||
if configured is not None:
|
||||
if isinstance(configured, str):
|
||||
return configured.lower() not in ("false", "0", "no", "off")
|
||||
return bool(configured)
|
||||
return os.getenv("DISCORD_HISTORY_BACKFILL", "true").lower() in ("true", "1", "yes")
|
||||
|
||||
def _discord_history_backfill_limit(self) -> int:
|
||||
"""Return the max number of messages to scan backwards for context.
|
||||
|
||||
In practice the scan usually stops much earlier — at the bot's own
|
||||
last message in the channel (the natural partition point). This
|
||||
limit is a safety cap for cold starts and long gaps where no prior
|
||||
bot message exists in recent history.
|
||||
"""
|
||||
configured = self.config.extra.get("history_backfill_limit")
|
||||
if configured is not None:
|
||||
try:
|
||||
return int(configured)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
raw = os.getenv("DISCORD_HISTORY_BACKFILL_LIMIT", "50")
|
||||
try:
|
||||
return int(raw)
|
||||
except (ValueError, TypeError):
|
||||
return 50
|
||||
|
||||
async def _fetch_channel_context(
|
||||
self,
|
||||
channel: Any,
|
||||
before: "DiscordMessage",
|
||||
) -> str:
|
||||
"""Fetch recent channel messages for conversational context.
|
||||
|
||||
Scans backwards from *before* and collects messages until it hits
|
||||
a message sent by this bot (the natural partition point between
|
||||
bot turns) or reaches ``history_backfill_limit``.
|
||||
|
||||
Returns a formatted block like::
|
||||
|
||||
[Recent channel messages]
|
||||
[Alice] some message
|
||||
[Bob [bot]] another message
|
||||
|
||||
Returns an empty string if no context is available.
|
||||
"""
|
||||
limit = self._discord_history_backfill_limit()
|
||||
if limit <= 0:
|
||||
return ""
|
||||
|
||||
# Determine which bot messages to include in context
|
||||
allow_bots_raw = os.getenv("DISCORD_ALLOW_BOTS", "none").lower().strip()
|
||||
include_other_bots = allow_bots_raw != "none"
|
||||
|
||||
# Use the in-memory cache to narrow the fetch window on hot paths.
|
||||
# If we know our last message ID in this channel, pass it as `after`
|
||||
# to avoid scanning the full limit. Falls back to scanning on cache
|
||||
# miss (cold start / restart).
|
||||
# Guard: only use the cache when it's chronologically before the
|
||||
# trigger — Discord snowflake IDs are monotonically increasing, so
|
||||
# a simple int comparison suffices.
|
||||
channel_id = str(getattr(channel, "id", ""))
|
||||
_cached_id = self._last_self_message_id.get(channel_id)
|
||||
_after_obj = None
|
||||
try:
|
||||
if _cached_id and int(_cached_id) < int(before.id):
|
||||
_after_obj = discord.Object(id=int(_cached_id))
|
||||
except (ValueError, TypeError):
|
||||
pass # Malformed cache entry — fall back to cold-start scan
|
||||
|
||||
try:
|
||||
collected = []
|
||||
# IMPORTANT: pass oldest_first=False explicitly. discord.py 2.x
|
||||
# silently flips the default to True when `after=` is supplied,
|
||||
# which would select the *earliest* N messages after our last
|
||||
# response instead of the *latest* N before the trigger. In
|
||||
# high-traffic windows that returns stale tool traces and drops
|
||||
# the actual final answer. See the regression test
|
||||
# `test_fetch_channel_context_cache_uses_latest_window_when_after_set`.
|
||||
async for msg in channel.history(
|
||||
limit=limit,
|
||||
before=before,
|
||||
after=_after_obj,
|
||||
oldest_first=False,
|
||||
):
|
||||
# Stop at our own message — this is the partition point.
|
||||
# Everything before this is already in the session transcript.
|
||||
# (Redundant when _after_obj is set, but needed for cold start.)
|
||||
if msg.author == self._client.user:
|
||||
break
|
||||
|
||||
# Skip system messages (pins, joins, thread renames, etc.)
|
||||
if msg.type not in (discord.MessageType.default, discord.MessageType.reply):
|
||||
continue
|
||||
|
||||
# Respect DISCORD_ALLOW_BOTS for other bots.
|
||||
# For history context, "mentions" is treated as "all" — we are
|
||||
# deciding what context to show, not whether to respond.
|
||||
if getattr(msg.author, "bot", False) and not include_other_bots:
|
||||
continue
|
||||
|
||||
content = getattr(msg, "clean_content", msg.content) or ""
|
||||
if not content and msg.attachments:
|
||||
content = "(attachment)"
|
||||
if not content:
|
||||
continue
|
||||
|
||||
name = msg.author.display_name
|
||||
if getattr(msg.author, "bot", False):
|
||||
name = f"{name} [bot]"
|
||||
collected.append(f"[{name}] {content}")
|
||||
|
||||
if not collected:
|
||||
return ""
|
||||
|
||||
# channel.history returns newest-first (oldest_first=False); reverse for chronological order
|
||||
collected.reverse()
|
||||
return "[Recent channel messages]\n" + "\n".join(collected)
|
||||
|
||||
except discord.Forbidden:
|
||||
logger.debug("[%s] Missing permissions to fetch channel history", self.name)
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.warning("[%s] Failed to fetch channel history: %s", self.name, e)
|
||||
return ""
|
||||
|
||||
def _thread_parent_channel(self, channel: Any) -> Any:
|
||||
"""Return the parent text channel when invoked from a thread."""
|
||||
return getattr(channel, "parent", None) or channel
|
||||
@@ -4504,9 +4642,50 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if pending_text_injection:
|
||||
event_text = f"{pending_text_injection}\n\n{event_text}" if event_text else pending_text_injection
|
||||
|
||||
# ── History backfill ─────────────────────────────────────────
|
||||
# When require_mention is active, the bot only processes messages
|
||||
# that @mention it. Messages in the channel between bot turns are
|
||||
# invisible to the session transcript. To recover that context,
|
||||
# fetch recent channel history and prepend it to the user message.
|
||||
#
|
||||
# The fetch window is: everything after the bot's last message in
|
||||
# the channel up to (but not including) the current trigger. On
|
||||
# cold start (no prior bot message found), fetch the last N messages
|
||||
# and stop at the first self-message encountered.
|
||||
#
|
||||
# Threads naturally scope to thread-only history (channel.history()
|
||||
# on a thread returns only that thread's messages). DMs are skipped
|
||||
# because every DM message triggers the bot — there's no mention gap
|
||||
# to fill; the session transcript already has everything.
|
||||
#
|
||||
# Per-user sessions also benefit: Alice's session is missing the
|
||||
# other-channel-participants' context, and her own messages from
|
||||
# before she mentioned the bot. Backfill fills that gap.
|
||||
#
|
||||
# Messages that arrive while the bot is processing (between trigger
|
||||
# and response) are not captured — this is an accepted simplification
|
||||
# to keep the partition rule clean.
|
||||
_channel_context = None
|
||||
_is_dm = isinstance(message.channel, discord.DMChannel)
|
||||
if not _is_dm:
|
||||
_needed_mention = (
|
||||
require_mention
|
||||
and not is_free_channel
|
||||
and not in_bot_thread
|
||||
)
|
||||
_backfill_enabled = self._discord_history_backfill()
|
||||
if _needed_mention and _backfill_enabled:
|
||||
_backfill_text = await self._fetch_channel_context(
|
||||
message.channel, before=message,
|
||||
)
|
||||
if _backfill_text:
|
||||
_channel_context = _backfill_text
|
||||
|
||||
# Defense-in-depth: prevent empty user messages from entering session
|
||||
# (can happen when user sends @mention-only with no other text)
|
||||
if not event_text or not event_text.strip():
|
||||
# (can happen when user sends @mention-only with no other text).
|
||||
# When channel_context is present, a bare mention means "catch me up"
|
||||
# — the context IS the message, so skip the placeholder.
|
||||
if (not event_text or not event_text.strip()) and not _channel_context:
|
||||
event_text = "(The user sent a message with no text content)"
|
||||
|
||||
_chan = message.channel
|
||||
@@ -4535,6 +4714,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
timestamp=message.created_at,
|
||||
auto_skill=_skills,
|
||||
channel_prompt=_channel_prompt,
|
||||
channel_context=_channel_context,
|
||||
)
|
||||
|
||||
# Track thread participation so the bot won't require @mention for
|
||||
|
||||
@@ -2785,7 +2785,10 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
from hermes_cli.commands import slack_subcommand_map
|
||||
subcommand_map = slack_subcommand_map()
|
||||
subcommand_map["compact"] = "/compress"
|
||||
first_word = text.split()[0] if text else ""
|
||||
# Guard against whitespace-only text where ``text`` is truthy but
|
||||
# ``text.split()`` returns ``[]`` (e.g. user sends ``/hermes ``).
|
||||
parts = text.split() if text else []
|
||||
first_word = parts[0] if parts else ""
|
||||
if first_word in subcommand_map:
|
||||
rest = text[len(first_word):].strip()
|
||||
text = f"{subcommand_map[first_word]} {rest}".strip() if rest else subcommand_map[first_word]
|
||||
|
||||
+150
-34
@@ -147,6 +147,9 @@ _YB_RES_REF_RE = re.compile(
|
||||
r"\[(image|voice|video|file(?::[^|\]]*)?)\|ybres:([A-Za-z0-9_\-]+)\]"
|
||||
)
|
||||
|
||||
# Media kinds that can be resolved and injected into the model context
|
||||
_RESOLVABLE_MEDIA_KINDS = frozenset({"image", "file"})
|
||||
|
||||
# Strip page indicators like (1/3) appended by BasePlatformAdapter
|
||||
_INDICATOR_RE = re.compile(r'\s*\(\d+/\d+\)$')
|
||||
|
||||
@@ -925,6 +928,7 @@ class InboundContext:
|
||||
# Populated by QuoteContextMiddleware
|
||||
reply_to_message_id: Optional[str] = None
|
||||
reply_to_text: Optional[str] = None
|
||||
quote_media_refs: list = dc_field(default_factory=list) # List of (rid, kind, filename)
|
||||
|
||||
# Populated by MediaResolveMiddleware
|
||||
media_urls: list = dc_field(default_factory=list)
|
||||
@@ -1645,6 +1649,25 @@ class ExtractContentMiddleware(InboundMiddleware):
|
||||
return None
|
||||
return f"[link: {link} | visit link for full content]"
|
||||
|
||||
@staticmethod
|
||||
def _parse_resource_id(url: str) -> str:
|
||||
"""Extract resourceId from Yuanbao resource URL query parameters.
|
||||
|
||||
Args:
|
||||
url: Resource URL (e.g., https://...?resourceId=abc123)
|
||||
|
||||
Returns:
|
||||
Resource ID string, or empty string if not found
|
||||
"""
|
||||
if not url:
|
||||
return ""
|
||||
try:
|
||||
query = urllib.parse.parse_qs(urllib.parse.urlparse(url).query)
|
||||
ids = query.get("resourceId") or query.get("resourceid") or []
|
||||
return str(ids[0]).strip() if ids else ""
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
def _extract_text(cls, msg_body: list) -> str:
|
||||
"""Extract plain text content from MsgBody.
|
||||
@@ -1668,14 +1691,35 @@ class ExtractContentMiddleware(InboundMiddleware):
|
||||
if text:
|
||||
parts.append(text)
|
||||
elif elem_type == "TIMImageElem":
|
||||
parts.append("[image]")
|
||||
# Extract resourceId from image_info_array URL
|
||||
image_info_array = content.get("image_info_array")
|
||||
if not isinstance(image_info_array, list):
|
||||
image_info_array = []
|
||||
image_info = None
|
||||
# Prefer medium image (index 1), fallback to index 0
|
||||
if len(image_info_array) > 1 and isinstance(image_info_array[1], dict):
|
||||
image_info = image_info_array[1]
|
||||
elif len(image_info_array) > 0 and isinstance(image_info_array[0], dict):
|
||||
image_info = image_info_array[0]
|
||||
image_url = str((image_info or {}).get("url") or "").strip()
|
||||
rid = cls._parse_resource_id(image_url)
|
||||
parts.append(f"[image|ybres:{rid}]" if rid else "[image]")
|
||||
elif elem_type == "TIMFileElem":
|
||||
filename = content.get("file_name", content.get("fileName", content.get("filename", "")))
|
||||
parts.append(f"[file: {filename}]" if filename else "[file]")
|
||||
file_url = str(content.get("url") or "").strip()
|
||||
rid = cls._parse_resource_id(file_url)
|
||||
if rid:
|
||||
parts.append(f"[file:{filename}|ybres:{rid}]" if filename else f"[file|ybres:{rid}]")
|
||||
else:
|
||||
parts.append(f"[file: {filename}]" if filename else "[file]")
|
||||
elif elem_type == "TIMSoundElem":
|
||||
parts.append("[voice]")
|
||||
sound_url = str(content.get("url") or "").strip()
|
||||
rid = cls._parse_resource_id(sound_url)
|
||||
parts.append(f"[voice|ybres:{rid}]" if rid else "[voice]")
|
||||
elif elem_type == "TIMVideoFileElem":
|
||||
parts.append("[video]")
|
||||
video_url = str(content.get("url") or "").strip()
|
||||
rid = cls._parse_resource_id(video_url)
|
||||
parts.append(f"[video|ybres:{rid}]" if rid else "[video]")
|
||||
elif elem_type == "TIMCustomElem":
|
||||
data_val = content.get("data", "")
|
||||
if data_val:
|
||||
@@ -2132,22 +2176,23 @@ class QuoteContextMiddleware(InboundMiddleware):
|
||||
name = "quote-context"
|
||||
|
||||
@staticmethod
|
||||
def _extract_quote_context(cloud_custom_data: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
def _extract_quote_context(cloud_custom_data: str) -> Tuple[Optional[str], Optional[str], list]:
|
||||
"""Extract quote context, mapping to MessageEvent.reply_to_*.
|
||||
|
||||
Returns:
|
||||
(reply_to_message_id, reply_to_text)
|
||||
(reply_to_message_id, reply_to_text, quote_media_refs)
|
||||
where quote_media_refs is a list of (rid, kind, filename) tuples
|
||||
"""
|
||||
if not cloud_custom_data:
|
||||
return None, None
|
||||
return None, None, []
|
||||
try:
|
||||
parsed = json.loads(cloud_custom_data)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return None, None
|
||||
return None, None, []
|
||||
|
||||
quote = parsed.get("quote") if isinstance(parsed, dict) else None
|
||||
if not isinstance(quote, dict):
|
||||
return None, None
|
||||
return None, None, []
|
||||
|
||||
# type=2 corresponds to image reference; desc may be empty, provide a placeholder.
|
||||
quote_type = int(quote.get("type") or 0)
|
||||
@@ -2155,15 +2200,26 @@ class QuoteContextMiddleware(InboundMiddleware):
|
||||
if quote_type == 2 and not desc:
|
||||
desc = "[image]"
|
||||
if not desc:
|
||||
return None, None
|
||||
return None, None, []
|
||||
|
||||
quote_id = str(quote.get("id") or "").strip() or None
|
||||
sender = str(quote.get("sender_nickname") or quote.get("sender_id") or "").strip()
|
||||
quote_text = f"{sender}: {desc}" if sender else desc
|
||||
return quote_id, quote_text
|
||||
|
||||
# Extract media references from desc using _YB_RES_REF_RE regex
|
||||
media_refs: list = []
|
||||
for m in _YB_RES_REF_RE.finditer(desc):
|
||||
head = m.group(1) # "image" | "file:<name>" | "voice" | "video"
|
||||
rid = m.group(2)
|
||||
kind, _, filename = head.partition(":")
|
||||
kind = kind.strip()
|
||||
media_refs.append((rid, kind, filename.strip()))
|
||||
|
||||
return quote_id, quote_text, media_refs
|
||||
|
||||
async def handle(self, ctx: InboundContext, next_fn) -> None:
|
||||
ctx.reply_to_message_id, ctx.reply_to_text = self._extract_quote_context(ctx.cloud_custom_data)
|
||||
ctx.reply_to_message_id, ctx.reply_to_text, ctx.quote_media_refs = self._extract_quote_context(ctx.cloud_custom_data)
|
||||
|
||||
await next_fn()
|
||||
|
||||
|
||||
@@ -2332,7 +2388,7 @@ class MediaResolveMiddleware(InboundMiddleware):
|
||||
for ref in media_refs:
|
||||
kind = str(ref.get("kind") or "").strip().lower()
|
||||
url = str(ref.get("url") or "").strip()
|
||||
if kind not in {"image", "file"} or not url:
|
||||
if kind not in _RESOLVABLE_MEDIA_KINDS or not url:
|
||||
continue
|
||||
|
||||
try:
|
||||
@@ -2391,7 +2447,7 @@ class MediaResolveMiddleware(InboundMiddleware):
|
||||
rid = m.group(2)
|
||||
kind, _, filename = head.partition(":")
|
||||
kind = kind.strip()
|
||||
if kind not in {"image", "file"}:
|
||||
if kind not in _RESOLVABLE_MEDIA_KINDS:
|
||||
continue
|
||||
if rid in seen:
|
||||
continue
|
||||
@@ -2458,26 +2514,82 @@ class DispatchMiddleware(InboundMiddleware):
|
||||
media_urls = list(ctx.media_urls)
|
||||
media_types = list(ctx.media_types)
|
||||
|
||||
# Backfill observed media from recent transcript history
|
||||
extra_img_urls: List[str] = []
|
||||
extra_img_mimes: List[str] = []
|
||||
try:
|
||||
extra_img_urls, extra_img_mimes = await MediaResolveMiddleware._collect_observed_media(
|
||||
adapter, ctx.source,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[%s] observed-image hydration raised, continuing anyway: %s",
|
||||
adapter.name, exc,
|
||||
)
|
||||
if extra_img_urls:
|
||||
current = set(media_urls)
|
||||
for u, m in zip(extra_img_urls, extra_img_mimes):
|
||||
if u in current:
|
||||
# If user quoted a message (reply_to_message_id is set), resolve only
|
||||
# quote_media_refs to avoid injecting unrelated history media.
|
||||
# Otherwise, backfill observed media from recent transcript history.
|
||||
if ctx.reply_to_message_id is not None:
|
||||
# Fallback: if desc didn't contain ybres refs, look up transcript
|
||||
if not ctx.quote_media_refs:
|
||||
try:
|
||||
store = getattr(adapter, "_session_store", None)
|
||||
if store:
|
||||
session_entry = store.get_or_create_session(ctx.source)
|
||||
history = store.load_transcript(session_entry.session_id)
|
||||
for msg in reversed(history or []):
|
||||
mid = msg.get("message_id", "")
|
||||
if mid and mid == ctx.reply_to_message_id:
|
||||
_content = msg.get("content", "")
|
||||
if isinstance(_content, str) and "|ybres:" in _content:
|
||||
for m in _YB_RES_REF_RE.finditer(_content):
|
||||
head = m.group(1)
|
||||
rid = m.group(2)
|
||||
kind, _, filename = head.partition(":")
|
||||
kind = kind.strip()
|
||||
if kind in _RESOLVABLE_MEDIA_KINDS:
|
||||
ctx.quote_media_refs.append((rid, kind, filename.strip()))
|
||||
break
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[%s] quote transcript lookup failed: %s",
|
||||
adapter.name, exc,
|
||||
)
|
||||
# User quoted a message — resolve only media from the quote
|
||||
for rid, kind, filename in ctx.quote_media_refs:
|
||||
if kind not in _RESOLVABLE_MEDIA_KINDS:
|
||||
continue
|
||||
media_urls.append(u)
|
||||
media_types.append(m)
|
||||
current.add(u)
|
||||
try:
|
||||
fresh_url = await MediaResolveMiddleware._resolve_by_resource_id(adapter, rid)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[%s] quote media resolve failed: rid=%s kind=%s err=%s",
|
||||
adapter.name, rid, kind, exc,
|
||||
)
|
||||
continue
|
||||
cached = await MediaResolveMiddleware._download_and_cache(
|
||||
adapter,
|
||||
fetch_url=fresh_url,
|
||||
kind=kind,
|
||||
file_name=filename or None,
|
||||
log_tag=f"quote rid={rid}",
|
||||
)
|
||||
if cached is None:
|
||||
continue
|
||||
path, mime = cached
|
||||
# Avoid duplicates
|
||||
if path not in media_urls:
|
||||
media_urls.append(path)
|
||||
media_types.append(mime)
|
||||
else:
|
||||
# No quote — backfill observed media from recent transcript history
|
||||
extra_img_urls: List[str] = []
|
||||
extra_img_mimes: List[str] = []
|
||||
try:
|
||||
extra_img_urls, extra_img_mimes = await MediaResolveMiddleware._collect_observed_media(
|
||||
adapter, ctx.source,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[%s] observed-image hydration raised, continuing anyway: %s",
|
||||
adapter.name, exc,
|
||||
)
|
||||
if extra_img_urls:
|
||||
current = set(media_urls)
|
||||
for u, m in zip(extra_img_urls, extra_img_mimes):
|
||||
if u in current:
|
||||
continue
|
||||
media_urls.append(u)
|
||||
media_types.append(m)
|
||||
current.add(u)
|
||||
|
||||
# Replace [kind|ybres:xxx] anchors with local cache paths so
|
||||
# the transcript records usable paths for the model.
|
||||
@@ -2506,7 +2618,11 @@ class DispatchMiddleware(InboundMiddleware):
|
||||
|
||||
event = MessageEvent(
|
||||
text=_patched_event_text,
|
||||
message_type=ctx.msg_type,
|
||||
message_type=(
|
||||
MessageType.DOCUMENT
|
||||
if any(mt.startswith(("application/", "text/")) for mt in media_types)
|
||||
else ctx.msg_type
|
||||
),
|
||||
source=ctx.source,
|
||||
message_id=ctx.msg_id or None,
|
||||
raw_message=ctx.push,
|
||||
|
||||
@@ -6809,6 +6809,12 @@ class GatewayRunner:
|
||||
if _is_shared_multi_user and source.user_name:
|
||||
message_text = f"[{source.user_name}] {message_text}"
|
||||
|
||||
# Prepend channel context from history backfill (if any). This
|
||||
# happens after sender-prefix so the prefix only applies to the
|
||||
# trigger message, not the backfill block.
|
||||
if getattr(event, "channel_context", None):
|
||||
message_text = f"{event.channel_context}\n\n[New message]\n{message_text}"
|
||||
|
||||
if event.media_urls:
|
||||
image_paths = []
|
||||
audio_paths = []
|
||||
@@ -7985,6 +7991,8 @@ class GatewayRunner:
|
||||
try:
|
||||
if _err_body is not None:
|
||||
_err_json = _err_body.json().get("error", {})
|
||||
if not isinstance(_err_json, dict):
|
||||
_err_json = {}
|
||||
except Exception:
|
||||
pass
|
||||
if _err_json.get("type") == "usage_limit_reached":
|
||||
|
||||
@@ -518,6 +518,9 @@ class SessionEntry:
|
||||
else None
|
||||
),
|
||||
"is_fresh_reset": self.is_fresh_reset,
|
||||
"was_auto_reset": self.was_auto_reset,
|
||||
"auto_reset_reason": self.auto_reset_reason,
|
||||
"reset_had_activity": self.reset_had_activity,
|
||||
}
|
||||
if self.origin:
|
||||
result["origin"] = self.origin.to_dict()
|
||||
@@ -567,6 +570,9 @@ class SessionEntry:
|
||||
resume_reason=data.get("resume_reason"),
|
||||
last_resume_marked_at=last_resume_marked_at,
|
||||
is_fresh_reset=data.get("is_fresh_reset", False),
|
||||
was_auto_reset=data.get("was_auto_reset", False),
|
||||
auto_reset_reason=data.get("auto_reset_reason"),
|
||||
reset_had_activity=data.get("reset_had_activity", False),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -470,6 +470,9 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
|
||||
model_short = model_short[:25] + "..."
|
||||
ctx_str = f" [dim {dim}]·[/] [dim {dim}]{_format_context_length(context_length)} context[/]" if context_length else ""
|
||||
left_lines.append(f"[{accent}]{model_short}[/]{ctx_str} [dim {dim}]·[/] [dim {dim}]Nous Research[/]")
|
||||
|
||||
if os.getenv("HERMES_YOLO_MODE"):
|
||||
left_lines.append(f"[bold red]⚠ YOLO mode[/] [dim {dim}]— all approval prompts bypassed[/]")
|
||||
left_lines.append(f"[dim {dim}]{cwd}[/]")
|
||||
if session_id:
|
||||
left_lines.append(f"[dim {session_color}]Session: {session_id}[/]")
|
||||
|
||||
@@ -304,6 +304,103 @@ def render_codex_toml_section(
|
||||
return "\n".join(out) + "\n"
|
||||
|
||||
|
||||
def _insert_managed_block_at_top_level(user_text: str, managed_block: str) -> str:
|
||||
"""Insert Hermes' managed Codex TOML block while keeping root keys root-scoped.
|
||||
|
||||
TOML has no syntax to return to the document root after a table header.
|
||||
Therefore appending a root key like `default_permissions = ...` after a
|
||||
user table such as `[features]` actually creates `features.default_permissions`,
|
||||
which Codex rejects. Insert the managed block before the first table header
|
||||
so its root keys remain top-level, while preserving user content verbatim.
|
||||
"""
|
||||
if not user_text.strip():
|
||||
return managed_block
|
||||
|
||||
lines = user_text.splitlines(keepends=True)
|
||||
first_table_idx: Optional[int] = None
|
||||
for idx, line in enumerate(lines):
|
||||
stripped = line.lstrip()
|
||||
if stripped.startswith("["):
|
||||
first_table_idx = idx
|
||||
break
|
||||
|
||||
if first_table_idx is None:
|
||||
prefix = user_text.rstrip("\n")
|
||||
return f"{prefix}\n\n{managed_block}" if prefix else managed_block
|
||||
|
||||
prefix = "".join(lines[:first_table_idx]).rstrip("\n")
|
||||
suffix = "".join(lines[first_table_idx:]).lstrip("\n")
|
||||
if prefix:
|
||||
return f"{prefix}\n\n{managed_block}\n{suffix}"
|
||||
return f"{managed_block}\n{suffix}"
|
||||
|
||||
|
||||
def _strip_unmanaged_plugin_tables(toml_text: str) -> str:
|
||||
"""Remove ``[plugins."<name>@<marketplace>"]`` tables that live OUTSIDE the
|
||||
managed block.
|
||||
|
||||
Codex itself writes these tables when the user runs ``codex plugins enable``
|
||||
directly (i.e. before Hermes' migrate has ever touched the file). When we
|
||||
later run migrate, ``_query_codex_plugins()`` reports the same plugins via
|
||||
the live ``plugin/list`` RPC and we re-emit them inside the managed block.
|
||||
The result without this strip is duplicate ``[plugins."X@Y"]`` table
|
||||
headers — codex's strict TOML parser then refuses to load the file.
|
||||
|
||||
We own the ``[plugins.*]`` namespace once migrate has run, so dropping any
|
||||
pre-existing ``[plugins.*]`` tables is safe: ``plugin/list`` is the source
|
||||
of truth for what's actually installed. The caller is expected to only
|
||||
invoke this strip when ``plugin/list`` succeeded — otherwise we'd lose
|
||||
plugins the user installed via ``codex`` without a way to re-emit them.
|
||||
|
||||
Behavior:
|
||||
* Lines beginning with ``[plugins.`` start a swallow region that ends at
|
||||
the next non-``[plugins.`` table header or end-of-file.
|
||||
* Content inside the managed block is untouched (callers should run
|
||||
``_strip_existing_managed_block`` first so the managed block has
|
||||
already been removed when this runs).
|
||||
"""
|
||||
lines = toml_text.splitlines(keepends=True)
|
||||
out: list[str] = []
|
||||
in_plugin_table = False
|
||||
for line in lines:
|
||||
stripped = line.lstrip()
|
||||
# Only treat a line as a table header when it has the shape
|
||||
# ``[...]`` (optionally followed by a comment). Multi-line array
|
||||
# continuations like ``["nested"],`` also start with ``[`` after
|
||||
# lstrip but are not headers — without this guard they would
|
||||
# falsely flip ``in_plugin_table`` to False mid-table and leak
|
||||
# array fragments into the output.
|
||||
if _looks_like_table_header(stripped):
|
||||
in_plugin_table = stripped.startswith("[plugins.")
|
||||
if in_plugin_table:
|
||||
continue
|
||||
if in_plugin_table:
|
||||
# Swallow keys/comments/blanks until the next table header.
|
||||
continue
|
||||
out.append(line)
|
||||
return "".join(out)
|
||||
|
||||
|
||||
def _looks_like_table_header(stripped_line: str) -> bool:
|
||||
"""Return True if ``stripped_line`` is a TOML table header.
|
||||
|
||||
A header has the shape ``[name]`` or ``[[name]]`` (array-of-tables),
|
||||
optionally followed by a comment. The closing ``]`` (or ``]]``) must
|
||||
appear on the same line, and no key-assignment ``=`` can precede it.
|
||||
This distinguishes real headers from multi-line array continuation
|
||||
lines that also start with ``[`` after ``lstrip()``.
|
||||
"""
|
||||
if not stripped_line.startswith("["):
|
||||
return False
|
||||
# Drop trailing comment so e.g. ``[features] # note`` still matches.
|
||||
head = stripped_line.split("#", 1)[0].rstrip()
|
||||
if not head.endswith("]"):
|
||||
return False
|
||||
# ``key = [x]`` would have an ``=`` before the bracket; a header doesn't.
|
||||
bracket_idx = head.index("]")
|
||||
return "=" not in head[: bracket_idx + 1]
|
||||
|
||||
|
||||
def _strip_existing_managed_block(toml_text: str) -> str:
|
||||
"""Remove any prior managed section so re-runs idempotently replace it.
|
||||
|
||||
@@ -431,6 +528,32 @@ def _query_codex_plugins(
|
||||
return out, None
|
||||
|
||||
|
||||
def _looks_like_test_tempdir(path: str) -> bool:
|
||||
"""Heuristic: does ``path`` look like a pytest/transient tempdir?
|
||||
|
||||
pytest tempdirs live under ``pytest-of-<user>/pytest-<n>/`` (created via
|
||||
``tmp_path`` / ``tmp_path_factory``) and are reaped between sessions.
|
||||
macOS routes ``/tmp`` through ``/private/var/folders/<…>/T`` which is
|
||||
what pytest's tempdir factory uses by default. If a HERMES_HOME pointing
|
||||
at one of those paths is burned into ``~/.codex/config.toml``, every
|
||||
codex-routed hermes-tools call fails silently once the directory is GC'd.
|
||||
|
||||
We err on the side of refusing — losing a (very unlikely) real
|
||||
``~/.hermes`` symlink that happens to live under ``/private/var/folders``
|
||||
is much less harmful than silently bricking codex's tool surface.
|
||||
"""
|
||||
if not path:
|
||||
return False
|
||||
needles = (
|
||||
"pytest-of-",
|
||||
"/pytest-",
|
||||
"/tmp/pytest",
|
||||
"/private/var/folders/", # macOS tempdir root
|
||||
)
|
||||
normalized = path.lower()
|
||||
return any(needle in normalized for needle in needles)
|
||||
|
||||
|
||||
def _build_hermes_tools_mcp_entry() -> dict:
|
||||
"""Build the codex stdio-transport entry that launches Hermes' own
|
||||
tool surface as an MCP server. Codex's subprocess will call back into
|
||||
@@ -443,9 +566,22 @@ def _build_hermes_tools_mcp_entry() -> dict:
|
||||
import sys
|
||||
|
||||
env: dict[str, str] = {}
|
||||
# HERMES_HOME passes through if set so the MCP subprocess sees the
|
||||
# same config / auth / sessions DB as the parent CLI.
|
||||
hermes_home = os.environ.get("HERMES_HOME")
|
||||
# HERMES_HOME passes through IF SET so the MCP subprocess sees the same
|
||||
# config / auth / sessions DB as the parent CLI. Read from os.environ
|
||||
# (not get_hermes_home()) on purpose: when the env var is unset we want
|
||||
# codex's subprocess to inherit whatever HERMES_HOME its launcher sets
|
||||
# at runtime (systemd unit, gateway, kanban dispatcher, custom shell),
|
||||
# rather than burning the migrate-time resolved default into config.toml
|
||||
# — that would override the launcher's HERMES_HOME and pin the subprocess
|
||||
# to the wrong profile.
|
||||
#
|
||||
# The pytest-tempdir guard below catches the issue #26250 Bug C scenario:
|
||||
# a sibling test's monkeypatch.setenv("HERMES_HOME", tmp_path) would
|
||||
# otherwise leak a transient pytest tempdir into the user's real
|
||||
# ~/.codex/config.toml and silently brick codex once the tempdir is GC'd.
|
||||
hermes_home = os.environ.get("HERMES_HOME") or ""
|
||||
if hermes_home and _looks_like_test_tempdir(hermes_home):
|
||||
hermes_home = ""
|
||||
if hermes_home:
|
||||
env["HERMES_HOME"] = hermes_home
|
||||
# PYTHONPATH passes through so a worktree-launched hermes finds the
|
||||
@@ -533,10 +669,16 @@ def migrate(
|
||||
# Discover installed Codex curated plugins. Best-effort — never blocks
|
||||
# the migration if codex is unreachable or the RPC fails.
|
||||
plugins: list[dict] = []
|
||||
plugin_query_succeeded = False
|
||||
if discover_plugins and not dry_run:
|
||||
plugins, plugin_err = _query_codex_plugins(codex_home=codex_home)
|
||||
if plugin_err:
|
||||
report.plugin_query_error = plugin_err
|
||||
else:
|
||||
# plugin/list returned authoritatively (even if the list is empty).
|
||||
# That means we own [plugins.*] for this re-render and can safely
|
||||
# strip any pre-existing tables outside the managed block.
|
||||
plugin_query_succeeded = True
|
||||
for p in plugins:
|
||||
report.migrated_plugins.append(f"{p['name']}@{p['marketplace']}")
|
||||
|
||||
@@ -571,14 +713,15 @@ def migrate(
|
||||
report.errors.append(f"could not read {target}: {exc}")
|
||||
return report
|
||||
without_managed = _strip_existing_managed_block(existing)
|
||||
# Ensure exactly one blank line between user content and managed block
|
||||
if without_managed and not without_managed.endswith("\n"):
|
||||
without_managed += "\n"
|
||||
new_text = (
|
||||
without_managed.rstrip("\n") + "\n\n" + managed_block
|
||||
if without_managed.strip()
|
||||
else managed_block
|
||||
)
|
||||
# Bug B: when plugin/list ran authoritatively, codex's own
|
||||
# [plugins."<name>@<marketplace>"] tables outside our managed block
|
||||
# would survive _strip_existing_managed_block and then collide with
|
||||
# the entries we re-emit inside the managed block — producing
|
||||
# duplicate-table-header parse errors on codex's next startup. Drop
|
||||
# those pre-existing tables since plugin/list is the source of truth.
|
||||
if plugin_query_succeeded:
|
||||
without_managed = _strip_unmanaged_plugin_tables(without_managed)
|
||||
new_text = _insert_managed_block_at_top_level(without_managed, managed_block)
|
||||
else:
|
||||
new_text = managed_block
|
||||
|
||||
|
||||
+4
-18
@@ -846,6 +846,7 @@ DEFAULT_CONFIG = {
|
||||
"timeout": 30,
|
||||
"extra_body": {},
|
||||
"max_concurrency": 3, # Clamp parallel summaries to avoid request-burst 429s on small providers
|
||||
"default_mode": "fast", # 'fast' | 'summary' — which mode session_search uses when caller passes none
|
||||
},
|
||||
"skills_hub": {
|
||||
"provider": "auto",
|
||||
@@ -1251,6 +1252,8 @@ DEFAULT_CONFIG = {
|
||||
"allowed_channels": "", # If set, bot ONLY responds in these channel IDs (whitelist)
|
||||
"auto_thread": True, # Auto-create threads on @mention in channels (like Slack)
|
||||
"thread_require_mention": False, # If True, require @mention in threads too (multi-bot threads)
|
||||
"history_backfill": True, # If True, prepend recent channel scrollback when bot is triggered (recovers messages missed while require_mention gated them out)
|
||||
"history_backfill_limit": 50, # Max number of recent messages to scan when assembling the backfill block
|
||||
"reactions": True, # Add 👀/✅/❌ reactions to messages during processing
|
||||
"channel_prompts": {}, # Per-channel ephemeral system prompts (forum parents apply to child threads)
|
||||
# Opt-in DM role-based auth (#12136). By default, DISCORD_ALLOWED_ROLES
|
||||
@@ -2136,22 +2139,6 @@ OPTIONAL_ENV_VARS = {
|
||||
"password": True,
|
||||
"category": "tool",
|
||||
},
|
||||
"TINKER_API_KEY": {
|
||||
"description": "Tinker API key for RL training",
|
||||
"prompt": "Tinker API key",
|
||||
"url": "https://tinker-console.thinkingmachines.ai/keys",
|
||||
"tools": ["rl_start_training", "rl_check_status", "rl_stop_training"],
|
||||
"password": True,
|
||||
"category": "tool",
|
||||
},
|
||||
"WANDB_API_KEY": {
|
||||
"description": "Weights & Biases API key for experiment tracking",
|
||||
"prompt": "WandB API key",
|
||||
"url": "https://wandb.ai/authorize",
|
||||
"tools": ["rl_get_results", "rl_check_status"],
|
||||
"password": True,
|
||||
"category": "tool",
|
||||
},
|
||||
"VOICE_TOOLS_OPENAI_KEY": {
|
||||
"description": "OpenAI API key for voice transcription (Whisper) and OpenAI TTS",
|
||||
"prompt": "OpenAI API Key (for Whisper STT + TTS)",
|
||||
@@ -4988,8 +4975,7 @@ def set_config_value(key: str, value: str):
|
||||
'FAL_KEY', 'TELEGRAM_BOT_TOKEN', 'DISCORD_BOT_TOKEN',
|
||||
'TERMINAL_SSH_HOST', 'TERMINAL_SSH_USER', 'TERMINAL_SSH_KEY',
|
||||
'SUDO_PASSWORD', 'SLACK_BOT_TOKEN', 'SLACK_APP_TOKEN',
|
||||
'GITHUB_TOKEN', 'HONCHO_API_KEY', 'WANDB_API_KEY',
|
||||
'TINKER_API_KEY',
|
||||
'GITHUB_TOKEN', 'HONCHO_API_KEY',
|
||||
]
|
||||
|
||||
if key.upper() in api_keys or key.upper().endswith(('_API_KEY', '_TOKEN')) or key.upper().startswith('TERMINAL_SSH'):
|
||||
|
||||
+8
-2
@@ -196,9 +196,15 @@ def cron_create(args):
|
||||
|
||||
|
||||
def cron_edit(args):
|
||||
from cron.jobs import get_job
|
||||
from cron.jobs import AmbiguousJobReference, resolve_job_ref
|
||||
|
||||
job = get_job(args.job_id)
|
||||
try:
|
||||
job = resolve_job_ref(args.job_id)
|
||||
except AmbiguousJobReference as exc:
|
||||
print(color(str(exc), Colors.RED))
|
||||
for m in exc.matches:
|
||||
print(f" {m['id']} (name: {m.get('name')!r})")
|
||||
return 1
|
||||
if not job:
|
||||
print(color(f"Job not found: {args.job_id}", Colors.RED))
|
||||
return 1
|
||||
|
||||
@@ -1595,28 +1595,6 @@ def run_doctor(args):
|
||||
for _issue in _r.issues:
|
||||
issues.append(_issue)
|
||||
|
||||
# =========================================================================
|
||||
# Check: Submodules
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Submodules", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
# tinker-atropos (RL training backend)
|
||||
tinker_dir = PROJECT_ROOT / "tinker-atropos"
|
||||
if tinker_dir.exists() and (tinker_dir / "pyproject.toml").exists():
|
||||
if py_version >= (3, 11):
|
||||
try:
|
||||
__import__("tinker_atropos")
|
||||
check_ok("tinker-atropos", "(RL training backend)")
|
||||
except ImportError:
|
||||
install_cmd = f"{_python_install_cmd()} -e ./tinker-atropos"
|
||||
check_warn("tinker-atropos found but not installed", f"(run: {install_cmd})")
|
||||
issues.append(f"Install tinker-atropos: {install_cmd}")
|
||||
else:
|
||||
check_warn("tinker-atropos requires Python 3.11+", f"(current: {py_version.major}.{py_version.minor})")
|
||||
else:
|
||||
check_warn("tinker-atropos not found", "(run: git submodule update --init --recursive)")
|
||||
|
||||
# =========================================================================
|
||||
# Check: Tool Availability
|
||||
# =========================================================================
|
||||
|
||||
+35
-1
@@ -45,6 +45,16 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MAX_TURNS = 20
|
||||
DEFAULT_JUDGE_TIMEOUT = 30.0
|
||||
# Judge output budget. The freeform judge returns a one-line JSON verdict, but
|
||||
# reasoning models (deepseek-v4, qwq, etc.) burn tokens on hidden reasoning
|
||||
# before emitting the visible JSON — and the first /goal turn's prompt is
|
||||
# larger than later turns, which pushes total reply length past tight caps.
|
||||
# 200 tokens (the original default) reliably truncated the JSON on reasoning
|
||||
# models, leaving '{"done": true, "reason": "The agent successfully' and
|
||||
# triggering the auto-pause. 4096 covers reasoning + verdict on every model
|
||||
# we've live-tested; override via auxiliary.goal_judge.max_tokens for
|
||||
# specifically constrained setups.
|
||||
DEFAULT_JUDGE_MAX_TOKENS = 4096
|
||||
# Cap how much of the last response + recent messages we send to the judge.
|
||||
_JUDGE_RESPONSE_SNIPPET_CHARS = 4000
|
||||
# After this many consecutive judge *parse* failures (empty output / non-JSON),
|
||||
@@ -282,6 +292,30 @@ def _truncate(text: str, limit: int) -> str:
|
||||
_JSON_OBJECT_RE = re.compile(r"\{.*?\}", re.DOTALL)
|
||||
|
||||
|
||||
def _goal_judge_max_tokens() -> int:
|
||||
"""Resolve auxiliary.goal_judge.max_tokens, falling back to the default.
|
||||
|
||||
``load_config()`` is cached on the config file's (mtime, size), so calling
|
||||
this once per judge turn is cheap. A non-positive or non-int value falls
|
||||
back to the default rather than crashing the goal loop.
|
||||
"""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
|
||||
cfg = load_config()
|
||||
value = (
|
||||
(cfg.get("auxiliary") or {})
|
||||
.get("goal_judge", {})
|
||||
.get("max_tokens", DEFAULT_JUDGE_MAX_TOKENS)
|
||||
)
|
||||
value = int(value)
|
||||
if value > 0:
|
||||
return value
|
||||
except Exception:
|
||||
pass
|
||||
return DEFAULT_JUDGE_MAX_TOKENS
|
||||
|
||||
|
||||
def _parse_judge_response(raw: str) -> Tuple[bool, str, bool]:
|
||||
"""Parse the judge's reply. Fail-open to ``(False, "<reason>", parse_failed)``.
|
||||
|
||||
@@ -404,7 +438,7 @@ def judge_goal(
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=200,
|
||||
max_tokens=_goal_judge_max_tokens(),
|
||||
timeout=timeout,
|
||||
extra_body=get_auxiliary_extra_body() or None,
|
||||
)
|
||||
|
||||
+141
-16
@@ -1452,6 +1452,17 @@ def cmd_gateway(args):
|
||||
gateway_command(args)
|
||||
|
||||
|
||||
def cmd_proxy(args):
|
||||
"""Local OpenAI-compatible proxy to OAuth providers."""
|
||||
# Lazy import — pulls in aiohttp, which is gated behind an extras install
|
||||
# for users who don't run the proxy or the messaging gateway.
|
||||
from hermes_cli.proxy.cli import cmd_proxy as _cmd_proxy
|
||||
|
||||
rc = _cmd_proxy(args)
|
||||
if isinstance(rc, int) and rc != 0:
|
||||
raise SystemExit(rc)
|
||||
|
||||
|
||||
def cmd_whatsapp(args):
|
||||
"""Set up WhatsApp: choose mode, configure, install bridge, pair via QR."""
|
||||
_require_tty("whatsapp")
|
||||
@@ -5670,21 +5681,50 @@ def _build_web_ui(web_dir: Path, *, fatal: bool = False) -> bool:
|
||||
if not _web_ui_build_needed(web_dir):
|
||||
return True
|
||||
|
||||
# Console-encoding-safe print: Windows consoles default to cp1252
|
||||
# (or similar) and will raise UnicodeEncodeError on arrow / check
|
||||
# glyphs unless PYTHONIOENCODING=utf-8 is set. Routing every print
|
||||
# in this function through _say() with errors="replace" keeps the
|
||||
# build path usable on a stock `py -m hermes_cli.main web` invocation.
|
||||
def _say(text: str) -> None:
|
||||
try:
|
||||
print(text)
|
||||
except UnicodeEncodeError:
|
||||
encoding = getattr(sys.stdout, "encoding", None) or "ascii"
|
||||
print(text.encode(encoding, errors="replace").decode(encoding, errors="replace"))
|
||||
|
||||
npm = shutil.which("npm")
|
||||
if not npm:
|
||||
if fatal:
|
||||
print("Web UI frontend not built and npm is not available.")
|
||||
print("Install Node.js, then run: cd web && npm install && npm run build")
|
||||
_say("Web UI frontend not built and npm is not available.")
|
||||
_say("Install Node.js, then run: cd web && npm install && npm run build")
|
||||
return not fatal
|
||||
print("→ Building web UI...")
|
||||
_say("→ Building web UI...")
|
||||
|
||||
def _relay(result: "subprocess.CompletedProcess") -> None:
|
||||
"""Print captured npm output so users can see *why* a step failed.
|
||||
|
||||
Windows users hitting `rm -rf` / `cp -r` errors (or any other
|
||||
sync-assets / Vite failure) would otherwise see only ``Web UI
|
||||
build failed`` with no hint of the underlying cause, because
|
||||
the npm calls run with ``capture_output=True``.
|
||||
"""
|
||||
for blob in (result.stdout, result.stderr):
|
||||
if not blob:
|
||||
continue
|
||||
text = blob.decode("utf-8", errors="replace").rstrip() if isinstance(blob, bytes) else blob.rstrip()
|
||||
if text:
|
||||
_say(text)
|
||||
|
||||
r1 = _run_npm_install_deterministic(npm, web_dir, extra_args=("--silent",))
|
||||
if r1.returncode != 0:
|
||||
print(
|
||||
_say(
|
||||
f" {'✗' if fatal else '⚠'} Web UI npm install failed"
|
||||
+ ("" if fatal else " (hermes web will not be available)")
|
||||
)
|
||||
_relay(r1)
|
||||
if fatal:
|
||||
print(" Run manually: cd web && npm install && npm run build")
|
||||
_say(" Run manually: cd web && npm install && npm run build")
|
||||
return False
|
||||
# First attempt
|
||||
r2 = subprocess.run(
|
||||
@@ -5719,21 +5759,20 @@ def _build_web_ui(web_dir: Path, *, fatal: bool = False) -> bool:
|
||||
# A stale UI is far better than no UI for non-interactive callers
|
||||
# (Windows Scheduled Tasks, CI) — issue #23817.
|
||||
if dist_index.exists():
|
||||
print(" ⚠ Web UI build failed — serving stale dist as fallback")
|
||||
_say(" ⚠ Web UI build failed — serving stale dist as fallback")
|
||||
if stderr_tail:
|
||||
print(f" Build error:\n {stderr_tail}")
|
||||
_say(f" Build error:\n {stderr_tail}")
|
||||
return True
|
||||
|
||||
print(
|
||||
_say(
|
||||
f" {'✗' if fatal else '⚠'} Web UI build failed"
|
||||
+ ("" if fatal else " (hermes web will not be available)")
|
||||
)
|
||||
if stderr_tail:
|
||||
print(f" Build error:\n {stderr_tail}")
|
||||
_relay(r2)
|
||||
if fatal:
|
||||
print(" Run manually: cd web && npm install && npm run build")
|
||||
_say(" Run manually: cd web && npm install && npm run build")
|
||||
return False
|
||||
print(" ✓ Web UI built")
|
||||
_say(" ✓ Web UI built")
|
||||
return True
|
||||
|
||||
|
||||
@@ -9385,7 +9424,7 @@ _BUILTIN_SUBCOMMANDS = frozenset(
|
||||
"config", "cron", "curator", "dashboard", "debug", "doctor",
|
||||
"dump", "fallback", "gateway", "hooks", "import", "insights",
|
||||
"kanban", "login", "logout", "logs", "lsp", "mcp", "memory",
|
||||
"model", "pairing", "plugins", "profile", "sessions", "setup",
|
||||
"model", "pairing", "plugins", "profile", "proxy", "sessions", "setup",
|
||||
"skills", "slack", "status", "tools", "uninstall", "update",
|
||||
"version", "webhook", "whatsapp", "chat",
|
||||
# Help-ish invocations — plugin commands not being listed in
|
||||
@@ -9727,6 +9766,51 @@ def main():
|
||||
help="Skip the confirmation prompt",
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# proxy command — local OpenAI-compatible proxy that attaches the user's
|
||||
# OAuth-authenticated provider credentials to outbound requests. Lets
|
||||
# external apps (OpenViking, Karakeep, Open WebUI, ...) ride a logged-in
|
||||
# subscription without copy-pasting static API keys.
|
||||
# =========================================================================
|
||||
proxy_parser = subparsers.add_parser(
|
||||
"proxy",
|
||||
help="Local OpenAI-compatible proxy to OAuth providers",
|
||||
description=(
|
||||
"Run a local HTTP server that forwards OpenAI-compatible requests "
|
||||
"to an OAuth-authenticated provider (e.g. Nous Portal). External "
|
||||
"apps can point at the proxy with any bearer token; the proxy "
|
||||
"attaches your real credentials."
|
||||
),
|
||||
)
|
||||
proxy_subparsers = proxy_parser.add_subparsers(dest="proxy_command")
|
||||
|
||||
proxy_start = proxy_subparsers.add_parser(
|
||||
"start", help="Run the proxy in the foreground"
|
||||
)
|
||||
proxy_start.add_argument(
|
||||
"--provider",
|
||||
default="nous",
|
||||
help="Upstream provider (default: nous). See `hermes proxy providers`.",
|
||||
)
|
||||
proxy_start.add_argument(
|
||||
"--host",
|
||||
default=None,
|
||||
help="Bind address (default: 127.0.0.1). Use 0.0.0.0 to expose on LAN.",
|
||||
)
|
||||
proxy_start.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Bind port (default: 8645)",
|
||||
)
|
||||
|
||||
proxy_subparsers.add_parser(
|
||||
"status", help="Show which proxy upstreams are ready"
|
||||
)
|
||||
proxy_subparsers.add_parser(
|
||||
"providers", help="List available proxy upstream providers"
|
||||
)
|
||||
proxy_parser.set_defaults(func=cmd_proxy)
|
||||
gateway_parser.set_defaults(func=cmd_gateway)
|
||||
|
||||
# =========================================================================
|
||||
@@ -11615,16 +11699,57 @@ Examples:
|
||||
description="Start Hermes Agent in ACP mode for editor integration (VS Code, Zed, JetBrains)",
|
||||
)
|
||||
_add_accept_hooks_flag(acp_parser)
|
||||
acp_parser.add_argument(
|
||||
"--version",
|
||||
action="store_true",
|
||||
dest="acp_version",
|
||||
help="Print Hermes ACP version and exit",
|
||||
)
|
||||
acp_parser.add_argument(
|
||||
"--check",
|
||||
action="store_true",
|
||||
help="Verify ACP dependencies and adapter imports, then exit",
|
||||
)
|
||||
acp_parser.add_argument(
|
||||
"--setup",
|
||||
action="store_true",
|
||||
help="Run interactive Hermes provider/model setup for ACP terminal auth",
|
||||
)
|
||||
acp_parser.add_argument(
|
||||
"--setup-browser",
|
||||
action="store_true",
|
||||
help="Install agent-browser + Playwright Chromium into ~/.hermes/node/ "
|
||||
"for browser tool support (idempotent).",
|
||||
)
|
||||
acp_parser.add_argument(
|
||||
"--yes",
|
||||
"-y",
|
||||
action="store_true",
|
||||
dest="assume_yes",
|
||||
help="Accept all prompts (used by --setup-browser to skip the "
|
||||
"~400 MB Chromium download confirmation).",
|
||||
)
|
||||
|
||||
def cmd_acp(args):
|
||||
"""Launch Hermes Agent as an ACP server."""
|
||||
try:
|
||||
from acp_adapter.entry import main as acp_main
|
||||
|
||||
acp_main()
|
||||
acp_argv = []
|
||||
if getattr(args, "acp_version", False):
|
||||
acp_argv.append("--version")
|
||||
if getattr(args, "check", False):
|
||||
acp_argv.append("--check")
|
||||
if getattr(args, "setup", False):
|
||||
acp_argv.append("--setup")
|
||||
if getattr(args, "setup_browser", False):
|
||||
acp_argv.append("--setup-browser")
|
||||
if getattr(args, "assume_yes", False):
|
||||
acp_argv.append("--yes")
|
||||
acp_main(acp_argv)
|
||||
except ImportError:
|
||||
print("ACP dependencies not installed.")
|
||||
print("Install them with: pip install -e '.[acp]'")
|
||||
print("ACP dependencies not installed.", file=sys.stderr)
|
||||
print("Install them with: pip install -e '.[acp]'", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
acp_parser.set_defaults(func=cmd_acp)
|
||||
|
||||
@@ -25,6 +25,7 @@ from hermes_cli.config import (
|
||||
)
|
||||
from hermes_cli.colors import Colors, color
|
||||
from hermes_constants import display_hermes_home
|
||||
from tools.mcp_tool import _ENV_VAR_PATTERN
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -551,7 +552,7 @@ def cmd_mcp_test(args):
|
||||
for k, v in headers.items():
|
||||
if isinstance(v, str) and ("key" in k.lower() or "auth" in k.lower()):
|
||||
# Mask the value
|
||||
resolved = _interpolate_value(v)
|
||||
resolved = _ENV_VAR_PATTERN.sub(lambda m: os.getenv(m.group(1), ""), v)
|
||||
if len(resolved) > 8:
|
||||
masked = resolved[:4] + "***" + resolved[-4:]
|
||||
else:
|
||||
@@ -581,13 +582,6 @@ def cmd_mcp_test(args):
|
||||
print()
|
||||
|
||||
|
||||
def _interpolate_value(value: str) -> str:
|
||||
"""Resolve ``${ENV_VAR}`` references in a string."""
|
||||
def _replace(m):
|
||||
return os.getenv(m.group(1), "")
|
||||
return re.sub(r"\$\{(\w+)\}", _replace, value)
|
||||
|
||||
|
||||
# ─── hermes mcp login ────────────────────────────────────────────────────────
|
||||
|
||||
def cmd_mcp_login(args):
|
||||
|
||||
@@ -3702,13 +3702,12 @@ def validate_requested_model(
|
||||
|
||||
# Static-catalog fallback: when the /models probe was unreachable,
|
||||
# validate against the curated list from provider_model_ids() — same
|
||||
# pattern as the openai-codex and minimax branches above. This fixes
|
||||
# /model switches in the gateway for providers like opencode-go and
|
||||
# opencode-zen whose /models endpoint returns 404 against the HTML
|
||||
# marketing site. Without this block, validate_requested_model would
|
||||
# reject every model on such providers, switch_model() would return
|
||||
# success=False, and the gateway would never write to
|
||||
# _session_model_overrides.
|
||||
# pattern as the openai-codex and minimax branches above. This keeps
|
||||
# /model switches working in the gateway for providers whose /models
|
||||
# endpoint is temporarily unreachable or returns a non-JSON payload.
|
||||
# Without this block, validate_requested_model would reject every model
|
||||
# on such providers, switch_model() would return success=False, and
|
||||
# the gateway would never write to _session_model_overrides.
|
||||
provider_label = _PROVIDER_LABELS.get(normalized, normalized)
|
||||
try:
|
||||
catalog_models = provider_model_ids(normalized)
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
"""Local OpenAI-compatible proxy that forwards to OAuth-authenticated upstreams.
|
||||
|
||||
Lets external apps (OpenViking, Karakeep, Open WebUI, ...) ride the user's
|
||||
already-logged-in provider subscription instead of needing a static API key
|
||||
copy-pasted into each app's config.
|
||||
|
||||
The proxy listens on ``127.0.0.1:<port>``, accepts any bearer (the client's
|
||||
``Authorization`` header is discarded), and attaches the user's real
|
||||
upstream credential to the forwarded request. The credential is refreshed
|
||||
automatically when it approaches expiry.
|
||||
|
||||
First-class adapter:
|
||||
- ``nous`` — Nous Portal (https://inference-api.nousresearch.com/v1)
|
||||
|
||||
Future adapters can plug in by implementing ``UpstreamAdapter``.
|
||||
"""
|
||||
|
||||
from hermes_cli.proxy.adapters.base import UpstreamAdapter
|
||||
|
||||
__all__ = ["UpstreamAdapter"]
|
||||
@@ -0,0 +1,35 @@
|
||||
"""Upstream adapter registry for the local proxy server.
|
||||
|
||||
Each adapter wraps a provider's OAuth state and exposes a uniform interface
|
||||
the proxy server can use to forward requests with a freshly-minted bearer
|
||||
token. See :class:`UpstreamAdapter` for the contract.
|
||||
"""
|
||||
|
||||
from typing import Dict, Type
|
||||
|
||||
from hermes_cli.proxy.adapters.base import UpstreamAdapter
|
||||
from hermes_cli.proxy.adapters.nous_portal import NousPortalAdapter
|
||||
|
||||
# Registry of available adapter classes keyed by provider name as used on
|
||||
# the ``hermes proxy start --provider <name>`` CLI flag.
|
||||
ADAPTERS: Dict[str, Type[UpstreamAdapter]] = {
|
||||
"nous": NousPortalAdapter,
|
||||
}
|
||||
|
||||
|
||||
def get_adapter(name: str) -> UpstreamAdapter:
|
||||
"""Instantiate an adapter by provider name.
|
||||
|
||||
Raises:
|
||||
ValueError: if ``name`` is not a registered adapter.
|
||||
"""
|
||||
key = (name or "").strip().lower()
|
||||
if key not in ADAPTERS:
|
||||
available = ", ".join(sorted(ADAPTERS)) or "(none)"
|
||||
raise ValueError(
|
||||
f"Unknown proxy upstream provider: {name!r}. Available: {available}"
|
||||
)
|
||||
return ADAPTERS[key]()
|
||||
|
||||
|
||||
__all__ = ["UpstreamAdapter", "ADAPTERS", "get_adapter"]
|
||||
@@ -0,0 +1,94 @@
|
||||
"""Abstract base for proxy upstream adapters.
|
||||
|
||||
An :class:`UpstreamAdapter` represents one OAuth-authenticated provider the
|
||||
local proxy can forward requests to. The adapter is responsible for:
|
||||
|
||||
- locating the user's auth state for that provider
|
||||
- refreshing/minting credentials when needed
|
||||
- reporting the resolved upstream base URL
|
||||
- declaring which request paths it accepts
|
||||
|
||||
The proxy server is otherwise provider-agnostic.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import FrozenSet, Optional
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UpstreamCredential:
|
||||
"""A resolved bearer + base URL ready to forward to."""
|
||||
|
||||
bearer: str
|
||||
"""Authorization header value to send upstream (token only, no ``Bearer`` prefix)."""
|
||||
|
||||
base_url: str
|
||||
"""Upstream base URL, e.g. ``https://inference-api.nousresearch.com/v1``."""
|
||||
|
||||
token_type: str = "Bearer"
|
||||
"""Auth scheme — currently always ``Bearer`` for supported providers."""
|
||||
|
||||
expires_at: Optional[str] = None
|
||||
"""ISO-8601 expiry timestamp for the bearer, when known. Informational."""
|
||||
|
||||
|
||||
class UpstreamAdapter(ABC):
|
||||
"""Contract for an upstream provider the proxy can forward to."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Adapter key used on the CLI (e.g. ``"nous"``)."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def display_name(self) -> str:
|
||||
"""Human-readable provider name for logs and ``proxy status``."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def allowed_paths(self) -> FrozenSet[str]:
|
||||
"""Set of relative request paths the upstream accepts.
|
||||
|
||||
Paths are relative to the proxy's ``/v1`` mount point. For example,
|
||||
``"/chat/completions"`` corresponds to a client request to
|
||||
``http://127.0.0.1:<port>/v1/chat/completions``. Requests to paths
|
||||
not in this set get a 404 with a helpful error body.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def is_authenticated(self) -> bool:
|
||||
"""Return True if the user has usable credentials for this upstream.
|
||||
|
||||
Should be cheap — no network calls. Used by ``proxy start`` for a
|
||||
clear up-front error before binding a port.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_credential(self) -> UpstreamCredential:
|
||||
"""Return a fresh credential, refreshing/minting if necessary.
|
||||
|
||||
Implementations should:
|
||||
- refresh the access token if it's near expiry
|
||||
- mint/rotate the upstream bearer key if it's near expiry
|
||||
- persist any refreshed state back to disk
|
||||
|
||||
Raises:
|
||||
RuntimeError: if the user isn't authenticated or the upstream
|
||||
refresh fails. The proxy will return 401 to the client.
|
||||
"""
|
||||
|
||||
def describe(self) -> str:
|
||||
"""One-line status summary for ``proxy status``."""
|
||||
try:
|
||||
cred = self.get_credential()
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
return f"{self.display_name}: not ready ({exc})"
|
||||
ttl = f" (expires {cred.expires_at})" if cred.expires_at else ""
|
||||
return f"{self.display_name}: {cred.base_url}{ttl}"
|
||||
|
||||
|
||||
__all__ = ["UpstreamAdapter", "UpstreamCredential"]
|
||||
@@ -0,0 +1,137 @@
|
||||
"""Nous Portal upstream adapter.
|
||||
|
||||
Reads the user's Nous OAuth state from ``~/.hermes/auth.json``, refreshes
|
||||
the access token and mints a fresh agent key when needed, and exposes the
|
||||
upstream base URL plus minted bearer for the proxy server to forward to.
|
||||
|
||||
The minted ``agent_key`` (not the OAuth ``access_token``) is what
|
||||
``inference-api.nousresearch.com`` accepts as a bearer. The refresh helper
|
||||
already handles both — see :func:`hermes_cli.auth.refresh_nous_oauth_from_state`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Dict, FrozenSet, Optional
|
||||
|
||||
from hermes_cli.auth import (
|
||||
DEFAULT_NOUS_INFERENCE_URL,
|
||||
_load_auth_store,
|
||||
_save_auth_store,
|
||||
_write_shared_nous_state,
|
||||
refresh_nous_oauth_from_state,
|
||||
)
|
||||
from hermes_cli.proxy.adapters.base import UpstreamAdapter, UpstreamCredential
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Endpoints inference-api.nousresearch.com actually serves. Anything else
|
||||
# the proxy will reject with 404 — keeps stray clients from leaking weird
|
||||
# requests to the upstream.
|
||||
_ALLOWED_PATHS: FrozenSet[str] = frozenset(
|
||||
{
|
||||
"/chat/completions",
|
||||
"/completions",
|
||||
"/embeddings",
|
||||
"/models",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class NousPortalAdapter(UpstreamAdapter):
|
||||
"""Proxy upstream for the Nous Portal inference API."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Lock guards _load → refresh → _save against parallel proxy requests
|
||||
# racing to refresh expired tokens. Refresh itself is HTTP, so we
|
||||
# hold the lock across the network call (brief; OAuth refresh is fast).
|
||||
self._lock = threading.Lock()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "nous"
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return "Nous Portal"
|
||||
|
||||
@property
|
||||
def allowed_paths(self) -> FrozenSet[str]:
|
||||
return _ALLOWED_PATHS
|
||||
|
||||
def is_authenticated(self) -> bool:
|
||||
state = self._read_state()
|
||||
if state is None:
|
||||
return False
|
||||
# We need either a usable agent_key OR (refresh_token + access_token)
|
||||
# to recover. The refresh helper will mint/refresh as needed.
|
||||
return bool(
|
||||
state.get("agent_key")
|
||||
or (state.get("refresh_token") and state.get("access_token"))
|
||||
)
|
||||
|
||||
def get_credential(self) -> UpstreamCredential:
|
||||
with self._lock:
|
||||
state = self._read_state()
|
||||
if state is None:
|
||||
raise RuntimeError(
|
||||
"Not logged into Nous Portal. Run `hermes login nous` first."
|
||||
)
|
||||
|
||||
try:
|
||||
refreshed = refresh_nous_oauth_from_state(state)
|
||||
except Exception as exc:
|
||||
raise RuntimeError(
|
||||
f"Failed to refresh Nous Portal credentials: {exc}"
|
||||
) from exc
|
||||
|
||||
self._save_state(refreshed)
|
||||
|
||||
agent_key = refreshed.get("agent_key")
|
||||
if not agent_key:
|
||||
raise RuntimeError(
|
||||
"Nous Portal refresh did not return a usable agent_key. "
|
||||
"Try `hermes login nous` to re-authenticate."
|
||||
)
|
||||
|
||||
base_url = refreshed.get("inference_base_url") or DEFAULT_NOUS_INFERENCE_URL
|
||||
base_url = base_url.rstrip("/")
|
||||
|
||||
return UpstreamCredential(
|
||||
bearer=agent_key,
|
||||
base_url=base_url,
|
||||
expires_at=refreshed.get("agent_key_expires_at"),
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers — auth.json access. Kept local rather than added
|
||||
# to hermes_cli.auth to avoid expanding that module's public surface.
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _read_state(self) -> Optional[Dict[str, Any]]:
|
||||
try:
|
||||
store = _load_auth_store()
|
||||
except Exception as exc:
|
||||
logger.warning("proxy: failed to load auth store: %s", exc)
|
||||
return None
|
||||
providers = store.get("providers") or {}
|
||||
state = providers.get("nous")
|
||||
if not isinstance(state, dict):
|
||||
return None
|
||||
return dict(state) # copy so the refresh helper can mutate freely
|
||||
|
||||
def _save_state(self, state: Dict[str, Any]) -> None:
|
||||
try:
|
||||
store = _load_auth_store()
|
||||
providers = store.setdefault("providers", {})
|
||||
providers["nous"] = state
|
||||
_save_auth_store(store)
|
||||
_write_shared_nous_state(state)
|
||||
except Exception as exc:
|
||||
# Best effort — we still return the fresh credential. The next
|
||||
# request just won't see cached state, which means another refresh.
|
||||
logger.warning("proxy: failed to persist refreshed Nous state: %s", exc)
|
||||
|
||||
|
||||
__all__ = ["NousPortalAdapter"]
|
||||
@@ -0,0 +1,141 @@
|
||||
"""CLI handlers for the ``hermes proxy`` subcommand."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from hermes_cli.proxy.adapters import ADAPTERS, get_adapter
|
||||
from hermes_cli.proxy.server import (
|
||||
AIOHTTP_AVAILABLE,
|
||||
DEFAULT_HOST,
|
||||
DEFAULT_PORT,
|
||||
run_server,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _print_aiohttp_missing() -> None:
|
||||
print(
|
||||
"hermes proxy requires aiohttp. Install one of:\n"
|
||||
" pip install 'hermes-agent[messaging]'\n"
|
||||
" pip install aiohttp",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
|
||||
def cmd_proxy_start(args: Any) -> int:
|
||||
"""Run the proxy server in the foreground.
|
||||
|
||||
Returns process exit code (0 on clean shutdown).
|
||||
"""
|
||||
if not AIOHTTP_AVAILABLE:
|
||||
_print_aiohttp_missing()
|
||||
return 1
|
||||
|
||||
provider = getattr(args, "provider", None) or "nous"
|
||||
try:
|
||||
adapter = get_adapter(provider)
|
||||
except ValueError as exc:
|
||||
print(f"Error: {exc}", file=sys.stderr)
|
||||
return 2
|
||||
|
||||
if not adapter.is_authenticated():
|
||||
print(
|
||||
f"Not logged into {adapter.display_name}. "
|
||||
f"Run `hermes login {adapter.name}` first.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 2
|
||||
|
||||
host = getattr(args, "host", None) or DEFAULT_HOST
|
||||
port = getattr(args, "port", None) or DEFAULT_PORT
|
||||
|
||||
print(
|
||||
f"Starting Hermes proxy for {adapter.display_name}\n"
|
||||
f" Listening on: http://{host}:{port}/v1\n"
|
||||
f" Forwarding to: (resolved per-request from your subscription)\n"
|
||||
f" Use any bearer token in the client — the proxy attaches your real credential.\n"
|
||||
f"\n"
|
||||
f"Press Ctrl+C to stop.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
try:
|
||||
asyncio.run(run_server(adapter, host=host, port=port))
|
||||
except KeyboardInterrupt:
|
||||
print("\nproxy: stopped", file=sys.stderr)
|
||||
except OSError as exc:
|
||||
print(f"proxy: failed to bind {host}:{port}: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_proxy_status(args: Any) -> int:
|
||||
"""Print the status of each configured upstream adapter."""
|
||||
print("Hermes proxy upstream adapters\n")
|
||||
for name in sorted(ADAPTERS):
|
||||
adapter = get_adapter(name)
|
||||
if not adapter.is_authenticated():
|
||||
print(f" [{name:8s}] {adapter.display_name} — not logged in")
|
||||
continue
|
||||
try:
|
||||
cred = adapter.get_credential()
|
||||
except Exception as exc:
|
||||
print(
|
||||
f" [{name:8s}] {adapter.display_name} — credentials need attention "
|
||||
f"({exc})"
|
||||
)
|
||||
continue
|
||||
expires = f" (bearer expires {cred.expires_at})" if cred.expires_at else ""
|
||||
print(f" [{name:8s}] {adapter.display_name} — ready{expires}")
|
||||
print(
|
||||
"\nStart the proxy with: hermes proxy start [--provider <name>]"
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_proxy_list_providers(args: Any) -> int:
|
||||
"""List available proxy upstream providers."""
|
||||
print("Available proxy upstream providers:")
|
||||
for name in sorted(ADAPTERS):
|
||||
adapter = get_adapter(name)
|
||||
print(f" {name} — {adapter.display_name}")
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_proxy(args: Any) -> int:
|
||||
"""Dispatch ``hermes proxy <subcommand>``."""
|
||||
sub = getattr(args, "proxy_command", None)
|
||||
if sub == "start":
|
||||
return cmd_proxy_start(args)
|
||||
if sub == "status":
|
||||
return cmd_proxy_status(args)
|
||||
if sub in ("providers", "list"):
|
||||
return cmd_proxy_list_providers(args)
|
||||
# No subcommand → print short help.
|
||||
print(
|
||||
"hermes proxy — local OpenAI-compatible proxy that attaches your\n"
|
||||
"OAuth-authenticated provider credentials to outbound requests.\n"
|
||||
"\n"
|
||||
"Subcommands:\n"
|
||||
" hermes proxy start [--provider nous] [--host 127.0.0.1] [--port 8645]\n"
|
||||
" Run the proxy in the foreground.\n"
|
||||
" hermes proxy status\n"
|
||||
" Show which upstream adapters are ready.\n"
|
||||
" hermes proxy providers\n"
|
||||
" List available upstream providers.\n",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
__all__ = [
|
||||
"cmd_proxy",
|
||||
"cmd_proxy_start",
|
||||
"cmd_proxy_status",
|
||||
"cmd_proxy_list_providers",
|
||||
]
|
||||
@@ -0,0 +1,265 @@
|
||||
"""HTTP server that forwards OpenAI-compatible requests to a configured upstream.
|
||||
|
||||
Listens on ``http://<host>:<port>/v1/<path>`` and forwards each request to
|
||||
``<upstream-base-url>/<path>`` with the client's ``Authorization`` header
|
||||
replaced by a freshly-resolved bearer from the configured adapter. The
|
||||
response is streamed back unmodified, preserving SSE.
|
||||
|
||||
The server is intentionally minimal: it does NOT mediate, log, transform,
|
||||
or rewrite request/response bodies. It's a credential-attaching forwarder.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import signal
|
||||
from typing import Optional
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
AIOHTTP_AVAILABLE = True
|
||||
except ImportError:
|
||||
aiohttp = None # type: ignore[assignment]
|
||||
web = None # type: ignore[assignment]
|
||||
AIOHTTP_AVAILABLE = False
|
||||
|
||||
from hermes_cli.proxy.adapters.base import UpstreamAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Headers we strip when forwarding to the upstream. ``host``/``content-length``
|
||||
# are recomputed by aiohttp; ``authorization`` is replaced with our bearer.
|
||||
# Everything else (content-type, accept, user-agent, x-* headers) passes through.
|
||||
_HOP_BY_HOP_HEADERS = frozenset(
|
||||
{
|
||||
"host",
|
||||
"content-length",
|
||||
"connection",
|
||||
"keep-alive",
|
||||
"proxy-authenticate",
|
||||
"proxy-authorization",
|
||||
"te",
|
||||
"trailers",
|
||||
"transfer-encoding",
|
||||
"upgrade",
|
||||
"authorization", # we replace this one
|
||||
}
|
||||
)
|
||||
|
||||
DEFAULT_PORT = 8645
|
||||
DEFAULT_HOST = "127.0.0.1"
|
||||
|
||||
|
||||
def _json_error(status: int, message: str, code: str = "proxy_error") -> "web.Response":
|
||||
"""Return an OpenAI-style error JSON response."""
|
||||
body = {"error": {"message": message, "type": code, "code": code}}
|
||||
return web.json_response(body, status=status)
|
||||
|
||||
|
||||
def _filter_request_headers(headers: "aiohttp.typedefs.LooseHeaders") -> dict:
|
||||
"""Strip hop-by-hop + auth headers from the inbound request."""
|
||||
out = {}
|
||||
for key, value in headers.items():
|
||||
if key.lower() in _HOP_BY_HOP_HEADERS:
|
||||
continue
|
||||
out[key] = value
|
||||
return out
|
||||
|
||||
|
||||
def _filter_response_headers(headers) -> dict:
|
||||
"""Strip hop-by-hop headers from the upstream response."""
|
||||
out = {}
|
||||
for key, value in headers.items():
|
||||
if key.lower() in _HOP_BY_HOP_HEADERS:
|
||||
continue
|
||||
# aiohttp recomputes Content-Encoding/Content-Length on stream — let it.
|
||||
if key.lower() in ("content-encoding", "content-length"):
|
||||
continue
|
||||
out[key] = value
|
||||
return out
|
||||
|
||||
|
||||
def create_app(adapter: UpstreamAdapter) -> "web.Application":
|
||||
"""Build the aiohttp application bound to a specific upstream adapter."""
|
||||
if not AIOHTTP_AVAILABLE:
|
||||
raise RuntimeError(
|
||||
"aiohttp is required for `hermes proxy`. Install with: "
|
||||
"pip install 'hermes-agent[messaging]' or `pip install aiohttp`."
|
||||
)
|
||||
|
||||
app = web.Application()
|
||||
# AppKey ensures forward-compat with future aiohttp versions that strip
|
||||
# bare-string keys.
|
||||
_adapter_key = web.AppKey("adapter", UpstreamAdapter)
|
||||
app[_adapter_key] = adapter
|
||||
|
||||
async def handle_health(request: "web.Request") -> "web.Response":
|
||||
return web.json_response(
|
||||
{
|
||||
"status": "ok",
|
||||
"upstream": adapter.display_name,
|
||||
"authenticated": adapter.is_authenticated(),
|
||||
}
|
||||
)
|
||||
|
||||
async def handle_models_fallback(request: "web.Request") -> "web.Response":
|
||||
# Most clients hit /v1/models on startup. If the upstream doesn't
|
||||
# serve /models, synthesize a minimal response so clients don't
|
||||
# crash. The actual forwarding path handles /models when allowed.
|
||||
return web.json_response(
|
||||
{
|
||||
"object": "list",
|
||||
"data": [],
|
||||
}
|
||||
)
|
||||
|
||||
async def handle_proxy(request: "web.Request") -> "web.StreamResponse":
|
||||
# Extract the path *after* /v1
|
||||
rel_path = request.match_info.get("tail", "")
|
||||
rel_path = "/" + rel_path.lstrip("/")
|
||||
|
||||
if rel_path not in adapter.allowed_paths:
|
||||
allowed = ", ".join(sorted(adapter.allowed_paths))
|
||||
return _json_error(
|
||||
404,
|
||||
f"Path /v1{rel_path} is not forwarded by this proxy. "
|
||||
f"Allowed: {allowed}",
|
||||
code="path_not_allowed",
|
||||
)
|
||||
|
||||
try:
|
||||
cred = adapter.get_credential()
|
||||
except Exception as exc:
|
||||
logger.warning("proxy: credential resolution failed: %s", exc)
|
||||
return _json_error(401, str(exc), code="upstream_auth_failed")
|
||||
|
||||
upstream_url = f"{cred.base_url.rstrip('/')}{rel_path}"
|
||||
# Preserve query string verbatim.
|
||||
if request.query_string:
|
||||
upstream_url = f"{upstream_url}?{request.query_string}"
|
||||
|
||||
# Forward body verbatim. Read into memory once — request bodies for
|
||||
# chat/completions/embeddings are small (<1MB typically). If we ever
|
||||
# need to forward large multipart uploads we'll switch to streaming
|
||||
# the request body too.
|
||||
body = await request.read()
|
||||
|
||||
fwd_headers = _filter_request_headers(request.headers)
|
||||
fwd_headers["Authorization"] = f"{cred.token_type} {cred.bearer}"
|
||||
|
||||
logger.debug(
|
||||
"proxy: forwarding %s %s -> %s (body=%d bytes)",
|
||||
request.method, rel_path, upstream_url, len(body),
|
||||
)
|
||||
|
||||
# Use a per-request session so connection state doesn't leak between
|
||||
# clients. Could be optimized to a shared session later.
|
||||
timeout = aiohttp.ClientTimeout(total=None, sock_connect=15, sock_read=300)
|
||||
try:
|
||||
session = aiohttp.ClientSession(timeout=timeout)
|
||||
except Exception as exc: # pragma: no cover - aiohttp setup issue
|
||||
return _json_error(500, f"proxy session init failed: {exc}")
|
||||
|
||||
try:
|
||||
upstream_resp = await session.request(
|
||||
request.method,
|
||||
upstream_url,
|
||||
data=body if body else None,
|
||||
headers=fwd_headers,
|
||||
allow_redirects=False,
|
||||
)
|
||||
except aiohttp.ClientError as exc:
|
||||
await session.close()
|
||||
logger.warning("proxy: upstream connection failed: %s", exc)
|
||||
return _json_error(502, f"upstream connection failed: {exc}",
|
||||
code="upstream_unreachable")
|
||||
except asyncio.TimeoutError:
|
||||
await session.close()
|
||||
return _json_error(504, "upstream request timed out",
|
||||
code="upstream_timeout")
|
||||
|
||||
# Stream response back. Headers first, then chunked body.
|
||||
resp = web.StreamResponse(
|
||||
status=upstream_resp.status,
|
||||
headers=_filter_response_headers(upstream_resp.headers),
|
||||
)
|
||||
await resp.prepare(request)
|
||||
|
||||
try:
|
||||
async for chunk in upstream_resp.content.iter_any():
|
||||
if chunk:
|
||||
await resp.write(chunk)
|
||||
except (aiohttp.ClientError, asyncio.CancelledError) as exc:
|
||||
logger.warning("proxy: streaming interrupted: %s", exc)
|
||||
finally:
|
||||
upstream_resp.release()
|
||||
await session.close()
|
||||
|
||||
await resp.write_eof()
|
||||
return resp
|
||||
|
||||
# /health doesn't go through the upstream
|
||||
app.router.add_get("/health", handle_health)
|
||||
# Catch-all under /v1 — forwards if the path is allowed.
|
||||
app.router.add_route("*", "/v1/{tail:.*}", handle_proxy)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
async def run_server(
|
||||
adapter: UpstreamAdapter,
|
||||
host: str = DEFAULT_HOST,
|
||||
port: int = DEFAULT_PORT,
|
||||
shutdown_event: Optional[asyncio.Event] = None,
|
||||
) -> None:
|
||||
"""Run the proxy in the current event loop until shutdown_event is set.
|
||||
|
||||
If shutdown_event is None, runs until cancelled (Ctrl+C or SIGTERM).
|
||||
"""
|
||||
if not AIOHTTP_AVAILABLE:
|
||||
raise RuntimeError(
|
||||
"aiohttp is required for `hermes proxy`. Install with: "
|
||||
"pip install 'hermes-agent[messaging]' or `pip install aiohttp`."
|
||||
)
|
||||
|
||||
app = create_app(adapter)
|
||||
runner = web.AppRunner(app, access_log=None)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, host=host, port=port)
|
||||
await site.start()
|
||||
|
||||
logger.info(
|
||||
"proxy: listening on http://%s:%d/v1 -> %s",
|
||||
host, port, adapter.display_name,
|
||||
)
|
||||
|
||||
stop_event = shutdown_event or asyncio.Event()
|
||||
|
||||
# Wire signal handlers when we own the loop's lifetime.
|
||||
if shutdown_event is None:
|
||||
loop = asyncio.get_running_loop()
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
try:
|
||||
loop.add_signal_handler(sig, stop_event.set) # windows-footgun: ok
|
||||
except NotImplementedError:
|
||||
# Windows / restricted environments — Ctrl+C will still
|
||||
# raise KeyboardInterrupt and unwind us.
|
||||
pass
|
||||
|
||||
try:
|
||||
await stop_event.wait()
|
||||
finally:
|
||||
logger.info("proxy: shutting down")
|
||||
await runner.cleanup()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"create_app",
|
||||
"run_server",
|
||||
"DEFAULT_HOST",
|
||||
"DEFAULT_PORT",
|
||||
"AIOHTTP_AVAILABLE",
|
||||
]
|
||||
@@ -102,8 +102,10 @@ def _auto_detect_local_model(base_url: str) -> str:
|
||||
model_id = models[0].get("id", "")
|
||||
if model_id:
|
||||
return model_id
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as exc:
|
||||
# Log instead of silently swallowing — aids debugging when
|
||||
# local model auto-detection fails unexpectedly.
|
||||
logger.debug("Auto-detect model from %s failed: %s", base_url, exc)
|
||||
return ""
|
||||
|
||||
|
||||
|
||||
@@ -522,14 +522,6 @@ def _print_setup_summary(config: dict, hermes_home):
|
||||
elif managed_nous_tools_enabled() and subscription_features.nous_auth_present:
|
||||
tool_status.append(("Modal Execution (optional via Nous subscription)", True, None))
|
||||
|
||||
# Tinker + WandB (RL training)
|
||||
if get_env_value("TINKER_API_KEY") and get_env_value("WANDB_API_KEY"):
|
||||
tool_status.append(("RL Training (Tinker)", True, None))
|
||||
elif get_env_value("TINKER_API_KEY"):
|
||||
tool_status.append(("RL Training (Tinker)", False, "WANDB_API_KEY"))
|
||||
else:
|
||||
tool_status.append(("RL Training (Tinker)", False, "TINKER_API_KEY"))
|
||||
|
||||
# Home Assistant
|
||||
if get_env_value("HASS_TOKEN"):
|
||||
tool_status.append(("Smart Home (Home Assistant)", True, None))
|
||||
|
||||
@@ -849,10 +849,14 @@ def get_prompt_toolkit_style_overrides() -> Dict[str, str]:
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
prompt = skin.get_color("prompt", "#FFF8DC")
|
||||
# Input/prompt: leave unset by default so the typed text inherits
|
||||
# the terminal's foreground color (readable in both light and dark
|
||||
# color schemes). Skins can opt into a colored prompt by setting
|
||||
# `prompt` explicitly in their YAML.
|
||||
prompt = skin.get_color("prompt", "")
|
||||
input_rule = skin.get_color("input_rule", "#CD7F32")
|
||||
title = skin.get_color("banner_title", "#FFD700")
|
||||
text = skin.get_color("banner_text", prompt)
|
||||
text = skin.get_color("banner_text", "#FFF8DC")
|
||||
dim = skin.get_color("banner_dim", "#555555")
|
||||
label = skin.get_color("ui_label", title)
|
||||
warn = skin.get_color("ui_warn", "#FF8C00")
|
||||
@@ -872,7 +876,11 @@ def get_prompt_toolkit_style_overrides() -> Dict[str, str]:
|
||||
menu_meta_current_bg = skin.get_color("completion_menu_meta_current_bg", menu_current_bg)
|
||||
|
||||
return {
|
||||
"input-area": prompt,
|
||||
# Typed input always uses terminal default fg/bg so it's
|
||||
# readable in both light and dark Terminal.app modes. The
|
||||
# skin's `prompt` color (if any) only styles the prompt symbol,
|
||||
# NOT the user's typed text.
|
||||
"input-area": "",
|
||||
"placeholder": f"{dim} italic",
|
||||
"prompt": prompt,
|
||||
"prompt-working": f"{dim} italic",
|
||||
|
||||
@@ -141,8 +141,6 @@ def show_status(args):
|
||||
"Browser Use": "BROWSER_USE_API_KEY", # Optional — local browser works without this
|
||||
"Browserbase": "BROWSERBASE_API_KEY", # Optional — direct credentials only
|
||||
"FAL": "FAL_KEY",
|
||||
"Tinker": "TINKER_API_KEY",
|
||||
"WandB": "WANDB_API_KEY",
|
||||
"ElevenLabs": "ELEVENLABS_API_KEY",
|
||||
"GitHub": "GITHUB_TOKEN",
|
||||
}
|
||||
|
||||
@@ -71,7 +71,6 @@ CONFIGURABLE_TOOLSETS = [
|
||||
("delegation", "👥 Task Delegation", "delegate_task"),
|
||||
("cronjob", "⏰ Cron Jobs", "create/list/update/pause/resume/run, with optional attached skills"),
|
||||
("messaging", "📨 Cross-Platform Messaging", "send_message"),
|
||||
("rl", "🧪 RL Training", "Tinker-Atropos training tools"),
|
||||
("homeassistant", "🏠 Home Assistant", "smart home device control"),
|
||||
("spotify", "🎵 Spotify", "playback, search, playlists, library"),
|
||||
("discord", "💬 Discord (read/participate)", "fetch messages, search members, create thread"),
|
||||
@@ -87,7 +86,7 @@ CONFIGURABLE_TOOLSETS = [
|
||||
# Video gen is off by default — it's a niche, paid, slow feature. Users
|
||||
# who want it opt in via `hermes tools` → Video Generation, which walks
|
||||
# them through provider + model selection.
|
||||
_DEFAULT_OFF_TOOLSETS = {"moa", "homeassistant", "rl", "spotify", "discord", "discord_admin", "video", "video_gen"}
|
||||
_DEFAULT_OFF_TOOLSETS = {"moa", "homeassistant", "spotify", "discord", "discord_admin", "video", "video_gen"}
|
||||
|
||||
# Platform-scoped toolsets: only appear in the `hermes tools` checklist for
|
||||
# these platforms, and only resolve/save for these platforms. A toolset
|
||||
@@ -424,22 +423,6 @@ TOOL_CATEGORIES = {
|
||||
},
|
||||
],
|
||||
},
|
||||
"rl": {
|
||||
"name": "RL Training",
|
||||
"icon": "🧪",
|
||||
"requires_python": (3, 11),
|
||||
"providers": [
|
||||
{
|
||||
"name": "Tinker / Atropos",
|
||||
"tag": "RL training platform",
|
||||
"env_vars": [
|
||||
{"key": "TINKER_API_KEY", "prompt": "Tinker API key", "url": "https://tinker-console.thinkingmachines.ai/keys"},
|
||||
{"key": "WANDB_API_KEY", "prompt": "WandB API key", "url": "https://wandb.ai/authorize"},
|
||||
],
|
||||
"post_setup": "rl_training",
|
||||
},
|
||||
],
|
||||
},
|
||||
"langfuse": {
|
||||
"name": "Langfuse Observability",
|
||||
"icon": "📊",
|
||||
@@ -912,24 +895,6 @@ def _run_post_setup(post_setup_key: str):
|
||||
_print_warning(f" Spotify login failed: {exc}")
|
||||
_print_info(" Run manually: hermes auth spotify")
|
||||
|
||||
elif post_setup_key == "rl_training":
|
||||
try:
|
||||
__import__("tinker_atropos")
|
||||
except ImportError:
|
||||
tinker_dir = PROJECT_ROOT / "tinker-atropos"
|
||||
if tinker_dir.exists() and (tinker_dir / "pyproject.toml").exists():
|
||||
_print_info(" Installing tinker-atropos submodule...")
|
||||
result = _pip_install(["-e", str(tinker_dir)])
|
||||
if result.returncode == 0:
|
||||
_print_success(" tinker-atropos installed")
|
||||
else:
|
||||
_print_warning(" tinker-atropos install failed - run manually:")
|
||||
_print_info(' uv pip install -e "./tinker-atropos"')
|
||||
else:
|
||||
_print_warning(" tinker-atropos submodule not found - run:")
|
||||
_print_info(" git submodule update --init --recursive")
|
||||
_print_info(' uv pip install -e "./tinker-atropos"')
|
||||
|
||||
elif post_setup_key == "langfuse":
|
||||
# Install the langfuse SDK.
|
||||
try:
|
||||
|
||||
+223
-4
@@ -25,7 +25,7 @@ from pathlib import Path
|
||||
|
||||
from agent.memory_manager import sanitize_context
|
||||
from hermes_constants import get_hermes_home
|
||||
from typing import Any, Callable, Dict, List, Optional, TypeVar
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1618,6 +1618,185 @@ class SessionDB:
|
||||
result.append(msg)
|
||||
return result
|
||||
|
||||
def get_messages_around(
|
||||
self,
|
||||
session_id: str,
|
||||
around_message_id: int,
|
||||
window: int = 5,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Load a window of messages anchored on a specific message id.
|
||||
|
||||
Returns up to ``window`` messages before the anchor, the anchor itself,
|
||||
and up to ``window`` messages after — all from the same session,
|
||||
ordered by id ascending. Boundaries are honoured: if the anchor is
|
||||
near the start or end of the session, fewer messages are returned on
|
||||
the truncated side.
|
||||
|
||||
If ``around_message_id`` is not a message id within ``session_id``,
|
||||
returns an empty list. Callers decide whether to surface that as an
|
||||
error.
|
||||
|
||||
Used by ``session_search`` mode='guided' to provide anchored
|
||||
drill-down into a specific session at a specific message — without
|
||||
the cost of summarisation or the risk of 100k-char truncation.
|
||||
"""
|
||||
if window < 0:
|
||||
window = 0
|
||||
with self._lock:
|
||||
# Confirm the anchor exists in this session — cheap guard against
|
||||
# cross-session contamination if a caller mixes up session/message
|
||||
# ids.
|
||||
anchor_exists = self._conn.execute(
|
||||
"SELECT 1 FROM messages WHERE id = ? AND session_id = ? LIMIT 1",
|
||||
(around_message_id, session_id),
|
||||
).fetchone()
|
||||
if not anchor_exists:
|
||||
return []
|
||||
|
||||
# Two queries: anchor + before (DESC, take window+1), and after
|
||||
# (ASC, take window). Final order is id ASC.
|
||||
before_rows = self._conn.execute(
|
||||
"SELECT * FROM messages "
|
||||
"WHERE session_id = ? AND id <= ? "
|
||||
"ORDER BY id DESC LIMIT ?",
|
||||
(session_id, around_message_id, window + 1),
|
||||
).fetchall()
|
||||
after_rows = self._conn.execute(
|
||||
"SELECT * FROM messages "
|
||||
"WHERE session_id = ? AND id > ? "
|
||||
"ORDER BY id ASC LIMIT ?",
|
||||
(session_id, around_message_id, window),
|
||||
).fetchall()
|
||||
|
||||
# before_rows is DESC; reverse so it's ASC, then concatenate after_rows.
|
||||
rows = list(reversed(before_rows)) + list(after_rows)
|
||||
result = []
|
||||
for row in rows:
|
||||
msg = dict(row)
|
||||
if "content" in msg:
|
||||
msg["content"] = self._decode_content(msg["content"])
|
||||
if msg.get("tool_calls"):
|
||||
try:
|
||||
msg["tool_calls"] = json.loads(msg["tool_calls"])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning(
|
||||
"Failed to deserialize tool_calls in get_messages_around, falling back to []"
|
||||
)
|
||||
msg["tool_calls"] = []
|
||||
result.append(msg)
|
||||
return result
|
||||
|
||||
def get_anchored_view(
|
||||
self,
|
||||
session_id: str,
|
||||
around_message_id: int,
|
||||
window: int = 5,
|
||||
bookend: int = 3,
|
||||
keep_roles: Optional[Tuple[str, ...]] = ("user", "assistant"),
|
||||
) -> Dict[str, Any]:
|
||||
"""Return an anchored window plus session bookends, opinionated for guided recall.
|
||||
|
||||
Built on top of ``get_messages_around``:
|
||||
- ``window``: messages immediately surrounding the anchor. Filtered to
|
||||
``keep_roles`` (tool-response noise dropped by default), EXCEPT the
|
||||
anchor itself is always included regardless of role — callers may
|
||||
have anchored on a tool message and dropping it would break the
|
||||
contract.
|
||||
- ``bookend_start``: first ``bookend`` messages of the session
|
||||
(filtered to ``keep_roles``), but ONLY those whose id sits strictly
|
||||
before the window's first message id. If the window already covers
|
||||
the session start, ``bookend_start`` is an empty list.
|
||||
- ``bookend_end``: last ``bookend`` messages of the session (same
|
||||
filter + non-overlap rule applied at the tail).
|
||||
|
||||
Bookends exist so an FTS5 hit anywhere in a long session still yields
|
||||
the goal (opening) and the resolution (closing) on a single guided
|
||||
call — without the cost of fetching the whole transcript.
|
||||
|
||||
Returns ``{"window": []}`` (empty) when the anchor isn't in the
|
||||
session — caller decides how to surface that.
|
||||
|
||||
``keep_roles=None`` disables role filtering entirely (raw window +
|
||||
raw bookends). Pass an explicit tuple to override the default.
|
||||
"""
|
||||
if bookend < 0:
|
||||
bookend = 0
|
||||
|
||||
# Reuse the primitive — it already handles the anchor-existence check,
|
||||
# window clamping, content decoding, and tool_calls deserialisation.
|
||||
window_rows = self.get_messages_around(
|
||||
session_id, around_message_id, window=window
|
||||
)
|
||||
if not window_rows:
|
||||
return {"window": [], "bookend_start": [], "bookend_end": []}
|
||||
|
||||
# Apply role filter to the window, but never drop the anchor itself.
|
||||
if keep_roles is not None:
|
||||
keep_set = set(keep_roles)
|
||||
filtered_window = [
|
||||
m for m in window_rows
|
||||
if m.get("id") == around_message_id or m.get("role") in keep_set
|
||||
]
|
||||
else:
|
||||
filtered_window = window_rows
|
||||
|
||||
window_min_id = window_rows[0]["id"]
|
||||
window_max_id = window_rows[-1]["id"]
|
||||
|
||||
# Fetch bookends only if there's space outside the window. SQL filters
|
||||
# by id range, role, and non-empty content — tool-call-only assistant
|
||||
# turns (content='' with tool_calls populated) are excluded so they
|
||||
# don't crowd out the actual prose openings/closings. ``bookend=0``
|
||||
# short-circuits both queries.
|
||||
bookend_start_rows: List[Any] = []
|
||||
bookend_end_rows: List[Any] = []
|
||||
if bookend > 0:
|
||||
with self._lock:
|
||||
role_clause = ""
|
||||
role_params: list = []
|
||||
if keep_roles is not None:
|
||||
role_placeholders = ",".join("?" for _ in keep_roles)
|
||||
role_clause = f" AND role IN ({role_placeholders})"
|
||||
role_params = list(keep_roles)
|
||||
|
||||
bookend_start_rows = self._conn.execute(
|
||||
f"SELECT * FROM messages "
|
||||
f"WHERE session_id = ? AND id < ?{role_clause} "
|
||||
f"AND length(content) > 0 "
|
||||
f"ORDER BY id ASC LIMIT ?",
|
||||
(session_id, window_min_id, *role_params, bookend),
|
||||
).fetchall()
|
||||
|
||||
bookend_end_rows = self._conn.execute(
|
||||
f"SELECT * FROM messages "
|
||||
f"WHERE session_id = ? AND id > ?{role_clause} "
|
||||
f"AND length(content) > 0 "
|
||||
f"ORDER BY id DESC LIMIT ?",
|
||||
(session_id, window_max_id, *role_params, bookend),
|
||||
).fetchall()
|
||||
# End rows came back DESC for the LIMIT cap; flip to ASC.
|
||||
bookend_end_rows = list(reversed(bookend_end_rows))
|
||||
|
||||
def _hydrate(row) -> Dict[str, Any]:
|
||||
msg = dict(row)
|
||||
if "content" in msg:
|
||||
msg["content"] = self._decode_content(msg["content"])
|
||||
if msg.get("tool_calls"):
|
||||
try:
|
||||
msg["tool_calls"] = json.loads(msg["tool_calls"])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning(
|
||||
"Failed to deserialize tool_calls in get_anchored_view, falling back to []"
|
||||
)
|
||||
msg["tool_calls"] = []
|
||||
return msg
|
||||
|
||||
return {
|
||||
"window": filtered_window,
|
||||
"bookend_start": [_hydrate(r) for r in bookend_start_rows],
|
||||
"bookend_end": [_hydrate(r) for r in bookend_end_rows],
|
||||
}
|
||||
|
||||
def resolve_resume_session_id(self, session_id: str) -> str:
|
||||
"""Redirect a resume target to the descendant session that holds the messages.
|
||||
|
||||
@@ -1885,6 +2064,7 @@ class SessionDB:
|
||||
role_filter: List[str] = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Full-text search across session messages using FTS5.
|
||||
@@ -1897,6 +2077,19 @@ class SessionDB:
|
||||
|
||||
Returns matching messages with session metadata, content snippet,
|
||||
and surrounding context (1 message before and after the match).
|
||||
|
||||
``sort`` controls temporal ordering of results:
|
||||
- ``None`` (default): FTS5 BM25 relevance only. Time-neutral, but
|
||||
ties between equally-relevant messages are broken arbitrarily.
|
||||
- ``"newest"``: order by message timestamp DESC, then by rank.
|
||||
Recent matches surface first; rank breaks same-timestamp ties.
|
||||
- ``"oldest"``: order by message timestamp ASC, then by rank.
|
||||
For "how did this start" / "what was the original X" questions.
|
||||
|
||||
The LIKE fallback path (short CJK queries) ignores ``sort`` because
|
||||
it has no rank to combine with — it already orders by timestamp DESC
|
||||
unconditionally. The trigram CJK path honours ``sort`` like the main
|
||||
FTS5 path.
|
||||
"""
|
||||
if not query or not query.strip():
|
||||
return []
|
||||
@@ -1905,6 +2098,25 @@ class SessionDB:
|
||||
if not query:
|
||||
return []
|
||||
|
||||
# Normalise sort. Anything not in the allowed set falls back to None
|
||||
# (FTS5 rank-only) — be forgiving to callers who pass empty string or
|
||||
# an unexpected value rather than failing the search.
|
||||
if isinstance(sort, str):
|
||||
sort_norm = sort.strip().lower()
|
||||
if sort_norm not in ("newest", "oldest"):
|
||||
sort_norm = None
|
||||
else:
|
||||
sort_norm = None
|
||||
|
||||
# ORDER BY shared by both FTS5 paths. With sort set, timestamp is
|
||||
# primary and rank is the tiebreaker; otherwise rank alone.
|
||||
if sort_norm == "newest":
|
||||
order_by_sql = "ORDER BY m.timestamp DESC, rank"
|
||||
elif sort_norm == "oldest":
|
||||
order_by_sql = "ORDER BY m.timestamp ASC, rank"
|
||||
else:
|
||||
order_by_sql = "ORDER BY rank"
|
||||
|
||||
# Build WHERE clauses dynamically
|
||||
where_clauses = ["messages_fts MATCH ?"]
|
||||
params: list = [query]
|
||||
@@ -1943,7 +2155,7 @@ class SessionDB:
|
||||
JOIN messages m ON m.id = messages_fts.rowid
|
||||
JOIN sessions s ON s.id = m.session_id
|
||||
WHERE {where_sql}
|
||||
ORDER BY rank
|
||||
{order_by_sql}
|
||||
LIMIT ? OFFSET ?
|
||||
"""
|
||||
|
||||
@@ -2012,7 +2224,7 @@ class SessionDB:
|
||||
JOIN messages m ON m.id = messages_fts_trigram.rowid
|
||||
JOIN sessions s ON s.id = m.session_id
|
||||
WHERE {' AND '.join(tri_where)}
|
||||
ORDER BY rank
|
||||
{order_by_sql}
|
||||
LIMIT ? OFFSET ?
|
||||
"""
|
||||
tri_params.extend([limit, offset])
|
||||
@@ -2051,6 +2263,13 @@ class SessionDB:
|
||||
if role_filter:
|
||||
like_where.append(f"m.role IN ({','.join('?' for _ in role_filter)})")
|
||||
like_params.extend(role_filter)
|
||||
# LIKE fallback has no rank to combine with — just timestamp
|
||||
# direction. Default/"newest" → DESC; "oldest" → ASC.
|
||||
like_order_sql = (
|
||||
"ORDER BY m.timestamp ASC"
|
||||
if sort_norm == "oldest"
|
||||
else "ORDER BY m.timestamp DESC"
|
||||
)
|
||||
like_sql = f"""
|
||||
SELECT m.id, m.session_id, m.role,
|
||||
substr(m.content,
|
||||
@@ -2061,7 +2280,7 @@ class SessionDB:
|
||||
FROM messages m
|
||||
JOIN sessions s ON s.id = m.session_id
|
||||
WHERE {' AND '.join(like_where)}
|
||||
ORDER BY m.timestamp DESC
|
||||
{like_order_sql}
|
||||
LIMIT ? OFFSET ?
|
||||
"""
|
||||
like_params.extend([limit, offset])
|
||||
|
||||
+1
-10
@@ -97,9 +97,7 @@ def _run_async(coro):
|
||||
asyncio.run()'s create-and-destroy lifecycle.
|
||||
|
||||
This is the single source of truth for sync->async bridging in tool
|
||||
handlers. The RL paths (agent_loop.py, tool_context.py) also provide
|
||||
outer thread-pool wrapping as defense-in-depth, but each handler is
|
||||
self-protecting via this function.
|
||||
handlers. Each handler is self-protecting via this function.
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
@@ -231,13 +229,6 @@ _LEGACY_TOOLSET_MAP = {
|
||||
"browser_vision", "browser_console"
|
||||
],
|
||||
"cronjob_tools": ["cronjob"],
|
||||
"rl_tools": [
|
||||
"rl_list_environments", "rl_select_environment",
|
||||
"rl_get_current_config", "rl_edit_config",
|
||||
"rl_start_training", "rl_check_status",
|
||||
"rl_stop_training", "rl_get_results",
|
||||
"rl_list_runs", "rl_test_inference"
|
||||
],
|
||||
"file_tools": ["read_file", "write_file", "patch", "search_files"],
|
||||
"tts_tools": ["text_to_speech"],
|
||||
}
|
||||
|
||||
@@ -192,7 +192,6 @@ stdenv.mkDerivation {
|
||||
source .venv/bin/activate
|
||||
uv pip install -e ".[all]"
|
||||
[ -d mini-swe-agent ] && uv pip install -e ./mini-swe-agent 2>/dev/null || true
|
||||
[ -d tinker-atropos ] && uv pip install -e ./tinker-atropos 2>/dev/null || true
|
||||
mkdir -p .nix-stamps
|
||||
echo "$STAMP_VALUE" > "$STAMP"
|
||||
else
|
||||
|
||||
@@ -1,303 +0,0 @@
|
||||
---
|
||||
name: hermes-atropos-environments
|
||||
description: Build, test, and debug Hermes Agent RL environments for Atropos training. Covers the HermesAgentBaseEnv interface, reward functions, agent loop integration, evaluation with tools, wandb logging, and the three CLI modes (serve/process/evaluate). Use when creating, reviewing, or fixing RL environments in the hermes-agent repo.
|
||||
version: 1.1.0
|
||||
author: Hermes Agent
|
||||
license: MIT
|
||||
platforms: [linux, macos, windows]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [atropos, rl, environments, training, reinforcement-learning, reward-functions]
|
||||
related_skills: [axolotl, fine-tuning-with-trl, lm-evaluation-harness]
|
||||
---
|
||||
|
||||
# Hermes Agent Atropos Environments
|
||||
|
||||
Guide for building RL environments in the hermes-agent repo that integrate with the Atropos training framework.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
```
|
||||
Atropos BaseEnv (atroposlib/envs/base.py)
|
||||
└── HermesAgentBaseEnv (environments/hermes_base_env.py)
|
||||
├── Handles agent loop orchestration
|
||||
├── Handles tool resolution per group
|
||||
├── Handles ToolContext for reward verification
|
||||
└── YOUR ENVIRONMENT (environments/your_env.py)
|
||||
Only implements: setup, get_next_item, format_prompt,
|
||||
compute_reward, evaluate, wandb_log
|
||||
```
|
||||
|
||||
Hermes environments are special because they run a **multi-turn agent loop with tool calling** — not just single-turn completions. The base env handles the loop; you implement the task and scoring.
|
||||
|
||||
## File Locations
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `environments/hermes_base_env.py` | Base class with agent loop + tool resolution |
|
||||
| `environments/agent_loop.py` | `HermesAgentLoop` + `AgentResult` dataclass |
|
||||
| `environments/tool_context.py` | `ToolContext` for reward verification |
|
||||
| `environments/tool_call_parsers.py` | Phase 2 tool call parsers (hermes, mistral, etc.) |
|
||||
| `environments/your_env.py` | Your environment implementation |
|
||||
|
||||
## Inference Setup — Ask the User First
|
||||
|
||||
**IMPORTANT:** Before running any test, evaluation, or data generation command, always ask the user how they want to handle inference. Do NOT assume OpenRouter or any specific endpoint. Present these options:
|
||||
|
||||
1. **OpenRouter** — Ask which model they want to use (e.g., `anthropic/claude-sonnet-4.5`, `google/gemini-2.5-pro`, `meta-llama/llama-3.3-70b-instruct`, etc.). Requires `OPENROUTER_API_KEY` in environment.
|
||||
2. **Self-hosted VLLM endpoint** — Ask for their base URL (e.g., `http://localhost:8000/v1`) and model name. Set `--openai.server_type vllm`.
|
||||
3. **Other OpenAI-compatible API** — Ask for the base URL, model name, and any required API key. Set `--openai.server_type openai` and `--openai.health_check false`.
|
||||
4. **Local Atropos training server** — For `serve` mode with a live training loop. Default `http://localhost:8000/v1`.
|
||||
|
||||
Once the user tells you their setup, use those values in all CLI commands for that session. Example prompts:
|
||||
|
||||
> "Before I run this, how would you like to handle inference?
|
||||
> 1. OpenRouter (I'll need your preferred model, e.g. claude-sonnet-4.5)
|
||||
> 2. A self-hosted VLLM endpoint (give me the URL and model name)
|
||||
> 3. Another OpenAI-compatible API (give me the URL, model, and any auth details)
|
||||
> 4. Local Atropos training server (serve mode)"
|
||||
|
||||
### Key flags by provider:
|
||||
|
||||
| Provider | `--openai.server_type` | `--openai.health_check` | `--openai.api_key` |
|
||||
|----------|----------------------|------------------------|-------------------|
|
||||
| OpenRouter | `openai` | `false` | `$OPENROUTER_API_KEY` |
|
||||
| VLLM (self-hosted) | `vllm` | (default) | (not needed) |
|
||||
| Other OpenAI-compatible | `openai` | `false` | As needed |
|
||||
| Local Atropos | (default) | (default) | (not needed) |
|
||||
|
||||
## Required Methods
|
||||
|
||||
### 1. `setup()` — Load dataset and initialize state
|
||||
|
||||
```python
|
||||
async def setup(self) -> None:
|
||||
"""Called once at startup. Load datasets, initialize state."""
|
||||
# Try HuggingFace first, fallback to built-in samples
|
||||
try:
|
||||
from datasets import load_dataset
|
||||
ds = load_dataset("your/dataset", split="test")
|
||||
self._items = [...]
|
||||
except Exception:
|
||||
self._items = BUILTIN_SAMPLES
|
||||
|
||||
# Always split into train/eval
|
||||
random.shuffle(self._items)
|
||||
eval_size = max(20, int(len(self._items) * 0.1))
|
||||
self._eval_items = self._items[:eval_size]
|
||||
self._items = self._items[eval_size:]
|
||||
```
|
||||
|
||||
### 2. `get_next_item()` — Return next training item
|
||||
|
||||
```python
|
||||
async def get_next_item(self) -> dict:
|
||||
"""Return next item, cycling through dataset."""
|
||||
item = self._items[self._index % len(self._items)]
|
||||
self._index += 1
|
||||
return item
|
||||
```
|
||||
|
||||
### 3. `format_prompt(item)` — Convert item to user message
|
||||
|
||||
```python
|
||||
def format_prompt(self, item: dict) -> str:
|
||||
"""Convert a dataset item into the user-facing prompt."""
|
||||
return f"Research this question: {item['question']}"
|
||||
```
|
||||
|
||||
### 4. `compute_reward(item, result, ctx)` — Score the rollout
|
||||
|
||||
**CRITICAL**: `result` is an `AgentResult`, NOT a dict. It has these attributes:
|
||||
- `result.messages` — List of message dicts (OpenAI format)
|
||||
- `result.turns_used` — Number of LLM calls made
|
||||
- `result.finished_naturally` — True if model stopped voluntarily
|
||||
- `result.tool_errors` — List of ToolError objects
|
||||
|
||||
**AgentResult does NOT have**: `final_response`, `tool_calls`, `tools_used`.
|
||||
You must extract these from `result.messages`:
|
||||
|
||||
```python
|
||||
async def compute_reward(self, item, result: AgentResult, ctx: ToolContext) -> float:
|
||||
# Extract final response (last assistant message with content)
|
||||
final_response = ""
|
||||
tools_used = []
|
||||
for msg in reversed(result.messages):
|
||||
if msg.get("role") == "assistant" and msg.get("content") and not final_response:
|
||||
final_response = msg["content"]
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
for tc in msg["tool_calls"]:
|
||||
fn = tc.get("function", {}) if isinstance(tc, dict) else {}
|
||||
name = fn.get("name", "")
|
||||
if name:
|
||||
tools_used.append(name)
|
||||
|
||||
# Score using LLM judge, heuristic, or ToolContext verification
|
||||
correctness = await self._llm_judge(item, final_response)
|
||||
return correctness
|
||||
```
|
||||
|
||||
`ctx` (ToolContext) gives you terminal/file access to the agent's sandbox for verification:
|
||||
```python
|
||||
# Run tests in the agent's sandbox
|
||||
result = ctx.terminal("pytest /workspace/test.py")
|
||||
return 1.0 if result["exit_code"] == 0 else 0.0
|
||||
```
|
||||
|
||||
### 5. `evaluate()` — Periodic evaluation with full agent loop
|
||||
|
||||
**MUST use the full agent loop with tools**, not single-turn chat_completion.
|
||||
The whole point of hermes-agent environments is agentic evaluation:
|
||||
|
||||
```python
|
||||
async def evaluate(self, *args, **kwargs) -> None:
|
||||
import time, uuid
|
||||
from environments.agent_loop import HermesAgentLoop
|
||||
from environments.tool_context import ToolContext
|
||||
|
||||
start_time = time.time()
|
||||
tools, valid_names = self._resolve_tools_for_group()
|
||||
samples = []
|
||||
|
||||
for item in self._eval_items[:self.config.eval_size]:
|
||||
task_id = str(uuid.uuid4())
|
||||
messages = []
|
||||
if self.config.system_prompt:
|
||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
messages.append({"role": "user", "content": self.format_prompt(item)})
|
||||
|
||||
agent = HermesAgentLoop(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=task_id,
|
||||
temperature=0.0, # Deterministic for eval
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
ctx = ToolContext(task_id)
|
||||
try:
|
||||
reward = await self.compute_reward(item, result, ctx)
|
||||
finally:
|
||||
ctx.cleanup()
|
||||
|
||||
samples.append({"prompt": ..., "response": ..., "reward": reward})
|
||||
|
||||
eval_metrics = {"eval/mean_reward": ...}
|
||||
await self.evaluate_log(metrics=eval_metrics, samples=samples,
|
||||
start_time=start_time, end_time=time.time())
|
||||
```
|
||||
|
||||
### 6. `wandb_log()` — Custom metrics logging
|
||||
|
||||
Always call `super().wandb_log()` at the end:
|
||||
|
||||
```python
|
||||
async def wandb_log(self, wandb_metrics=None):
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
if self._reward_buffer:
|
||||
n = len(self._reward_buffer)
|
||||
wandb_metrics["train/mean_reward"] = sum(self._reward_buffer) / n
|
||||
self._reward_buffer.clear()
|
||||
await super().wandb_log(wandb_metrics) # MUST call super
|
||||
```
|
||||
|
||||
**Pitfall**: `compute_reward` appends to metric buffers. During eval, this pollutes training metrics. Roll back buffer entries added during eval.
|
||||
|
||||
## Config Class
|
||||
|
||||
Always create a custom config subclass with Pydantic Field descriptors. Key inherited fields you can tune: `enabled_toolsets`, `max_agent_turns`, `agent_temperature`, `system_prompt`, `terminal_backend`, `group_size`, `steps_per_eval`, `total_steps`.
|
||||
|
||||
## config_init() — Default Configuration
|
||||
|
||||
Classmethod returning `(YourEnvConfig, [APIServerConfig(...)])`. Set server_type to "openai" for OpenRouter/external APIs. Load API key from environment variable.
|
||||
|
||||
## Three CLI Modes
|
||||
|
||||
```bash
|
||||
# SERVE — Full training loop (connects to Atropos API server)
|
||||
python environments/my_env.py serve --openai.base_url http://localhost:8000/v1
|
||||
|
||||
# PROCESS — Offline data generation (saves JSONL)
|
||||
python environments/my_env.py process --env.total_steps 10 --env.group_size 1 \
|
||||
--env.use_wandb false --env.data_path_to_save_groups output.jsonl \
|
||||
--openai.base_url "<USER_BASE_URL>" \
|
||||
--openai.model_name "<USER_MODEL>" \
|
||||
--openai.server_type <USER_SERVER_TYPE> --openai.health_check false
|
||||
|
||||
# EVALUATE — Standalone eval (runs setup + evaluate only)
|
||||
python environments/my_env.py evaluate --env.eval_size 20 \
|
||||
--env.data_dir_to_save_evals /tmp/eval_results \
|
||||
--openai.base_url "<USER_BASE_URL>" \
|
||||
--openai.model_name "<USER_MODEL>" \
|
||||
--openai.server_type <USER_SERVER_TYPE> --openai.health_check false
|
||||
```
|
||||
|
||||
Config priority: CLI args > YAML file > config_init() defaults.
|
||||
|
||||
## Common Pitfalls
|
||||
|
||||
1. **AgentResult has .messages, not .final_response** — Extract the final response by iterating reversed(result.messages) looking for the last assistant message with content.
|
||||
|
||||
2. **evaluate() must use HermesAgentLoop, not chat_completion** — Single-turn chat_completion has no tools. The whole point of hermes-agent benchmarks is agentic evaluation with tool use.
|
||||
|
||||
3. **Don't call _llm_judge twice** — If compute_reward already calls it, extract the score from the buffer instead of calling judge separately in evaluate().
|
||||
|
||||
4. **Eval pollutes training buffers** — compute_reward appends to metric buffers. During eval, roll back buffer entries to keep training metrics clean.
|
||||
|
||||
5. **Always set health_check=false for OpenRouter** — OpenRouter has no /health endpoint.
|
||||
|
||||
6. **Set data_dir_to_save_evals in evaluate mode** — Without it, results aren't saved.
|
||||
|
||||
7. **default_toolsets class variable vs enabled_toolsets config** — The class variable is a hint; the config field is what actually controls tool resolution.
|
||||
|
||||
8. **Tool call parsing in messages** — Tool calls are dicts with `{"function": {"name": ..., "arguments": ...}}`. Always check `isinstance(tc, dict)`.
|
||||
|
||||
9. **ToolContext.cleanup()** — Always call in a finally block to release sandbox resources.
|
||||
|
||||
10. **server_type must be "openai" for external APIs** — Without it, Atropos assumes a local VLLM server.
|
||||
|
||||
11. **Always ask the user for their inference setup** — Never hardcode or assume a specific provider/model. See the "Inference Setup" section above.
|
||||
|
||||
## Reward Function Patterns
|
||||
|
||||
### LLM Judge (for open-ended tasks)
|
||||
Use `self.server.chat_completion()` with a scoring prompt. Parse JSON response for score float. Always include a heuristic fallback (keyword overlap) for when the judge call fails.
|
||||
|
||||
### Binary Verification (for code/terminal tasks)
|
||||
Use `ctx.terminal("pytest test.py -q")` to run tests in the agent's sandbox. Return 1.0 for pass, 0.0 for fail.
|
||||
|
||||
### Multi-Signal (combine multiple indicators)
|
||||
Weight correctness (0.6) + tool usage (0.2) + efficiency (0.2) + optional bonuses. Clamp to [0, 1].
|
||||
|
||||
## Testing Your Environment
|
||||
|
||||
1. **Import test**: `python -c "from environments.my_env import MyEnv; print('OK')"`
|
||||
2. **Ask the user for inference setup** (see "Inference Setup" section above)
|
||||
3. **Process mode** (1 item): Verify JSONL output has valid tokens, masks, scores
|
||||
4. **Evaluate mode**: Verify full agent loop runs with tools, metrics logged correctly
|
||||
5. **Check reward range**: Scores should be in [0, 1], not all identical
|
||||
|
||||
## Minimum Implementation Checklist
|
||||
|
||||
```python
|
||||
class MyEnv(HermesAgentBaseEnv):
|
||||
name = "my-env"
|
||||
env_config_cls = MyEnvConfig
|
||||
|
||||
@classmethod
|
||||
def config_init(cls): ... # Default server + env config
|
||||
async def setup(self): ... # Load dataset + train/eval split
|
||||
async def get_next_item(self): ... # Cycle through training items
|
||||
def format_prompt(self, item): ... # Item → user message string
|
||||
async def compute_reward(self, item, result, ctx): ... # Score rollout
|
||||
async def evaluate(self, *args, **kwargs): ... # Full agent loop eval
|
||||
async def wandb_log(self, metrics=None): ... # Custom metrics + super()
|
||||
|
||||
if __name__ == "__main__":
|
||||
MyEnv.cli()
|
||||
```
|
||||
@@ -1,59 +0,0 @@
|
||||
# AgentResult Fields Reference
|
||||
|
||||
`AgentResult` is defined in `environments/agent_loop.py` as a dataclass.
|
||||
|
||||
## Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `messages` | `List[Dict[str, Any]]` | Full conversation history in OpenAI message format |
|
||||
| `managed_state` | `Optional[Dict]` | ManagedServer.get_state() if Phase 2, else None |
|
||||
| `turns_used` | `int` | Number of LLM calls made during the loop |
|
||||
| `finished_naturally` | `bool` | True if model stopped calling tools on its own |
|
||||
| `reasoning_per_turn` | `List[Optional[str]]` | Extracted reasoning content per turn |
|
||||
| `tool_errors` | `List[ToolError]` | Tool errors encountered during the loop |
|
||||
|
||||
## ToolError Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `turn` | `int` | Which turn the error occurred |
|
||||
| `tool_name` | `str` | Name of the tool that failed |
|
||||
| `arguments` | `str` | Arguments passed to the tool |
|
||||
| `error` | `str` | Error message |
|
||||
| `tool_result` | `str` | The result returned to the model |
|
||||
|
||||
## Extracting Data from Messages
|
||||
|
||||
Messages follow OpenAI format. Common patterns:
|
||||
|
||||
```python
|
||||
# Get final assistant response
|
||||
for msg in reversed(result.messages):
|
||||
if msg.get("role") == "assistant" and msg.get("content"):
|
||||
final_response = msg["content"]
|
||||
break
|
||||
|
||||
# Get all tool names used
|
||||
tools = []
|
||||
for msg in result.messages:
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
for tc in msg["tool_calls"]:
|
||||
fn = tc.get("function", {}) if isinstance(tc, dict) else {}
|
||||
tools.append(fn.get("name", ""))
|
||||
|
||||
# Get tool results
|
||||
for msg in result.messages:
|
||||
if msg.get("role") == "tool":
|
||||
tool_output = msg.get("content", "")
|
||||
call_id = msg.get("tool_call_id", "")
|
||||
```
|
||||
|
||||
## Fields that DO NOT EXIST
|
||||
|
||||
These are common mistakes — AgentResult does NOT have:
|
||||
- `final_response` — extract from messages
|
||||
- `tool_calls` — extract from messages
|
||||
- `tools_used` — extract from messages
|
||||
- `output` — extract from messages
|
||||
- `response` — extract from messages
|
||||
@@ -1,65 +0,0 @@
|
||||
# Atropos BaseEnv Reference
|
||||
|
||||
Source: `atroposlib/envs/base.py` (~2124 lines)
|
||||
|
||||
## Abstract Methods (MUST implement)
|
||||
|
||||
| Method | Signature | Description |
|
||||
|--------|-----------|-------------|
|
||||
| `get_next_item()` | `async def get_next_item(self) -> Item` | Return next item for trajectory. Return None to pause. |
|
||||
| `evaluate()` | `async def evaluate(self, *args, **kwargs)` | Called every steps_per_eval steps. |
|
||||
| `setup()` | `async def setup(self)` | Called once at start. Load datasets, init models. |
|
||||
| `collect_trajectory()` | `async def collect_trajectory(self, item) -> Tuple[Optional[ScoredDataItem], List[Item]]` | Single rollout. Or override collect_trajectories instead. |
|
||||
|
||||
## Overridable Methods
|
||||
|
||||
| Method | Default Behavior | Override When |
|
||||
|--------|-----------------|---------------|
|
||||
| `collect_trajectories()` | Runs collect_trajectory group_size times in parallel | Batch generation, MCTS, coupled rollouts |
|
||||
| `wandb_log()` | Logs completion lengths, rollout table, perf stats | Add custom metrics (always call super) |
|
||||
| `config_init()` | Returns (env_config_cls(), ServerBaseline()) | Custom defaults + server configs |
|
||||
| `postprocess_histories()` | Passthrough | Final processing before sending to trainer |
|
||||
| `save_checkpoint()` | Saves JSON to checkpoint_dir | Custom serialization |
|
||||
| `cleanup()` | No-op | Release resources after each rollout |
|
||||
|
||||
## ScoredDataGroup Structure
|
||||
|
||||
```python
|
||||
ScoredDataGroup = TypedDict with:
|
||||
tokens: List[List[int]] # Token IDs per rollout
|
||||
masks: List[List[int]] # -100=prompt, token_id=completion
|
||||
scores: List[float] # Score per rollout
|
||||
advantages: Optional[...] # Per-token advantages
|
||||
ref_logprobs: Optional[...] # Reference model logprobs
|
||||
messages: Optional[...] # OpenAI-format messages
|
||||
inference_logprobs: Optional[...] # Inference logprobs
|
||||
```
|
||||
|
||||
## BaseEnvConfig Key Fields
|
||||
|
||||
| Field | Default | Description |
|
||||
|-------|---------|-------------|
|
||||
| `group_size` | 4 | Responses grouped for scoring |
|
||||
| `steps_per_eval` | 100 | Steps between evaluations |
|
||||
| `max_token_length` | 2048 | Max token length for generations |
|
||||
| `total_steps` | 1000 | Total training steps |
|
||||
| `use_wandb` | True | Enable wandb logging |
|
||||
| `tokenizer_name` | DeepHermes-3 | Tokenizer for token encoding |
|
||||
| `ensure_scores_are_not_same` | True | Skip groups with identical scores |
|
||||
| `worker_timeout` | 600 | Task timeout seconds |
|
||||
|
||||
## Data Flow
|
||||
|
||||
```
|
||||
env_manager() → add_train_workers() → handle_env()
|
||||
→ collect_trajectories() → postprocess_histories()
|
||||
→ handle_send_to_api() → training server
|
||||
```
|
||||
|
||||
## Atropos Environment Statistics (82 environments analyzed)
|
||||
|
||||
- 95% implement setup, collect_trajectories, evaluate, get_next_item
|
||||
- 76% override wandb_log
|
||||
- 54% have custom config class
|
||||
- Most use collect_trajectories (plural), not collect_trajectory (singular)
|
||||
- Common reward patterns: LLM-judge (~40), regex-extract (~35), code-exec (~12)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user